Move VectCore to Vectorise tree
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo, voidType, mkWrapType,
9   mkPADictType, mkPArrayType, mkPDataType, mkPReprType, mkPArray,
10
11   pdataReprTyCon, pdataReprDataCon, mkVScrut,
12   prDictOfType, prDFunOfTyCon,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, wrapPR, replicatePD, emptyPD, packByTagPD,
15   combinePD,
16   liftPD,
17   zipScalars, scalarClosure,
18   polyAbstract, polyApply, polyVApply, polyArity,
19   Inline(..), addInlineArity, inlineMe,
20   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
21   buildClosure, buildClosures,
22   mkClosureApp
23 ) where
24 import VectMonad
25 import Vectorise.Env
26 import Vectorise.Vect
27
28 import MkCore ( mkCoreTup, mkWildCase )
29 import CoreSyn
30 import CoreUtils
31 import CoreUnfold         ( mkInlineRule )
32 import Coercion
33 import Type
34 import TypeRep
35 import TyCon
36 import DataCon
37 import Var
38 import MkId               ( unwrapFamInstScrut )
39 import Id                 ( setIdUnfolding )
40 import TysWiredIn
41 import BasicTypes         ( Boxity(..), Arity )
42 import Literal            ( Literal, mkMachInt )
43
44
45 import Outputable
46 import FastString
47
48 import Control.Monad
49
50 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
51 collectAnnTypeArgs expr = go expr []
52   where
53     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
54     go e                             tys = (e, tys)
55
56 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
57 collectAnnTypeBinders expr = go [] expr
58   where
59     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
60     go bs e                           = (reverse bs, e)
61
62 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
63 collectAnnValBinders expr = go [] expr
64   where
65     go bs (_, AnnLam b e) | isId b = go (b:bs) e
66     go bs e                        = (reverse bs, e)
67
68 isAnnTypeArg :: AnnExpr b ann -> Bool
69 isAnnTypeArg (_, AnnType _) = True
70 isAnnTypeArg _              = False
71
72 dataConTagZ :: DataCon -> Int
73 dataConTagZ con = dataConTag con - fIRST_TAG
74
75 mkDataConTagLit :: DataCon -> Literal
76 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
77
78 mkDataConTag :: DataCon -> CoreExpr
79 mkDataConTag = mkIntLitInt . dataConTagZ
80
81 splitPrimTyCon :: Type -> Maybe TyCon
82 splitPrimTyCon ty
83   | Just (tycon, []) <- splitTyConApp_maybe ty
84   , isPrimTyCon tycon
85   = Just tycon
86
87   | otherwise = Nothing
88
89 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
90 mkBuiltinTyConApp get_tc tys
91   = do
92       tc <- builtin get_tc
93       return $ mkTyConApp tc tys
94
95 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
96 mkBuiltinTyConApps get_tc tys ty
97   = do
98       tc <- builtin get_tc
99       return $ foldr (mk tc) ty tys
100   where
101     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
102
103 voidType :: VM Type
104 voidType = mkBuiltinTyConApp VectMonad.voidTyCon []
105
106 mkWrapType :: Type -> VM Type
107 mkWrapType ty = mkBuiltinTyConApp wrapTyCon [ty]
108
109 mkClosureTypes :: [Type] -> Type -> VM Type
110 mkClosureTypes = mkBuiltinTyConApps closureTyCon
111
112 mkPReprType :: Type -> VM Type
113 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
114
115 mkPADictType :: Type -> VM Type
116 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
117
118 mkPArrayType :: Type -> VM Type
119 mkPArrayType ty
120   | Just tycon <- splitPrimTyCon ty
121   = do
122       r <- lookupPrimPArray tycon
123       case r of
124         Just arr -> return $ mkTyConApp arr []
125         Nothing  -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
126 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
127
128 mkPDataType :: Type -> VM Type
129 mkPDataType ty = mkBuiltinTyConApp pdataTyCon [ty]
130
131 mkPArray :: Type -> CoreExpr -> CoreExpr -> VM CoreExpr
132 mkPArray ty len dat = do
133                         tc <- builtin parrayTyCon
134                         let [dc] = tyConDataCons tc
135                         return $ mkConApp dc [Type ty, len, dat]
136
137 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
138 mkBuiltinCo get_tc
139   = do
140       tc <- builtin get_tc
141       return $ mkTyConApp tc []
142
143 pdataReprTyCon :: Type -> VM (TyCon, [Type])
144 pdataReprTyCon ty = builtin pdataTyCon >>= (`lookupFamInst` [ty])
145
146 pdataReprDataCon :: Type -> VM (DataCon, [Type])
147 pdataReprDataCon ty
148   = do
149       (tc, arg_tys) <- pdataReprTyCon ty
150       let [dc] = tyConDataCons tc
151       return (dc, arg_tys)
152
153 mkVScrut :: VExpr -> VM (CoreExpr, CoreExpr, TyCon, [Type])
154 mkVScrut (ve, le)
155   = do
156       (tc, arg_tys) <- pdataReprTyCon ty
157       return (ve, unwrapFamInstScrut tc arg_tys le, tc, arg_tys)
158   where
159     ty = exprType ve
160
161 prDFunOfTyCon :: TyCon -> VM CoreExpr
162 prDFunOfTyCon tycon
163   = liftM Var
164   . maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
165   $ lookupTyConPR tycon
166
167
168 paDictArgType :: TyVar -> VM (Maybe Type)
169 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
170   where
171     go ty k | Just k' <- kindView k = go ty k'
172     go ty (FunTy k1 k2)
173       = do
174           tv   <- newTyVar (fsLit "a") k1
175           mty1 <- go (TyVarTy tv) k1
176           case mty1 of
177             Just ty1 -> do
178                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
179                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
180             Nothing  -> go ty k2
181
182     go ty k
183       | isLiftedTypeKind k
184       = liftM Just (mkPADictType ty)
185
186     go _ _ = return Nothing
187
188
189 -- | Get the PA dictionary for some type, or `Nothing` if there isn't one.
190 paDictOfType :: Type -> VM (Maybe CoreExpr)
191 paDictOfType ty 
192   = paDictOfTyApp ty_fn ty_args
193   where
194     (ty_fn, ty_args) = splitAppTys ty
195
196     paDictOfTyApp :: Type -> [Type] -> VM (Maybe CoreExpr)
197     paDictOfTyApp ty_fn ty_args
198         | Just ty_fn' <- coreView ty_fn 
199         = paDictOfTyApp ty_fn' ty_args
200
201     paDictOfTyApp (TyVarTy tv) ty_args
202      = do dfun <- maybeV (lookupTyVarPA tv)
203           liftM Just $ paDFunApply dfun ty_args
204
205     paDictOfTyApp (TyConApp tc _) ty_args
206      = do mdfun <- lookupTyConPA tc
207           case mdfun of
208             Nothing     
209              -> pprTrace "VectUtils.paDictOfType"
210                          (vcat [ text "No PA dictionary"
211                                , text "for tycon: " <> ppr tc
212                                , text "in type:   " <> ppr ty])
213              $ return Nothing
214
215             Just dfun   -> liftM Just $ paDFunApply (Var dfun) ty_args
216
217     paDictOfTyApp ty _
218      = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
219
220
221
222 paDFunType :: TyCon -> VM Type
223 paDFunType tc
224   = do
225       margs <- mapM paDictArgType tvs
226       res   <- mkPADictType (mkTyConApp tc arg_tys)
227       return . mkForAllTys tvs
228              $ mkFunTys [arg | Just arg <- margs] res
229   where
230     tvs = tyConTyVars tc
231     arg_tys = mkTyVarTys tvs
232
233 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
234 paDFunApply dfun tys
235  = do Just dicts <- liftM sequence $ mapM paDictOfType tys
236       return $ mkApps (mkTyApps dfun tys) dicts
237
238
239 paMethod :: (Builtins -> Var) -> String -> Type -> VM CoreExpr
240 paMethod _ name ty
241   | Just tycon <- splitPrimTyCon ty
242   = liftM Var
243   . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
244   $ lookupPrimMethod tycon name
245
246 paMethod method _ ty
247   = do
248       fn        <- builtin method
249       Just dict <- paDictOfType ty
250       return $ mkApps (Var fn) [Type ty, dict]
251
252 prDictOfType :: Type -> VM CoreExpr
253 prDictOfType ty = prDictOfTyApp ty_fn ty_args
254   where
255     (ty_fn, ty_args) = splitAppTys ty
256
257 prDictOfTyApp :: Type -> [Type] -> VM CoreExpr
258 prDictOfTyApp ty_fn ty_args
259   | Just ty_fn' <- coreView ty_fn = prDictOfTyApp ty_fn' ty_args
260 prDictOfTyApp (TyConApp tc _) ty_args
261   = do
262       dfun <- liftM Var $ maybeV (lookupTyConPR tc)
263       prDFunApply dfun ty_args
264 prDictOfTyApp _ _ = noV
265
266 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
267 prDFunApply dfun tys
268   = do
269       dicts <- mapM prDictOfType tys
270       return $ mkApps (mkTyApps dfun tys) dicts
271
272 wrapPR :: Type -> VM CoreExpr
273 wrapPR ty
274   = do
275       Just  pa_dict <- paDictOfType ty
276       pr_dfun       <- prDFunOfTyCon =<< builtin wrapTyCon
277       return $ mkApps pr_dfun [Type ty, pa_dict]
278
279 replicatePD :: CoreExpr -> CoreExpr -> VM CoreExpr
280 replicatePD len x = liftM (`mkApps` [len,x])
281                           (paMethod replicatePDVar "replicatePD" (exprType x))
282
283 emptyPD :: Type -> VM CoreExpr
284 emptyPD = paMethod emptyPDVar "emptyPD"
285
286 packByTagPD :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> CoreExpr
287                  -> VM CoreExpr
288 packByTagPD ty xs len tags t
289   = liftM (`mkApps` [xs, len, tags, t])
290           (paMethod packByTagPDVar "packByTagPD" ty)
291
292 combinePD :: Type -> CoreExpr -> CoreExpr -> [CoreExpr]
293           -> VM CoreExpr
294 combinePD ty len sel xs
295   = liftM (`mkApps` (len : sel : xs))
296           (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
297   where
298     n = length xs
299
300 -- | Like `replicatePD` but use the lifting context in the vectoriser state.
301 liftPD :: CoreExpr -> VM CoreExpr
302 liftPD x
303   = do
304       lc <- builtin liftingContext
305       replicatePD (Var lc) x
306
307 zipScalars :: [Type] -> Type -> VM CoreExpr
308 zipScalars arg_tys res_ty
309   = do
310       scalar <- builtin scalarClass
311       (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
312       zipf <- builtin (scalarZip $ length arg_tys)
313       return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
314     where
315       ty_args = arg_tys ++ [res_ty]
316
317 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
318 scalarClosure arg_tys res_ty scalar_fun array_fun
319   = do
320       ctr      <- builtin (closureCtrFun $ length arg_tys)
321       Just pas <- liftM sequence $ mapM paDictOfType (init arg_tys)
322       return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
323                        `mkApps`   (pas ++ [scalar_fun, array_fun])
324
325 newLocalVVar :: FastString -> Type -> VM VVar
326 newLocalVVar fs vty
327   = do
328       lty <- mkPDataType vty
329       vv  <- newLocalVar fs vty
330       lv  <- newLocalVar fs lty
331       return (vv,lv)
332
333 polyAbstract :: [TyVar] -> ([Var] -> VM a) -> VM a
334 polyAbstract tvs p
335   = localV
336   $ do
337       mdicts <- mapM mk_dict_var tvs
338       zipWithM_ (\tv -> maybe (defLocalTyVar tv)
339                               (defLocalTyVarWithPA tv . Var)) tvs mdicts
340       p (mk_args mdicts)
341   where
342     mk_dict_var tv = do
343                        r <- paDictArgType tv
344                        case r of
345                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
346                          Nothing -> return Nothing
347
348     mk_args mdicts = [dict | Just dict <- mdicts]
349
350 polyArity :: [TyVar] -> VM Int
351 polyArity tvs = do
352                   tys <- mapM paDictArgType tvs
353                   return $ length [() | Just _ <- tys]
354
355 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
356 polyApply expr tys
357  = do Just dicts <- liftM sequence $ mapM paDictOfType tys
358       return $ expr `mkTyApps` tys `mkApps` dicts
359
360 polyVApply :: VExpr -> [Type] -> VM VExpr
361 polyVApply expr tys
362  = do Just dicts <- liftM sequence $ mapM paDictOfType tys
363       return     $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
364
365 -- Inline ---------------------------------------------------------------------
366 -- | Records whether we should inline a particular binding.
367 data Inline 
368         = Inline Arity
369         | DontInline
370
371 -- | Add to the arity contained within an `Inline`, if any.
372 addInlineArity :: Inline -> Int -> Inline
373 addInlineArity (Inline m) n = Inline (m+n)
374 addInlineArity DontInline _ = DontInline
375
376 -- | Says to always inline a binding.
377 inlineMe :: Inline
378 inlineMe = Inline 0
379
380
381 -- Hoising --------------------------------------------------------------------
382 hoistBinding :: Var -> CoreExpr -> VM ()
383 hoistBinding v e = updGEnv $ \env ->
384   env { global_bindings = (v,e) : global_bindings env }
385
386 hoistExpr :: FastString -> CoreExpr -> Inline -> VM Var
387 hoistExpr fs expr inl
388   = do
389       var <- mk_inline `liftM` newLocalVar fs (exprType expr)
390       hoistBinding var expr
391       return var
392   where
393     mk_inline var = case inl of
394                       Inline arity -> var `setIdUnfolding`
395                                       mkInlineRule expr (Just arity)
396                       DontInline   -> var
397
398 hoistVExpr :: VExpr -> Inline -> VM VVar
399 hoistVExpr (ve, le) inl
400   = do
401       fs <- getBindName
402       vv <- hoistExpr ('v' `consFS` fs) ve inl
403       lv <- hoistExpr ('l' `consFS` fs) le (addInlineArity inl 1)
404       return (vv, lv)
405
406 hoistPolyVExpr :: [TyVar] -> Inline -> VM VExpr -> VM VExpr
407 hoistPolyVExpr tvs inline p
408   = do
409       inline' <- liftM (addInlineArity inline) (polyArity tvs)
410       expr <- closedV . polyAbstract tvs $ \args ->
411               liftM (mapVect (mkLams $ tvs ++ args)) p
412       fn   <- hoistVExpr expr inline'
413       polyVApply (vVar fn) (mkTyVarTys tvs)
414
415 takeHoisted :: VM [(Var, CoreExpr)]
416 takeHoisted
417   = do
418       env <- readGEnv id
419       setGEnv $ env { global_bindings = [] }
420       return $ global_bindings env
421
422 {-
423 boxExpr :: Type -> VExpr -> VM VExpr
424 boxExpr ty (vexpr, lexpr)
425   | Just (tycon, []) <- splitTyConApp_maybe ty
426   , isUnLiftedTyCon tycon
427   = do
428       r <- lookupBoxedTyCon tycon
429       case r of
430         Just tycon' -> let [dc] = tyConDataCons tycon'
431                        in
432                        return (mkConApp dc [vexpr], lexpr)
433         Nothing     -> return (vexpr, lexpr)
434 -}
435
436 -- Closures -------------------------------------------------------------------
437 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
438 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
439  = do Just dict <- paDictOfType env_ty
440       mkv       <- builtin closureVar
441       mkl       <- builtin liftedClosureVar
442       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
443               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
444
445
446 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
447 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
448  = do vapply <- builtin applyVar
449       lapply <- builtin liftedApplyVar
450       lc     <- builtin liftingContext
451       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
452               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [Var lc, lclo, larg])
453
454
455 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
456 buildClosures _   _    [] _ mk_body
457   = mk_body
458 buildClosures tvs vars [arg_ty] res_ty mk_body
459   = -- liftM vInlineMe $
460       buildClosure tvs vars arg_ty res_ty mk_body
461 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
462   = do
463       res_ty' <- mkClosureTypes arg_tys res_ty
464       arg <- newLocalVVar (fsLit "x") arg_ty
465       -- liftM vInlineMe
466       buildClosure tvs vars arg_ty res_ty'
467         . hoistPolyVExpr tvs (Inline (length vars + 1))
468         $ do
469             lc <- builtin liftingContext
470             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
471             return $ vLams lc (vars ++ [arg]) clo
472
473 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
474 --   where
475 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
476 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
477 --
478 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
479 buildClosure tvs vars arg_ty res_ty mk_body
480   = do
481       (env_ty, env, bind) <- buildEnv vars
482       env_bndr <- newLocalVVar (fsLit "env") env_ty
483       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
484
485       fn <- hoistPolyVExpr tvs (Inline 2)
486           $ do
487               lc    <- builtin liftingContext
488               body  <- mk_body
489               return -- . vInlineMe
490                      . vLams lc [env_bndr, arg_bndr]
491                      $ bind (vVar env_bndr)
492                             (vVarApps lc body (vars ++ [arg_bndr]))
493
494       mkClosure arg_ty res_ty env_ty fn env
495
496
497 -- Environments ---------------------------------------------------------------
498 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
499 buildEnv [] = do
500              ty    <- voidType
501              void  <- builtin voidVar
502              pvoid <- builtin pvoidVar
503              return (ty, vVar (void, pvoid), \_ body -> body)
504
505 buildEnv [v] = return (vVarType v, vVar v,
506                     \env body -> vLet (vNonRec v env) body)
507
508 buildEnv vs
509   = do
510       
511       (lenv_tc, lenv_tyargs) <- pdataReprTyCon ty
512
513       let venv_con   = tupleCon Boxed (length vs) 
514           [lenv_con] = tyConDataCons lenv_tc
515
516           venv       = mkCoreTup (map Var vvs)
517           lenv       = Var (dataConWrapId lenv_con)
518                        `mkTyApps` lenv_tyargs
519                        `mkApps`   map Var lvs
520
521           vbind env body = mkWildCase env ty (exprType body)
522                            [(DataAlt venv_con, vvs, body)]
523
524           lbind env body =
525             let scrut = unwrapFamInstScrut lenv_tc lenv_tyargs env
526             in
527             mkWildCase scrut (exprType scrut) (exprType body)
528               [(DataAlt lenv_con, lvs, body)]
529
530           bind (venv, lenv) (vbody, lbody) = (vbind venv vbody,
531                                               lbind lenv lbody)
532
533       return (ty, (venv, lenv), bind)
534   where
535     (vvs, lvs) = unzip vs
536     tys        = map vVarType vs
537     ty         = mkBoxedTupleTy tys
538