More refactoring
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   splitClosureTy,
5   mkPADictType, mkPArrayType,
6   paDictArgType, paDictOfType,
7   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
8   polyAbstract, polyApply, polyVApply,
9   lookupPArrayFamInst,
10   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
11   buildClosure, buildClosures,
12   mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon            ( dataConWrapId )
27 import Var
28 import Id                 ( mkWildId )
29 import MkId               ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes         ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad         ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41   where
42     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43     go e                             tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47   where
48     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49     go bs e                           = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isId b = go (b:bs) e
55     go bs e                        = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _              = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67   , isClosureTyCon tc
68   = (arg_ty, res_ty)
69
70   | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78   , isPArrayTyCon tc
79   = arg_ty
80
81   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85   = do
86       tc <- builtin closureTyCon
87       return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91   = do
92       tc <- builtin closureTyCon
93       return $ foldr (mk tc) res_ty arg_tys
94   where
95     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99   = do
100       tc <- builtin paDictTyCon
101       return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105   = do
106       tc <- builtin parrayTyCon
107       return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111   where
112     go ty k | Just k' <- kindView k = go ty k'
113     go ty (FunTy k1 k2)
114       = do
115           tv   <- newTyVar FSLIT("a") k1
116           mty1 <- go (TyVarTy tv) k1
117           case mty1 of
118             Just ty1 -> do
119                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
120                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
121             Nothing  -> go ty k2
122
123     go ty k
124       | isLiftedTypeKind k
125       = liftM Just (mkPADictType ty)
126
127     go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131   where
132     (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138   = do
139       dfun <- maybeV (lookupTyVarPA tv)
140       paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142   = do
143       pa_class <- builtin paClass
144       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
145       paDFunApply (Var dfun) ty_args'
146 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
147
148 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
149 paDFunApply dfun tys
150   = do
151       dicts <- mapM paDictOfType tys
152       return $ mkApps (mkTyApps dfun tys) dicts
153
154 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
155 paMethod method ty
156   = do
157       fn   <- builtin method
158       dict <- paDictOfType ty
159       return $ mkApps (Var fn) [Type ty, dict]
160
161 lengthPA :: CoreExpr -> VM CoreExpr
162 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
163   where
164     ty = splitPArrayTy (exprType x)
165
166 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
167 replicatePA len x = liftM (`mkApps` [len,x])
168                           (paMethod replicatePAVar (exprType x))
169
170 emptyPA :: Type -> VM CoreExpr
171 emptyPA = paMethod emptyPAVar
172
173 liftPA :: CoreExpr -> VM CoreExpr
174 liftPA x
175   = do
176       lc <- builtin liftingContext
177       replicatePA (Var lc) x
178
179 newLocalVVar :: FastString -> Type -> VM VVar
180 newLocalVVar fs vty
181   = do
182       lty <- mkPArrayType vty
183       vv  <- newLocalVar fs vty
184       lv  <- newLocalVar fs lty
185       return (vv,lv)
186
187 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
188 polyAbstract tvs p
189   = localV
190   $ do
191       mdicts <- mapM mk_dict_var tvs
192       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
193       p (mk_lams mdicts)
194   where
195     mk_dict_var tv = do
196                        r <- paDictArgType tv
197                        case r of
198                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
199                          Nothing -> return Nothing
200
201     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
202
203 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
204 polyApply expr tys
205   = do
206       dicts <- mapM paDictOfType tys
207       return $ expr `mkTyApps` tys `mkApps` dicts
208
209 polyVApply :: VExpr -> [Type] -> VM VExpr
210 polyVApply expr tys
211   = do
212       dicts <- mapM paDictOfType tys
213       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
214
215 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
216 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
217
218 hoistBinding :: Var -> CoreExpr -> VM ()
219 hoistBinding v e = updGEnv $ \env ->
220   env { global_bindings = (v,e) : global_bindings env }
221
222 hoistExpr :: FastString -> CoreExpr -> VM Var
223 hoistExpr fs expr
224   = do
225       var <- newLocalVar fs (exprType expr)
226       hoistBinding var expr
227       return var
228
229 hoistVExpr :: VExpr -> VM VVar
230 hoistVExpr (ve, le)
231   = do
232       fs <- getBindName
233       vv <- hoistExpr ('v' `consFS` fs) ve
234       lv <- hoistExpr ('l' `consFS` fs) le
235       return (vv, lv)
236
237 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
238 hoistPolyVExpr tvs p
239   = do
240       expr <- closedV . polyAbstract tvs $ \abstract ->
241               liftM (mapVect abstract) p
242       fn   <- hoistVExpr expr
243       polyVApply (vVar fn) (mkTyVarTys tvs)
244
245 takeHoisted :: VM [(Var, CoreExpr)]
246 takeHoisted
247   = do
248       env <- readGEnv id
249       setGEnv $ env { global_bindings = [] }
250       return $ global_bindings env
251
252 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
253 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
254   = do
255       dict <- paDictOfType env_ty
256       mkv  <- builtin mkClosureVar
257       mkl  <- builtin mkClosurePVar
258       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
259               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
260
261 mkClosureApp :: VExpr -> VExpr -> VM VExpr
262 mkClosureApp (vclo, lclo) (varg, larg)
263   = do
264       vapply <- builtin applyClosureVar
265       lapply <- builtin applyClosurePVar
266       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
267               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
268   where
269     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
270
271 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
272 buildClosures tvs vars [arg_ty] res_ty mk_body
273   = buildClosure tvs vars arg_ty res_ty mk_body
274 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
275   = do
276       res_ty' <- mkClosureTypes arg_tys res_ty
277       arg <- newLocalVVar FSLIT("x") arg_ty
278       buildClosure tvs vars arg_ty res_ty'
279         . hoistPolyVExpr tvs
280         $ do
281             lc <- builtin liftingContext
282             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
283             return $ vLams lc (vars ++ [arg]) clo
284
285 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
286 --   where
287 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
288 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
289 --
290 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
291 buildClosure tvs vars arg_ty res_ty mk_body
292   = do
293       (env_ty, env, bind) <- buildEnv vars
294       env_bndr <- newLocalVVar FSLIT("env") env_ty
295       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
296
297       fn <- hoistPolyVExpr tvs
298           $ do
299               lc    <- builtin liftingContext
300               body  <- mk_body
301               body' <- bind (vVar env_bndr)
302                             (vVarApps lc body (vars ++ [arg_bndr]))
303               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
304
305       mkClosure arg_ty res_ty env_ty fn env
306
307 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
308 buildEnv vvs
309   = do
310       lc <- builtin liftingContext
311       let (ty, venv, vbind) = mkVectEnv tys vs
312       (lenv, lbind) <- mkLiftEnv lc tys ls
313       return (ty, (venv, lenv),
314               \(venv,lenv) (vbody,lbody) ->
315               do
316                 let vbody' = vbind venv vbody
317                 lbody' <- lbind lenv lbody
318                 return (vbody', lbody'))
319   where
320     (vs,ls) = unzip vvs
321     tys     = map idType vs
322
323 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
324 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
325 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
326 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
327                         \env body -> Case env (mkWildId ty) (exprType body)
328                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
329   where
330     ty = mkCoreTupTy tys
331
332 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
333 mkLiftEnv lc [ty] [v]
334   = return (Var v, \env body ->
335                    do
336                      len <- lengthPA (Var v)
337                      return . Let (NonRec v env)
338                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
339
340 -- NOTE: this transparently deals with empty environments
341 mkLiftEnv lc tys vs
342   = do
343       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
344       let [env_con] = tyConDataCons env_tc
345           
346           env = Var (dataConWrapId env_con)
347                 `mkTyApps`  env_tyargs
348                 `mkVarApps` (lc : vs)
349
350           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
351                           in
352                           return $ Case scrut (mkWildId (exprType scrut))
353                                         (exprType body)
354                                         [(DataAlt env_con, lc : bndrs, body)]
355       return (env, bind)
356   where
357     vty = mkCoreTupTy tys
358
359     bndrs | null vs   = [mkWildId unitTy]
360           | otherwise = vs
361