Move code
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   splitClosureTy,
5   mkPADictType, mkPArrayType,
6   paDictArgType, paDictOfType,
7   paMethod, lengthPA, replicatePA, emptyPA,
8   polyAbstract, polyApply, polyVApply,
9   lookupPArrayFamInst,
10   hoistExpr, hoistPolyVExpr, takeHoisted,
11   buildClosure, buildClosures,
12   mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon            ( dataConWrapId )
27 import Var
28 import Id                 ( mkWildId )
29 import MkId               ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes         ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad         ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41   where
42     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43     go e                             tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47   where
48     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49     go bs e                           = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isId b = go (b:bs) e
55     go bs e                        = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _              = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67   , isClosureTyCon tc
68   = (arg_ty, res_ty)
69
70   | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78   , isPArrayTyCon tc
79   = arg_ty
80
81   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85   = do
86       tc <- builtin closureTyCon
87       return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91   = do
92       tc <- builtin closureTyCon
93       return $ foldr (mk tc) res_ty arg_tys
94   where
95     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99   = do
100       tc <- builtin paDictTyCon
101       return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105   = do
106       tc <- builtin parrayTyCon
107       return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111   where
112     go ty k | Just k' <- kindView k = go ty k'
113     go ty (FunTy k1 k2)
114       = do
115           tv   <- newTyVar FSLIT("a") k1
116           mty1 <- go (TyVarTy tv) k1
117           case mty1 of
118             Just ty1 -> do
119                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
120                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
121             Nothing  -> go ty k2
122
123     go ty k
124       | isLiftedTypeKind k
125       = liftM Just (mkPADictType ty)
126
127     go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131   where
132     (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138   = do
139       dfun <- maybeV (lookupTyVarPA tv)
140       paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142   = do
143       pa_class <- builtin paClass
144       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
145       paDFunApply (Var dfun) ty_args'
146 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
147
148 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
149 paDFunApply dfun tys
150   = do
151       dicts <- mapM paDictOfType tys
152       return $ mkApps (mkTyApps dfun tys) dicts
153
154 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
155 paMethod method ty
156   = do
157       fn   <- builtin method
158       dict <- paDictOfType ty
159       return $ mkApps (Var fn) [Type ty, dict]
160
161 lengthPA :: CoreExpr -> VM CoreExpr
162 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
163   where
164     ty = splitPArrayTy (exprType x)
165
166 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
167 replicatePA len x = liftM (`mkApps` [len,x])
168                           (paMethod replicatePAVar (exprType x))
169
170 emptyPA :: Type -> VM CoreExpr
171 emptyPA = paMethod emptyPAVar
172
173 newLocalVVar :: FastString -> Type -> VM VVar
174 newLocalVVar fs vty
175   = do
176       lty <- mkPArrayType vty
177       vv  <- newLocalVar fs vty
178       lv  <- newLocalVar fs lty
179       return (vv,lv)
180
181 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
182 polyAbstract tvs p
183   = localV
184   $ do
185       mdicts <- mapM mk_dict_var tvs
186       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
187       p (mk_lams mdicts)
188   where
189     mk_dict_var tv = do
190                        r <- paDictArgType tv
191                        case r of
192                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
193                          Nothing -> return Nothing
194
195     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
196
197 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
198 polyApply expr tys
199   = do
200       dicts <- mapM paDictOfType tys
201       return $ expr `mkTyApps` tys `mkApps` dicts
202
203 polyVApply :: VExpr -> [Type] -> VM VExpr
204 polyVApply expr tys
205   = do
206       dicts <- mapM paDictOfType tys
207       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
208
209 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
210 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
211
212 hoistExpr :: FastString -> CoreExpr -> VM Var
213 hoistExpr fs expr
214   = do
215       var <- newLocalVar fs (exprType expr)
216       updGEnv $ \env ->
217         env { global_bindings = (var, expr) : global_bindings env }
218       return var
219
220 hoistVExpr :: VExpr -> VM VVar
221 hoistVExpr (ve, le)
222   = do
223       fs <- getBindName
224       vv <- hoistExpr ('v' `consFS` fs) ve
225       lv <- hoistExpr ('l' `consFS` fs) le
226       return (vv, lv)
227
228 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
229 hoistPolyVExpr tvs p
230   = do
231       expr <- closedV . polyAbstract tvs $ \abstract ->
232               liftM (mapVect abstract) p
233       fn   <- hoistVExpr expr
234       polyVApply (vVar fn) (mkTyVarTys tvs)
235
236 takeHoisted :: VM [(Var, CoreExpr)]
237 takeHoisted
238   = do
239       env <- readGEnv id
240       setGEnv $ env { global_bindings = [] }
241       return $ global_bindings env
242
243 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
244 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
245   = do
246       dict <- paDictOfType env_ty
247       mkv  <- builtin mkClosureVar
248       mkl  <- builtin mkClosurePVar
249       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
250               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
251
252 mkClosureApp :: VExpr -> VExpr -> VM VExpr
253 mkClosureApp (vclo, lclo) (varg, larg)
254   = do
255       vapply <- builtin applyClosureVar
256       lapply <- builtin applyClosurePVar
257       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
258               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
259   where
260     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
261
262 buildClosures :: [TyVar] -> Var -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
263 buildClosures tvs lc vars [arg_ty] res_ty mk_body
264   = buildClosure tvs lc vars arg_ty res_ty mk_body
265 buildClosures tvs lc vars (arg_ty : arg_tys) res_ty mk_body
266   = do
267       res_ty' <- mkClosureTypes arg_tys res_ty
268       arg <- newLocalVVar FSLIT("x") arg_ty
269       buildClosure tvs lc vars arg_ty res_ty'
270         . hoistPolyVExpr tvs
271         $ do
272             clo <- buildClosures tvs lc (vars ++ [arg]) arg_tys res_ty mk_body
273             return $ vLams lc (vars ++ [arg]) clo
274
275 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
276 --   where
277 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
278 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
279 --
280 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
281 buildClosure tvs lv vars arg_ty res_ty mk_body
282   = do
283       (env_ty, env, bind) <- buildEnv lv vars
284       env_bndr <- newLocalVVar FSLIT("env") env_ty
285       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
286
287       fn <- hoistPolyVExpr tvs
288           $ do
289               body  <- mk_body
290               body' <- bind (vVar env_bndr)
291                             (vVarApps lv body (vars ++ [arg_bndr]))
292               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
293
294       mkClosure arg_ty res_ty env_ty fn env
295
296 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
297 buildEnv lv vvs
298   = do
299       let (ty, venv, vbind) = mkVectEnv tys vs
300       (lenv, lbind) <- mkLiftEnv lv tys ls
301       return (ty, (venv, lenv),
302               \(venv,lenv) (vbody,lbody) ->
303               do
304                 let vbody' = vbind venv vbody
305                 lbody' <- lbind lenv lbody
306                 return (vbody', lbody'))
307   where
308     (vs,ls) = unzip vvs
309     tys     = map idType vs
310
311 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
312 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
313 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
314 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
315                         \env body -> Case env (mkWildId ty) (exprType body)
316                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
317   where
318     ty = mkCoreTupTy tys
319
320 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
321 mkLiftEnv lv [ty] [v]
322   = return (Var v, \env body ->
323                    do
324                      len <- lengthPA (Var v)
325                      return . Let (NonRec v env)
326                             $ Case len lv (exprType body) [(DEFAULT, [], body)])
327
328 -- NOTE: this transparently deals with empty environments
329 mkLiftEnv lv tys vs
330   = do
331       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
332       let [env_con] = tyConDataCons env_tc
333           
334           env = Var (dataConWrapId env_con)
335                 `mkTyApps`  env_tyargs
336                 `mkVarApps` (lv : vs)
337
338           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
339                           in
340                           return $ Case scrut (mkWildId (exprType scrut))
341                                         (exprType body)
342                                         [(DataAlt env_con, lv : bndrs, body)]
343       return (env, bind)
344   where
345     vty = mkCoreTupTy tys
346
347     bndrs | null vs   = [mkWildId unitTy]
348           | otherwise = vs
349