Complete PA dictionary generation for product types
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   mkBuiltinCo,
8   mkPADictType, mkPArrayType, mkPReprType,
9
10   parrayReprTyCon, parrayReprDataCon, mkVScrut,
11   prDFunOfTyCon,
12   paDictArgType, paDictOfType, paDFunType,
13   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14   polyAbstract, polyApply, polyVApply,
15   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16   buildClosure, buildClosures,
17   mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id                 ( mkWildId )
35 import MkId               ( unwrapFamInstScrut )
36 import Name               ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim            ( intPrimTy )
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List             ( zipWith4 )
46 import Control.Monad         ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50   where
51     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52     go e                             tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58     go bs e                           = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62   where
63     go bs (_, AnnLam b e) | isId b = go (b:bs) e
64     go bs e                        = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _              = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75   | Just (tc, [ty']) <- splitTyConApp_maybe ty
76   , tyConName tc == name
77   = ty'
78
79   | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84   , tyConName tc == name
85   = (ty1, ty2)
86
87   | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91   | Just (tc', tys) <- splitTyConApp_maybe ty
92   , tc == tc'
93   = tys
94
95   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
104 mkBuiltinTyConApp get_tc tys
105   = do
106       tc <- builtin get_tc
107       return $ mkTyConApp tc tys
108
109 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
110 mkBuiltinTyConApps get_tc tys ty
111   = do
112       tc <- builtin get_tc
113       return $ foldr (mk tc) ty tys
114   where
115     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
118 mkBuiltinTyConApps1 get_tc dft [] = return dft
119 mkBuiltinTyConApps1 get_tc dft tys
120   = do
121       tc <- builtin get_tc
122       case tys of
123         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
124         _  -> return $ foldr1 (mk tc) tys
125   where
126     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
127
128 mkClosureType :: Type -> Type -> VM Type
129 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
130
131 mkClosureTypes :: [Type] -> Type -> VM Type
132 mkClosureTypes = mkBuiltinTyConApps closureTyCon
133
134 mkPReprType :: Type -> VM Type
135 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
136
137 mkPADictType :: Type -> VM Type
138 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
139
140 mkPArrayType :: Type -> VM Type
141 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
142
143 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
144 mkBuiltinCo get_tc
145   = do
146       tc <- builtin get_tc
147       return $ mkTyConApp tc []
148
149 parrayReprTyCon :: Type -> VM (TyCon, [Type])
150 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
151
152 parrayReprDataCon :: Type -> VM (DataCon, [Type])
153 parrayReprDataCon ty
154   = do
155       (tc, arg_tys) <- parrayReprTyCon ty
156       let [dc] = tyConDataCons tc
157       return (dc, arg_tys)
158
159 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
160 mkVScrut (ve, le)
161   = do
162       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
163       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
164
165 prDFunOfTyCon :: TyCon -> VM CoreExpr
166 prDFunOfTyCon tycon
167   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
168
169 paDictArgType :: TyVar -> VM (Maybe Type)
170 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
171   where
172     go ty k | Just k' <- kindView k = go ty k'
173     go ty (FunTy k1 k2)
174       = do
175           tv   <- newTyVar FSLIT("a") k1
176           mty1 <- go (TyVarTy tv) k1
177           case mty1 of
178             Just ty1 -> do
179                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
180                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
181             Nothing  -> go ty k2
182
183     go ty k
184       | isLiftedTypeKind k
185       = liftM Just (mkPADictType ty)
186
187     go ty k = return Nothing
188
189 paDictOfType :: Type -> VM CoreExpr
190 paDictOfType ty = paDictOfTyApp ty_fn ty_args
191   where
192     (ty_fn, ty_args) = splitAppTys ty
193
194 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
195 paDictOfTyApp ty_fn ty_args
196   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
197 paDictOfTyApp (TyVarTy tv) ty_args
198   = do
199       dfun <- maybeV (lookupTyVarPA tv)
200       paDFunApply dfun ty_args
201 paDictOfTyApp (TyConApp tc _) ty_args
202   = do
203       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
204       paDFunApply (Var dfun) ty_args
205 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
206
207 paDFunType :: TyCon -> VM Type
208 paDFunType tc
209   = do
210       margs <- mapM paDictArgType tvs
211       res   <- mkPADictType (mkTyConApp tc arg_tys)
212       return . mkForAllTys tvs
213              $ mkFunTys [arg | Just arg <- margs] res
214   where
215     tvs = tyConTyVars tc
216     arg_tys = mkTyVarTys tvs
217
218 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
219 paDFunApply dfun tys
220   = do
221       dicts <- mapM paDictOfType tys
222       return $ mkApps (mkTyApps dfun tys) dicts
223
224 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
225 paMethod method ty
226   = do
227       fn   <- builtin method
228       dict <- paDictOfType ty
229       return $ mkApps (Var fn) [Type ty, dict]
230
231 mkPR :: Type -> VM CoreExpr
232 mkPR = paMethod mkPRVar
233
234 lengthPA :: CoreExpr -> VM CoreExpr
235 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
236   where
237     ty = splitPArrayTy (exprType x)
238
239 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
240 replicatePA len x = liftM (`mkApps` [len,x])
241                           (paMethod replicatePAVar (exprType x))
242
243 emptyPA :: Type -> VM CoreExpr
244 emptyPA = paMethod emptyPAVar
245
246 liftPA :: CoreExpr -> VM CoreExpr
247 liftPA x
248   = do
249       lc <- builtin liftingContext
250       replicatePA (Var lc) x
251
252 newLocalVVar :: FastString -> Type -> VM VVar
253 newLocalVVar fs vty
254   = do
255       lty <- mkPArrayType vty
256       vv  <- newLocalVar fs vty
257       lv  <- newLocalVar fs lty
258       return (vv,lv)
259
260 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
261 polyAbstract tvs p
262   = localV
263   $ do
264       mdicts <- mapM mk_dict_var tvs
265       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
266       p (mk_lams mdicts)
267   where
268     mk_dict_var tv = do
269                        r <- paDictArgType tv
270                        case r of
271                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
272                          Nothing -> return Nothing
273
274     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
275
276 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
277 polyApply expr tys
278   = do
279       dicts <- mapM paDictOfType tys
280       return $ expr `mkTyApps` tys `mkApps` dicts
281
282 polyVApply :: VExpr -> [Type] -> VM VExpr
283 polyVApply expr tys
284   = do
285       dicts <- mapM paDictOfType tys
286       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
287
288 hoistBinding :: Var -> CoreExpr -> VM ()
289 hoistBinding v e = updGEnv $ \env ->
290   env { global_bindings = (v,e) : global_bindings env }
291
292 hoistExpr :: FastString -> CoreExpr -> VM Var
293 hoistExpr fs expr
294   = do
295       var <- newLocalVar fs (exprType expr)
296       hoistBinding var expr
297       return var
298
299 hoistVExpr :: VExpr -> VM VVar
300 hoistVExpr (ve, le)
301   = do
302       fs <- getBindName
303       vv <- hoistExpr ('v' `consFS` fs) ve
304       lv <- hoistExpr ('l' `consFS` fs) le
305       return (vv, lv)
306
307 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
308 hoistPolyVExpr tvs p
309   = do
310       expr <- closedV . polyAbstract tvs $ \abstract ->
311               liftM (mapVect abstract) p
312       fn   <- hoistVExpr expr
313       polyVApply (vVar fn) (mkTyVarTys tvs)
314
315 takeHoisted :: VM [(Var, CoreExpr)]
316 takeHoisted
317   = do
318       env <- readGEnv id
319       setGEnv $ env { global_bindings = [] }
320       return $ global_bindings env
321
322 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
323 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
324   = do
325       dict <- paDictOfType env_ty
326       mkv  <- builtin mkClosureVar
327       mkl  <- builtin mkClosurePVar
328       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
329               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
330
331 mkClosureApp :: VExpr -> VExpr -> VM VExpr
332 mkClosureApp (vclo, lclo) (varg, larg)
333   = do
334       vapply <- builtin applyClosureVar
335       lapply <- builtin applyClosurePVar
336       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
337               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
338   where
339     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
340
341 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
342 buildClosures tvs vars [] res_ty mk_body
343   = mk_body
344 buildClosures tvs vars [arg_ty] res_ty mk_body
345   = buildClosure tvs vars arg_ty res_ty mk_body
346 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
347   = do
348       res_ty' <- mkClosureTypes arg_tys res_ty
349       arg <- newLocalVVar FSLIT("x") arg_ty
350       buildClosure tvs vars arg_ty res_ty'
351         . hoistPolyVExpr tvs
352         $ do
353             lc <- builtin liftingContext
354             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
355             return $ vLams lc (vars ++ [arg]) clo
356
357 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
358 --   where
359 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
360 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
361 --
362 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
363 buildClosure tvs vars arg_ty res_ty mk_body
364   = do
365       (env_ty, env, bind) <- buildEnv vars
366       env_bndr <- newLocalVVar FSLIT("env") env_ty
367       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
368
369       fn <- hoistPolyVExpr tvs
370           $ do
371               lc    <- builtin liftingContext
372               body  <- mk_body
373               body' <- bind (vVar env_bndr)
374                             (vVarApps lc body (vars ++ [arg_bndr]))
375               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
376
377       mkClosure arg_ty res_ty env_ty fn env
378
379 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
380 buildEnv vvs
381   = do
382       lc <- builtin liftingContext
383       let (ty, venv, vbind) = mkVectEnv tys vs
384       (lenv, lbind) <- mkLiftEnv lc tys ls
385       return (ty, (venv, lenv),
386               \(venv,lenv) (vbody,lbody) ->
387               do
388                 let vbody' = vbind venv vbody
389                 lbody' <- lbind lenv lbody
390                 return (vbody', lbody'))
391   where
392     (vs,ls) = unzip vvs
393     tys     = map idType vs
394
395 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
396 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
397 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
398 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
399                         \env body -> Case env (mkWildId ty) (exprType body)
400                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
401   where
402     ty = mkCoreTupTy tys
403
404 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
405 mkLiftEnv lc [ty] [v]
406   = return (Var v, \env body ->
407                    do
408                      len <- lengthPA (Var v)
409                      return . Let (NonRec v env)
410                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
411
412 -- NOTE: this transparently deals with empty environments
413 mkLiftEnv lc tys vs
414   = do
415       (env_tc, env_tyargs) <- parrayReprTyCon vty
416       let [env_con] = tyConDataCons env_tc
417           
418           env = Var (dataConWrapId env_con)
419                 `mkTyApps`  env_tyargs
420                 `mkVarApps` (lc : vs)
421
422           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
423                           in
424                           return $ Case scrut (mkWildId (exprType scrut))
425                                         (exprType body)
426                                         [(DataAlt env_con, lc : bndrs, body)]
427       return (env, bind)
428   where
429     vty = mkCoreTupTy tys
430
431     bndrs | null vs   = [mkWildId unitTy]
432           | otherwise = vs
433