Break out closure utils into own module
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import VectUtils
7 import VectType
8 import Vectorise.Utils.Closure
9 import Vectorise.Var
10 import Vectorise.Vect
11 import Vectorise.Env
12 import Vectorise.Monad
13 import Vectorise.Builtins
14
15 import CoreSyn
16 import CoreUtils
17 import MkCore
18 import CoreFVs
19 import DataCon
20 import TyCon
21 import Type
22 import Var
23 import VarEnv
24 import VarSet
25 import Id
26 import BasicTypes
27 import Literal
28 import TysWiredIn
29 import TysPrim
30 import Outputable
31 import FastString
32 import Control.Monad
33 import Data.List
34
35
36 -- | Vectorise a polymorphic expression.
37 vectPolyExpr 
38         :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
39                                 --   binding is a loop breaker.
40         -> CoreExprWithFVs
41         -> VM (Inline, VExpr)
42
43 vectPolyExpr loop_breaker (_, AnnNote note expr)
44  = do (inline, expr') <- vectPolyExpr loop_breaker expr
45       return (inline, vNote note expr')
46
47 vectPolyExpr loop_breaker expr
48  = do
49       arity <- polyArity tvs
50       polyAbstract tvs $ \args ->
51         do
52           (inline, mono') <- vectFnExpr False loop_breaker mono
53           return (addInlineArity inline arity,
54                   mapVect (mkLams $ tvs ++ args) mono')
55   where
56     (tvs, mono) = collectAnnTypeBinders expr
57
58
59 -- | Vectorise an expression.
60 vectExpr :: CoreExprWithFVs -> VM VExpr
61 vectExpr (_, AnnType ty)
62   = liftM vType (vectType ty)
63
64 vectExpr (_, AnnVar v) 
65   = vectVar v
66
67 vectExpr (_, AnnLit lit) 
68   = vectLiteral lit
69
70 vectExpr (_, AnnNote note expr)
71   = liftM (vNote note) (vectExpr expr)
72
73 vectExpr e@(_, AnnApp _ arg)
74   | isAnnTypeArg arg
75   = vectTyAppExpr fn tys
76   where
77     (fn, tys) = collectAnnTypeArgs e
78
79 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
80   | Just con <- isDataConId_maybe v
81   , is_special_con con
82   = do
83       let vexpr = App (Var v) (Lit lit)
84       lexpr <- liftPD vexpr
85       return (vexpr, lexpr)
86   where
87     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
88
89
90 -- TODO: Avoid using closure application for dictionaries.
91 -- vectExpr (_, AnnApp fn arg)
92 --  | if is application of dictionary 
93 --    just use regular app instead of closure app.
94
95 -- for lifted version. 
96 --      do liftPD (sub a dNumber)
97 --      lift the result of the selection, not sub and dNumber seprately. 
98
99 vectExpr (_, AnnApp fn arg)
100  = do
101       arg_ty' <- vectType arg_ty
102       res_ty' <- vectType res_ty
103
104       fn'     <- vectExpr fn
105       arg'    <- vectExpr arg
106
107       mkClosureApp arg_ty' res_ty' fn' arg'
108   where
109     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
110
111 vectExpr (_, AnnCase scrut bndr ty alts)
112   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
113   , isAlgTyCon tycon
114   = vectAlgCase tycon ty_args scrut bndr ty alts
115   where
116     scrut_ty = exprType (deAnnotate scrut)
117
118 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
119   = do
120       vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
121       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
122       return $ vLet (vNonRec vbndr vrhs) vbody
123
124 vectExpr (_, AnnLet (AnnRec bs) body)
125   = do
126       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
127                                 $ liftM2 (,)
128                                   (zipWithM vect_rhs bndrs rhss)
129                                   (vectExpr body)
130       return $ vLet (vRec vbndrs vrhss) vbody
131   where
132     (bndrs, rhss) = unzip bs
133
134     vect_rhs bndr rhs = localV
135                       . inBind bndr
136                       . liftM snd
137                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
138
139 vectExpr e@(_, AnnLam bndr _)
140   | isId bndr = liftM snd $ vectFnExpr True False e
141 {-
142 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
143                 `orElseV` vectLam True fvs bs body
144   where
145     (bs,body) = collectAnnValBinders e
146 -}
147
148 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
149
150
151 -- | Vectorise an expression with an outer lambda abstraction.
152 vectFnExpr 
153         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
154         -> Bool                 -- ^ Whether the binding is a loop breaker.
155         -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
156         -> VM (Inline, VExpr)
157
158 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
159   | isId bndr = onlyIfV (isEmptyVarSet fvs)
160                         (mark DontInline . vectScalarLam bs $ deAnnotate body)
161                 `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
162   where
163     (bs,body) = collectAnnValBinders e
164
165 vectFnExpr _ _ e = mark DontInline $ vectExpr e
166
167 mark :: Inline -> VM a -> VM (Inline, a)
168 mark b p = do { x <- p; return (b,x) }
169
170
171 -- | Vectorise a function where are the args have scalar type,
172 --   that is Int, Float, Double etc.
173 vectScalarLam 
174         :: [Var]        -- ^ Bound variables of function.
175         -> CoreExpr     -- ^ Function body.
176         -> VM VExpr
177         
178 vectScalarLam args body
179  = do scalars <- globalScalars
180       onlyIfV (all is_scalar_ty arg_tys
181                && is_scalar_ty res_ty
182                && is_scalar (extendVarSetList scalars args) body
183                && uses scalars body)
184         $ do
185             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
186             zipf    <- zipScalars arg_tys res_ty
187             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
188                                                 (zipf `App` Var fn_var)
189             clo_var <- hoistExpr (fsLit "clo") clo DontInline
190             lclo    <- liftPD (Var clo_var)
191             return (Var clo_var, lclo)
192   where
193     arg_tys = map idType args
194     res_ty  = exprType body
195
196     is_scalar_ty ty 
197         | Just (tycon, [])   <- splitTyConApp_maybe ty
198         =    tycon == intTyCon
199           || tycon == floatTyCon
200           || tycon == doubleTyCon
201
202         | otherwise = False
203
204     is_scalar vs (Var v)     = v `elemVarSet` vs
205     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
206     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
207     is_scalar _ _            = False
208
209     -- A scalar function has to actually compute something. Without the check,
210     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
211     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
212     -- (\n# x -> x) which is what we want.
213     uses funs (Var v)     = v `elemVarSet` funs 
214     uses funs (App e1 e2) = uses funs e1 || uses funs e2
215     uses _ _              = False
216
217
218 -- | Vectorise a lambda abstraction.
219 vectLam 
220         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
221         -> Bool                 -- ^ Whether the binding is a loop breaker.
222         -> VarSet               -- ^ The free variables in the body.
223         -> [Var]                -- ^ Binding variables.
224         -> CoreExprWithFVs      -- ^ Body of abstraction.
225         -> VM VExpr
226
227 vectLam inline loop_breaker fvs bs body
228  = do tyvars    <- localTyVars
229       (vs, vvs) <- readLEnv $ \env ->
230                    unzip [(var, vv) | var <- varSetElems fvs
231                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
232
233       arg_tys   <- mapM (vectType . idType) bs
234       res_ty    <- vectType (exprType $ deAnnotate body)
235
236       buildClosures tyvars vvs arg_tys res_ty
237         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
238         $ do
239             lc              <- builtin liftingContext
240             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
241
242             vbody' <- break_loop lc res_ty vbody
243             return $ vLams lc vbndrs vbody'
244   where
245     maybe_inline n | inline    = Inline n
246                    | otherwise = DontInline
247
248     break_loop lc ty (ve, le)
249       | loop_breaker
250       = do
251           empty <- emptyPD ty
252           lty <- mkPDataType ty
253           return (ve, mkWildCase (Var lc) intPrimTy lty
254                         [(DEFAULT, [], le),
255                          (LitAlt (mkMachInt 0), [], empty)])
256
257       | otherwise = return (ve, le)
258  
259
260 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
261 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
262 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
263                         (ppr $ deAnnotate e `mkTyApps` tys)
264
265
266 -- | Vectorise an algebraic case expression.
267 --   We convert
268 --
269 --   case e :: t of v { ... }
270 --
271 -- to
272 --
273 --   V:    let v' = e in case v' of _ { ... }
274 --   L:    let v' = e in case v' `cast` ... of _ { ... }
275 --
276 --   When lifting, we have to do it this way because v must have the type
277 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
278 --   have to handle the case where v is a wild var correctly.
279 --
280
281 -- FIXME: this is too lazy
282 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
283             -> [(AltCon, [Var], CoreExprWithFVs)]
284             -> VM VExpr
285 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
286   = do
287       vscrut         <- vectExpr scrut
288       (vty, lty)     <- vectAndLiftType ty
289       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
290       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
291
292 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
293   = do
294       vscrut         <- vectExpr scrut
295       (vty, lty)     <- vectAndLiftType ty
296       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
297       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
298
299 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
300   = do
301       (vty, lty) <- vectAndLiftType ty
302       vexpr      <- vectExpr scrut
303       (vbndr, (vbndrs, (vect_body, lift_body)))
304          <- vect_scrut_bndr
305           . vectBndrsIn bndrs
306           $ vectExpr body
307       let (vect_bndrs, lift_bndrs) = unzip vbndrs
308       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
309       vect_dc <- maybeV (lookupDataCon dc)
310       let [pdata_dc] = tyConDataCons pdata_tc
311
312       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
313           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
314
315       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
316   where
317     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
318                     | otherwise         = vectBndrIn bndr
319
320     mk_wild_case expr ty dc bndrs body
321       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
322
323 vectAlgCase tycon _ty_args scrut bndr ty alts
324   = do
325       vect_tc     <- maybeV (lookupTyCon tycon)
326       (vty, lty)  <- vectAndLiftType ty
327
328       let arity = length (tyConDataCons vect_tc)
329       sel_ty <- builtin (selTy arity)
330       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
331       let sel = Var sel_bndr
332
333       (vbndr, valts) <- vect_scrut_bndr
334                       $ mapM (proc_alt arity sel vty lty) alts'
335       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
336
337       vexpr <- vectExpr scrut
338       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
339       let [pdata_dc] = tyConDataCons pdata_tc
340
341       let (vect_bodies, lift_bodies) = unzip vbodies
342
343       vdummy <- newDummyVar (exprType vect_scrut)
344       ldummy <- newDummyVar (exprType lift_scrut)
345       let vect_case = Case vect_scrut vdummy vty
346                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
347
348       lc <- builtin liftingContext
349       lbody <- combinePD vty (Var lc) sel lift_bodies
350       let lift_case = Case lift_scrut ldummy lty
351                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
352                              lbody)]
353
354       return . vLet (vNonRec vbndr vexpr)
355              $ (vect_case, lift_case)
356   where
357     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
358                     | otherwise         = vectBndrIn bndr
359
360     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
361
362     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
363     cmp DEFAULT       DEFAULT       = EQ
364     cmp DEFAULT       _             = LT
365     cmp _             DEFAULT       = GT
366     cmp _             _             = panic "vectAlgCase/cmp"
367
368     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
369       = do
370           vect_dc <- maybeV (lookupDataCon dc)
371           let ntag = dataConTagZ vect_dc
372               tag  = mkDataConTag vect_dc
373               fvs  = freeVarsOf body `delVarSetList` bndrs
374
375           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
376           lc        <- builtin liftingContext
377           elems     <- builtin (selElements arity ntag)
378
379           (vbndrs, vbody)
380             <- vectBndrsIn bndrs
381              . localV
382              $ do
383                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
384                            . filter isLocalId
385                            $ varSetElems fvs
386                  (ve, le) <- vectExpr body
387                  return (ve, Case (elems `App` sel) lc lty
388                              [(DEFAULT, [], (mkLets (concat binds) le))])
389                  -- empty    <- emptyPD vty
390                  -- return (ve, Case (elems `App` sel) lc lty
391                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
392                  --                             $ mkLets (concat binds) le),
393                  --               (LitAlt (mkMachInt 0), [], empty)])
394           let (vect_bndrs, lift_bndrs) = unzip vbndrs
395           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
396
397     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
398
399     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
400
401     pack_var len tags t v
402       = do
403           r <- lookupVar v
404           case r of
405             Local (vv, lv) ->
406               do
407                 lv'  <- cloneVar lv
408                 expr <- packByTagPD (idType vv) (Var lv) len tags t
409                 updLEnv (\env -> env { local_vars = extendVarEnv
410                                                 (local_vars env) v (vv, lv') })
411                 return [(NonRec lv' expr)]
412
413             _ -> return []
414