Vectorise nullary constructors correctly
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPADictType, mkPArrayType,
7   parrayReprTyCon, parrayReprDataCon, mkVScrut,
8   paDictArgType, paDictOfType, paDFunType,
9   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
10   polyAbstract, polyApply, polyVApply,
11   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
12   buildClosure, buildClosures,
13   mkClosureApp
14 ) where
15
16 #include "HsVersions.h"
17
18 import VectCore
19 import VectMonad
20
21 import DsUtils
22 import CoreSyn
23 import CoreUtils
24 import Type
25 import TypeRep
26 import TyCon
27 import DataCon            ( DataCon, dataConWrapId, dataConTag )
28 import Var
29 import Id                 ( mkWildId )
30 import MkId               ( unwrapFamInstScrut )
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes         ( Boxity(..) )
34
35 import Outputable
36 import FastString
37
38 import Control.Monad         ( liftM, zipWithM_ )
39
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType t) = True
60 isAnnTypeArg _              = False
61
62 mkDataConTag :: DataCon -> CoreExpr
63 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
64
65 isClosureTyCon :: TyCon -> Bool
66 isClosureTyCon tc = tyConName tc == closureTyConName
67
68 splitClosureTy :: Type -> (Type, Type)
69 splitClosureTy ty
70   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
71   , isClosureTyCon tc
72   = (arg_ty, res_ty)
73
74   | otherwise = pprPanic "splitClosureTy" (ppr ty)
75
76 isPArrayTyCon :: TyCon -> Bool
77 isPArrayTyCon tc = tyConName tc == parrayTyConName
78
79 splitPArrayTy :: Type -> Type
80 splitPArrayTy ty
81   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
82   , isPArrayTyCon tc
83   = arg_ty
84
85   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
86
87 mkClosureType :: Type -> Type -> VM Type
88 mkClosureType arg_ty res_ty
89   = do
90       tc <- builtin closureTyCon
91       return $ mkTyConApp tc [arg_ty, res_ty]
92
93 mkClosureTypes :: [Type] -> Type -> VM Type
94 mkClosureTypes arg_tys res_ty
95   = do
96       tc <- builtin closureTyCon
97       return $ foldr (mk tc) res_ty arg_tys
98   where
99     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
100
101 mkPADictType :: Type -> VM Type
102 mkPADictType ty
103   = do
104       tc <- builtin paTyCon
105       return $ TyConApp tc [ty]
106
107 mkPArrayType :: Type -> VM Type
108 mkPArrayType ty
109   = do
110       tc <- builtin parrayTyCon
111       return $ TyConApp tc [ty]
112
113 parrayReprTyCon :: Type -> VM (TyCon, [Type])
114 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
115
116 parrayReprDataCon :: Type -> VM (DataCon, [Type])
117 parrayReprDataCon ty
118   = do
119       (tc, arg_tys) <- parrayReprTyCon ty
120       let [dc] = tyConDataCons tc
121       return (dc, arg_tys)
122
123 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
124 mkVScrut (ve, le)
125   = do
126       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
127       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
128
129 paDictArgType :: TyVar -> VM (Maybe Type)
130 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
131   where
132     go ty k | Just k' <- kindView k = go ty k'
133     go ty (FunTy k1 k2)
134       = do
135           tv   <- newTyVar FSLIT("a") k1
136           mty1 <- go (TyVarTy tv) k1
137           case mty1 of
138             Just ty1 -> do
139                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
140                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
141             Nothing  -> go ty k2
142
143     go ty k
144       | isLiftedTypeKind k
145       = liftM Just (mkPADictType ty)
146
147     go ty k = return Nothing
148
149 paDictOfType :: Type -> VM CoreExpr
150 paDictOfType ty = paDictOfTyApp ty_fn ty_args
151   where
152     (ty_fn, ty_args) = splitAppTys ty
153
154 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
155 paDictOfTyApp ty_fn ty_args
156   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
157 paDictOfTyApp (TyVarTy tv) ty_args
158   = do
159       dfun <- maybeV (lookupTyVarPA tv)
160       paDFunApply dfun ty_args
161 paDictOfTyApp (TyConApp tc _) ty_args
162   = do
163       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
164       paDFunApply (Var dfun) ty_args
165 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
166
167 paDFunType :: TyCon -> VM Type
168 paDFunType tc
169   = do
170       margs <- mapM paDictArgType tvs
171       res   <- mkPADictType (mkTyConApp tc arg_tys)
172       return . mkForAllTys tvs
173              $ mkFunTys [arg | Just arg <- margs] res
174   where
175     tvs = tyConTyVars tc
176     arg_tys = mkTyVarTys tvs
177
178 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
179 paDFunApply dfun tys
180   = do
181       dicts <- mapM paDictOfType tys
182       return $ mkApps (mkTyApps dfun tys) dicts
183
184 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
185 paMethod method ty
186   = do
187       fn   <- builtin method
188       dict <- paDictOfType ty
189       return $ mkApps (Var fn) [Type ty, dict]
190
191 lengthPA :: CoreExpr -> VM CoreExpr
192 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
193   where
194     ty = splitPArrayTy (exprType x)
195
196 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
197 replicatePA len x = liftM (`mkApps` [len,x])
198                           (paMethod replicatePAVar (exprType x))
199
200 emptyPA :: Type -> VM CoreExpr
201 emptyPA = paMethod emptyPAVar
202
203 liftPA :: CoreExpr -> VM CoreExpr
204 liftPA x
205   = do
206       lc <- builtin liftingContext
207       replicatePA (Var lc) x
208
209 newLocalVVar :: FastString -> Type -> VM VVar
210 newLocalVVar fs vty
211   = do
212       lty <- mkPArrayType vty
213       vv  <- newLocalVar fs vty
214       lv  <- newLocalVar fs lty
215       return (vv,lv)
216
217 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
218 polyAbstract tvs p
219   = localV
220   $ do
221       mdicts <- mapM mk_dict_var tvs
222       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
223       p (mk_lams mdicts)
224   where
225     mk_dict_var tv = do
226                        r <- paDictArgType tv
227                        case r of
228                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
229                          Nothing -> return Nothing
230
231     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
232
233 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
234 polyApply expr tys
235   = do
236       dicts <- mapM paDictOfType tys
237       return $ expr `mkTyApps` tys `mkApps` dicts
238
239 polyVApply :: VExpr -> [Type] -> VM VExpr
240 polyVApply expr tys
241   = do
242       dicts <- mapM paDictOfType tys
243       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
244
245 hoistBinding :: Var -> CoreExpr -> VM ()
246 hoistBinding v e = updGEnv $ \env ->
247   env { global_bindings = (v,e) : global_bindings env }
248
249 hoistExpr :: FastString -> CoreExpr -> VM Var
250 hoistExpr fs expr
251   = do
252       var <- newLocalVar fs (exprType expr)
253       hoistBinding var expr
254       return var
255
256 hoistVExpr :: VExpr -> VM VVar
257 hoistVExpr (ve, le)
258   = do
259       fs <- getBindName
260       vv <- hoistExpr ('v' `consFS` fs) ve
261       lv <- hoistExpr ('l' `consFS` fs) le
262       return (vv, lv)
263
264 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
265 hoistPolyVExpr tvs p
266   = do
267       expr <- closedV . polyAbstract tvs $ \abstract ->
268               liftM (mapVect abstract) p
269       fn   <- hoistVExpr expr
270       polyVApply (vVar fn) (mkTyVarTys tvs)
271
272 takeHoisted :: VM [(Var, CoreExpr)]
273 takeHoisted
274   = do
275       env <- readGEnv id
276       setGEnv $ env { global_bindings = [] }
277       return $ global_bindings env
278
279 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
280 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
281   = do
282       dict <- paDictOfType env_ty
283       mkv  <- builtin mkClosureVar
284       mkl  <- builtin mkClosurePVar
285       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
286               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
287
288 mkClosureApp :: VExpr -> VExpr -> VM VExpr
289 mkClosureApp (vclo, lclo) (varg, larg)
290   = do
291       vapply <- builtin applyClosureVar
292       lapply <- builtin applyClosurePVar
293       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
294               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
295   where
296     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
297
298 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
299 buildClosures tvs vars [] res_ty mk_body
300   = mk_body
301 buildClosures tvs vars [arg_ty] res_ty mk_body
302   = buildClosure tvs vars arg_ty res_ty mk_body
303 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
304   = do
305       res_ty' <- mkClosureTypes arg_tys res_ty
306       arg <- newLocalVVar FSLIT("x") arg_ty
307       buildClosure tvs vars arg_ty res_ty'
308         . hoistPolyVExpr tvs
309         $ do
310             lc <- builtin liftingContext
311             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
312             return $ vLams lc (vars ++ [arg]) clo
313
314 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
315 --   where
316 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
317 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
318 --
319 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
320 buildClosure tvs vars arg_ty res_ty mk_body
321   = do
322       (env_ty, env, bind) <- buildEnv vars
323       env_bndr <- newLocalVVar FSLIT("env") env_ty
324       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
325
326       fn <- hoistPolyVExpr tvs
327           $ do
328               lc    <- builtin liftingContext
329               body  <- mk_body
330               body' <- bind (vVar env_bndr)
331                             (vVarApps lc body (vars ++ [arg_bndr]))
332               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
333
334       mkClosure arg_ty res_ty env_ty fn env
335
336 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
337 buildEnv vvs
338   = do
339       lc <- builtin liftingContext
340       let (ty, venv, vbind) = mkVectEnv tys vs
341       (lenv, lbind) <- mkLiftEnv lc tys ls
342       return (ty, (venv, lenv),
343               \(venv,lenv) (vbody,lbody) ->
344               do
345                 let vbody' = vbind venv vbody
346                 lbody' <- lbind lenv lbody
347                 return (vbody', lbody'))
348   where
349     (vs,ls) = unzip vvs
350     tys     = map idType vs
351
352 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
353 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
354 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
355 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
356                         \env body -> Case env (mkWildId ty) (exprType body)
357                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
358   where
359     ty = mkCoreTupTy tys
360
361 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
362 mkLiftEnv lc [ty] [v]
363   = return (Var v, \env body ->
364                    do
365                      len <- lengthPA (Var v)
366                      return . Let (NonRec v env)
367                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
368
369 -- NOTE: this transparently deals with empty environments
370 mkLiftEnv lc tys vs
371   = do
372       (env_tc, env_tyargs) <- parrayReprTyCon vty
373       let [env_con] = tyConDataCons env_tc
374           
375           env = Var (dataConWrapId env_con)
376                 `mkTyApps`  env_tyargs
377                 `mkVarApps` (lc : vs)
378
379           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
380                           in
381                           return $ Case scrut (mkWildId (exprType scrut))
382                                         (exprType body)
383                                         [(DataAlt env_con, lc : bndrs, body)]
384       return (env, bind)
385   where
386     vty = mkCoreTupTy tys
387
388     bndrs | null vs   = [mkWildId unitTy]
389           | otherwise = vs
390