3bf97fa7ffbe2f2cad7e29c38b90926c44c09cd7
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 import VectCore
22 import VectMonad
23
24 import MkCore
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import Id                 ( mkWildId )
34 import MkId               ( unwrapFamInstScrut )
35 import TysWiredIn
36 import BasicTypes         ( Boxity(..) )
37 import Literal            ( Literal, mkMachInt )
38
39 import Outputable
40 import FastString
41
42 import Control.Monad
43
44
45 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
46 collectAnnTypeArgs expr = go expr []
47   where
48     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
49     go e                             tys = (e, tys)
50
51 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnTypeBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
55     go bs e                           = (reverse bs, e)
56
57 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
58 collectAnnValBinders expr = go [] expr
59   where
60     go bs (_, AnnLam b e) | isIdVar b = go (b:bs) e
61     go bs e                           = (reverse bs, e)
62
63 isAnnTypeArg :: AnnExpr b ann -> Bool
64 isAnnTypeArg (_, AnnType _) = True
65 isAnnTypeArg _              = False
66
67 dataConTagZ :: DataCon -> Int
68 dataConTagZ con = dataConTag con - fIRST_TAG
69
70 mkDataConTagLit :: DataCon -> Literal
71 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
72
73 mkDataConTag :: DataCon -> CoreExpr
74 mkDataConTag = mkIntLitInt . dataConTagZ
75
76 splitPrimTyCon :: Type -> Maybe TyCon
77 splitPrimTyCon ty
78   | Just (tycon, []) <- splitTyConApp_maybe ty
79   , isPrimTyCon tycon
80   = Just tycon
81
82   | otherwise = Nothing
83
84 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
85 mkBuiltinTyConApp get_tc tys
86   = do
87       tc <- builtin get_tc
88       return $ mkTyConApp tc tys
89
90 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
91 mkBuiltinTyConApps get_tc tys ty
92   = do
93       tc <- builtin get_tc
94       return $ foldr (mk tc) ty tys
95   where
96     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
97
98 {-
99 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
100 mkBuiltinTyConApps1 _      dft [] = return dft
101 mkBuiltinTyConApps1 get_tc _   tys
102   = do
103       tc <- builtin get_tc
104       case tys of
105         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
106         _  -> return $ foldr1 (mk tc) tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkClosureType :: Type -> Type -> VM Type
111 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
112 -}
113
114 mkClosureTypes :: [Type] -> Type -> VM Type
115 mkClosureTypes = mkBuiltinTyConApps closureTyCon
116
117 mkPReprType :: Type -> VM Type
118 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
119
120 mkPADictType :: Type -> VM Type
121 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
122
123 mkPArrayType :: Type -> VM Type
124 mkPArrayType ty
125   | Just tycon <- splitPrimTyCon ty
126   = do
127       r <- lookupPrimPArray tycon
128       case r of
129         Just arr -> return $ mkTyConApp arr []
130         Nothing  -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
131 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
132
133 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
134 mkBuiltinCo get_tc
135   = do
136       tc <- builtin get_tc
137       return $ mkTyConApp tc []
138
139 parrayReprTyCon :: Type -> VM (TyCon, [Type])
140 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
141
142 parrayReprDataCon :: Type -> VM (DataCon, [Type])
143 parrayReprDataCon ty
144   = do
145       (tc, arg_tys) <- parrayReprTyCon ty
146       let [dc] = tyConDataCons tc
147       return (dc, arg_tys)
148
149 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
150 mkVScrut (ve, le)
151   = do
152       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
153       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
154
155 prDFunOfTyCon :: TyCon -> VM CoreExpr
156 prDFunOfTyCon tycon
157   = liftM Var
158   . maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
159   $ lookupTyConPR tycon
160
161 paDictArgType :: TyVar -> VM (Maybe Type)
162 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
163   where
164     go ty k | Just k' <- kindView k = go ty k'
165     go ty (FunTy k1 k2)
166       = do
167           tv   <- newTyVar (fsLit "a") k1
168           mty1 <- go (TyVarTy tv) k1
169           case mty1 of
170             Just ty1 -> do
171                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
172                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
173             Nothing  -> go ty k2
174
175     go ty k
176       | isLiftedTypeKind k
177       = liftM Just (mkPADictType ty)
178
179     go _ _ = return Nothing
180
181 paDictOfType :: Type -> VM CoreExpr
182 paDictOfType ty = paDictOfTyApp ty_fn ty_args
183   where
184     (ty_fn, ty_args) = splitAppTys ty
185
186 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
187 paDictOfTyApp ty_fn ty_args
188   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
189 paDictOfTyApp (TyVarTy tv) ty_args
190   = do
191       dfun <- maybeV (lookupTyVarPA tv)
192       paDFunApply dfun ty_args
193 paDictOfTyApp (TyConApp tc _) ty_args
194   = do
195       dfun <- maybeCantVectoriseM "No PA dictionary for tycon" (ppr tc)
196             $ lookupTyConPA tc
197       paDFunApply (Var dfun) ty_args
198 paDictOfTyApp ty _
199   = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
200
201 paDFunType :: TyCon -> VM Type
202 paDFunType tc
203   = do
204       margs <- mapM paDictArgType tvs
205       res   <- mkPADictType (mkTyConApp tc arg_tys)
206       return . mkForAllTys tvs
207              $ mkFunTys [arg | Just arg <- margs] res
208   where
209     tvs = tyConTyVars tc
210     arg_tys = mkTyVarTys tvs
211
212 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
213 paDFunApply dfun tys
214   = do
215       dicts <- mapM paDictOfType tys
216       return $ mkApps (mkTyApps dfun tys) dicts
217
218 type PAMethod = (Builtins -> Var, String)
219
220 pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
221 pa_length    = (lengthPAVar,    "lengthPA")
222 pa_replicate = (replicatePAVar, "replicatePA")
223 pa_empty     = (emptyPAVar,     "emptyPA")
224 pa_pack      = (packPAVar,      "packPA")
225
226 paMethod :: PAMethod -> Type -> VM CoreExpr
227 paMethod (_method, name) ty
228   | Just tycon <- splitPrimTyCon ty
229   = liftM Var
230   . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
231   $ lookupPrimMethod tycon name
232
233 paMethod (method, _name) ty
234   = do
235       fn   <- builtin method
236       dict <- paDictOfType ty
237       return $ mkApps (Var fn) [Type ty, dict]
238
239 mkPR :: Type -> VM CoreExpr
240 mkPR ty
241   = do
242       fn   <- builtin mkPRVar
243       dict <- paDictOfType ty
244       return $ mkApps (Var fn) [Type ty, dict]
245
246 lengthPA :: Type -> CoreExpr -> VM CoreExpr
247 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
248
249 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
250 replicatePA len x = liftM (`mkApps` [len,x])
251                           (paMethod pa_replicate (exprType x))
252
253 emptyPA :: Type -> VM CoreExpr
254 emptyPA = paMethod pa_empty
255
256 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
257 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
258                              (paMethod pa_pack ty)
259
260 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
261           -> VM CoreExpr
262 combinePA ty len sel is xs
263   = liftM (`mkApps` (len : sel : is : xs))
264           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
265   where
266     n = length xs
267
268 liftPA :: CoreExpr -> VM CoreExpr
269 liftPA x
270   = do
271       lc <- builtin liftingContext
272       replicatePA (Var lc) x
273
274 newLocalVVar :: FastString -> Type -> VM VVar
275 newLocalVVar fs vty
276   = do
277       lty <- mkPArrayType vty
278       vv  <- newLocalVar fs vty
279       lv  <- newLocalVar fs lty
280       return (vv,lv)
281
282 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
283 polyAbstract tvs p
284   = localV
285   $ do
286       mdicts <- mapM mk_dict_var tvs
287       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
288       p (mk_lams mdicts)
289   where
290     mk_dict_var tv = do
291                        r <- paDictArgType tv
292                        case r of
293                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
294                          Nothing -> return Nothing
295
296     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
297
298 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
299 polyApply expr tys
300   = do
301       dicts <- mapM paDictOfType tys
302       return $ expr `mkTyApps` tys `mkApps` dicts
303
304 polyVApply :: VExpr -> [Type] -> VM VExpr
305 polyVApply expr tys
306   = do
307       dicts <- mapM paDictOfType tys
308       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
309
310 hoistBinding :: Var -> CoreExpr -> VM ()
311 hoistBinding v e = updGEnv $ \env ->
312   env { global_bindings = (v,e) : global_bindings env }
313
314 hoistExpr :: FastString -> CoreExpr -> VM Var
315 hoistExpr fs expr
316   = do
317       var <- newLocalVar fs (exprType expr)
318       hoistBinding var expr
319       return var
320
321 hoistVExpr :: VExpr -> VM VVar
322 hoistVExpr (ve, le)
323   = do
324       fs <- getBindName
325       vv <- hoistExpr ('v' `consFS` fs) ve
326       lv <- hoistExpr ('l' `consFS` fs) le
327       return (vv, lv)
328
329 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
330 hoistPolyVExpr tvs p
331   = do
332       expr <- closedV . polyAbstract tvs $ \abstract ->
333               liftM (mapVect abstract) p
334       fn   <- hoistVExpr expr
335       polyVApply (vVar fn) (mkTyVarTys tvs)
336
337 takeHoisted :: VM [(Var, CoreExpr)]
338 takeHoisted
339   = do
340       env <- readGEnv id
341       setGEnv $ env { global_bindings = [] }
342       return $ global_bindings env
343
344 {-
345 boxExpr :: Type -> VExpr -> VM VExpr
346 boxExpr ty (vexpr, lexpr)
347   | Just (tycon, []) <- splitTyConApp_maybe ty
348   , isUnLiftedTyCon tycon
349   = do
350       r <- lookupBoxedTyCon tycon
351       case r of
352         Just tycon' -> let [dc] = tyConDataCons tycon'
353                        in
354                        return (mkConApp dc [vexpr], lexpr)
355         Nothing     -> return (vexpr, lexpr)
356 -}
357
358 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
359 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
360   = do
361       dict <- paDictOfType env_ty
362       mkv  <- builtin mkClosureVar
363       mkl  <- builtin mkClosurePVar
364       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
365               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
366
367 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
368 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
369   = do
370       vapply <- builtin applyClosureVar
371       lapply <- builtin applyClosurePVar
372       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
373               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
374
375 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
376 buildClosures _   _    [] _ mk_body
377   = mk_body
378 buildClosures tvs vars [arg_ty] res_ty mk_body
379   = buildClosure tvs vars arg_ty res_ty mk_body
380 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
381   = do
382       res_ty' <- mkClosureTypes arg_tys res_ty
383       arg <- newLocalVVar (fsLit "x") arg_ty
384       buildClosure tvs vars arg_ty res_ty'
385         . hoistPolyVExpr tvs
386         $ do
387             lc <- builtin liftingContext
388             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
389             return $ vLams lc (vars ++ [arg]) clo
390
391 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
392 --   where
393 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
394 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
395 --
396 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
397 buildClosure tvs vars arg_ty res_ty mk_body
398   = do
399       (env_ty, env, bind) <- buildEnv vars
400       env_bndr <- newLocalVVar (fsLit "env") env_ty
401       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
402
403       fn <- hoistPolyVExpr tvs
404           $ do
405               lc    <- builtin liftingContext
406               body  <- mk_body
407               body' <- bind (vVar env_bndr)
408                             (vVarApps lc body (vars ++ [arg_bndr]))
409               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
410
411       mkClosure arg_ty res_ty env_ty fn env
412
413 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
414 buildEnv vvs
415   = do
416       lc <- builtin liftingContext
417       let (ty, venv, vbind) = mkVectEnv tys vs
418       (lenv, lbind) <- mkLiftEnv lc tys ls
419       return (ty, (venv, lenv),
420               \(venv,lenv) (vbody,lbody) ->
421               do
422                 let vbody' = vbind venv vbody
423                 lbody' <- lbind lenv lbody
424                 return (vbody', lbody'))
425   where
426     (vs,ls) = unzip vvs
427     tys     = map varType vs
428
429 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
430 mkVectEnv []   []  = (unitTy, Var unitDataConId, \_ body -> body)
431 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
432 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
433                         \env body -> Case env (mkWildId ty) (exprType body)
434                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
435   where
436     ty = mkCoreTupTy tys
437
438 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
439 mkLiftEnv lc [ty] [v]
440   = return (Var v, \env body ->
441                    do
442                      len <- lengthPA ty (Var v)
443                      return . Let (NonRec v env)
444                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
445
446 -- NOTE: this transparently deals with empty environments
447 mkLiftEnv lc tys vs
448   = do
449       (env_tc, env_tyargs) <- parrayReprTyCon vty
450
451       bndrs <- if null vs then do
452                                  v <- newDummyVar unitTy
453                                  return [v]
454                           else return vs
455       let [env_con] = tyConDataCons env_tc
456           
457           env = Var (dataConWrapId env_con)
458                 `mkTyApps`  env_tyargs
459                 `mkApps`    (Var lc : args)
460
461           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
462                           in
463                           return $ Case scrut (mkWildId (exprType scrut))
464                                         (exprType body)
465                                         [(DataAlt env_con, lc : bndrs, body)]
466       return (env, bind)
467   where
468     vty = mkCoreTupTy tys
469
470     args  | null vs   = [Var unitDataConId]
471           | otherwise = map Var vs
472