Improve closure generation for functions with multiple parameters
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   splitClosureTy,
5   mkPADictType, mkPArrayType,
6   paDictArgType, paDictOfType,
7   paMethod, lengthPA, replicatePA, emptyPA,
8   polyAbstract, polyApply, polyVApply,
9   lookupPArrayFamInst,
10   hoistExpr, hoistPolyVExpr, takeHoisted,
11   buildClosure, buildClosures
12 ) where
13
14 #include "HsVersions.h"
15
16 import VectCore
17 import VectMonad
18
19 import DsUtils
20 import CoreSyn
21 import CoreUtils
22 import Type
23 import TypeRep
24 import TyCon
25 import DataCon            ( dataConWrapId )
26 import Var
27 import Id                 ( mkWildId )
28 import MkId               ( unwrapFamInstScrut )
29 import PrelNames
30 import TysWiredIn
31 import BasicTypes         ( Boxity(..) )
32
33 import Outputable
34 import FastString
35
36 import Control.Monad         ( liftM, zipWithM_ )
37
38 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
39 collectAnnTypeArgs expr = go expr []
40   where
41     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
42     go e                             tys = (e, tys)
43
44 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
45 collectAnnTypeBinders expr = go [] expr
46   where
47     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
48     go bs e                           = (reverse bs, e)
49
50 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
51 collectAnnValBinders expr = go [] expr
52   where
53     go bs (_, AnnLam b e) | isId b = go (b:bs) e
54     go bs e                        = (reverse bs, e)
55
56 isAnnTypeArg :: AnnExpr b ann -> Bool
57 isAnnTypeArg (_, AnnType t) = True
58 isAnnTypeArg _              = False
59
60 isClosureTyCon :: TyCon -> Bool
61 isClosureTyCon tc = tyConName tc == closureTyConName
62
63 splitClosureTy :: Type -> (Type, Type)
64 splitClosureTy ty
65   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
66   , isClosureTyCon tc
67   = (arg_ty, res_ty)
68
69   | otherwise = pprPanic "splitClosureTy" (ppr ty)
70
71 isPArrayTyCon :: TyCon -> Bool
72 isPArrayTyCon tc = tyConName tc == parrayTyConName
73
74 splitPArrayTy :: Type -> Type
75 splitPArrayTy ty
76   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
77   , isPArrayTyCon tc
78   = arg_ty
79
80   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
81
82 mkClosureType :: Type -> Type -> VM Type
83 mkClosureType arg_ty res_ty
84   = do
85       tc <- builtin closureTyCon
86       return $ mkTyConApp tc [arg_ty, res_ty]
87
88 mkClosureTypes :: [Type] -> Type -> VM Type
89 mkClosureTypes arg_tys res_ty
90   = do
91       tc <- builtin closureTyCon
92       return $ foldr (mk tc) res_ty arg_tys
93   where
94     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
95
96 mkPADictType :: Type -> VM Type
97 mkPADictType ty
98   = do
99       tc <- builtin paDictTyCon
100       return $ TyConApp tc [ty]
101
102 mkPArrayType :: Type -> VM Type
103 mkPArrayType ty
104   = do
105       tc <- builtin parrayTyCon
106       return $ TyConApp tc [ty]
107
108 paDictArgType :: TyVar -> VM (Maybe Type)
109 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
110   where
111     go ty k | Just k' <- kindView k = go ty k'
112     go ty (FunTy k1 k2)
113       = do
114           tv   <- newTyVar FSLIT("a") k1
115           mty1 <- go (TyVarTy tv) k1
116           case mty1 of
117             Just ty1 -> do
118                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
119                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
120             Nothing  -> go ty k2
121
122     go ty k
123       | isLiftedTypeKind k
124       = liftM Just (mkPADictType ty)
125
126     go ty k = return Nothing
127
128 paDictOfType :: Type -> VM CoreExpr
129 paDictOfType ty = paDictOfTyApp ty_fn ty_args
130   where
131     (ty_fn, ty_args) = splitAppTys ty
132
133 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
134 paDictOfTyApp ty_fn ty_args
135   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
136 paDictOfTyApp (TyVarTy tv) ty_args
137   = do
138       dfun <- maybeV (lookupTyVarPA tv)
139       paDFunApply dfun ty_args
140 paDictOfTyApp (TyConApp tc _) ty_args
141   = do
142       pa_class <- builtin paClass
143       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
144       paDFunApply (Var dfun) ty_args'
145 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
146
147 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
148 paDFunApply dfun tys
149   = do
150       dicts <- mapM paDictOfType tys
151       return $ mkApps (mkTyApps dfun tys) dicts
152
153 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
154 paMethod method ty
155   = do
156       fn   <- builtin method
157       dict <- paDictOfType ty
158       return $ mkApps (Var fn) [Type ty, dict]
159
160 lengthPA :: CoreExpr -> VM CoreExpr
161 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
162   where
163     ty = splitPArrayTy (exprType x)
164
165 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
166 replicatePA len x = liftM (`mkApps` [len,x])
167                           (paMethod replicatePAVar (exprType x))
168
169 emptyPA :: Type -> VM CoreExpr
170 emptyPA = paMethod emptyPAVar
171
172 newLocalVVar :: FastString -> Type -> VM VVar
173 newLocalVVar fs vty
174   = do
175       lty <- mkPArrayType vty
176       vv  <- newLocalVar fs vty
177       lv  <- newLocalVar fs lty
178       return (vv,lv)
179
180 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
181 polyAbstract tvs p
182   = localV
183   $ do
184       mdicts <- mapM mk_dict_var tvs
185       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
186       p (mk_lams mdicts)
187   where
188     mk_dict_var tv = do
189                        r <- paDictArgType tv
190                        case r of
191                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
192                          Nothing -> return Nothing
193
194     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
195
196 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
197 polyApply expr tys
198   = do
199       dicts <- mapM paDictOfType tys
200       return $ expr `mkTyApps` tys `mkApps` dicts
201
202 polyVApply :: VExpr -> [Type] -> VM VExpr
203 polyVApply expr tys
204   = do
205       dicts <- mapM paDictOfType tys
206       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
207
208 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
209 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
210
211 hoistExpr :: FastString -> CoreExpr -> VM Var
212 hoistExpr fs expr
213   = do
214       var <- newLocalVar fs (exprType expr)
215       updGEnv $ \env ->
216         env { global_bindings = (var, expr) : global_bindings env }
217       return var
218
219 hoistVExpr :: FastString -> VExpr -> VM VVar
220 hoistVExpr fs (ve, le)
221   = do
222       vv <- hoistExpr ('v' `consFS` fs) ve
223       lv <- hoistExpr ('l' `consFS` fs) le
224       return (vv, lv)
225
226 hoistPolyVExpr :: FastString -> [TyVar] -> VM VExpr -> VM VExpr
227 hoistPolyVExpr fs tvs p
228   = do
229       expr <- closedV . polyAbstract tvs $ \abstract ->
230               liftM (mapVect abstract) p
231       fn   <- hoistVExpr fs expr
232       polyVApply (vVar fn) (mkTyVarTys tvs)
233
234 takeHoisted :: VM [(Var, CoreExpr)]
235 takeHoisted
236   = do
237       env <- readGEnv id
238       setGEnv $ env { global_bindings = [] }
239       return $ global_bindings env
240
241
242 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
243 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
244   = do
245       dict <- paDictOfType env_ty
246       mkv  <- builtin mkClosureVar
247       mkl  <- builtin mkClosurePVar
248       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
249               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
250
251 buildClosures :: [TyVar] -> Var -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
252 buildClosures tvs lc vars [arg_ty] res_ty mk_body
253   = buildClosure tvs lc vars arg_ty res_ty mk_body
254 buildClosures tvs lc vars (arg_ty : arg_tys) res_ty mk_body
255   = do
256       res_ty' <- mkClosureTypes arg_tys res_ty
257       arg <- newLocalVVar FSLIT("x") arg_ty
258       buildClosure tvs lc vars arg_ty res_ty'
259         . hoistPolyVExpr FSLIT("fn") tvs
260         $ do
261             clo <- buildClosures tvs lc (vars ++ [arg]) arg_tys res_ty mk_body
262             return $ vLams lc (vars ++ [arg]) clo
263
264 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
265 --   where
266 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
267 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
268 --
269 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
270 buildClosure tvs lv vars arg_ty res_ty mk_body
271   = do
272       (env_ty, env, bind) <- buildEnv lv vars
273       env_bndr <- newLocalVVar FSLIT("env") env_ty
274       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
275
276       fn <- hoistPolyVExpr FSLIT("fn") tvs
277           $ do
278               body  <- mk_body
279               body' <- bind (vVar env_bndr)
280                             (vVarApps lv body (vars ++ [arg_bndr]))
281               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
282
283       mkClosure arg_ty res_ty env_ty fn env
284
285 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
286 buildEnv lv vvs
287   = do
288       let (ty, venv, vbind) = mkVectEnv tys vs
289       (lenv, lbind) <- mkLiftEnv lv tys ls
290       return (ty, (venv, lenv),
291               \(venv,lenv) (vbody,lbody) ->
292               do
293                 let vbody' = vbind venv vbody
294                 lbody' <- lbind lenv lbody
295                 return (vbody', lbody'))
296   where
297     (vs,ls) = unzip vvs
298     tys     = map idType vs
299
300 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
301 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
302 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
303 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
304                         \env body -> Case env (mkWildId ty) (exprType body)
305                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
306   where
307     ty = mkCoreTupTy tys
308
309 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
310 mkLiftEnv lv [ty] [v]
311   = return (Var v, \env body ->
312                    do
313                      len <- lengthPA (Var v)
314                      return . Let (NonRec v env)
315                             $ Case len lv (exprType body) [(DEFAULT, [], body)])
316
317 -- NOTE: this transparently deals with empty environments
318 mkLiftEnv lv tys vs
319   = do
320       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
321       let [env_con] = tyConDataCons env_tc
322           
323           env = Var (dataConWrapId env_con)
324                 `mkTyApps`  env_tyargs
325                 `mkVarApps` (lv : vs)
326
327           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
328                           in
329                           return $ Case scrut (mkWildId (exprType scrut))
330                                         (exprType body)
331                                         [(DataAlt env_con, lv : bndrs, body)]
332       return (env, bind)
333   where
334     vty = mkCoreTupTy tys
335
336     bndrs | null vs   = [mkWildId unitTy]
337           | otherwise = vs
338