Simplify generation of PR dictionaries for products
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   mkPADictType, mkPArrayType, mkPReprType,
8
9   parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
10   prDFunOfTyCon, prCoerce,
11   paDictArgType, paDictOfType, paDFunType,
12   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
13   polyAbstract, polyApply, polyVApply,
14   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
15   buildClosure, buildClosures,
16   mkClosureApp
17 ) where
18
19 #include "HsVersions.h"
20
21 import VectCore
22 import VectMonad
23
24 import DsUtils
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import Id                 ( mkWildId )
34 import MkId               ( unwrapFamInstScrut )
35 import Name               ( Name )
36 import PrelNames
37 import TysWiredIn
38 import TysPrim            ( intPrimTy )
39 import BasicTypes         ( Boxity(..) )
40
41 import Outputable
42 import FastString
43
44 import Data.List             ( zipWith4 )
45 import Control.Monad         ( liftM, liftM2, zipWithM_ )
46
47 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
48 collectAnnTypeArgs expr = go expr []
49   where
50     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
51     go e                             tys = (e, tys)
52
53 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
54 collectAnnTypeBinders expr = go [] expr
55   where
56     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
57     go bs e                           = (reverse bs, e)
58
59 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
60 collectAnnValBinders expr = go [] expr
61   where
62     go bs (_, AnnLam b e) | isId b = go (b:bs) e
63     go bs e                        = (reverse bs, e)
64
65 isAnnTypeArg :: AnnExpr b ann -> Bool
66 isAnnTypeArg (_, AnnType t) = True
67 isAnnTypeArg _              = False
68
69 mkDataConTag :: DataCon -> CoreExpr
70 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
71
72 splitUnTy :: String -> Name -> Type -> Type
73 splitUnTy s name ty
74   | Just (tc, [ty']) <- splitTyConApp_maybe ty
75   , tyConName tc == name
76   = ty'
77
78   | otherwise = pprPanic s (ppr ty)
79
80 splitBinTy :: String -> Name -> Type -> (Type, Type)
81 splitBinTy s name ty
82   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
83   , tyConName tc == name
84   = (ty1, ty2)
85
86   | otherwise = pprPanic s (ppr ty)
87
88 splitFixedTyConApp :: TyCon -> Type -> [Type]
89 splitFixedTyConApp tc ty
90   | Just (tc', tys) <- splitTyConApp_maybe ty
91   , tc == tc'
92   = tys
93
94   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
95
96 splitClosureTy :: Type -> (Type, Type)
97 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
98
99 splitPArrayTy :: Type -> Type
100 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
101
102 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
103 mkBuiltinTyConApp get_tc tys
104   = do
105       tc <- builtin get_tc
106       return $ mkTyConApp tc tys
107
108 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
109 mkBuiltinTyConApps get_tc tys ty
110   = do
111       tc <- builtin get_tc
112       return $ foldr (mk tc) ty tys
113   where
114     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
115
116 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
117 mkBuiltinTyConApps1 get_tc dft [] = return dft
118 mkBuiltinTyConApps1 get_tc dft tys
119   = do
120       tc <- builtin get_tc
121       case tys of
122         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
123         _  -> return $ foldr1 (mk tc) tys
124   where
125     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
126
127 mkClosureType :: Type -> Type -> VM Type
128 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
129
130 mkClosureTypes :: [Type] -> Type -> VM Type
131 mkClosureTypes = mkBuiltinTyConApps closureTyCon
132
133 mkPReprType :: Type -> VM Type
134 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
135
136 mkPADictType :: Type -> VM Type
137 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
138
139 mkPArrayType :: Type -> VM Type
140 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
141
142 parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
143 parrayCoerce repr_tc args expr
144   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
145   = do
146       parray <- builtin parrayTyCon
147
148       let co = mkAppCoercion (mkTyConApp parray [])
149                              (mkSymCoercion (mkTyConApp arg_co args))
150
151       return $ mkCoerce co expr
152
153 parrayReprTyCon :: Type -> VM (TyCon, [Type])
154 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
155
156 parrayReprDataCon :: Type -> VM (DataCon, [Type])
157 parrayReprDataCon ty
158   = do
159       (tc, arg_tys) <- parrayReprTyCon ty
160       let [dc] = tyConDataCons tc
161       return (dc, arg_tys)
162
163 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
164 mkVScrut (ve, le)
165   = do
166       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
167       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
168
169 prDFunOfTyCon :: TyCon -> VM CoreExpr
170 prDFunOfTyCon tycon
171   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
172
173 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
174 prCoerce repr_tc args expr
175   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
176   = do
177       pr_tc <- builtin prTyCon
178
179       let co = mkAppCoercion (mkTyConApp pr_tc [])
180                              (mkSymCoercion (mkTyConApp arg_co args))
181
182       return $ mkCoerce co expr
183
184 paDictArgType :: TyVar -> VM (Maybe Type)
185 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
186   where
187     go ty k | Just k' <- kindView k = go ty k'
188     go ty (FunTy k1 k2)
189       = do
190           tv   <- newTyVar FSLIT("a") k1
191           mty1 <- go (TyVarTy tv) k1
192           case mty1 of
193             Just ty1 -> do
194                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
195                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
196             Nothing  -> go ty k2
197
198     go ty k
199       | isLiftedTypeKind k
200       = liftM Just (mkPADictType ty)
201
202     go ty k = return Nothing
203
204 paDictOfType :: Type -> VM CoreExpr
205 paDictOfType ty = paDictOfTyApp ty_fn ty_args
206   where
207     (ty_fn, ty_args) = splitAppTys ty
208
209 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
210 paDictOfTyApp ty_fn ty_args
211   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
212 paDictOfTyApp (TyVarTy tv) ty_args
213   = do
214       dfun <- maybeV (lookupTyVarPA tv)
215       paDFunApply dfun ty_args
216 paDictOfTyApp (TyConApp tc _) ty_args
217   = do
218       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
219       paDFunApply (Var dfun) ty_args
220 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
221
222 paDFunType :: TyCon -> VM Type
223 paDFunType tc
224   = do
225       margs <- mapM paDictArgType tvs
226       res   <- mkPADictType (mkTyConApp tc arg_tys)
227       return . mkForAllTys tvs
228              $ mkFunTys [arg | Just arg <- margs] res
229   where
230     tvs = tyConTyVars tc
231     arg_tys = mkTyVarTys tvs
232
233 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
234 paDFunApply dfun tys
235   = do
236       dicts <- mapM paDictOfType tys
237       return $ mkApps (mkTyApps dfun tys) dicts
238
239 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
240 paMethod method ty
241   = do
242       fn   <- builtin method
243       dict <- paDictOfType ty
244       return $ mkApps (Var fn) [Type ty, dict]
245
246 mkPR :: Type -> VM CoreExpr
247 mkPR = paMethod mkPRVar
248
249 lengthPA :: CoreExpr -> VM CoreExpr
250 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
251   where
252     ty = splitPArrayTy (exprType x)
253
254 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
255 replicatePA len x = liftM (`mkApps` [len,x])
256                           (paMethod replicatePAVar (exprType x))
257
258 emptyPA :: Type -> VM CoreExpr
259 emptyPA = paMethod emptyPAVar
260
261 liftPA :: CoreExpr -> VM CoreExpr
262 liftPA x
263   = do
264       lc <- builtin liftingContext
265       replicatePA (Var lc) x
266
267 newLocalVVar :: FastString -> Type -> VM VVar
268 newLocalVVar fs vty
269   = do
270       lty <- mkPArrayType vty
271       vv  <- newLocalVar fs vty
272       lv  <- newLocalVar fs lty
273       return (vv,lv)
274
275 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
276 polyAbstract tvs p
277   = localV
278   $ do
279       mdicts <- mapM mk_dict_var tvs
280       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
281       p (mk_lams mdicts)
282   where
283     mk_dict_var tv = do
284                        r <- paDictArgType tv
285                        case r of
286                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
287                          Nothing -> return Nothing
288
289     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
290
291 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
292 polyApply expr tys
293   = do
294       dicts <- mapM paDictOfType tys
295       return $ expr `mkTyApps` tys `mkApps` dicts
296
297 polyVApply :: VExpr -> [Type] -> VM VExpr
298 polyVApply expr tys
299   = do
300       dicts <- mapM paDictOfType tys
301       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
302
303 hoistBinding :: Var -> CoreExpr -> VM ()
304 hoistBinding v e = updGEnv $ \env ->
305   env { global_bindings = (v,e) : global_bindings env }
306
307 hoistExpr :: FastString -> CoreExpr -> VM Var
308 hoistExpr fs expr
309   = do
310       var <- newLocalVar fs (exprType expr)
311       hoistBinding var expr
312       return var
313
314 hoistVExpr :: VExpr -> VM VVar
315 hoistVExpr (ve, le)
316   = do
317       fs <- getBindName
318       vv <- hoistExpr ('v' `consFS` fs) ve
319       lv <- hoistExpr ('l' `consFS` fs) le
320       return (vv, lv)
321
322 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
323 hoistPolyVExpr tvs p
324   = do
325       expr <- closedV . polyAbstract tvs $ \abstract ->
326               liftM (mapVect abstract) p
327       fn   <- hoistVExpr expr
328       polyVApply (vVar fn) (mkTyVarTys tvs)
329
330 takeHoisted :: VM [(Var, CoreExpr)]
331 takeHoisted
332   = do
333       env <- readGEnv id
334       setGEnv $ env { global_bindings = [] }
335       return $ global_bindings env
336
337 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
338 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
339   = do
340       dict <- paDictOfType env_ty
341       mkv  <- builtin mkClosureVar
342       mkl  <- builtin mkClosurePVar
343       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
344               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
345
346 mkClosureApp :: VExpr -> VExpr -> VM VExpr
347 mkClosureApp (vclo, lclo) (varg, larg)
348   = do
349       vapply <- builtin applyClosureVar
350       lapply <- builtin applyClosurePVar
351       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
352               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
353   where
354     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
355
356 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
357 buildClosures tvs vars [] res_ty mk_body
358   = mk_body
359 buildClosures tvs vars [arg_ty] res_ty mk_body
360   = buildClosure tvs vars arg_ty res_ty mk_body
361 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
362   = do
363       res_ty' <- mkClosureTypes arg_tys res_ty
364       arg <- newLocalVVar FSLIT("x") arg_ty
365       buildClosure tvs vars arg_ty res_ty'
366         . hoistPolyVExpr tvs
367         $ do
368             lc <- builtin liftingContext
369             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
370             return $ vLams lc (vars ++ [arg]) clo
371
372 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
373 --   where
374 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
375 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
376 --
377 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
378 buildClosure tvs vars arg_ty res_ty mk_body
379   = do
380       (env_ty, env, bind) <- buildEnv vars
381       env_bndr <- newLocalVVar FSLIT("env") env_ty
382       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
383
384       fn <- hoistPolyVExpr tvs
385           $ do
386               lc    <- builtin liftingContext
387               body  <- mk_body
388               body' <- bind (vVar env_bndr)
389                             (vVarApps lc body (vars ++ [arg_bndr]))
390               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
391
392       mkClosure arg_ty res_ty env_ty fn env
393
394 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
395 buildEnv vvs
396   = do
397       lc <- builtin liftingContext
398       let (ty, venv, vbind) = mkVectEnv tys vs
399       (lenv, lbind) <- mkLiftEnv lc tys ls
400       return (ty, (venv, lenv),
401               \(venv,lenv) (vbody,lbody) ->
402               do
403                 let vbody' = vbind venv vbody
404                 lbody' <- lbind lenv lbody
405                 return (vbody', lbody'))
406   where
407     (vs,ls) = unzip vvs
408     tys     = map idType vs
409
410 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
411 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
412 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
413 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
414                         \env body -> Case env (mkWildId ty) (exprType body)
415                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
416   where
417     ty = mkCoreTupTy tys
418
419 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
420 mkLiftEnv lc [ty] [v]
421   = return (Var v, \env body ->
422                    do
423                      len <- lengthPA (Var v)
424                      return . Let (NonRec v env)
425                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
426
427 -- NOTE: this transparently deals with empty environments
428 mkLiftEnv lc tys vs
429   = do
430       (env_tc, env_tyargs) <- parrayReprTyCon vty
431       let [env_con] = tyConDataCons env_tc
432           
433           env = Var (dataConWrapId env_con)
434                 `mkTyApps`  env_tyargs
435                 `mkVarApps` (lc : vs)
436
437           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
438                           in
439                           return $ Case scrut (mkWildId (exprType scrut))
440                                         (exprType body)
441                                         [(DataAlt env_con, lc : bndrs, body)]
442       return (env, bind)
443   where
444     vty = mkCoreTupTy tys
445
446     bndrs | null vs   = [mkWildId unitTy]
447           | otherwise = vs
448