Tidy up the treatment of dead binders
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
15   polyAbstract, polyApply, polyVApply,
16   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17   buildClosure, buildClosures,
18   mkClosureApp
19 ) where
20
21 import VectCore
22 import VectMonad
23
24 import MkCore
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import MkId               ( unwrapFamInstScrut )
34 import TysWiredIn
35 import BasicTypes         ( Boxity(..) )
36 import Literal            ( Literal, mkMachInt )
37
38 import Outputable
39 import FastString
40
41 import Control.Monad
42
43
44 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
45 collectAnnTypeArgs expr = go expr []
46   where
47     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
48     go e                             tys = (e, tys)
49
50 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
51 collectAnnTypeBinders expr = go [] expr
52   where
53     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
54     go bs e                           = (reverse bs, e)
55
56 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
57 collectAnnValBinders expr = go [] expr
58   where
59     go bs (_, AnnLam b e) | isIdVar b = go (b:bs) e
60     go bs e                           = (reverse bs, e)
61
62 isAnnTypeArg :: AnnExpr b ann -> Bool
63 isAnnTypeArg (_, AnnType _) = True
64 isAnnTypeArg _              = False
65
66 dataConTagZ :: DataCon -> Int
67 dataConTagZ con = dataConTag con - fIRST_TAG
68
69 mkDataConTagLit :: DataCon -> Literal
70 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
71
72 mkDataConTag :: DataCon -> CoreExpr
73 mkDataConTag = mkIntLitInt . dataConTagZ
74
75 splitPrimTyCon :: Type -> Maybe TyCon
76 splitPrimTyCon ty
77   | Just (tycon, []) <- splitTyConApp_maybe ty
78   , isPrimTyCon tycon
79   = Just tycon
80
81   | otherwise = Nothing
82
83 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
84 mkBuiltinTyConApp get_tc tys
85   = do
86       tc <- builtin get_tc
87       return $ mkTyConApp tc tys
88
89 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
90 mkBuiltinTyConApps get_tc tys ty
91   = do
92       tc <- builtin get_tc
93       return $ foldr (mk tc) ty tys
94   where
95     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
96
97 {-
98 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
99 mkBuiltinTyConApps1 _      dft [] = return dft
100 mkBuiltinTyConApps1 get_tc _   tys
101   = do
102       tc <- builtin get_tc
103       case tys of
104         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
105         _  -> return $ foldr1 (mk tc) tys
106   where
107     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
108
109 mkClosureType :: Type -> Type -> VM Type
110 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
111 -}
112
113 mkClosureTypes :: [Type] -> Type -> VM Type
114 mkClosureTypes = mkBuiltinTyConApps closureTyCon
115
116 mkPReprType :: Type -> VM Type
117 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
118
119 mkPADictType :: Type -> VM Type
120 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
121
122 mkPArrayType :: Type -> VM Type
123 mkPArrayType ty
124   | Just tycon <- splitPrimTyCon ty
125   = do
126       r <- lookupPrimPArray tycon
127       case r of
128         Just arr -> return $ mkTyConApp arr []
129         Nothing  -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
130 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
131
132 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
133 mkBuiltinCo get_tc
134   = do
135       tc <- builtin get_tc
136       return $ mkTyConApp tc []
137
138 parrayReprTyCon :: Type -> VM (TyCon, [Type])
139 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
140
141 parrayReprDataCon :: Type -> VM (DataCon, [Type])
142 parrayReprDataCon ty
143   = do
144       (tc, arg_tys) <- parrayReprTyCon ty
145       let [dc] = tyConDataCons tc
146       return (dc, arg_tys)
147
148 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
149 mkVScrut (ve, le)
150   = do
151       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
152       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
153
154 prDFunOfTyCon :: TyCon -> VM CoreExpr
155 prDFunOfTyCon tycon
156   = liftM Var
157   . maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
158   $ lookupTyConPR tycon
159
160 paDictArgType :: TyVar -> VM (Maybe Type)
161 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
162   where
163     go ty k | Just k' <- kindView k = go ty k'
164     go ty (FunTy k1 k2)
165       = do
166           tv   <- newTyVar (fsLit "a") k1
167           mty1 <- go (TyVarTy tv) k1
168           case mty1 of
169             Just ty1 -> do
170                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
171                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
172             Nothing  -> go ty k2
173
174     go ty k
175       | isLiftedTypeKind k
176       = liftM Just (mkPADictType ty)
177
178     go _ _ = return Nothing
179
180 paDictOfType :: Type -> VM CoreExpr
181 paDictOfType ty = paDictOfTyApp ty_fn ty_args
182   where
183     (ty_fn, ty_args) = splitAppTys ty
184
185 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
186 paDictOfTyApp ty_fn ty_args
187   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
188 paDictOfTyApp (TyVarTy tv) ty_args
189   = do
190       dfun <- maybeV (lookupTyVarPA tv)
191       paDFunApply dfun ty_args
192 paDictOfTyApp (TyConApp tc _) ty_args
193   = do
194       dfun <- maybeCantVectoriseM "No PA dictionary for tycon" (ppr tc)
195             $ lookupTyConPA tc
196       paDFunApply (Var dfun) ty_args
197 paDictOfTyApp ty _
198   = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
199
200 paDFunType :: TyCon -> VM Type
201 paDFunType tc
202   = do
203       margs <- mapM paDictArgType tvs
204       res   <- mkPADictType (mkTyConApp tc arg_tys)
205       return . mkForAllTys tvs
206              $ mkFunTys [arg | Just arg <- margs] res
207   where
208     tvs = tyConTyVars tc
209     arg_tys = mkTyVarTys tvs
210
211 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
212 paDFunApply dfun tys
213   = do
214       dicts <- mapM paDictOfType tys
215       return $ mkApps (mkTyApps dfun tys) dicts
216
217 type PAMethod = (Builtins -> Var, String)
218
219 pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
220 pa_length    = (lengthPAVar,    "lengthPA")
221 pa_replicate = (replicatePAVar, "replicatePA")
222 pa_empty     = (emptyPAVar,     "emptyPA")
223 pa_pack      = (packPAVar,      "packPA")
224
225 paMethod :: PAMethod -> Type -> VM CoreExpr
226 paMethod (_method, name) ty
227   | Just tycon <- splitPrimTyCon ty
228   = liftM Var
229   . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
230   $ lookupPrimMethod tycon name
231
232 paMethod (method, _name) ty
233   = do
234       fn   <- builtin method
235       dict <- paDictOfType ty
236       return $ mkApps (Var fn) [Type ty, dict]
237
238 mkPR :: Type -> VM CoreExpr
239 mkPR ty
240   = do
241       fn   <- builtin mkPRVar
242       dict <- paDictOfType ty
243       return $ mkApps (Var fn) [Type ty, dict]
244
245 lengthPA :: Type -> CoreExpr -> VM CoreExpr
246 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
247
248 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
249 replicatePA len x = liftM (`mkApps` [len,x])
250                           (paMethod pa_replicate (exprType x))
251
252 emptyPA :: Type -> VM CoreExpr
253 emptyPA = paMethod pa_empty
254
255 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
256 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
257                              (paMethod pa_pack ty)
258
259 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
260           -> VM CoreExpr
261 combinePA ty len sel is xs
262   = liftM (`mkApps` (len : sel : is : xs))
263           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
264   where
265     n = length xs
266
267 liftPA :: CoreExpr -> VM CoreExpr
268 liftPA x
269   = do
270       lc <- builtin liftingContext
271       replicatePA (Var lc) x
272
273 newLocalVVar :: FastString -> Type -> VM VVar
274 newLocalVVar fs vty
275   = do
276       lty <- mkPArrayType vty
277       vv  <- newLocalVar fs vty
278       lv  <- newLocalVar fs lty
279       return (vv,lv)
280
281 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
282 polyAbstract tvs p
283   = localV
284   $ do
285       mdicts <- mapM mk_dict_var tvs
286       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
287       p (mk_lams mdicts)
288   where
289     mk_dict_var tv = do
290                        r <- paDictArgType tv
291                        case r of
292                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
293                          Nothing -> return Nothing
294
295     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
296
297 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
298 polyApply expr tys
299   = do
300       dicts <- mapM paDictOfType tys
301       return $ expr `mkTyApps` tys `mkApps` dicts
302
303 polyVApply :: VExpr -> [Type] -> VM VExpr
304 polyVApply expr tys
305   = do
306       dicts <- mapM paDictOfType tys
307       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
308
309 hoistBinding :: Var -> CoreExpr -> VM ()
310 hoistBinding v e = updGEnv $ \env ->
311   env { global_bindings = (v,e) : global_bindings env }
312
313 hoistExpr :: FastString -> CoreExpr -> VM Var
314 hoistExpr fs expr
315   = do
316       var <- newLocalVar fs (exprType expr)
317       hoistBinding var expr
318       return var
319
320 hoistVExpr :: VExpr -> VM VVar
321 hoistVExpr (ve, le)
322   = do
323       fs <- getBindName
324       vv <- hoistExpr ('v' `consFS` fs) ve
325       lv <- hoistExpr ('l' `consFS` fs) le
326       return (vv, lv)
327
328 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
329 hoistPolyVExpr tvs p
330   = do
331       expr <- closedV . polyAbstract tvs $ \abstract ->
332               liftM (mapVect abstract) p
333       fn   <- hoistVExpr expr
334       polyVApply (vVar fn) (mkTyVarTys tvs)
335
336 takeHoisted :: VM [(Var, CoreExpr)]
337 takeHoisted
338   = do
339       env <- readGEnv id
340       setGEnv $ env { global_bindings = [] }
341       return $ global_bindings env
342
343 {-
344 boxExpr :: Type -> VExpr -> VM VExpr
345 boxExpr ty (vexpr, lexpr)
346   | Just (tycon, []) <- splitTyConApp_maybe ty
347   , isUnLiftedTyCon tycon
348   = do
349       r <- lookupBoxedTyCon tycon
350       case r of
351         Just tycon' -> let [dc] = tyConDataCons tycon'
352                        in
353                        return (mkConApp dc [vexpr], lexpr)
354         Nothing     -> return (vexpr, lexpr)
355 -}
356
357 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
358 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
359   = do
360       dict <- paDictOfType env_ty
361       mkv  <- builtin mkClosureVar
362       mkl  <- builtin mkClosurePVar
363       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
364               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
365
366 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
367 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
368   = do
369       vapply <- builtin applyClosureVar
370       lapply <- builtin applyClosurePVar
371       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
372               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
373
374 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
375 buildClosures _   _    [] _ mk_body
376   = mk_body
377 buildClosures tvs vars [arg_ty] res_ty mk_body
378   = buildClosure tvs vars arg_ty res_ty mk_body
379 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
380   = do
381       res_ty' <- mkClosureTypes arg_tys res_ty
382       arg <- newLocalVVar (fsLit "x") arg_ty
383       buildClosure tvs vars arg_ty res_ty'
384         . hoistPolyVExpr tvs
385         $ do
386             lc <- builtin liftingContext
387             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
388             return $ vLams lc (vars ++ [arg]) clo
389
390 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
391 --   where
392 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
393 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
394 --
395 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
396 buildClosure tvs vars arg_ty res_ty mk_body
397   = do
398       (env_ty, env, bind) <- buildEnv vars
399       env_bndr <- newLocalVVar (fsLit "env") env_ty
400       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
401
402       fn <- hoistPolyVExpr tvs
403           $ do
404               lc    <- builtin liftingContext
405               body  <- mk_body
406               body' <- bind (vVar env_bndr)
407                             (vVarApps lc body (vars ++ [arg_bndr]))
408               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
409
410       mkClosure arg_ty res_ty env_ty fn env
411
412 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
413 buildEnv vvs
414   = do
415       lc <- builtin liftingContext
416       let (ty, venv, vbind) = mkVectEnv tys vs
417       (lenv, lbind) <- mkLiftEnv lc tys ls
418       return (ty, (venv, lenv),
419               \(venv,lenv) (vbody,lbody) ->
420               do
421                 let vbody' = vbind venv vbody
422                 lbody' <- lbind lenv lbody
423                 return (vbody', lbody'))
424   where
425     (vs,ls) = unzip vvs
426     tys     = map varType vs
427
428 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
429 mkVectEnv []   []  = (unitTy, Var unitDataConId, \_ body -> body)
430 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
431 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
432                         \env body -> mkWildCase env ty (exprType body)
433                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
434   where
435     ty = mkCoreTupTy tys
436
437 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
438 mkLiftEnv lc [ty] [v]
439   = return (Var v, \env body ->
440                    do
441                      len <- lengthPA ty (Var v)
442                      return . Let (NonRec v env)
443                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
444
445 -- NOTE: this transparently deals with empty environments
446 mkLiftEnv lc tys vs
447   = do
448       (env_tc, env_tyargs) <- parrayReprTyCon vty
449
450       bndrs <- if null vs then do
451                                  v <- newDummyVar unitTy
452                                  return [v]
453                           else return vs
454       let [env_con] = tyConDataCons env_tc
455           
456           env = Var (dataConWrapId env_con)
457                 `mkTyApps`  env_tyargs
458                 `mkApps`    (Var lc : args)
459
460           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
461                           in
462                           return $ mkWildCase scrut (exprType scrut)
463                                         (exprType body)
464                                         [(DataAlt env_con, lc : bndrs, body)]
465       return (env, bind)
466   where
467     vty = mkCoreTupTy tys
468
469     args  | null vs   = [Var unitDataConId]
470           | otherwise = map Var vs
471