More refactoring
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPlusType, mkPlusTypes, mkCrossType, mkCrossTypes, mkEmbedType,
7   mkPADictType, mkPArrayType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   paDictArgType, paDictOfType, paDFunType,
10   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
11   polyAbstract, polyApply, polyVApply,
12   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
13   buildClosure, buildClosures,
14   mkClosureApp
15 ) where
16
17 #include "HsVersions.h"
18
19 import VectCore
20 import VectMonad
21
22 import DsUtils
23 import CoreSyn
24 import CoreUtils
25 import Type
26 import TypeRep
27 import TyCon
28 import DataCon            ( DataCon, dataConWrapId, dataConTag )
29 import Var
30 import Id                 ( mkWildId )
31 import MkId               ( unwrapFamInstScrut )
32 import PrelNames
33 import TysWiredIn
34 import BasicTypes         ( Boxity(..) )
35
36 import Outputable
37 import FastString
38
39 import Control.Monad         ( liftM, zipWithM_ )
40
41 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
42 collectAnnTypeArgs expr = go expr []
43   where
44     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
45     go e                             tys = (e, tys)
46
47 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
48 collectAnnTypeBinders expr = go [] expr
49   where
50     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
51     go bs e                           = (reverse bs, e)
52
53 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
54 collectAnnValBinders expr = go [] expr
55   where
56     go bs (_, AnnLam b e) | isId b = go (b:bs) e
57     go bs e                        = (reverse bs, e)
58
59 isAnnTypeArg :: AnnExpr b ann -> Bool
60 isAnnTypeArg (_, AnnType t) = True
61 isAnnTypeArg _              = False
62
63 mkDataConTag :: DataCon -> CoreExpr
64 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
65
66 isClosureTyCon :: TyCon -> Bool
67 isClosureTyCon tc = tyConName tc == closureTyConName
68
69 splitClosureTy :: Type -> (Type, Type)
70 splitClosureTy ty
71   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
72   , isClosureTyCon tc
73   = (arg_ty, res_ty)
74
75   | otherwise = pprPanic "splitClosureTy" (ppr ty)
76
77 isPArrayTyCon :: TyCon -> Bool
78 isPArrayTyCon tc = tyConName tc == parrayTyConName
79
80 splitPArrayTy :: Type -> Type
81 splitPArrayTy ty
82   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
83   , isPArrayTyCon tc
84   = arg_ty
85
86   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
87
88 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
89 mkBuiltinTyConApp get_tc tys
90   = do
91       tc <- builtin get_tc
92       return $ mkTyConApp tc tys
93
94 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
95 mkBuiltinTyConApps get_tc tys ty
96   = do
97       tc <- builtin get_tc
98       return $ foldr (mk tc) ty tys
99   where
100     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
101
102 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
103 mkBuiltinTyConApps1 get_tc dft [] = return dft
104 mkBuiltinTyConApps1 get_tc dft tys
105   = do
106       tc <- builtin get_tc
107       case tys of
108         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
109         _  -> return $ foldr1 (mk tc) tys
110   where
111     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
112
113 mkPlusType :: Type -> Type -> VM Type
114 mkPlusType ty1 ty2 = mkBuiltinTyConApp plusTyCon [ty1, ty2]
115
116 mkPlusTypes :: Type -> [Type] -> VM Type
117 mkPlusTypes = mkBuiltinTyConApps1 plusTyCon
118
119 mkCrossType :: Type -> Type -> VM Type
120 mkCrossType ty1 ty2 = mkBuiltinTyConApp crossTyCon [ty1, ty2]
121
122 mkCrossTypes :: Type -> [Type] -> VM Type
123 mkCrossTypes = mkBuiltinTyConApps1 crossTyCon
124
125 mkEmbedType :: Type -> VM Type
126 mkEmbedType ty = mkBuiltinTyConApp embedTyCon [ty]
127
128 mkClosureType :: Type -> Type -> VM Type
129 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
130
131 mkClosureTypes :: [Type] -> Type -> VM Type
132 mkClosureTypes = mkBuiltinTyConApps closureTyCon
133
134 mkPADictType :: Type -> VM Type
135 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
136
137 mkPArrayType :: Type -> VM Type
138 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
139
140 parrayReprTyCon :: Type -> VM (TyCon, [Type])
141 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
142
143 parrayReprDataCon :: Type -> VM (DataCon, [Type])
144 parrayReprDataCon ty
145   = do
146       (tc, arg_tys) <- parrayReprTyCon ty
147       let [dc] = tyConDataCons tc
148       return (dc, arg_tys)
149
150 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
151 mkVScrut (ve, le)
152   = do
153       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
154       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
155
156 paDictArgType :: TyVar -> VM (Maybe Type)
157 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
158   where
159     go ty k | Just k' <- kindView k = go ty k'
160     go ty (FunTy k1 k2)
161       = do
162           tv   <- newTyVar FSLIT("a") k1
163           mty1 <- go (TyVarTy tv) k1
164           case mty1 of
165             Just ty1 -> do
166                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
167                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
168             Nothing  -> go ty k2
169
170     go ty k
171       | isLiftedTypeKind k
172       = liftM Just (mkPADictType ty)
173
174     go ty k = return Nothing
175
176 paDictOfType :: Type -> VM CoreExpr
177 paDictOfType ty = paDictOfTyApp ty_fn ty_args
178   where
179     (ty_fn, ty_args) = splitAppTys ty
180
181 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
182 paDictOfTyApp ty_fn ty_args
183   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
184 paDictOfTyApp (TyVarTy tv) ty_args
185   = do
186       dfun <- maybeV (lookupTyVarPA tv)
187       paDFunApply dfun ty_args
188 paDictOfTyApp (TyConApp tc _) ty_args
189   = do
190       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
191       paDFunApply (Var dfun) ty_args
192 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
193
194 paDFunType :: TyCon -> VM Type
195 paDFunType tc
196   = do
197       margs <- mapM paDictArgType tvs
198       res   <- mkPADictType (mkTyConApp tc arg_tys)
199       return . mkForAllTys tvs
200              $ mkFunTys [arg | Just arg <- margs] res
201   where
202     tvs = tyConTyVars tc
203     arg_tys = mkTyVarTys tvs
204
205 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
206 paDFunApply dfun tys
207   = do
208       dicts <- mapM paDictOfType tys
209       return $ mkApps (mkTyApps dfun tys) dicts
210
211 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
212 paMethod method ty
213   = do
214       fn   <- builtin method
215       dict <- paDictOfType ty
216       return $ mkApps (Var fn) [Type ty, dict]
217
218 lengthPA :: CoreExpr -> VM CoreExpr
219 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
220   where
221     ty = splitPArrayTy (exprType x)
222
223 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
224 replicatePA len x = liftM (`mkApps` [len,x])
225                           (paMethod replicatePAVar (exprType x))
226
227 emptyPA :: Type -> VM CoreExpr
228 emptyPA = paMethod emptyPAVar
229
230 liftPA :: CoreExpr -> VM CoreExpr
231 liftPA x
232   = do
233       lc <- builtin liftingContext
234       replicatePA (Var lc) x
235
236 newLocalVVar :: FastString -> Type -> VM VVar
237 newLocalVVar fs vty
238   = do
239       lty <- mkPArrayType vty
240       vv  <- newLocalVar fs vty
241       lv  <- newLocalVar fs lty
242       return (vv,lv)
243
244 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
245 polyAbstract tvs p
246   = localV
247   $ do
248       mdicts <- mapM mk_dict_var tvs
249       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
250       p (mk_lams mdicts)
251   where
252     mk_dict_var tv = do
253                        r <- paDictArgType tv
254                        case r of
255                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
256                          Nothing -> return Nothing
257
258     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
259
260 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
261 polyApply expr tys
262   = do
263       dicts <- mapM paDictOfType tys
264       return $ expr `mkTyApps` tys `mkApps` dicts
265
266 polyVApply :: VExpr -> [Type] -> VM VExpr
267 polyVApply expr tys
268   = do
269       dicts <- mapM paDictOfType tys
270       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
271
272 hoistBinding :: Var -> CoreExpr -> VM ()
273 hoistBinding v e = updGEnv $ \env ->
274   env { global_bindings = (v,e) : global_bindings env }
275
276 hoistExpr :: FastString -> CoreExpr -> VM Var
277 hoistExpr fs expr
278   = do
279       var <- newLocalVar fs (exprType expr)
280       hoistBinding var expr
281       return var
282
283 hoistVExpr :: VExpr -> VM VVar
284 hoistVExpr (ve, le)
285   = do
286       fs <- getBindName
287       vv <- hoistExpr ('v' `consFS` fs) ve
288       lv <- hoistExpr ('l' `consFS` fs) le
289       return (vv, lv)
290
291 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
292 hoistPolyVExpr tvs p
293   = do
294       expr <- closedV . polyAbstract tvs $ \abstract ->
295               liftM (mapVect abstract) p
296       fn   <- hoistVExpr expr
297       polyVApply (vVar fn) (mkTyVarTys tvs)
298
299 takeHoisted :: VM [(Var, CoreExpr)]
300 takeHoisted
301   = do
302       env <- readGEnv id
303       setGEnv $ env { global_bindings = [] }
304       return $ global_bindings env
305
306 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
307 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
308   = do
309       dict <- paDictOfType env_ty
310       mkv  <- builtin mkClosureVar
311       mkl  <- builtin mkClosurePVar
312       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
313               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
314
315 mkClosureApp :: VExpr -> VExpr -> VM VExpr
316 mkClosureApp (vclo, lclo) (varg, larg)
317   = do
318       vapply <- builtin applyClosureVar
319       lapply <- builtin applyClosurePVar
320       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
321               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
322   where
323     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
324
325 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
326 buildClosures tvs vars [] res_ty mk_body
327   = mk_body
328 buildClosures tvs vars [arg_ty] res_ty mk_body
329   = buildClosure tvs vars arg_ty res_ty mk_body
330 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
331   = do
332       res_ty' <- mkClosureTypes arg_tys res_ty
333       arg <- newLocalVVar FSLIT("x") arg_ty
334       buildClosure tvs vars arg_ty res_ty'
335         . hoistPolyVExpr tvs
336         $ do
337             lc <- builtin liftingContext
338             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
339             return $ vLams lc (vars ++ [arg]) clo
340
341 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
342 --   where
343 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
344 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
345 --
346 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
347 buildClosure tvs vars arg_ty res_ty mk_body
348   = do
349       (env_ty, env, bind) <- buildEnv vars
350       env_bndr <- newLocalVVar FSLIT("env") env_ty
351       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
352
353       fn <- hoistPolyVExpr tvs
354           $ do
355               lc    <- builtin liftingContext
356               body  <- mk_body
357               body' <- bind (vVar env_bndr)
358                             (vVarApps lc body (vars ++ [arg_bndr]))
359               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
360
361       mkClosure arg_ty res_ty env_ty fn env
362
363 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
364 buildEnv vvs
365   = do
366       lc <- builtin liftingContext
367       let (ty, venv, vbind) = mkVectEnv tys vs
368       (lenv, lbind) <- mkLiftEnv lc tys ls
369       return (ty, (venv, lenv),
370               \(venv,lenv) (vbody,lbody) ->
371               do
372                 let vbody' = vbind venv vbody
373                 lbody' <- lbind lenv lbody
374                 return (vbody', lbody'))
375   where
376     (vs,ls) = unzip vvs
377     tys     = map idType vs
378
379 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
380 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
381 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
382 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
383                         \env body -> Case env (mkWildId ty) (exprType body)
384                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
385   where
386     ty = mkCoreTupTy tys
387
388 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
389 mkLiftEnv lc [ty] [v]
390   = return (Var v, \env body ->
391                    do
392                      len <- lengthPA (Var v)
393                      return . Let (NonRec v env)
394                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
395
396 -- NOTE: this transparently deals with empty environments
397 mkLiftEnv lc tys vs
398   = do
399       (env_tc, env_tyargs) <- parrayReprTyCon vty
400       let [env_con] = tyConDataCons env_tc
401           
402           env = Var (dataConWrapId env_con)
403                 `mkTyApps`  env_tyargs
404                 `mkVarApps` (lc : vs)
405
406           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
407                           in
408                           return $ Case scrut (mkWildId (exprType scrut))
409                                         (exprType body)
410                                         [(DataAlt env_con, lc : bndrs, body)]
411       return (env, bind)
412   where
413     vty = mkCoreTupTy tys
414
415     bndrs | null vs   = [mkWildId unitTy]
416           | otherwise = vs
417