scheduleDoGC: if we're doing heapCensus(), do it *before* releasing
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp (
4
5   -- Vectorise a polymorphic expression
6   vectPolyExpr, 
7   
8   -- Vectorise a scalar expression of functional type
9   vectScalarFun
10 ) where
11
12 #include "HsVersions.h"
13
14 import Vectorise.Type.Type
15 import Vectorise.Var
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19 import Vectorise.Builtins
20 import Vectorise.Utils
21
22 import CoreSyn
23 import CoreUtils
24 import MkCore
25 import CoreFVs
26 import DataCon
27 import TyCon
28 import Type
29 import Var
30 import VarEnv
31 import VarSet
32 import Id
33 import BasicTypes( isLoopBreaker )
34 import Literal
35 import TysWiredIn
36 import TysPrim
37 import Outputable
38 import FastString
39 import Control.Monad
40 import Data.List
41
42
43 -- | Vectorise a polymorphic expression.
44 --
45 vectPolyExpr :: Bool            -- ^ When vectorising the RHS of a binding, whether that
46                                               --   binding is a loop breaker.
47                    -> [Var]                     
48                    -> CoreExprWithFVs
49                    -> VM (Inline, Bool, VExpr)
50 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
51  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
52       return (inline, isScalarFn, vNote note expr')
53 vectPolyExpr loop_breaker recFns expr
54  = do
55       arity <- polyArity tvs
56       polyAbstract tvs $ \args ->
57         do
58           (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
59           return (addInlineArity inline arity, isScalarFn, 
60                   mapVect (mkLams $ tvs ++ args) mono')
61   where
62     (tvs, mono) = collectAnnTypeBinders expr
63
64
65 -- | Vectorise an expression.
66 vectExpr :: CoreExprWithFVs -> VM VExpr
67 vectExpr (_, AnnType ty)
68   = liftM vType (vectType ty)
69
70 vectExpr (_, AnnVar v) 
71   = vectVar v
72
73 vectExpr (_, AnnLit lit) 
74   = vectLiteral lit
75
76 vectExpr (_, AnnNote note expr)
77   = liftM (vNote note) (vectExpr expr)
78
79 vectExpr e@(_, AnnApp _ arg)
80   | isAnnTypeArg arg
81   = vectTyAppExpr fn tys
82   where
83     (fn, tys) = collectAnnTypeArgs e
84
85 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
86   | Just con <- isDataConId_maybe v
87   , is_special_con con
88   = do
89       let vexpr = App (Var v) (Lit lit)
90       lexpr <- liftPD vexpr
91       return (vexpr, lexpr)
92   where
93     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
94
95
96 -- TODO: Avoid using closure application for dictionaries.
97 -- vectExpr (_, AnnApp fn arg)
98 --  | if is application of dictionary 
99 --    just use regular app instead of closure app.
100
101 -- for lifted version. 
102 --      do liftPD (sub a dNumber)
103 --      lift the result of the selection, not sub and dNumber seprately. 
104
105 vectExpr (_, AnnApp fn arg)
106  = do
107       arg_ty' <- vectType arg_ty
108       res_ty' <- vectType res_ty
109
110       fn'     <- vectExpr fn
111       arg'    <- vectExpr arg
112
113       mkClosureApp arg_ty' res_ty' fn' arg'
114   where
115     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
116
117 vectExpr (_, AnnCase scrut bndr ty alts)
118   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
119   , isAlgTyCon tycon
120   = vectAlgCase tycon ty_args scrut bndr ty alts
121   | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
122   where
123     scrut_ty = exprType (deAnnotate scrut)
124
125 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
126   = do
127       vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
128       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
129       return $ vLet (vNonRec vbndr vrhs) vbody
130
131 vectExpr (_, AnnLet (AnnRec bs) body)
132   = do
133       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
134                                 $ liftM2 (,)
135                                   (zipWithM vect_rhs bndrs rhss)
136                                   (vectExpr body)
137       return $ vLet (vRec vbndrs vrhss) vbody
138   where
139     (bndrs, rhss) = unzip bs
140
141     vect_rhs bndr rhs = localV
142                       . inBind bndr
143                       . liftM (\(_,_,z)->z)
144                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
145
146 vectExpr e@(_, AnnLam bndr _)
147   | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
148 {-
149 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
150                 `orElseV` vectLam True fvs bs body
151   where
152     (bs,body) = collectAnnValBinders e
153 -}
154
155 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
156
157 -- | Vectorise an expression with an outer lambda abstraction.
158 --
159 vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
160                                --   be inlined
161            -> Bool             -- ^ Whether the binding is a loop breaker
162            -> [Var]            -- ^ Names of function in same recursive binding group
163            -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
164            -> VM (Inline, Bool, VExpr)
165 vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
166   | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
167                 `orElseV` 
168                 mark inlineMe False (vectLam inline loop_breaker expr)
169 vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
170
171 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
172 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
173
174 -- |Vectorise an expression of functional type, where all arguments and the result are of scalar
175 -- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
176 -- involve parallel arrays.  Such functionals do not requires the full blown vectorisation
177 -- transformation; instead, they can be lifted by application of a member of the zipWith family
178 -- (i.e., 'map', 'zipWith', zipWith3', etc.)
179 --
180 vectScalarFun :: Bool       -- ^ Was the function marked as scalar by the user?
181               -> [Var]      -- ^ Functions names in same recursive binding group
182               -> CoreExpr   -- ^ Expression to be vectorised
183               -> VM VExpr
184 vectScalarFun forceScalar recFns expr
185  = do { gscalars <- globalScalars
186       ; let scalars = gscalars `extendVarSetList` recFns
187             (arg_tys, res_ty) = splitFunTys (exprType expr)
188       ; MASSERT( not $ null arg_tys )
189       ; onlyIfV (forceScalar                    -- user asserts the functions is scalar
190                  ||
191                  all is_prim_ty arg_tys         -- check whether the function is scalar
192                   && is_prim_ty res_ty
193                   && is_scalar scalars expr
194                   && uses scalars expr)
195         $ mkScalarFun arg_tys res_ty expr
196       }
197   where
198     -- FIXME: This is woefully insufficient!!!  We need a scalar pragma for types!!!
199     is_prim_ty ty 
200         | Just (tycon, [])   <- splitTyConApp_maybe ty
201         =    tycon == intTyCon
202           || tycon == floatTyCon
203           || tycon == doubleTyCon
204         | otherwise = False
205
206     -- Checks whether an expression contain a non-scalar subexpression. 
207     --
208     -- Precodition: The variables in the first argument are scalar.
209     --
210     -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
211     -- them to the list of scalar variables) and then check them.  If one of them turns out not to
212     -- be scalar, the entire group is regarded as not being scalar.
213     --
214     -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
215     --        data constructor as scalar.  Should be changed once scalar types are passed
216     --        through VectInfo.
217     --
218     is_scalar :: VarSet -> CoreExpr -> Bool
219     is_scalar scalars  (Var v)         = v `elemVarSet` scalars
220     is_scalar _scalars (Lit _)         = True
221     is_scalar scalars  e@(App e1 e2) 
222       | maybe_parr_ty  (exprType e)    = False
223       | otherwise                      = is_scalar scalars e1 && is_scalar scalars e2
224     is_scalar scalars  (Lam var body)  
225       | maybe_parr_ty  (varType var)   = False
226       | otherwise                      = is_scalar (scalars `extendVarSet` var) body
227     is_scalar scalars  (Let bind body) = bindsAreScalar && is_scalar scalars' body
228       where
229         (bindsAreScalar, scalars') = is_scalar_bind scalars bind
230     is_scalar scalars  (Case e var ty alts)
231       | is_prim_ty ty                  = is_scalar scalars' e && all (is_scalar_alt scalars') alts
232       | otherwise                      = False
233       where
234         scalars' = scalars `extendVarSet` var
235     is_scalar scalars  (Cast e _coe)   = is_scalar scalars e
236     is_scalar scalars  (Note _ e   )   = is_scalar scalars e
237     is_scalar _scalars (Type _)        = True
238
239     -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
240     is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
241     is_scalar_bind scalars (Rec bnds)     = (all (is_scalar scalars') es, scalars')
242       where
243         (vars, es) = unzip bnds
244         scalars'   = scalars `extendVarSetList` vars
245
246     is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
247
248     -- Checks whether the type might be a parallel array type.  In particular, if the outermost
249     -- constructor is a type family, we conservatively assume that it may be a parallel array type.
250     maybe_parr_ty :: Type -> Bool
251     maybe_parr_ty ty 
252       | Just ty'        <- coreView ty            = maybe_parr_ty ty'
253       | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon 
254     maybe_parr_ty _                               = False
255
256     -- FIXME: I'm not convinced that this reasoning is (always) sound.  If the identify functions
257     --        is called by some other function that is otherwise scalar, it would be very bad
258     --        that just this call to the identity makes it not be scalar.
259     -- A scalar function has to actually compute something. Without the check,
260     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
261     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
262     -- (\n# x -> x) which is what we want.
263     uses funs (Var v)       = v `elemVarSet` funs 
264     uses funs (App e1 e2)   = uses funs e1 || uses funs e2
265     uses funs (Lam b body)  = uses (funs `extendVarSet` b) body
266     uses funs (Let (NonRec _b letExpr) body) 
267                             = uses funs letExpr || uses funs  body
268     uses funs (Case e _eId _ty alts) 
269                             = uses funs e || any (uses_alt funs) alts
270     uses _ _                = False
271
272     uses_alt funs (_, _bs, e) = uses funs e 
273
274 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
275 mkScalarFun arg_tys res_ty expr
276   = do { fn_var  <- hoistExpr (fsLit "fn") expr DontInline
277        ; zipf    <- zipScalars arg_tys res_ty
278        ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
279        ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
280        ; lclo    <- liftPD (Var clo_var)
281        ; return (Var clo_var, lclo)
282        }
283
284 -- | Vectorise a lambda abstraction.
285 --
286 vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
287         -> Bool             -- ^ Whether the binding is a loop breaker.
288         -> CoreExprWithFVs  -- ^ Body of abstraction.
289         -> VM VExpr
290 vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
291  = do let (bs, body) = collectAnnValBinders expr
292  
293       tyvars    <- localTyVars
294       (vs, vvs) <- readLEnv $ \env ->
295                    unzip [(var, vv) | var <- varSetElems fvs
296                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
297
298       arg_tys   <- mapM (vectType . idType) bs
299       res_ty    <- vectType (exprType $ deAnnotate body)
300
301       buildClosures tyvars vvs arg_tys res_ty
302         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
303         $ do
304             lc              <- builtin liftingContext
305             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
306
307             vbody' <- break_loop lc res_ty vbody
308             return $ vLams lc vbndrs vbody'
309   where
310     maybe_inline n | inline    = Inline n
311                    | otherwise = DontInline
312
313     break_loop lc ty (ve, le)
314       | loop_breaker
315       = do
316           empty <- emptyPD ty
317           lty <- mkPDataType ty
318           return (ve, mkWildCase (Var lc) intPrimTy lty
319                         [(DEFAULT, [], le),
320                          (LitAlt (mkMachInt 0), [], empty)])
321
322       | otherwise = return (ve, le)
323 vectLam _ _ _ = panic "vectLam"
324  
325
326 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
327 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
328 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
329                         (ppr $ deAnnotate e `mkTyApps` tys)
330
331
332 -- | Vectorise an algebraic case expression.
333 --   We convert
334 --
335 --   case e :: t of v { ... }
336 --
337 -- to
338 --
339 --   V:    let v' = e in case v' of _ { ... }
340 --   L:    let v' = e in case v' `cast` ... of _ { ... }
341 --
342 --   When lifting, we have to do it this way because v must have the type
343 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
344 --   have to handle the case where v is a wild var correctly.
345 --
346
347 -- FIXME: this is too lazy
348 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
349             -> [(AltCon, [Var], CoreExprWithFVs)]
350             -> VM VExpr
351 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
352   = do
353       vscrut         <- vectExpr scrut
354       (vty, lty)     <- vectAndLiftType ty
355       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
356       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
357
358 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
359   = do
360       vscrut         <- vectExpr scrut
361       (vty, lty)     <- vectAndLiftType ty
362       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
363       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
364
365 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
366   = do
367       (vty, lty) <- vectAndLiftType ty
368       vexpr      <- vectExpr scrut
369       (vbndr, (vbndrs, (vect_body, lift_body)))
370          <- vect_scrut_bndr
371           . vectBndrsIn bndrs
372           $ vectExpr body
373       let (vect_bndrs, lift_bndrs) = unzip vbndrs
374       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
375       vect_dc <- maybeV (lookupDataCon dc)
376       let [pdata_dc] = tyConDataCons pdata_tc
377
378       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
379           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
380
381       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
382   where
383     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
384                     | otherwise         = vectBndrIn bndr
385
386     mk_wild_case expr ty dc bndrs body
387       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
388
389 vectAlgCase tycon _ty_args scrut bndr ty alts
390   = do
391       vect_tc     <- maybeV (lookupTyCon tycon)
392       (vty, lty)  <- vectAndLiftType ty
393
394       let arity = length (tyConDataCons vect_tc)
395       sel_ty <- builtin (selTy arity)
396       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
397       let sel = Var sel_bndr
398
399       (vbndr, valts) <- vect_scrut_bndr
400                       $ mapM (proc_alt arity sel vty lty) alts'
401       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
402
403       vexpr <- vectExpr scrut
404       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
405       let [pdata_dc] = tyConDataCons pdata_tc
406
407       let (vect_bodies, lift_bodies) = unzip vbodies
408
409       vdummy <- newDummyVar (exprType vect_scrut)
410       ldummy <- newDummyVar (exprType lift_scrut)
411       let vect_case = Case vect_scrut vdummy vty
412                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
413
414       lc <- builtin liftingContext
415       lbody <- combinePD vty (Var lc) sel lift_bodies
416       let lift_case = Case lift_scrut ldummy lty
417                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
418                              lbody)]
419
420       return . vLet (vNonRec vbndr vexpr)
421              $ (vect_case, lift_case)
422   where
423     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
424                     | otherwise         = vectBndrIn bndr
425
426     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
427
428     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
429     cmp DEFAULT       DEFAULT       = EQ
430     cmp DEFAULT       _             = LT
431     cmp _             DEFAULT       = GT
432     cmp _             _             = panic "vectAlgCase/cmp"
433
434     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
435       = do
436           vect_dc <- maybeV (lookupDataCon dc)
437           let ntag = dataConTagZ vect_dc
438               tag  = mkDataConTag vect_dc
439               fvs  = freeVarsOf body `delVarSetList` bndrs
440
441           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
442           lc        <- builtin liftingContext
443           elems     <- builtin (selElements arity ntag)
444
445           (vbndrs, vbody)
446             <- vectBndrsIn bndrs
447              . localV
448              $ do
449                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
450                            . filter isLocalId
451                            $ varSetElems fvs
452                  (ve, le) <- vectExpr body
453                  return (ve, Case (elems `App` sel) lc lty
454                              [(DEFAULT, [], (mkLets (concat binds) le))])
455                  -- empty    <- emptyPD vty
456                  -- return (ve, Case (elems `App` sel) lc lty
457                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
458                  --                             $ mkLets (concat binds) le),
459                  --               (LitAlt (mkMachInt 0), [], empty)])
460           let (vect_bndrs, lift_bndrs) = unzip vbndrs
461           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
462
463     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
464
465     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
466
467     pack_var len tags t v
468       = do
469           r <- lookupVar v
470           case r of
471             Local (vv, lv) ->
472               do
473                 lv'  <- cloneVar lv
474                 expr <- packByTagPD (idType vv) (Var lv) len tags t
475                 updLEnv (\env -> env { local_vars = extendVarEnv
476                                                 (local_vars env) v (vv, lv') })
477                 return [(NonRec lv' expr)]
478
479             _ -> return []
480