Adapt PArray instance generation to new scheme
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   TyConRepr(..), mkTyConRepr,
8   mkToArrPRepr, mkFromArrPRepr,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon, prCoerce,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 #include "HsVersions.h"
22
23 import VectCore
24 import VectMonad
25
26 import DsUtils
27 import CoreSyn
28 import CoreUtils
29 import Coercion
30 import Type
31 import TypeRep
32 import TyCon
33 import DataCon
34 import Var
35 import Id                 ( mkWildId )
36 import MkId               ( unwrapFamInstScrut )
37 import Name               ( Name )
38 import PrelNames
39 import TysWiredIn
40 import TysPrim            ( intPrimTy )
41 import BasicTypes         ( Boxity(..) )
42
43 import Outputable
44 import FastString
45 import Maybes             ( orElse )
46
47 import Data.List             ( zipWith4 )
48 import Control.Monad         ( liftM, liftM2, zipWithM_ )
49
50 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
51 collectAnnTypeArgs expr = go expr []
52   where
53     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
54     go e                             tys = (e, tys)
55
56 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
57 collectAnnTypeBinders expr = go [] expr
58   where
59     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
60     go bs e                           = (reverse bs, e)
61
62 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
63 collectAnnValBinders expr = go [] expr
64   where
65     go bs (_, AnnLam b e) | isId b = go (b:bs) e
66     go bs e                        = (reverse bs, e)
67
68 isAnnTypeArg :: AnnExpr b ann -> Bool
69 isAnnTypeArg (_, AnnType t) = True
70 isAnnTypeArg _              = False
71
72 mkDataConTag :: DataCon -> CoreExpr
73 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
74
75 splitUnTy :: String -> Name -> Type -> Type
76 splitUnTy s name ty
77   | Just (tc, [ty']) <- splitTyConApp_maybe ty
78   , tyConName tc == name
79   = ty'
80
81   | otherwise = pprPanic s (ppr ty)
82
83 splitBinTy :: String -> Name -> Type -> (Type, Type)
84 splitBinTy s name ty
85   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
86   , tyConName tc == name
87   = (ty1, ty2)
88
89   | otherwise = pprPanic s (ppr ty)
90
91 splitFixedTyConApp :: TyCon -> Type -> [Type]
92 splitFixedTyConApp tc ty
93   | Just (tc', tys) <- splitTyConApp_maybe ty
94   , tc == tc'
95   = tys
96
97   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
98
99 splitClosureTy :: Type -> (Type, Type)
100 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
101
102 splitPArrayTy :: Type -> Type
103 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
104
105 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
106 mkBuiltinTyConApp get_tc tys
107   = do
108       tc <- builtin get_tc
109       return $ mkTyConApp tc tys
110
111 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
112 mkBuiltinTyConApps get_tc tys ty
113   = do
114       tc <- builtin get_tc
115       return $ foldr (mk tc) ty tys
116   where
117     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
118
119 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
120 mkBuiltinTyConApps1 get_tc dft [] = return dft
121 mkBuiltinTyConApps1 get_tc dft tys
122   = do
123       tc <- builtin get_tc
124       case tys of
125         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
126         _  -> return $ foldr1 (mk tc) tys
127   where
128     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
129
130 data TyConRepr = TyConRepr {
131                    repr_tyvars         :: [TyVar]
132                  , repr_tys            :: [[Type]]
133                  , arr_shape_tys       :: [Type]
134                  , arr_repr_tys        :: [[Type]]
135
136                  , repr_prod_tycons    :: [Maybe TyCon]
137                  , repr_prod_data_cons :: [Maybe DataCon]
138                  , repr_prod_tys       :: [Type]
139                  , repr_sum_tycon      :: Maybe TyCon
140                  , repr_sum_data_cons  :: [DataCon]
141                  , repr_type           :: Type
142                  }
143
144 mkTyConRepr :: TyCon -> VM TyConRepr
145 mkTyConRepr vect_tc
146   = do
147       uarr <- builtin uarrTyCon
148       prod_tycons  <- mapM (mk_tycon prodTyCon) rep_tys
149       let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys
150       sum_tycon    <- mk_tycon sumTyCon prod_tys
151       arr_repr_tys <- mapM (mapM mkPArrayType . arr_repr_elem_tys) rep_tys
152
153       return $ TyConRepr {
154                  repr_tyvars         = tyvars
155                , repr_tys            = rep_tys
156                , arr_shape_tys       = mk_shape uarr
157                , arr_repr_tys        = arr_repr_tys
158
159                , repr_prod_tycons    = prod_tycons
160                , repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons
161                , repr_prod_tys       = prod_tys
162                , repr_sum_tycon      = sum_tycon
163                , repr_sum_data_cons  = fmap tyConDataCons sum_tycon `orElse` []
164                , repr_type           = mk_tc_app_maybe sum_tycon prod_tys
165                }
166   where
167     tyvars = tyConTyVars vect_tc
168     data_cons = tyConDataCons vect_tc
169     rep_tys   = map dataConRepArgTys data_cons
170
171     is_product | [_] <- data_cons = True
172                | otherwise        = False
173
174     mk_shape uarr = intPrimTy : mk_sel uarr
175
176     mk_sel uarr | is_product = []
177                 | otherwise  = [uarr_int, uarr_int]
178       where
179         uarr_int = mkTyConApp uarr [intTy]
180
181     mk_tycon get_tc tys
182       | n > 1     = builtin (Just . get_tc n)
183       | otherwise = return Nothing
184       where n = length tys
185
186     mk_single_datacon tc | [dc] <- tyConDataCons tc = dc
187
188     mk_tc_app_maybe Nothing   []   = unitTy
189     mk_tc_app_maybe Nothing   [ty] = ty
190     mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys
191
192     arr_repr_elem_tys []  = [unitTy]
193     arr_repr_elem_tys tys = tys
194
195 mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
196 mkToArrPRepr len sel ess
197   = do
198       let mk_sum [(expr, ty)] = return (expr, ty)
199           mk_sum es
200             = do
201                 sum_tc <- builtin . sumTyCon $ length es
202                 (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
203                 let [sum_rdc] = tyConDataCons sum_rtc
204
205                 return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
206                         mkTyConApp sum_tc tys)
207             where
208               (exprs, tys) = unzip es
209
210           mk_prod [expr] = return (expr, splitPArrayTy (exprType expr))
211           mk_prod exprs
212             = do
213                 prod_tc <- builtin . prodTyCon $ length exprs
214                 (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
215                 let [prod_rdc] = tyConDataCons prod_rtc
216
217                 return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
218                         mkTyConApp prod_tc tys)
219             where
220               tys = map (splitPArrayTy . exprType) exprs
221
222       liftM fst (mk_sum =<< mapM mk_prod ess)
223
224 mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
225                -> VM CoreExpr
226 mkFromArrPRepr scrut res_ty len sel vars res
227   = return (Var unitDataConId)
228
229 mkClosureType :: Type -> Type -> VM Type
230 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
231
232 mkClosureTypes :: [Type] -> Type -> VM Type
233 mkClosureTypes = mkBuiltinTyConApps closureTyCon
234
235 mkPReprType :: Type -> VM Type
236 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
237
238 mkPADictType :: Type -> VM Type
239 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
240
241 mkPArrayType :: Type -> VM Type
242 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
243
244 parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
245 parrayCoerce repr_tc args expr
246   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
247   = do
248       parray <- builtin parrayTyCon
249
250       let co = mkAppCoercion (mkTyConApp parray [])
251                              (mkSymCoercion (mkTyConApp arg_co args))
252
253       return $ mkCoerce co expr
254
255 parrayReprTyCon :: Type -> VM (TyCon, [Type])
256 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
257
258 parrayReprDataCon :: Type -> VM (DataCon, [Type])
259 parrayReprDataCon ty
260   = do
261       (tc, arg_tys) <- parrayReprTyCon ty
262       let [dc] = tyConDataCons tc
263       return (dc, arg_tys)
264
265 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
266 mkVScrut (ve, le)
267   = do
268       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
269       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
270
271 prDFunOfTyCon :: TyCon -> VM CoreExpr
272 prDFunOfTyCon tycon
273   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
274
275 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
276 prCoerce repr_tc args expr
277   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
278   = do
279       pr_tc <- builtin prTyCon
280
281       let co = mkAppCoercion (mkTyConApp pr_tc [])
282                              (mkSymCoercion (mkTyConApp arg_co args))
283
284       return $ mkCoerce co expr
285
286 paDictArgType :: TyVar -> VM (Maybe Type)
287 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
288   where
289     go ty k | Just k' <- kindView k = go ty k'
290     go ty (FunTy k1 k2)
291       = do
292           tv   <- newTyVar FSLIT("a") k1
293           mty1 <- go (TyVarTy tv) k1
294           case mty1 of
295             Just ty1 -> do
296                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
297                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
298             Nothing  -> go ty k2
299
300     go ty k
301       | isLiftedTypeKind k
302       = liftM Just (mkPADictType ty)
303
304     go ty k = return Nothing
305
306 paDictOfType :: Type -> VM CoreExpr
307 paDictOfType ty = paDictOfTyApp ty_fn ty_args
308   where
309     (ty_fn, ty_args) = splitAppTys ty
310
311 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
312 paDictOfTyApp ty_fn ty_args
313   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
314 paDictOfTyApp (TyVarTy tv) ty_args
315   = do
316       dfun <- maybeV (lookupTyVarPA tv)
317       paDFunApply dfun ty_args
318 paDictOfTyApp (TyConApp tc _) ty_args
319   = do
320       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
321       paDFunApply (Var dfun) ty_args
322 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
323
324 paDFunType :: TyCon -> VM Type
325 paDFunType tc
326   = do
327       margs <- mapM paDictArgType tvs
328       res   <- mkPADictType (mkTyConApp tc arg_tys)
329       return . mkForAllTys tvs
330              $ mkFunTys [arg | Just arg <- margs] res
331   where
332     tvs = tyConTyVars tc
333     arg_tys = mkTyVarTys tvs
334
335 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
336 paDFunApply dfun tys
337   = do
338       dicts <- mapM paDictOfType tys
339       return $ mkApps (mkTyApps dfun tys) dicts
340
341 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
342 paMethod method ty
343   = do
344       fn   <- builtin method
345       dict <- paDictOfType ty
346       return $ mkApps (Var fn) [Type ty, dict]
347
348 mkPR :: Type -> VM CoreExpr
349 mkPR = paMethod mkPRVar
350
351 lengthPA :: CoreExpr -> VM CoreExpr
352 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
353   where
354     ty = splitPArrayTy (exprType x)
355
356 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
357 replicatePA len x = liftM (`mkApps` [len,x])
358                           (paMethod replicatePAVar (exprType x))
359
360 emptyPA :: Type -> VM CoreExpr
361 emptyPA = paMethod emptyPAVar
362
363 liftPA :: CoreExpr -> VM CoreExpr
364 liftPA x
365   = do
366       lc <- builtin liftingContext
367       replicatePA (Var lc) x
368
369 newLocalVVar :: FastString -> Type -> VM VVar
370 newLocalVVar fs vty
371   = do
372       lty <- mkPArrayType vty
373       vv  <- newLocalVar fs vty
374       lv  <- newLocalVar fs lty
375       return (vv,lv)
376
377 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
378 polyAbstract tvs p
379   = localV
380   $ do
381       mdicts <- mapM mk_dict_var tvs
382       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
383       p (mk_lams mdicts)
384   where
385     mk_dict_var tv = do
386                        r <- paDictArgType tv
387                        case r of
388                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
389                          Nothing -> return Nothing
390
391     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
392
393 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
394 polyApply expr tys
395   = do
396       dicts <- mapM paDictOfType tys
397       return $ expr `mkTyApps` tys `mkApps` dicts
398
399 polyVApply :: VExpr -> [Type] -> VM VExpr
400 polyVApply expr tys
401   = do
402       dicts <- mapM paDictOfType tys
403       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
404
405 hoistBinding :: Var -> CoreExpr -> VM ()
406 hoistBinding v e = updGEnv $ \env ->
407   env { global_bindings = (v,e) : global_bindings env }
408
409 hoistExpr :: FastString -> CoreExpr -> VM Var
410 hoistExpr fs expr
411   = do
412       var <- newLocalVar fs (exprType expr)
413       hoistBinding var expr
414       return var
415
416 hoistVExpr :: VExpr -> VM VVar
417 hoistVExpr (ve, le)
418   = do
419       fs <- getBindName
420       vv <- hoistExpr ('v' `consFS` fs) ve
421       lv <- hoistExpr ('l' `consFS` fs) le
422       return (vv, lv)
423
424 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
425 hoistPolyVExpr tvs p
426   = do
427       expr <- closedV . polyAbstract tvs $ \abstract ->
428               liftM (mapVect abstract) p
429       fn   <- hoistVExpr expr
430       polyVApply (vVar fn) (mkTyVarTys tvs)
431
432 takeHoisted :: VM [(Var, CoreExpr)]
433 takeHoisted
434   = do
435       env <- readGEnv id
436       setGEnv $ env { global_bindings = [] }
437       return $ global_bindings env
438
439 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
440 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
441   = do
442       dict <- paDictOfType env_ty
443       mkv  <- builtin mkClosureVar
444       mkl  <- builtin mkClosurePVar
445       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
446               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
447
448 mkClosureApp :: VExpr -> VExpr -> VM VExpr
449 mkClosureApp (vclo, lclo) (varg, larg)
450   = do
451       vapply <- builtin applyClosureVar
452       lapply <- builtin applyClosurePVar
453       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
454               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
455   where
456     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
457
458 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
459 buildClosures tvs vars [] res_ty mk_body
460   = mk_body
461 buildClosures tvs vars [arg_ty] res_ty mk_body
462   = buildClosure tvs vars arg_ty res_ty mk_body
463 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
464   = do
465       res_ty' <- mkClosureTypes arg_tys res_ty
466       arg <- newLocalVVar FSLIT("x") arg_ty
467       buildClosure tvs vars arg_ty res_ty'
468         . hoistPolyVExpr tvs
469         $ do
470             lc <- builtin liftingContext
471             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
472             return $ vLams lc (vars ++ [arg]) clo
473
474 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
475 --   where
476 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
477 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
478 --
479 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
480 buildClosure tvs vars arg_ty res_ty mk_body
481   = do
482       (env_ty, env, bind) <- buildEnv vars
483       env_bndr <- newLocalVVar FSLIT("env") env_ty
484       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
485
486       fn <- hoistPolyVExpr tvs
487           $ do
488               lc    <- builtin liftingContext
489               body  <- mk_body
490               body' <- bind (vVar env_bndr)
491                             (vVarApps lc body (vars ++ [arg_bndr]))
492               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
493
494       mkClosure arg_ty res_ty env_ty fn env
495
496 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
497 buildEnv vvs
498   = do
499       lc <- builtin liftingContext
500       let (ty, venv, vbind) = mkVectEnv tys vs
501       (lenv, lbind) <- mkLiftEnv lc tys ls
502       return (ty, (venv, lenv),
503               \(venv,lenv) (vbody,lbody) ->
504               do
505                 let vbody' = vbind venv vbody
506                 lbody' <- lbind lenv lbody
507                 return (vbody', lbody'))
508   where
509     (vs,ls) = unzip vvs
510     tys     = map idType vs
511
512 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
513 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
514 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
515 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
516                         \env body -> Case env (mkWildId ty) (exprType body)
517                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
518   where
519     ty = mkCoreTupTy tys
520
521 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
522 mkLiftEnv lc [ty] [v]
523   = return (Var v, \env body ->
524                    do
525                      len <- lengthPA (Var v)
526                      return . Let (NonRec v env)
527                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
528
529 -- NOTE: this transparently deals with empty environments
530 mkLiftEnv lc tys vs
531   = do
532       (env_tc, env_tyargs) <- parrayReprTyCon vty
533       let [env_con] = tyConDataCons env_tc
534           
535           env = Var (dataConWrapId env_con)
536                 `mkTyApps`  env_tyargs
537                 `mkVarApps` (lc : vs)
538
539           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
540                           in
541                           return $ Case scrut (mkWildId (exprType scrut))
542                                         (exprType body)
543                                         [(DataAlt env_con, lc : bndrs, body)]
544       return (env, bind)
545   where
546     vty = mkCoreTupTy tys
547
548     bndrs | null vs   = [mkWildId unitTy]
549           | otherwise = vs
550