Vectorise Case on products
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPADictType, mkPArrayType,
7   parrayReprTyCon, parrayReprDataCon, mkVScrut,
8   paDictArgType, paDictOfType, paDFunType,
9   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
10   polyAbstract, polyApply, polyVApply,
11   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
12   buildClosure, buildClosures,
13   mkClosureApp
14 ) where
15
16 #include "HsVersions.h"
17
18 import VectCore
19 import VectMonad
20
21 import DsUtils
22 import CoreSyn
23 import CoreUtils
24 import Type
25 import TypeRep
26 import TyCon
27 import DataCon            ( DataCon, dataConWrapId, dataConTag )
28 import Var
29 import Id                 ( mkWildId )
30 import MkId               ( unwrapFamInstScrut )
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes         ( Boxity(..) )
34
35 import Outputable
36 import FastString
37
38 import Control.Monad         ( liftM, zipWithM_ )
39
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType t) = True
60 isAnnTypeArg _              = False
61
62 mkDataConTag :: DataCon -> CoreExpr
63 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
64
65 isClosureTyCon :: TyCon -> Bool
66 isClosureTyCon tc = tyConName tc == closureTyConName
67
68 splitClosureTy :: Type -> (Type, Type)
69 splitClosureTy ty
70   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
71   , isClosureTyCon tc
72   = (arg_ty, res_ty)
73
74   | otherwise = pprPanic "splitClosureTy" (ppr ty)
75
76 isPArrayTyCon :: TyCon -> Bool
77 isPArrayTyCon tc = tyConName tc == parrayTyConName
78
79 splitPArrayTy :: Type -> Type
80 splitPArrayTy ty
81   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
82   , isPArrayTyCon tc
83   = arg_ty
84
85   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
86
87 mkClosureType :: Type -> Type -> VM Type
88 mkClosureType arg_ty res_ty
89   = do
90       tc <- builtin closureTyCon
91       return $ mkTyConApp tc [arg_ty, res_ty]
92
93 mkClosureTypes :: [Type] -> Type -> VM Type
94 mkClosureTypes arg_tys res_ty
95   = do
96       tc <- builtin closureTyCon
97       return $ foldr (mk tc) res_ty arg_tys
98   where
99     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
100
101 mkPADictType :: Type -> VM Type
102 mkPADictType ty
103   = do
104       tc <- builtin paTyCon
105       return $ TyConApp tc [ty]
106
107 mkPArrayType :: Type -> VM Type
108 mkPArrayType ty
109   = do
110       tc <- builtin parrayTyCon
111       return $ TyConApp tc [ty]
112
113 parrayReprTyCon :: Type -> VM (TyCon, [Type])
114 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
115
116 parrayReprDataCon :: Type -> VM (DataCon, [Type])
117 parrayReprDataCon ty
118   = do
119       (tc, arg_tys) <- parrayReprTyCon ty
120       let [dc] = tyConDataCons tc
121       return (dc, arg_tys)
122
123 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
124 mkVScrut (ve, le)
125   = do
126       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
127       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
128
129 paDictArgType :: TyVar -> VM (Maybe Type)
130 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
131   where
132     go ty k | Just k' <- kindView k = go ty k'
133     go ty (FunTy k1 k2)
134       = do
135           tv   <- newTyVar FSLIT("a") k1
136           mty1 <- go (TyVarTy tv) k1
137           case mty1 of
138             Just ty1 -> do
139                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
140                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
141             Nothing  -> go ty k2
142
143     go ty k
144       | isLiftedTypeKind k
145       = liftM Just (mkPADictType ty)
146
147     go ty k = return Nothing
148
149 paDictOfType :: Type -> VM CoreExpr
150 paDictOfType ty = paDictOfTyApp ty_fn ty_args
151   where
152     (ty_fn, ty_args) = splitAppTys ty
153
154 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
155 paDictOfTyApp ty_fn ty_args
156   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
157 paDictOfTyApp (TyVarTy tv) ty_args
158   = do
159       dfun <- maybeV (lookupTyVarPA tv)
160       paDFunApply dfun ty_args
161 paDictOfTyApp (TyConApp tc _) ty_args
162   = do
163       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
164       paDFunApply (Var dfun) ty_args
165 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
166
167 paDFunType :: TyCon -> VM Type
168 paDFunType tc
169   = do
170       margs <- mapM paDictArgType tvs
171       res   <- mkPADictType (mkTyConApp tc arg_tys)
172       return . mkForAllTys tvs
173              $ mkFunTys [arg | Just arg <- margs] res
174   where
175     tvs = tyConTyVars tc
176     arg_tys = mkTyVarTys tvs
177
178 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
179 paDFunApply dfun tys
180   = do
181       dicts <- mapM paDictOfType tys
182       return $ mkApps (mkTyApps dfun tys) dicts
183
184 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
185 paMethod method ty
186   = do
187       fn   <- builtin method
188       dict <- paDictOfType ty
189       return $ mkApps (Var fn) [Type ty, dict]
190
191 lengthPA :: CoreExpr -> VM CoreExpr
192 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
193   where
194     ty = splitPArrayTy (exprType x)
195
196 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
197 replicatePA len x = liftM (`mkApps` [len,x])
198                           (paMethod replicatePAVar (exprType x))
199
200 emptyPA :: Type -> VM CoreExpr
201 emptyPA = paMethod emptyPAVar
202
203 liftPA :: CoreExpr -> VM CoreExpr
204 liftPA x
205   = do
206       lc <- builtin liftingContext
207       replicatePA (Var lc) x
208
209 newLocalVVar :: FastString -> Type -> VM VVar
210 newLocalVVar fs vty
211   = do
212       lty <- mkPArrayType vty
213       vv  <- newLocalVar fs vty
214       lv  <- newLocalVar fs lty
215       return (vv,lv)
216
217 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
218 polyAbstract tvs p
219   = localV
220   $ do
221       mdicts <- mapM mk_dict_var tvs
222       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
223       p (mk_lams mdicts)
224   where
225     mk_dict_var tv = do
226                        r <- paDictArgType tv
227                        case r of
228                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
229                          Nothing -> return Nothing
230
231     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
232
233 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
234 polyApply expr tys
235   = do
236       dicts <- mapM paDictOfType tys
237       return $ expr `mkTyApps` tys `mkApps` dicts
238
239 polyVApply :: VExpr -> [Type] -> VM VExpr
240 polyVApply expr tys
241   = do
242       dicts <- mapM paDictOfType tys
243       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
244
245 hoistBinding :: Var -> CoreExpr -> VM ()
246 hoistBinding v e = updGEnv $ \env ->
247   env { global_bindings = (v,e) : global_bindings env }
248
249 hoistExpr :: FastString -> CoreExpr -> VM Var
250 hoistExpr fs expr
251   = do
252       var <- newLocalVar fs (exprType expr)
253       hoistBinding var expr
254       return var
255
256 hoistVExpr :: VExpr -> VM VVar
257 hoistVExpr (ve, le)
258   = do
259       fs <- getBindName
260       vv <- hoistExpr ('v' `consFS` fs) ve
261       lv <- hoistExpr ('l' `consFS` fs) le
262       return (vv, lv)
263
264 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
265 hoistPolyVExpr tvs p
266   = do
267       expr <- closedV . polyAbstract tvs $ \abstract ->
268               liftM (mapVect abstract) p
269       fn   <- hoistVExpr expr
270       polyVApply (vVar fn) (mkTyVarTys tvs)
271
272 takeHoisted :: VM [(Var, CoreExpr)]
273 takeHoisted
274   = do
275       env <- readGEnv id
276       setGEnv $ env { global_bindings = [] }
277       return $ global_bindings env
278
279 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
280 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
281   = do
282       dict <- paDictOfType env_ty
283       mkv  <- builtin mkClosureVar
284       mkl  <- builtin mkClosurePVar
285       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
286               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
287
288 mkClosureApp :: VExpr -> VExpr -> VM VExpr
289 mkClosureApp (vclo, lclo) (varg, larg)
290   = do
291       vapply <- builtin applyClosureVar
292       lapply <- builtin applyClosurePVar
293       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
294               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
295   where
296     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
297
298 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
299 buildClosures tvs vars [arg_ty] res_ty mk_body
300   = buildClosure tvs vars arg_ty res_ty mk_body
301 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
302   = do
303       res_ty' <- mkClosureTypes arg_tys res_ty
304       arg <- newLocalVVar FSLIT("x") arg_ty
305       buildClosure tvs vars arg_ty res_ty'
306         . hoistPolyVExpr tvs
307         $ do
308             lc <- builtin liftingContext
309             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
310             return $ vLams lc (vars ++ [arg]) clo
311
312 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
313 --   where
314 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
315 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
316 --
317 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
318 buildClosure tvs vars arg_ty res_ty mk_body
319   = do
320       (env_ty, env, bind) <- buildEnv vars
321       env_bndr <- newLocalVVar FSLIT("env") env_ty
322       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
323
324       fn <- hoistPolyVExpr tvs
325           $ do
326               lc    <- builtin liftingContext
327               body  <- mk_body
328               body' <- bind (vVar env_bndr)
329                             (vVarApps lc body (vars ++ [arg_bndr]))
330               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
331
332       mkClosure arg_ty res_ty env_ty fn env
333
334 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
335 buildEnv vvs
336   = do
337       lc <- builtin liftingContext
338       let (ty, venv, vbind) = mkVectEnv tys vs
339       (lenv, lbind) <- mkLiftEnv lc tys ls
340       return (ty, (venv, lenv),
341               \(venv,lenv) (vbody,lbody) ->
342               do
343                 let vbody' = vbind venv vbody
344                 lbody' <- lbind lenv lbody
345                 return (vbody', lbody'))
346   where
347     (vs,ls) = unzip vvs
348     tys     = map idType vs
349
350 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
351 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
352 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
353 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
354                         \env body -> Case env (mkWildId ty) (exprType body)
355                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
356   where
357     ty = mkCoreTupTy tys
358
359 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
360 mkLiftEnv lc [ty] [v]
361   = return (Var v, \env body ->
362                    do
363                      len <- lengthPA (Var v)
364                      return . Let (NonRec v env)
365                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
366
367 -- NOTE: this transparently deals with empty environments
368 mkLiftEnv lc tys vs
369   = do
370       (env_tc, env_tyargs) <- parrayReprTyCon vty
371       let [env_con] = tyConDataCons env_tc
372           
373           env = Var (dataConWrapId env_con)
374                 `mkTyApps`  env_tyargs
375                 `mkVarApps` (lc : vs)
376
377           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
378                           in
379                           return $ Case scrut (mkWildId (exprType scrut))
380                                         (exprType body)
381                                         [(DataAlt env_con, lc : bndrs, body)]
382       return (env, bind)
383   where
384     vty = mkCoreTupTy tys
385
386     bndrs | null vs   = [mkWildId unitTy]
387           | otherwise = vs
388