42bcab37bb8c1d4287bc3b2ae26e5dc4d0835e13
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag, mkDataConTagLit,
5
6   mkBuiltinCo,
7   mkPADictType, mkPArrayType, mkPReprType,
8
9   parrayReprTyCon, parrayReprDataCon, mkVScrut,
10   prDFunOfTyCon,
11   paDictArgType, paDictOfType, paDFunType,
12   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
13   polyAbstract, polyApply, polyVApply,
14   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
15   buildClosure, buildClosures,
16   mkClosureApp
17 ) where
18
19 #include "HsVersions.h"
20
21 import VectCore
22 import VectMonad
23
24 import DsUtils
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import Id                 ( mkWildId )
34 import MkId               ( unwrapFamInstScrut )
35 import Name               ( Name )
36 import PrelNames
37 import TysWiredIn
38 import TysPrim            ( intPrimTy )
39 import BasicTypes         ( Boxity(..) )
40 import Literal            ( Literal, mkMachInt )
41
42 import Outputable
43 import FastString
44
45 import Data.List             ( zipWith4 )
46 import Control.Monad         ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50   where
51     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52     go e                             tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58     go bs e                           = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62   where
63     go bs (_, AnnLam b e) | isId b = go (b:bs) e
64     go bs e                        = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _              = False
69
70 mkDataConTagLit :: DataCon -> Literal
71 mkDataConTagLit con
72   = mkMachInt . toInteger $ dataConTag con - fIRST_TAG
73
74 mkDataConTag :: DataCon -> CoreExpr
75 mkDataConTag con = mkIntLitInt (dataConTag con - fIRST_TAG)
76
77 splitPrimTyCon :: Type -> Maybe TyCon
78 splitPrimTyCon ty
79   | Just (tycon, []) <- splitTyConApp_maybe ty
80   , isPrimTyCon tycon
81   = Just tycon
82
83   | otherwise = Nothing
84
85 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
86 mkBuiltinTyConApp get_tc tys
87   = do
88       tc <- builtin get_tc
89       return $ mkTyConApp tc tys
90
91 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
92 mkBuiltinTyConApps get_tc tys ty
93   = do
94       tc <- builtin get_tc
95       return $ foldr (mk tc) ty tys
96   where
97     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
98
99 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
100 mkBuiltinTyConApps1 get_tc dft [] = return dft
101 mkBuiltinTyConApps1 get_tc dft tys
102   = do
103       tc <- builtin get_tc
104       case tys of
105         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
106         _  -> return $ foldr1 (mk tc) tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkClosureType :: Type -> Type -> VM Type
111 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
112
113 mkClosureTypes :: [Type] -> Type -> VM Type
114 mkClosureTypes = mkBuiltinTyConApps closureTyCon
115
116 mkPReprType :: Type -> VM Type
117 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
118
119 mkPADictType :: Type -> VM Type
120 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
121
122 mkPArrayType :: Type -> VM Type
123 mkPArrayType ty
124   | Just tycon <- splitPrimTyCon ty
125   = do
126       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
127            $ lookupPrimPArray tycon
128       return $ mkTyConApp arr []
129 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
130
131 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
132 mkBuiltinCo get_tc
133   = do
134       tc <- builtin get_tc
135       return $ mkTyConApp tc []
136
137 parrayReprTyCon :: Type -> VM (TyCon, [Type])
138 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
139
140 parrayReprDataCon :: Type -> VM (DataCon, [Type])
141 parrayReprDataCon ty
142   = do
143       (tc, arg_tys) <- parrayReprTyCon ty
144       let [dc] = tyConDataCons tc
145       return (dc, arg_tys)
146
147 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
148 mkVScrut (ve, le)
149   = do
150       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
151       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
152
153 prDFunOfTyCon :: TyCon -> VM CoreExpr
154 prDFunOfTyCon tycon
155   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
156
157 paDictArgType :: TyVar -> VM (Maybe Type)
158 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
159   where
160     go ty k | Just k' <- kindView k = go ty k'
161     go ty (FunTy k1 k2)
162       = do
163           tv   <- newTyVar FSLIT("a") k1
164           mty1 <- go (TyVarTy tv) k1
165           case mty1 of
166             Just ty1 -> do
167                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
168                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
169             Nothing  -> go ty k2
170
171     go ty k
172       | isLiftedTypeKind k
173       = liftM Just (mkPADictType ty)
174
175     go ty k = return Nothing
176
177 paDictOfType :: Type -> VM CoreExpr
178 paDictOfType ty = paDictOfTyApp ty_fn ty_args
179   where
180     (ty_fn, ty_args) = splitAppTys ty
181
182 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
183 paDictOfTyApp ty_fn ty_args
184   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
185 paDictOfTyApp (TyVarTy tv) ty_args
186   = do
187       dfun <- maybeV (lookupTyVarPA tv)
188       paDFunApply dfun ty_args
189 paDictOfTyApp (TyConApp tc _) ty_args
190   = do
191       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
192       paDFunApply (Var dfun) ty_args
193 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
194
195 paDFunType :: TyCon -> VM Type
196 paDFunType tc
197   = do
198       margs <- mapM paDictArgType tvs
199       res   <- mkPADictType (mkTyConApp tc arg_tys)
200       return . mkForAllTys tvs
201              $ mkFunTys [arg | Just arg <- margs] res
202   where
203     tvs = tyConTyVars tc
204     arg_tys = mkTyVarTys tvs
205
206 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
207 paDFunApply dfun tys
208   = do
209       dicts <- mapM paDictOfType tys
210       return $ mkApps (mkTyApps dfun tys) dicts
211
212 type PAMethod = (Builtins -> Var, String)
213
214 pa_length    = (lengthPAVar,    "lengthPA")
215 pa_replicate = (replicatePAVar, "replicatePA")
216 pa_empty     = (emptyPAVar,     "emptyPA")
217
218 paMethod :: PAMethod -> Type -> VM CoreExpr
219 paMethod (method, name) ty
220   | Just tycon <- splitPrimTyCon ty
221   = do
222       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
223           $ lookupPrimMethod tycon name
224       return (Var fn)
225
226 paMethod (method, name) ty
227   = do
228       fn   <- builtin method
229       dict <- paDictOfType ty
230       return $ mkApps (Var fn) [Type ty, dict]
231
232 mkPR :: Type -> VM CoreExpr
233 mkPR ty
234   = do
235       fn   <- builtin mkPRVar
236       dict <- paDictOfType ty
237       return $ mkApps (Var fn) [Type ty, dict]
238
239 lengthPA :: Type -> CoreExpr -> VM CoreExpr
240 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
241
242 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
243 replicatePA len x = liftM (`mkApps` [len,x])
244                           (paMethod pa_replicate (exprType x))
245
246 emptyPA :: Type -> VM CoreExpr
247 emptyPA = paMethod pa_empty
248
249 liftPA :: CoreExpr -> VM CoreExpr
250 liftPA x
251   = do
252       lc <- builtin liftingContext
253       replicatePA (Var lc) x
254
255 newLocalVVar :: FastString -> Type -> VM VVar
256 newLocalVVar fs vty
257   = do
258       lty <- mkPArrayType vty
259       vv  <- newLocalVar fs vty
260       lv  <- newLocalVar fs lty
261       return (vv,lv)
262
263 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
264 polyAbstract tvs p
265   = localV
266   $ do
267       mdicts <- mapM mk_dict_var tvs
268       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
269       p (mk_lams mdicts)
270   where
271     mk_dict_var tv = do
272                        r <- paDictArgType tv
273                        case r of
274                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
275                          Nothing -> return Nothing
276
277     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
278
279 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
280 polyApply expr tys
281   = do
282       dicts <- mapM paDictOfType tys
283       return $ expr `mkTyApps` tys `mkApps` dicts
284
285 polyVApply :: VExpr -> [Type] -> VM VExpr
286 polyVApply expr tys
287   = do
288       dicts <- mapM paDictOfType tys
289       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
290
291 hoistBinding :: Var -> CoreExpr -> VM ()
292 hoistBinding v e = updGEnv $ \env ->
293   env { global_bindings = (v,e) : global_bindings env }
294
295 hoistExpr :: FastString -> CoreExpr -> VM Var
296 hoistExpr fs expr
297   = do
298       var <- newLocalVar fs (exprType expr)
299       hoistBinding var expr
300       return var
301
302 hoistVExpr :: VExpr -> VM VVar
303 hoistVExpr (ve, le)
304   = do
305       fs <- getBindName
306       vv <- hoistExpr ('v' `consFS` fs) ve
307       lv <- hoistExpr ('l' `consFS` fs) le
308       return (vv, lv)
309
310 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
311 hoistPolyVExpr tvs p
312   = do
313       expr <- closedV . polyAbstract tvs $ \abstract ->
314               liftM (mapVect abstract) p
315       fn   <- hoistVExpr expr
316       polyVApply (vVar fn) (mkTyVarTys tvs)
317
318 takeHoisted :: VM [(Var, CoreExpr)]
319 takeHoisted
320   = do
321       env <- readGEnv id
322       setGEnv $ env { global_bindings = [] }
323       return $ global_bindings env
324
325 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
326 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
327   = do
328       dict <- paDictOfType env_ty
329       mkv  <- builtin mkClosureVar
330       mkl  <- builtin mkClosurePVar
331       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
332               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
333
334 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
335 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
336   = do
337       vapply <- builtin applyClosureVar
338       lapply <- builtin applyClosurePVar
339       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
340               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
341
342 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
343 buildClosures tvs vars [] res_ty mk_body
344   = mk_body
345 buildClosures tvs vars [arg_ty] res_ty mk_body
346   = buildClosure tvs vars arg_ty res_ty mk_body
347 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
348   = do
349       res_ty' <- mkClosureTypes arg_tys res_ty
350       arg <- newLocalVVar FSLIT("x") arg_ty
351       buildClosure tvs vars arg_ty res_ty'
352         . hoistPolyVExpr tvs
353         $ do
354             lc <- builtin liftingContext
355             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
356             return $ vLams lc (vars ++ [arg]) clo
357
358 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
359 --   where
360 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
361 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
362 --
363 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
364 buildClosure tvs vars arg_ty res_ty mk_body
365   = do
366       (env_ty, env, bind) <- buildEnv vars
367       env_bndr <- newLocalVVar FSLIT("env") env_ty
368       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
369
370       fn <- hoistPolyVExpr tvs
371           $ do
372               lc    <- builtin liftingContext
373               body  <- mk_body
374               body' <- bind (vVar env_bndr)
375                             (vVarApps lc body (vars ++ [arg_bndr]))
376               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
377
378       mkClosure arg_ty res_ty env_ty fn env
379
380 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
381 buildEnv vvs
382   = do
383       lc <- builtin liftingContext
384       let (ty, venv, vbind) = mkVectEnv tys vs
385       (lenv, lbind) <- mkLiftEnv lc tys ls
386       return (ty, (venv, lenv),
387               \(venv,lenv) (vbody,lbody) ->
388               do
389                 let vbody' = vbind venv vbody
390                 lbody' <- lbind lenv lbody
391                 return (vbody', lbody'))
392   where
393     (vs,ls) = unzip vvs
394     tys     = map idType vs
395
396 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
397 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
398 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
399 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
400                         \env body -> Case env (mkWildId ty) (exprType body)
401                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
402   where
403     ty = mkCoreTupTy tys
404
405 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
406 mkLiftEnv lc [ty] [v]
407   = return (Var v, \env body ->
408                    do
409                      len <- lengthPA ty (Var v)
410                      return . Let (NonRec v env)
411                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
412
413 -- NOTE: this transparently deals with empty environments
414 mkLiftEnv lc tys vs
415   = do
416       (env_tc, env_tyargs) <- parrayReprTyCon vty
417       let [env_con] = tyConDataCons env_tc
418           
419           env = Var (dataConWrapId env_con)
420                 `mkTyApps`  env_tyargs
421                 `mkVarApps` (lc : vs)
422
423           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
424                           in
425                           return $ Case scrut (mkWildId (exprType scrut))
426                                         (exprType body)
427                                         [(DataAlt env_con, lc : bndrs, body)]
428       return (env, bind)
429   where
430     vty = mkCoreTupTy tys
431
432     bndrs | null vs   = [mkWildId unitTy]
433           | otherwise = vs
434