Break out hoisting utils into their own module
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Utils / Closure.hs
1
2 module Vectorise.Utils.Closure (
3         mkClosure,
4         mkClosureApp,
5         buildClosure,
6         buildClosures,
7         buildEnv
8 )
9 where
10 import VectUtils
11 import Vectorise.Utils.Hoisting
12 import Vectorise.Builtins
13 import Vectorise.Vect
14 import Vectorise.Monad
15
16 import CoreSyn
17 import Type
18 import Var
19 import MkCore
20 import CoreUtils
21 import TyCon
22 import DataCon
23 import MkId
24 import TysWiredIn
25 import BasicTypes
26 import FastString
27
28
29 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
30 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
31  = do Just dict <- paDictOfType env_ty
32       mkv       <- builtin closureVar
33       mkl       <- builtin liftedClosureVar
34       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
35               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
36
37
38 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
39 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
40  = do vapply <- builtin applyVar
41       lapply <- builtin liftedApplyVar
42       lc     <- builtin liftingContext
43       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
44               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [Var lc, lclo, larg])
45
46
47 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
48 buildClosures _   _    [] _ mk_body
49   = mk_body
50 buildClosures tvs vars [arg_ty] res_ty mk_body
51   = -- liftM vInlineMe $
52       buildClosure tvs vars arg_ty res_ty mk_body
53 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
54   = do
55       res_ty' <- mkClosureTypes arg_tys res_ty
56       arg <- newLocalVVar (fsLit "x") arg_ty
57       -- liftM vInlineMe
58       buildClosure tvs vars arg_ty res_ty'
59         . hoistPolyVExpr tvs (Inline (length vars + 1))
60         $ do
61             lc <- builtin liftingContext
62             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
63             return $ vLams lc (vars ++ [arg]) clo
64
65
66 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
67 --   where
68 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
69 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
70 --
71 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
72 buildClosure tvs vars arg_ty res_ty mk_body
73   = do
74       (env_ty, env, bind) <- buildEnv vars
75       env_bndr <- newLocalVVar (fsLit "env") env_ty
76       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
77
78       fn <- hoistPolyVExpr tvs (Inline 2)
79           $ do
80               lc    <- builtin liftingContext
81               body  <- mk_body
82               return -- . vInlineMe
83                      . vLams lc [env_bndr, arg_bndr]
84                      $ bind (vVar env_bndr)
85                             (vVarApps lc body (vars ++ [arg_bndr]))
86
87       mkClosure arg_ty res_ty env_ty fn env
88
89
90 -- Environments ---------------------------------------------------------------
91 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
92 buildEnv [] = do
93              ty    <- voidType
94              void  <- builtin voidVar
95              pvoid <- builtin pvoidVar
96              return (ty, vVar (void, pvoid), \_ body -> body)
97
98 buildEnv [v] = return (vVarType v, vVar v,
99                     \env body -> vLet (vNonRec v env) body)
100
101 buildEnv vs
102   = do
103       
104       (lenv_tc, lenv_tyargs) <- pdataReprTyCon ty
105
106       let venv_con   = tupleCon Boxed (length vs) 
107           [lenv_con] = tyConDataCons lenv_tc
108
109           venv       = mkCoreTup (map Var vvs)
110           lenv       = Var (dataConWrapId lenv_con)
111                        `mkTyApps` lenv_tyargs
112                        `mkApps`   map Var lvs
113
114           vbind env body = mkWildCase env ty (exprType body)
115                            [(DataAlt venv_con, vvs, body)]
116
117           lbind env body =
118             let scrut = unwrapFamInstScrut lenv_tc lenv_tyargs env
119             in
120             mkWildCase scrut (exprType scrut) (exprType body)
121               [(DataAlt lenv_con, lvs, body)]
122
123           bind (venv, lenv) (vbody, lbody) = (vbind venv vbody,
124                                               lbind lenv lbody)
125
126       return (ty, (venv, lenv), bind)
127   where
128     (vvs, lvs) = unzip vs
129     tys        = map vVarType vs
130     ty         = mkBoxedTupleTy tys