Nicer names for hoisted functions
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   splitClosureTy,
5   mkPADictType, mkPArrayType,
6   paDictArgType, paDictOfType,
7   paMethod, lengthPA, replicatePA, emptyPA,
8   polyAbstract, polyApply, polyVApply,
9   lookupPArrayFamInst,
10   hoistExpr, hoistPolyVExpr, takeHoisted,
11   buildClosure, buildClosures
12 ) where
13
14 #include "HsVersions.h"
15
16 import VectCore
17 import VectMonad
18
19 import DsUtils
20 import CoreSyn
21 import CoreUtils
22 import Type
23 import TypeRep
24 import TyCon
25 import DataCon            ( dataConWrapId )
26 import Var
27 import Id                 ( mkWildId )
28 import MkId               ( unwrapFamInstScrut )
29 import PrelNames
30 import TysWiredIn
31 import BasicTypes         ( Boxity(..) )
32
33 import Outputable
34 import FastString
35
36 import Control.Monad         ( liftM, zipWithM_ )
37
38 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
39 collectAnnTypeArgs expr = go expr []
40   where
41     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
42     go e                             tys = (e, tys)
43
44 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
45 collectAnnTypeBinders expr = go [] expr
46   where
47     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
48     go bs e                           = (reverse bs, e)
49
50 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
51 collectAnnValBinders expr = go [] expr
52   where
53     go bs (_, AnnLam b e) | isId b = go (b:bs) e
54     go bs e                        = (reverse bs, e)
55
56 isAnnTypeArg :: AnnExpr b ann -> Bool
57 isAnnTypeArg (_, AnnType t) = True
58 isAnnTypeArg _              = False
59
60 isClosureTyCon :: TyCon -> Bool
61 isClosureTyCon tc = tyConName tc == closureTyConName
62
63 splitClosureTy :: Type -> (Type, Type)
64 splitClosureTy ty
65   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
66   , isClosureTyCon tc
67   = (arg_ty, res_ty)
68
69   | otherwise = pprPanic "splitClosureTy" (ppr ty)
70
71 isPArrayTyCon :: TyCon -> Bool
72 isPArrayTyCon tc = tyConName tc == parrayTyConName
73
74 splitPArrayTy :: Type -> Type
75 splitPArrayTy ty
76   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
77   , isPArrayTyCon tc
78   = arg_ty
79
80   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
81
82 mkClosureType :: Type -> Type -> VM Type
83 mkClosureType arg_ty res_ty
84   = do
85       tc <- builtin closureTyCon
86       return $ mkTyConApp tc [arg_ty, res_ty]
87
88 mkClosureTypes :: [Type] -> Type -> VM Type
89 mkClosureTypes arg_tys res_ty
90   = do
91       tc <- builtin closureTyCon
92       return $ foldr (mk tc) res_ty arg_tys
93   where
94     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
95
96 mkPADictType :: Type -> VM Type
97 mkPADictType ty
98   = do
99       tc <- builtin paDictTyCon
100       return $ TyConApp tc [ty]
101
102 mkPArrayType :: Type -> VM Type
103 mkPArrayType ty
104   = do
105       tc <- builtin parrayTyCon
106       return $ TyConApp tc [ty]
107
108 paDictArgType :: TyVar -> VM (Maybe Type)
109 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
110   where
111     go ty k | Just k' <- kindView k = go ty k'
112     go ty (FunTy k1 k2)
113       = do
114           tv   <- newTyVar FSLIT("a") k1
115           mty1 <- go (TyVarTy tv) k1
116           case mty1 of
117             Just ty1 -> do
118                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
119                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
120             Nothing  -> go ty k2
121
122     go ty k
123       | isLiftedTypeKind k
124       = liftM Just (mkPADictType ty)
125
126     go ty k = return Nothing
127
128 paDictOfType :: Type -> VM CoreExpr
129 paDictOfType ty = paDictOfTyApp ty_fn ty_args
130   where
131     (ty_fn, ty_args) = splitAppTys ty
132
133 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
134 paDictOfTyApp ty_fn ty_args
135   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
136 paDictOfTyApp (TyVarTy tv) ty_args
137   = do
138       dfun <- maybeV (lookupTyVarPA tv)
139       paDFunApply dfun ty_args
140 paDictOfTyApp (TyConApp tc _) ty_args
141   = do
142       pa_class <- builtin paClass
143       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
144       paDFunApply (Var dfun) ty_args'
145 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
146
147 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
148 paDFunApply dfun tys
149   = do
150       dicts <- mapM paDictOfType tys
151       return $ mkApps (mkTyApps dfun tys) dicts
152
153 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
154 paMethod method ty
155   = do
156       fn   <- builtin method
157       dict <- paDictOfType ty
158       return $ mkApps (Var fn) [Type ty, dict]
159
160 lengthPA :: CoreExpr -> VM CoreExpr
161 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
162   where
163     ty = splitPArrayTy (exprType x)
164
165 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
166 replicatePA len x = liftM (`mkApps` [len,x])
167                           (paMethod replicatePAVar (exprType x))
168
169 emptyPA :: Type -> VM CoreExpr
170 emptyPA = paMethod emptyPAVar
171
172 newLocalVVar :: FastString -> Type -> VM VVar
173 newLocalVVar fs vty
174   = do
175       lty <- mkPArrayType vty
176       vv  <- newLocalVar fs vty
177       lv  <- newLocalVar fs lty
178       return (vv,lv)
179
180 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
181 polyAbstract tvs p
182   = localV
183   $ do
184       mdicts <- mapM mk_dict_var tvs
185       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
186       p (mk_lams mdicts)
187   where
188     mk_dict_var tv = do
189                        r <- paDictArgType tv
190                        case r of
191                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
192                          Nothing -> return Nothing
193
194     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
195
196 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
197 polyApply expr tys
198   = do
199       dicts <- mapM paDictOfType tys
200       return $ expr `mkTyApps` tys `mkApps` dicts
201
202 polyVApply :: VExpr -> [Type] -> VM VExpr
203 polyVApply expr tys
204   = do
205       dicts <- mapM paDictOfType tys
206       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
207
208 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
209 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
210
211 hoistExpr :: FastString -> CoreExpr -> VM Var
212 hoistExpr fs expr
213   = do
214       var <- newLocalVar fs (exprType expr)
215       updGEnv $ \env ->
216         env { global_bindings = (var, expr) : global_bindings env }
217       return var
218
219 hoistVExpr :: VExpr -> VM VVar
220 hoistVExpr (ve, le)
221   = do
222       fs <- getBindName
223       vv <- hoistExpr ('v' `consFS` fs) ve
224       lv <- hoistExpr ('l' `consFS` fs) le
225       return (vv, lv)
226
227 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
228 hoistPolyVExpr tvs p
229   = do
230       expr <- closedV . polyAbstract tvs $ \abstract ->
231               liftM (mapVect abstract) p
232       fn   <- hoistVExpr expr
233       polyVApply (vVar fn) (mkTyVarTys tvs)
234
235 takeHoisted :: VM [(Var, CoreExpr)]
236 takeHoisted
237   = do
238       env <- readGEnv id
239       setGEnv $ env { global_bindings = [] }
240       return $ global_bindings env
241
242
243 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
244 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
245   = do
246       dict <- paDictOfType env_ty
247       mkv  <- builtin mkClosureVar
248       mkl  <- builtin mkClosurePVar
249       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
250               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
251
252 buildClosures :: [TyVar] -> Var -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
253 buildClosures tvs lc vars [arg_ty] res_ty mk_body
254   = buildClosure tvs lc vars arg_ty res_ty mk_body
255 buildClosures tvs lc vars (arg_ty : arg_tys) res_ty mk_body
256   = do
257       res_ty' <- mkClosureTypes arg_tys res_ty
258       arg <- newLocalVVar FSLIT("x") arg_ty
259       buildClosure tvs lc vars arg_ty res_ty'
260         . hoistPolyVExpr tvs
261         $ do
262             clo <- buildClosures tvs lc (vars ++ [arg]) arg_tys res_ty mk_body
263             return $ vLams lc (vars ++ [arg]) clo
264
265 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
266 --   where
267 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
268 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
269 --
270 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
271 buildClosure tvs lv vars arg_ty res_ty mk_body
272   = do
273       (env_ty, env, bind) <- buildEnv lv vars
274       env_bndr <- newLocalVVar FSLIT("env") env_ty
275       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
276
277       fn <- hoistPolyVExpr tvs
278           $ do
279               body  <- mk_body
280               body' <- bind (vVar env_bndr)
281                             (vVarApps lv body (vars ++ [arg_bndr]))
282               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
283
284       mkClosure arg_ty res_ty env_ty fn env
285
286 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
287 buildEnv lv vvs
288   = do
289       let (ty, venv, vbind) = mkVectEnv tys vs
290       (lenv, lbind) <- mkLiftEnv lv tys ls
291       return (ty, (venv, lenv),
292               \(venv,lenv) (vbody,lbody) ->
293               do
294                 let vbody' = vbind venv vbody
295                 lbody' <- lbind lenv lbody
296                 return (vbody', lbody'))
297   where
298     (vs,ls) = unzip vvs
299     tys     = map idType vs
300
301 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
302 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
303 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
304 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
305                         \env body -> Case env (mkWildId ty) (exprType body)
306                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
307   where
308     ty = mkCoreTupTy tys
309
310 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
311 mkLiftEnv lv [ty] [v]
312   = return (Var v, \env body ->
313                    do
314                      len <- lengthPA (Var v)
315                      return . Let (NonRec v env)
316                             $ Case len lv (exprType body) [(DEFAULT, [], body)])
317
318 -- NOTE: this transparently deals with empty environments
319 mkLiftEnv lv tys vs
320   = do
321       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
322       let [env_con] = tyConDataCons env_tc
323           
324           env = Var (dataConWrapId env_con)
325                 `mkTyApps`  env_tyargs
326                 `mkVarApps` (lv : vs)
327
328           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
329                           in
330                           return $ Case scrut (mkWildId (exprType scrut))
331                                         (exprType body)
332                                         [(DataAlt env_con, lv : bndrs, body)]
333       return (env, bind)
334   where
335     vty = mkCoreTupTy tys
336
337     bndrs | null vs   = [mkWildId unitTy]
338           | otherwise = vs
339