Add code for looking up PA methods of primitive TyCons
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6
7   mkBuiltinCo,
8   mkPADictType, mkPArrayType, mkPReprType,
9
10   parrayReprTyCon, parrayReprDataCon, mkVScrut,
11   prDFunOfTyCon,
12   paDictArgType, paDictOfType, paDFunType,
13   paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14   polyAbstract, polyApply, polyVApply,
15   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16   buildClosure, buildClosures,
17   mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id                 ( mkWildId )
35 import MkId               ( unwrapFamInstScrut )
36 import Name               ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim            ( intPrimTy )
40 import BasicTypes         ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List             ( zipWith4 )
46 import Control.Monad         ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50   where
51     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52     go e                             tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56   where
57     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58     go bs e                           = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62   where
63     go bs (_, AnnLam b e) | isId b = go (b:bs) e
64     go bs e                        = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _              = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75   | Just (tc, [ty']) <- splitTyConApp_maybe ty
76   , tyConName tc == name
77   = ty'
78
79   | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84   , tyConName tc == name
85   = (ty1, ty2)
86
87   | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91   | Just (tc', tys) <- splitTyConApp_maybe ty
92   , tc == tc'
93   = tys
94
95   | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
104 mkBuiltinTyConApp get_tc tys
105   = do
106       tc <- builtin get_tc
107       return $ mkTyConApp tc tys
108
109 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
110 mkBuiltinTyConApps get_tc tys ty
111   = do
112       tc <- builtin get_tc
113       return $ foldr (mk tc) ty tys
114   where
115     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
118 mkBuiltinTyConApps1 get_tc dft [] = return dft
119 mkBuiltinTyConApps1 get_tc dft tys
120   = do
121       tc <- builtin get_tc
122       case tys of
123         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
124         _  -> return $ foldr1 (mk tc) tys
125   where
126     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
127
128 mkClosureType :: Type -> Type -> VM Type
129 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
130
131 mkClosureTypes :: [Type] -> Type -> VM Type
132 mkClosureTypes = mkBuiltinTyConApps closureTyCon
133
134 mkPReprType :: Type -> VM Type
135 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
136
137 mkPADictType :: Type -> VM Type
138 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
139
140 mkPArrayType :: Type -> VM Type
141 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
142
143 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
144 mkBuiltinCo get_tc
145   = do
146       tc <- builtin get_tc
147       return $ mkTyConApp tc []
148
149 parrayReprTyCon :: Type -> VM (TyCon, [Type])
150 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
151
152 parrayReprDataCon :: Type -> VM (DataCon, [Type])
153 parrayReprDataCon ty
154   = do
155       (tc, arg_tys) <- parrayReprTyCon ty
156       let [dc] = tyConDataCons tc
157       return (dc, arg_tys)
158
159 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
160 mkVScrut (ve, le)
161   = do
162       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
163       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
164
165 prDFunOfTyCon :: TyCon -> VM CoreExpr
166 prDFunOfTyCon tycon
167   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
168
169 paDictArgType :: TyVar -> VM (Maybe Type)
170 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
171   where
172     go ty k | Just k' <- kindView k = go ty k'
173     go ty (FunTy k1 k2)
174       = do
175           tv   <- newTyVar FSLIT("a") k1
176           mty1 <- go (TyVarTy tv) k1
177           case mty1 of
178             Just ty1 -> do
179                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
180                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
181             Nothing  -> go ty k2
182
183     go ty k
184       | isLiftedTypeKind k
185       = liftM Just (mkPADictType ty)
186
187     go ty k = return Nothing
188
189 paDictOfType :: Type -> VM CoreExpr
190 paDictOfType ty = paDictOfTyApp ty_fn ty_args
191   where
192     (ty_fn, ty_args) = splitAppTys ty
193
194 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
195 paDictOfTyApp ty_fn ty_args
196   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
197 paDictOfTyApp (TyVarTy tv) ty_args
198   = do
199       dfun <- maybeV (lookupTyVarPA tv)
200       paDFunApply dfun ty_args
201 paDictOfTyApp (TyConApp tc _) ty_args
202   = do
203       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
204       paDFunApply (Var dfun) ty_args
205 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
206
207 paDFunType :: TyCon -> VM Type
208 paDFunType tc
209   = do
210       margs <- mapM paDictArgType tvs
211       res   <- mkPADictType (mkTyConApp tc arg_tys)
212       return . mkForAllTys tvs
213              $ mkFunTys [arg | Just arg <- margs] res
214   where
215     tvs = tyConTyVars tc
216     arg_tys = mkTyVarTys tvs
217
218 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
219 paDFunApply dfun tys
220   = do
221       dicts <- mapM paDictOfType tys
222       return $ mkApps (mkTyApps dfun tys) dicts
223
224 type PAMethod = (Builtins -> Var, String)
225
226 pa_length    = (lengthPAVar,    "lengthPA")
227 pa_replicate = (replicatePAVar, "replicatePA")
228 pa_empty     = (emptyPAVar,     "emptyPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232   | Just (tycon, []) <- splitTyConApp_maybe ty
233   , isPrimTyCon tycon
234   = do
235       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
236           $ lookupPrimMethod tycon name
237       return (Var fn)
238
239 paMethod (method, name) ty
240   = do
241       fn   <- builtin method
242       dict <- paDictOfType ty
243       return $ mkApps (Var fn) [Type ty, dict]
244
245 mkPR :: Type -> VM CoreExpr
246 mkPR ty
247   = do
248       fn   <- builtin mkPRVar
249       dict <- paDictOfType ty
250       return $ mkApps (Var fn) [Type ty, dict]
251
252 lengthPA :: CoreExpr -> VM CoreExpr
253 lengthPA x = liftM (`App` x) (paMethod pa_length ty)
254   where
255     ty = splitPArrayTy (exprType x)
256
257 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
258 replicatePA len x = liftM (`mkApps` [len,x])
259                           (paMethod pa_replicate (exprType x))
260
261 emptyPA :: Type -> VM CoreExpr
262 emptyPA = paMethod pa_empty
263
264 liftPA :: CoreExpr -> VM CoreExpr
265 liftPA x
266   = do
267       lc <- builtin liftingContext
268       replicatePA (Var lc) x
269
270 newLocalVVar :: FastString -> Type -> VM VVar
271 newLocalVVar fs vty
272   = do
273       lty <- mkPArrayType vty
274       vv  <- newLocalVar fs vty
275       lv  <- newLocalVar fs lty
276       return (vv,lv)
277
278 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
279 polyAbstract tvs p
280   = localV
281   $ do
282       mdicts <- mapM mk_dict_var tvs
283       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
284       p (mk_lams mdicts)
285   where
286     mk_dict_var tv = do
287                        r <- paDictArgType tv
288                        case r of
289                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
290                          Nothing -> return Nothing
291
292     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
293
294 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
295 polyApply expr tys
296   = do
297       dicts <- mapM paDictOfType tys
298       return $ expr `mkTyApps` tys `mkApps` dicts
299
300 polyVApply :: VExpr -> [Type] -> VM VExpr
301 polyVApply expr tys
302   = do
303       dicts <- mapM paDictOfType tys
304       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
305
306 hoistBinding :: Var -> CoreExpr -> VM ()
307 hoistBinding v e = updGEnv $ \env ->
308   env { global_bindings = (v,e) : global_bindings env }
309
310 hoistExpr :: FastString -> CoreExpr -> VM Var
311 hoistExpr fs expr
312   = do
313       var <- newLocalVar fs (exprType expr)
314       hoistBinding var expr
315       return var
316
317 hoistVExpr :: VExpr -> VM VVar
318 hoistVExpr (ve, le)
319   = do
320       fs <- getBindName
321       vv <- hoistExpr ('v' `consFS` fs) ve
322       lv <- hoistExpr ('l' `consFS` fs) le
323       return (vv, lv)
324
325 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
326 hoistPolyVExpr tvs p
327   = do
328       expr <- closedV . polyAbstract tvs $ \abstract ->
329               liftM (mapVect abstract) p
330       fn   <- hoistVExpr expr
331       polyVApply (vVar fn) (mkTyVarTys tvs)
332
333 takeHoisted :: VM [(Var, CoreExpr)]
334 takeHoisted
335   = do
336       env <- readGEnv id
337       setGEnv $ env { global_bindings = [] }
338       return $ global_bindings env
339
340 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
341 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
342   = do
343       dict <- paDictOfType env_ty
344       mkv  <- builtin mkClosureVar
345       mkl  <- builtin mkClosurePVar
346       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
347               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
348
349 mkClosureApp :: VExpr -> VExpr -> VM VExpr
350 mkClosureApp (vclo, lclo) (varg, larg)
351   = do
352       vapply <- builtin applyClosureVar
353       lapply <- builtin applyClosurePVar
354       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
355               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
356   where
357     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
358
359 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
360 buildClosures tvs vars [] res_ty mk_body
361   = mk_body
362 buildClosures tvs vars [arg_ty] res_ty mk_body
363   = buildClosure tvs vars arg_ty res_ty mk_body
364 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
365   = do
366       res_ty' <- mkClosureTypes arg_tys res_ty
367       arg <- newLocalVVar FSLIT("x") arg_ty
368       buildClosure tvs vars arg_ty res_ty'
369         . hoistPolyVExpr tvs
370         $ do
371             lc <- builtin liftingContext
372             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
373             return $ vLams lc (vars ++ [arg]) clo
374
375 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
376 --   where
377 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
378 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
379 --
380 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
381 buildClosure tvs vars arg_ty res_ty mk_body
382   = do
383       (env_ty, env, bind) <- buildEnv vars
384       env_bndr <- newLocalVVar FSLIT("env") env_ty
385       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
386
387       fn <- hoistPolyVExpr tvs
388           $ do
389               lc    <- builtin liftingContext
390               body  <- mk_body
391               body' <- bind (vVar env_bndr)
392                             (vVarApps lc body (vars ++ [arg_bndr]))
393               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
394
395       mkClosure arg_ty res_ty env_ty fn env
396
397 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
398 buildEnv vvs
399   = do
400       lc <- builtin liftingContext
401       let (ty, venv, vbind) = mkVectEnv tys vs
402       (lenv, lbind) <- mkLiftEnv lc tys ls
403       return (ty, (venv, lenv),
404               \(venv,lenv) (vbody,lbody) ->
405               do
406                 let vbody' = vbind venv vbody
407                 lbody' <- lbind lenv lbody
408                 return (vbody', lbody'))
409   where
410     (vs,ls) = unzip vvs
411     tys     = map idType vs
412
413 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
414 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
415 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
416 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
417                         \env body -> Case env (mkWildId ty) (exprType body)
418                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
419   where
420     ty = mkCoreTupTy tys
421
422 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
423 mkLiftEnv lc [ty] [v]
424   = return (Var v, \env body ->
425                    do
426                      len <- lengthPA (Var v)
427                      return . Let (NonRec v env)
428                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
429
430 -- NOTE: this transparently deals with empty environments
431 mkLiftEnv lc tys vs
432   = do
433       (env_tc, env_tyargs) <- parrayReprTyCon vty
434       let [env_con] = tyConDataCons env_tc
435           
436           env = Var (dataConWrapId env_con)
437                 `mkTyApps`  env_tyargs
438                 `mkVarApps` (lc : vs)
439
440           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
441                           in
442                           return $ Case scrut (mkWildId (exprType scrut))
443                                         (exprType body)
444                                         [(DataAlt env_con, lc : bndrs, body)]
445       return (env, bind)
446   where
447     vty = mkCoreTupTy tys
448
449     bndrs | null vs   = [mkWildId unitTy]
450           | otherwise = vs
451