Generate lots of __inline_me during vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo,
9   mkPADictType, mkPArrayType, mkPReprType,
10
11   parrayReprTyCon, parrayReprDataCon, mkVScrut,
12   prDFunOfTyCon,
13   paDictArgType, paDictOfType, paDFunType,
14   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
15   zipScalars, scalarClosure,
16   polyAbstract, polyApply, polyVApply,
17   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
18   buildClosure, buildClosures,
19   mkClosureApp
20 ) where
21
22 import VectCore
23 import VectMonad
24
25 import MkCore
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import MkId               ( unwrapFamInstScrut )
35 import TysWiredIn
36 import BasicTypes         ( Boxity(..) )
37 import Literal            ( Literal, mkMachInt )
38
39 import Outputable
40 import FastString
41
42 import Control.Monad
43
44
45 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
46 collectAnnTypeArgs expr = go expr []
47   where
48     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
49     go e                             tys = (e, tys)
50
51 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnTypeBinders expr = go [] expr
53   where
54     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
55     go bs e                           = (reverse bs, e)
56
57 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
58 collectAnnValBinders expr = go [] expr
59   where
60     go bs (_, AnnLam b e) | isId b = go (b:bs) e
61     go bs e                        = (reverse bs, e)
62
63 isAnnTypeArg :: AnnExpr b ann -> Bool
64 isAnnTypeArg (_, AnnType _) = True
65 isAnnTypeArg _              = False
66
67 dataConTagZ :: DataCon -> Int
68 dataConTagZ con = dataConTag con - fIRST_TAG
69
70 mkDataConTagLit :: DataCon -> Literal
71 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
72
73 mkDataConTag :: DataCon -> CoreExpr
74 mkDataConTag = mkIntLitInt . dataConTagZ
75
76 splitPrimTyCon :: Type -> Maybe TyCon
77 splitPrimTyCon ty
78   | Just (tycon, []) <- splitTyConApp_maybe ty
79   , isPrimTyCon tycon
80   = Just tycon
81
82   | otherwise = Nothing
83
84 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
85 mkBuiltinTyConApp get_tc tys
86   = do
87       tc <- builtin get_tc
88       return $ mkTyConApp tc tys
89
90 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
91 mkBuiltinTyConApps get_tc tys ty
92   = do
93       tc <- builtin get_tc
94       return $ foldr (mk tc) ty tys
95   where
96     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
97
98 {-
99 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
100 mkBuiltinTyConApps1 _      dft [] = return dft
101 mkBuiltinTyConApps1 get_tc _   tys
102   = do
103       tc <- builtin get_tc
104       case tys of
105         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
106         _  -> return $ foldr1 (mk tc) tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkClosureType :: Type -> Type -> VM Type
111 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
112 -}
113
114 mkClosureTypes :: [Type] -> Type -> VM Type
115 mkClosureTypes = mkBuiltinTyConApps closureTyCon
116
117 mkPReprType :: Type -> VM Type
118 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
119
120 mkPADictType :: Type -> VM Type
121 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
122
123 mkPArrayType :: Type -> VM Type
124 mkPArrayType ty
125   | Just tycon <- splitPrimTyCon ty
126   = do
127       r <- lookupPrimPArray tycon
128       case r of
129         Just arr -> return $ mkTyConApp arr []
130         Nothing  -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
131 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
132
133 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
134 mkBuiltinCo get_tc
135   = do
136       tc <- builtin get_tc
137       return $ mkTyConApp tc []
138
139 parrayReprTyCon :: Type -> VM (TyCon, [Type])
140 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
141
142 parrayReprDataCon :: Type -> VM (DataCon, [Type])
143 parrayReprDataCon ty
144   = do
145       (tc, arg_tys) <- parrayReprTyCon ty
146       let [dc] = tyConDataCons tc
147       return (dc, arg_tys)
148
149 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
150 mkVScrut (ve, le)
151   = do
152       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
153       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
154
155 prDFunOfTyCon :: TyCon -> VM CoreExpr
156 prDFunOfTyCon tycon
157   = liftM Var
158   . maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
159   $ lookupTyConPR tycon
160
161 paDictArgType :: TyVar -> VM (Maybe Type)
162 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
163   where
164     go ty k | Just k' <- kindView k = go ty k'
165     go ty (FunTy k1 k2)
166       = do
167           tv   <- newTyVar (fsLit "a") k1
168           mty1 <- go (TyVarTy tv) k1
169           case mty1 of
170             Just ty1 -> do
171                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
172                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
173             Nothing  -> go ty k2
174
175     go ty k
176       | isLiftedTypeKind k
177       = liftM Just (mkPADictType ty)
178
179     go _ _ = return Nothing
180
181 paDictOfType :: Type -> VM CoreExpr
182 paDictOfType ty = paDictOfTyApp ty_fn ty_args
183   where
184     (ty_fn, ty_args) = splitAppTys ty
185
186 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
187 paDictOfTyApp ty_fn ty_args
188   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
189 paDictOfTyApp (TyVarTy tv) ty_args
190   = do
191       dfun <- maybeV (lookupTyVarPA tv)
192       paDFunApply dfun ty_args
193 paDictOfTyApp (TyConApp tc _) ty_args
194   = do
195       dfun <- maybeCantVectoriseM "No PA dictionary for tycon" (ppr tc)
196             $ lookupTyConPA tc
197       paDFunApply (Var dfun) ty_args
198 paDictOfTyApp ty _
199   = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
200
201 paDFunType :: TyCon -> VM Type
202 paDFunType tc
203   = do
204       margs <- mapM paDictArgType tvs
205       res   <- mkPADictType (mkTyConApp tc arg_tys)
206       return . mkForAllTys tvs
207              $ mkFunTys [arg | Just arg <- margs] res
208   where
209     tvs = tyConTyVars tc
210     arg_tys = mkTyVarTys tvs
211
212 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
213 paDFunApply dfun tys
214   = do
215       dicts <- mapM paDictOfType tys
216       return $ mkApps (mkTyApps dfun tys) dicts
217
218 type PAMethod = (Builtins -> Var, String)
219
220 pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
221 pa_length    = (lengthPAVar,    "lengthPA")
222 pa_replicate = (replicatePAVar, "replicatePA")
223 pa_empty     = (emptyPAVar,     "emptyPA")
224 pa_pack      = (packPAVar,      "packPA")
225
226 paMethod :: PAMethod -> Type -> VM CoreExpr
227 paMethod (_method, name) ty
228   | Just tycon <- splitPrimTyCon ty
229   = liftM Var
230   . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
231   $ lookupPrimMethod tycon name
232
233 paMethod (method, _name) ty
234   = do
235       fn   <- builtin method
236       dict <- paDictOfType ty
237       return $ mkApps (Var fn) [Type ty, dict]
238
239 mkPR :: Type -> VM CoreExpr
240 mkPR ty
241   = do
242       fn   <- builtin mkPRVar
243       dict <- paDictOfType ty
244       return $ mkApps (Var fn) [Type ty, dict]
245
246 lengthPA :: Type -> CoreExpr -> VM CoreExpr
247 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
248
249 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
250 replicatePA len x = liftM (`mkApps` [len,x])
251                           (paMethod pa_replicate (exprType x))
252
253 emptyPA :: Type -> VM CoreExpr
254 emptyPA = paMethod pa_empty
255
256 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
257 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
258                              (paMethod pa_pack ty)
259
260 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
261           -> VM CoreExpr
262 combinePA ty len sel is xs
263   = liftM (`mkApps` (len : sel : is : xs))
264           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
265   where
266     n = length xs
267
268 liftPA :: CoreExpr -> VM CoreExpr
269 liftPA x
270   = do
271       lc <- builtin liftingContext
272       replicatePA (Var lc) x
273
274 zipScalars :: [Type] -> Type -> VM CoreExpr
275 zipScalars arg_tys res_ty
276   = do
277       scalar <- builtin scalarClass
278       (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
279       zipf <- builtin (scalarZip $ length arg_tys)
280       return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
281     where
282       ty_args = arg_tys ++ [res_ty]
283
284 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
285 scalarClosure arg_tys res_ty scalar_fun array_fun
286   = do
287       ctr <- builtin (closureCtrFun $ length arg_tys)
288       pas <- mapM paDictOfType (init arg_tys)
289       return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
290                        `mkApps`   (pas ++ [scalar_fun, array_fun])
291
292 newLocalVVar :: FastString -> Type -> VM VVar
293 newLocalVVar fs vty
294   = do
295       lty <- mkPArrayType vty
296       vv  <- newLocalVar fs vty
297       lv  <- newLocalVar fs lty
298       return (vv,lv)
299
300 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
301 polyAbstract tvs p
302   = localV
303   $ do
304       mdicts <- mapM mk_dict_var tvs
305       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
306       p (mk_lams mdicts)
307   where
308     mk_dict_var tv = do
309                        r <- paDictArgType tv
310                        case r of
311                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
312                          Nothing -> return Nothing
313
314     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
315
316 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
317 polyApply expr tys
318   = do
319       dicts <- mapM paDictOfType tys
320       return $ expr `mkTyApps` tys `mkApps` dicts
321
322 polyVApply :: VExpr -> [Type] -> VM VExpr
323 polyVApply expr tys
324   = do
325       dicts <- mapM paDictOfType tys
326       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
327
328 hoistBinding :: Var -> CoreExpr -> VM ()
329 hoistBinding v e = updGEnv $ \env ->
330   env { global_bindings = (v,e) : global_bindings env }
331
332 hoistExpr :: FastString -> CoreExpr -> VM Var
333 hoistExpr fs expr
334   = do
335       var <- newLocalVar fs (exprType expr)
336       hoistBinding var expr
337       return var
338
339 hoistVExpr :: VExpr -> VM VVar
340 hoistVExpr (ve, le)
341   = do
342       fs <- getBindName
343       vv <- hoistExpr ('v' `consFS` fs) ve
344       lv <- hoistExpr ('l' `consFS` fs) le
345       return (vv, lv)
346
347 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
348 hoistPolyVExpr tvs p
349   = do
350       expr <- closedV . polyAbstract tvs $ \abstract ->
351               liftM (mapVect abstract) p
352       fn   <- hoistVExpr expr
353       polyVApply (vVar fn) (mkTyVarTys tvs)
354
355 takeHoisted :: VM [(Var, CoreExpr)]
356 takeHoisted
357   = do
358       env <- readGEnv id
359       setGEnv $ env { global_bindings = [] }
360       return $ global_bindings env
361
362 {-
363 boxExpr :: Type -> VExpr -> VM VExpr
364 boxExpr ty (vexpr, lexpr)
365   | Just (tycon, []) <- splitTyConApp_maybe ty
366   , isUnLiftedTyCon tycon
367   = do
368       r <- lookupBoxedTyCon tycon
369       case r of
370         Just tycon' -> let [dc] = tyConDataCons tycon'
371                        in
372                        return (mkConApp dc [vexpr], lexpr)
373         Nothing     -> return (vexpr, lexpr)
374 -}
375
376 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
377 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
378   = do
379       dict <- paDictOfType env_ty
380       mkv  <- builtin mkClosureVar
381       mkl  <- builtin mkClosurePVar
382       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
383               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
384
385 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
386 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
387   = do
388       vapply <- builtin applyClosureVar
389       lapply <- builtin applyClosurePVar
390       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
391               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
392
393 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
394 buildClosures _   _    [] _ mk_body
395   = mk_body
396 buildClosures tvs vars [arg_ty] res_ty mk_body
397   = liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body)
398 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
399   = do
400       res_ty' <- mkClosureTypes arg_tys res_ty
401       arg <- newLocalVVar (fsLit "x") arg_ty
402       liftM vInlineMe
403         . buildClosure tvs vars arg_ty res_ty'
404         . hoistPolyVExpr tvs
405         $ do
406             lc <- builtin liftingContext
407             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
408             return $ vLams lc (vars ++ [arg]) clo
409
410 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
411 --   where
412 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
413 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
414 --
415 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
416 buildClosure tvs vars arg_ty res_ty mk_body
417   = do
418       (env_ty, env, bind) <- buildEnv vars
419       env_bndr <- newLocalVVar (fsLit "env") env_ty
420       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
421
422       fn <- hoistPolyVExpr tvs
423           $ do
424               lc    <- builtin liftingContext
425               body  <- mk_body
426               body' <- bind (vVar env_bndr)
427                             (vVarApps lc body (vars ++ [arg_bndr]))
428               return . vInlineMe $ vLamsWithoutLC [env_bndr, arg_bndr] body'
429
430       mkClosure arg_ty res_ty env_ty fn env
431
432 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
433 buildEnv vvs
434   = do
435       lc <- builtin liftingContext
436       let (ty, venv, vbind) = mkVectEnv tys vs
437       (lenv, lbind) <- mkLiftEnv lc tys ls
438       return (ty, (venv, lenv),
439               \(venv,lenv) (vbody,lbody) ->
440               do
441                 let vbody' = vbind venv vbody
442                 lbody' <- lbind lenv lbody
443                 return (vbody', lbody'))
444   where
445     (vs,ls) = unzip vvs
446     tys     = map varType vs
447
448 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
449 mkVectEnv []   []  = (unitTy, Var unitDataConId, \_ body -> body)
450 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
451 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
452                         \env body -> mkWildCase env ty (exprType body)
453                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
454   where
455     ty = mkCoreTupTy tys
456
457 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
458 mkLiftEnv lc [ty] [v]
459   = return (Var v, \env body ->
460                    do
461                      len <- lengthPA ty (Var v)
462                      return . Let (NonRec v env)
463                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
464
465 -- NOTE: this transparently deals with empty environments
466 mkLiftEnv lc tys vs
467   = do
468       (env_tc, env_tyargs) <- parrayReprTyCon vty
469
470       bndrs <- if null vs then do
471                                  v <- newDummyVar unitTy
472                                  return [v]
473                           else return vs
474       let [env_con] = tyConDataCons env_tc
475           
476           env = Var (dataConWrapId env_con)
477                 `mkTyApps`  env_tyargs
478                 `mkApps`    (Var lc : args)
479
480           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
481                           in
482                           return $ mkWildCase scrut (exprType scrut)
483                                         (exprType body)
484                                         [(DataAlt env_con, lc : bndrs, body)]
485       return (env, bind)
486   where
487     vty = mkCoreTupTy tys
488
489     args  | null vs   = [Var unitDataConId]
490           | otherwise = map Var vs
491