Refactoring
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPADictType, mkPArrayType,
7   parrayReprTyCon, parrayReprDataCon, mkVScrut,
8   paDictArgType, paDictOfType, paDFunType,
9   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
10   polyAbstract, polyApply, polyVApply,
11   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
12   buildClosure, buildClosures,
13   mkClosureApp
14 ) where
15
16 #include "HsVersions.h"
17
18 import VectCore
19 import VectMonad
20
21 import DsUtils
22 import CoreSyn
23 import CoreUtils
24 import Type
25 import TypeRep
26 import TyCon
27 import DataCon            ( DataCon, dataConWrapId, dataConTag )
28 import Var
29 import Id                 ( mkWildId )
30 import MkId               ( unwrapFamInstScrut )
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes         ( Boxity(..) )
34
35 import Outputable
36 import FastString
37
38 import Control.Monad         ( liftM, zipWithM_ )
39
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType t) = True
60 isAnnTypeArg _              = False
61
62 mkDataConTag :: DataCon -> CoreExpr
63 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
64
65 isClosureTyCon :: TyCon -> Bool
66 isClosureTyCon tc = tyConName tc == closureTyConName
67
68 splitClosureTy :: Type -> (Type, Type)
69 splitClosureTy ty
70   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
71   , isClosureTyCon tc
72   = (arg_ty, res_ty)
73
74   | otherwise = pprPanic "splitClosureTy" (ppr ty)
75
76 isPArrayTyCon :: TyCon -> Bool
77 isPArrayTyCon tc = tyConName tc == parrayTyConName
78
79 splitPArrayTy :: Type -> Type
80 splitPArrayTy ty
81   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
82   , isPArrayTyCon tc
83   = arg_ty
84
85   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
86
87 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
88 mkBuiltinTyConApp get_tc tys
89   = do
90       tc <- builtin get_tc
91       return $ mkTyConApp tc tys
92
93 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
94 mkBuiltinTyConApps get_tc tys ty
95   = do
96       tc <- builtin get_tc
97       return $ foldr (mk tc) ty tys
98   where
99     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
100
101 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> [Type] -> VM Type
102 mkBuiltinTyConApps1 get_tc tys
103   = do
104       tc <- builtin get_tc
105       case tys of
106         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
107         _  -> return $ foldr1 (mk tc) tys
108   where
109     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
110
111 mkClosureType :: Type -> Type -> VM Type
112 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
113
114 mkClosureTypes :: [Type] -> Type -> VM Type
115 mkClosureTypes = mkBuiltinTyConApps closureTyCon
116
117 mkPADictType :: Type -> VM Type
118 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
119
120 mkPArrayType :: Type -> VM Type
121 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
122
123 parrayReprTyCon :: Type -> VM (TyCon, [Type])
124 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
125
126 parrayReprDataCon :: Type -> VM (DataCon, [Type])
127 parrayReprDataCon ty
128   = do
129       (tc, arg_tys) <- parrayReprTyCon ty
130       let [dc] = tyConDataCons tc
131       return (dc, arg_tys)
132
133 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
134 mkVScrut (ve, le)
135   = do
136       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
137       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
138
139 paDictArgType :: TyVar -> VM (Maybe Type)
140 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
141   where
142     go ty k | Just k' <- kindView k = go ty k'
143     go ty (FunTy k1 k2)
144       = do
145           tv   <- newTyVar FSLIT("a") k1
146           mty1 <- go (TyVarTy tv) k1
147           case mty1 of
148             Just ty1 -> do
149                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
150                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
151             Nothing  -> go ty k2
152
153     go ty k
154       | isLiftedTypeKind k
155       = liftM Just (mkPADictType ty)
156
157     go ty k = return Nothing
158
159 paDictOfType :: Type -> VM CoreExpr
160 paDictOfType ty = paDictOfTyApp ty_fn ty_args
161   where
162     (ty_fn, ty_args) = splitAppTys ty
163
164 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
165 paDictOfTyApp ty_fn ty_args
166   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
167 paDictOfTyApp (TyVarTy tv) ty_args
168   = do
169       dfun <- maybeV (lookupTyVarPA tv)
170       paDFunApply dfun ty_args
171 paDictOfTyApp (TyConApp tc _) ty_args
172   = do
173       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
174       paDFunApply (Var dfun) ty_args
175 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
176
177 paDFunType :: TyCon -> VM Type
178 paDFunType tc
179   = do
180       margs <- mapM paDictArgType tvs
181       res   <- mkPADictType (mkTyConApp tc arg_tys)
182       return . mkForAllTys tvs
183              $ mkFunTys [arg | Just arg <- margs] res
184   where
185     tvs = tyConTyVars tc
186     arg_tys = mkTyVarTys tvs
187
188 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
189 paDFunApply dfun tys
190   = do
191       dicts <- mapM paDictOfType tys
192       return $ mkApps (mkTyApps dfun tys) dicts
193
194 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
195 paMethod method ty
196   = do
197       fn   <- builtin method
198       dict <- paDictOfType ty
199       return $ mkApps (Var fn) [Type ty, dict]
200
201 lengthPA :: CoreExpr -> VM CoreExpr
202 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
203   where
204     ty = splitPArrayTy (exprType x)
205
206 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
207 replicatePA len x = liftM (`mkApps` [len,x])
208                           (paMethod replicatePAVar (exprType x))
209
210 emptyPA :: Type -> VM CoreExpr
211 emptyPA = paMethod emptyPAVar
212
213 liftPA :: CoreExpr -> VM CoreExpr
214 liftPA x
215   = do
216       lc <- builtin liftingContext
217       replicatePA (Var lc) x
218
219 newLocalVVar :: FastString -> Type -> VM VVar
220 newLocalVVar fs vty
221   = do
222       lty <- mkPArrayType vty
223       vv  <- newLocalVar fs vty
224       lv  <- newLocalVar fs lty
225       return (vv,lv)
226
227 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
228 polyAbstract tvs p
229   = localV
230   $ do
231       mdicts <- mapM mk_dict_var tvs
232       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
233       p (mk_lams mdicts)
234   where
235     mk_dict_var tv = do
236                        r <- paDictArgType tv
237                        case r of
238                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
239                          Nothing -> return Nothing
240
241     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
242
243 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
244 polyApply expr tys
245   = do
246       dicts <- mapM paDictOfType tys
247       return $ expr `mkTyApps` tys `mkApps` dicts
248
249 polyVApply :: VExpr -> [Type] -> VM VExpr
250 polyVApply expr tys
251   = do
252       dicts <- mapM paDictOfType tys
253       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
254
255 hoistBinding :: Var -> CoreExpr -> VM ()
256 hoistBinding v e = updGEnv $ \env ->
257   env { global_bindings = (v,e) : global_bindings env }
258
259 hoistExpr :: FastString -> CoreExpr -> VM Var
260 hoistExpr fs expr
261   = do
262       var <- newLocalVar fs (exprType expr)
263       hoistBinding var expr
264       return var
265
266 hoistVExpr :: VExpr -> VM VVar
267 hoistVExpr (ve, le)
268   = do
269       fs <- getBindName
270       vv <- hoistExpr ('v' `consFS` fs) ve
271       lv <- hoistExpr ('l' `consFS` fs) le
272       return (vv, lv)
273
274 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
275 hoistPolyVExpr tvs p
276   = do
277       expr <- closedV . polyAbstract tvs $ \abstract ->
278               liftM (mapVect abstract) p
279       fn   <- hoistVExpr expr
280       polyVApply (vVar fn) (mkTyVarTys tvs)
281
282 takeHoisted :: VM [(Var, CoreExpr)]
283 takeHoisted
284   = do
285       env <- readGEnv id
286       setGEnv $ env { global_bindings = [] }
287       return $ global_bindings env
288
289 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
290 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
291   = do
292       dict <- paDictOfType env_ty
293       mkv  <- builtin mkClosureVar
294       mkl  <- builtin mkClosurePVar
295       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
296               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
297
298 mkClosureApp :: VExpr -> VExpr -> VM VExpr
299 mkClosureApp (vclo, lclo) (varg, larg)
300   = do
301       vapply <- builtin applyClosureVar
302       lapply <- builtin applyClosurePVar
303       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
304               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
305   where
306     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
307
308 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
309 buildClosures tvs vars [] res_ty mk_body
310   = mk_body
311 buildClosures tvs vars [arg_ty] res_ty mk_body
312   = buildClosure tvs vars arg_ty res_ty mk_body
313 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
314   = do
315       res_ty' <- mkClosureTypes arg_tys res_ty
316       arg <- newLocalVVar FSLIT("x") arg_ty
317       buildClosure tvs vars arg_ty res_ty'
318         . hoistPolyVExpr tvs
319         $ do
320             lc <- builtin liftingContext
321             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
322             return $ vLams lc (vars ++ [arg]) clo
323
324 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
325 --   where
326 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
327 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
328 --
329 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
330 buildClosure tvs vars arg_ty res_ty mk_body
331   = do
332       (env_ty, env, bind) <- buildEnv vars
333       env_bndr <- newLocalVVar FSLIT("env") env_ty
334       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
335
336       fn <- hoistPolyVExpr tvs
337           $ do
338               lc    <- builtin liftingContext
339               body  <- mk_body
340               body' <- bind (vVar env_bndr)
341                             (vVarApps lc body (vars ++ [arg_bndr]))
342               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
343
344       mkClosure arg_ty res_ty env_ty fn env
345
346 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
347 buildEnv vvs
348   = do
349       lc <- builtin liftingContext
350       let (ty, venv, vbind) = mkVectEnv tys vs
351       (lenv, lbind) <- mkLiftEnv lc tys ls
352       return (ty, (venv, lenv),
353               \(venv,lenv) (vbody,lbody) ->
354               do
355                 let vbody' = vbind venv vbody
356                 lbody' <- lbind lenv lbody
357                 return (vbody', lbody'))
358   where
359     (vs,ls) = unzip vvs
360     tys     = map idType vs
361
362 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
363 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
364 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
365 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
366                         \env body -> Case env (mkWildId ty) (exprType body)
367                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
368   where
369     ty = mkCoreTupTy tys
370
371 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
372 mkLiftEnv lc [ty] [v]
373   = return (Var v, \env body ->
374                    do
375                      len <- lengthPA (Var v)
376                      return . Let (NonRec v env)
377                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
378
379 -- NOTE: this transparently deals with empty environments
380 mkLiftEnv lc tys vs
381   = do
382       (env_tc, env_tyargs) <- parrayReprTyCon vty
383       let [env_con] = tyConDataCons env_tc
384           
385           env = Var (dataConWrapId env_con)
386                 `mkTyApps`  env_tyargs
387                 `mkVarApps` (lc : vs)
388
389           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
390                           in
391                           return $ Case scrut (mkWildId (exprType scrut))
392                                         (exprType body)
393                                         [(DataAlt env_con, lc : bndrs, body)]
394       return (env, bind)
395   where
396     vty = mkCoreTupTy tys
397
398     bndrs | null vs   = [mkWildId unitTy]
399           | otherwise = vs
400