Rename functions
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType,
6   paMethod, lengthPA, replicatePA, emptyPA,
7   polyAbstract, polyApply, polyVApply,
8   lookupPArrayFamInst,
9   hoistExpr, hoistPolyVExpr, takeHoisted,
10   buildClosure
11 ) where
12
13 #include "HsVersions.h"
14
15 import VectCore
16 import VectMonad
17
18 import DsUtils
19 import CoreSyn
20 import CoreUtils
21 import Type
22 import TypeRep
23 import TyCon
24 import DataCon            ( dataConWrapId )
25 import Var
26 import Id                 ( mkWildId )
27 import MkId               ( unwrapFamInstScrut )
28 import PrelNames
29 import TysWiredIn
30 import BasicTypes         ( Boxity(..) )
31
32 import Outputable
33 import FastString
34
35 import Control.Monad         ( liftM, zipWithM_ )
36
37 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
38 collectAnnTypeArgs expr = go expr []
39   where
40     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
41     go e                             tys = (e, tys)
42
43 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
44 collectAnnTypeBinders expr = go [] expr
45   where
46     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
47     go bs e                           = (reverse bs, e)
48
49 isAnnTypeArg :: AnnExpr b ann -> Bool
50 isAnnTypeArg (_, AnnType t) = True
51 isAnnTypeArg _              = False
52
53 isClosureTyCon :: TyCon -> Bool
54 isClosureTyCon tc = tyConName tc == closureTyConName
55
56 splitClosureTy :: Type -> (Type, Type)
57 splitClosureTy ty
58   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
59   , isClosureTyCon tc
60   = (arg_ty, res_ty)
61
62   | otherwise = pprPanic "splitClosureTy" (ppr ty)
63
64 isPArrayTyCon :: TyCon -> Bool
65 isPArrayTyCon tc = tyConName tc == parrayTyConName
66
67 splitPArrayTy :: Type -> Type
68 splitPArrayTy ty
69   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
70   , isPArrayTyCon tc
71   = arg_ty
72
73   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
74
75 mkPADictType :: Type -> VM Type
76 mkPADictType ty
77   = do
78       tc <- builtin paDictTyCon
79       return $ TyConApp tc [ty]
80
81 mkPArrayType :: Type -> VM Type
82 mkPArrayType ty
83   = do
84       tc <- builtin parrayTyCon
85       return $ TyConApp tc [ty]
86
87 paDictArgType :: TyVar -> VM (Maybe Type)
88 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
89   where
90     go ty k | Just k' <- kindView k = go ty k'
91     go ty (FunTy k1 k2)
92       = do
93           tv   <- newTyVar FSLIT("a") k1
94           mty1 <- go (TyVarTy tv) k1
95           case mty1 of
96             Just ty1 -> do
97                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
98                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
99             Nothing  -> go ty k2
100
101     go ty k
102       | isLiftedTypeKind k
103       = liftM Just (mkPADictType ty)
104
105     go ty k = return Nothing
106
107 paDictOfType :: Type -> VM CoreExpr
108 paDictOfType ty = paDictOfTyApp ty_fn ty_args
109   where
110     (ty_fn, ty_args) = splitAppTys ty
111
112 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
113 paDictOfTyApp ty_fn ty_args
114   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
115 paDictOfTyApp (TyVarTy tv) ty_args
116   = do
117       dfun <- maybeV (lookupTyVarPA tv)
118       paDFunApply dfun ty_args
119 paDictOfTyApp (TyConApp tc _) ty_args
120   = do
121       pa_class <- builtin paClass
122       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
123       paDFunApply (Var dfun) ty_args'
124 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
125
126 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
127 paDFunApply dfun tys
128   = do
129       dicts <- mapM paDictOfType tys
130       return $ mkApps (mkTyApps dfun tys) dicts
131
132 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
133 paMethod method ty
134   = do
135       fn   <- builtin method
136       dict <- paDictOfType ty
137       return $ mkApps (Var fn) [Type ty, dict]
138
139 lengthPA :: CoreExpr -> VM CoreExpr
140 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
141   where
142     ty = splitPArrayTy (exprType x)
143
144 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
145 replicatePA len x = liftM (`mkApps` [len,x])
146                           (paMethod replicatePAVar (exprType x))
147
148 emptyPA :: Type -> VM CoreExpr
149 emptyPA = paMethod emptyPAVar
150
151 newLocalVVar :: FastString -> Type -> VM VVar
152 newLocalVVar fs vty
153   = do
154       lty <- mkPArrayType vty
155       vv  <- newLocalVar fs vty
156       lv  <- newLocalVar fs lty
157       return (vv,lv)
158
159 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
160 polyAbstract tvs p
161   = localV
162   $ do
163       mdicts <- mapM mk_dict_var tvs
164       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
165       p (mk_lams mdicts)
166   where
167     mk_dict_var tv = do
168                        r <- paDictArgType tv
169                        case r of
170                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
171                          Nothing -> return Nothing
172
173     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
174
175 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
176 polyApply expr tys
177   = do
178       dicts <- mapM paDictOfType tys
179       return $ expr `mkTyApps` tys `mkApps` dicts
180
181 polyVApply :: VExpr -> [Type] -> VM VExpr
182 polyVApply expr tys
183   = do
184       dicts <- mapM paDictOfType tys
185       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
186
187 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
188 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
189
190 hoistExpr :: FastString -> CoreExpr -> VM Var
191 hoistExpr fs expr
192   = do
193       var <- newLocalVar fs (exprType expr)
194       updGEnv $ \env ->
195         env { global_bindings = (var, expr) : global_bindings env }
196       return var
197
198 hoistVExpr :: FastString -> VExpr -> VM VVar
199 hoistVExpr fs (ve, le)
200   = do
201       vv <- hoistExpr ('v' `consFS` fs) ve
202       lv <- hoistExpr ('l' `consFS` fs) le
203       return (vv, lv)
204
205 hoistPolyVExpr :: FastString -> [TyVar] -> VM VExpr -> VM VExpr
206 hoistPolyVExpr fs tvs p
207   = do
208       expr <- closedV . polyAbstract tvs $ \abstract ->
209               liftM (mapVect abstract) p
210       fn   <- hoistVExpr fs expr
211       polyVApply (vVar fn) (mkTyVarTys tvs)
212
213 takeHoisted :: VM [(Var, CoreExpr)]
214 takeHoisted
215   = do
216       env <- readGEnv id
217       setGEnv $ env { global_bindings = [] }
218       return $ global_bindings env
219
220
221 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
222 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
223   = do
224       dict <- paDictOfType env_ty
225       mkv  <- builtin mkClosureVar
226       mkl  <- builtin mkClosurePVar
227       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
228               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
229
230 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
231 --   where
232 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
233 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
234
235 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
236 buildClosure tvs lv vars arg_ty res_ty mk_body
237   = do
238       (env_ty, env, bind) <- buildEnv lv vars
239       env_bndr <- newLocalVVar FSLIT("env") env_ty
240       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
241
242       fn <- hoistPolyVExpr FSLIT("fn") tvs
243           $ do
244               body  <- mk_body
245               body' <- bind (vVar env_bndr)
246                             (vVarApps lv body (vars ++ [arg_bndr]))
247               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
248
249       mkClosure arg_ty res_ty env_ty fn env
250
251 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
252 buildEnv lv vvs
253   = do
254       let (ty, venv, vbind) = mkVectEnv tys vs
255       (lenv, lbind) <- mkLiftEnv lv tys ls
256       return (ty, (venv, lenv),
257               \(venv,lenv) (vbody,lbody) ->
258               do
259                 let vbody' = vbind venv vbody
260                 lbody' <- lbind lenv lbody
261                 return (vbody', lbody'))
262   where
263     (vs,ls) = unzip vvs
264     tys     = map idType vs
265
266 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
267 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
268 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
269 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
270                         \env body -> Case env (mkWildId ty) (exprType body)
271                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
272   where
273     ty = mkCoreTupTy tys
274
275 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
276 mkLiftEnv lv [ty] [v]
277   = return (Var v, \env body ->
278                    do
279                      len <- lengthPA (Var v)
280                      return . Let (NonRec v env)
281                             $ Case len lv (exprType body) [(DEFAULT, [], body)])
282
283 -- NOTE: this transparently deals with empty environments
284 mkLiftEnv lv tys vs
285   = do
286       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
287       let [env_con] = tyConDataCons env_tc
288           
289           env = Var (dataConWrapId env_con)
290                 `mkTyApps`  env_tyargs
291                 `mkVarApps` (lv : vs)
292
293           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
294                           in
295                           return $ Case scrut (mkWildId (exprType scrut))
296                                         (exprType body)
297                                         [(DataAlt env_con, lv : bndrs, body)]
298       return (env, bind)
299   where
300     vty = mkCoreTupTy tys
301
302     bndrs | null vs   = [mkWildId unitTy]
303           | otherwise = vs
304