Break out hoisting utils into their own module
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import VectUtils
7 import VectType
8 import Vectorise.Utils.Closure
9 import Vectorise.Utils.Hoisting
10 import Vectorise.Var
11 import Vectorise.Vect
12 import Vectorise.Env
13 import Vectorise.Monad
14 import Vectorise.Builtins
15
16 import CoreSyn
17 import CoreUtils
18 import MkCore
19 import CoreFVs
20 import DataCon
21 import TyCon
22 import Type
23 import Var
24 import VarEnv
25 import VarSet
26 import Id
27 import BasicTypes
28 import Literal
29 import TysWiredIn
30 import TysPrim
31 import Outputable
32 import FastString
33 import Control.Monad
34 import Data.List
35
36
37 -- | Vectorise a polymorphic expression.
38 vectPolyExpr 
39         :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
40                                 --   binding is a loop breaker.
41         -> CoreExprWithFVs
42         -> VM (Inline, VExpr)
43
44 vectPolyExpr loop_breaker (_, AnnNote note expr)
45  = do (inline, expr') <- vectPolyExpr loop_breaker expr
46       return (inline, vNote note expr')
47
48 vectPolyExpr loop_breaker expr
49  = do
50       arity <- polyArity tvs
51       polyAbstract tvs $ \args ->
52         do
53           (inline, mono') <- vectFnExpr False loop_breaker mono
54           return (addInlineArity inline arity,
55                   mapVect (mkLams $ tvs ++ args) mono')
56   where
57     (tvs, mono) = collectAnnTypeBinders expr
58
59
60 -- | Vectorise an expression.
61 vectExpr :: CoreExprWithFVs -> VM VExpr
62 vectExpr (_, AnnType ty)
63   = liftM vType (vectType ty)
64
65 vectExpr (_, AnnVar v) 
66   = vectVar v
67
68 vectExpr (_, AnnLit lit) 
69   = vectLiteral lit
70
71 vectExpr (_, AnnNote note expr)
72   = liftM (vNote note) (vectExpr expr)
73
74 vectExpr e@(_, AnnApp _ arg)
75   | isAnnTypeArg arg
76   = vectTyAppExpr fn tys
77   where
78     (fn, tys) = collectAnnTypeArgs e
79
80 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
81   | Just con <- isDataConId_maybe v
82   , is_special_con con
83   = do
84       let vexpr = App (Var v) (Lit lit)
85       lexpr <- liftPD vexpr
86       return (vexpr, lexpr)
87   where
88     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
89
90
91 -- TODO: Avoid using closure application for dictionaries.
92 -- vectExpr (_, AnnApp fn arg)
93 --  | if is application of dictionary 
94 --    just use regular app instead of closure app.
95
96 -- for lifted version. 
97 --      do liftPD (sub a dNumber)
98 --      lift the result of the selection, not sub and dNumber seprately. 
99
100 vectExpr (_, AnnApp fn arg)
101  = do
102       arg_ty' <- vectType arg_ty
103       res_ty' <- vectType res_ty
104
105       fn'     <- vectExpr fn
106       arg'    <- vectExpr arg
107
108       mkClosureApp arg_ty' res_ty' fn' arg'
109   where
110     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
111
112 vectExpr (_, AnnCase scrut bndr ty alts)
113   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
114   , isAlgTyCon tycon
115   = vectAlgCase tycon ty_args scrut bndr ty alts
116   where
117     scrut_ty = exprType (deAnnotate scrut)
118
119 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
120   = do
121       vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
122       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
123       return $ vLet (vNonRec vbndr vrhs) vbody
124
125 vectExpr (_, AnnLet (AnnRec bs) body)
126   = do
127       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
128                                 $ liftM2 (,)
129                                   (zipWithM vect_rhs bndrs rhss)
130                                   (vectExpr body)
131       return $ vLet (vRec vbndrs vrhss) vbody
132   where
133     (bndrs, rhss) = unzip bs
134
135     vect_rhs bndr rhs = localV
136                       . inBind bndr
137                       . liftM snd
138                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
139
140 vectExpr e@(_, AnnLam bndr _)
141   | isId bndr = liftM snd $ vectFnExpr True False e
142 {-
143 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
144                 `orElseV` vectLam True fvs bs body
145   where
146     (bs,body) = collectAnnValBinders e
147 -}
148
149 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
150
151
152 -- | Vectorise an expression with an outer lambda abstraction.
153 vectFnExpr 
154         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
155         -> Bool                 -- ^ Whether the binding is a loop breaker.
156         -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
157         -> VM (Inline, VExpr)
158
159 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
160   | isId bndr = onlyIfV (isEmptyVarSet fvs)
161                         (mark DontInline . vectScalarLam bs $ deAnnotate body)
162                 `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
163   where
164     (bs,body) = collectAnnValBinders e
165
166 vectFnExpr _ _ e = mark DontInline $ vectExpr e
167
168 mark :: Inline -> VM a -> VM (Inline, a)
169 mark b p = do { x <- p; return (b,x) }
170
171
172 -- | Vectorise a function where are the args have scalar type,
173 --   that is Int, Float, Double etc.
174 vectScalarLam 
175         :: [Var]        -- ^ Bound variables of function.
176         -> CoreExpr     -- ^ Function body.
177         -> VM VExpr
178         
179 vectScalarLam args body
180  = do scalars <- globalScalars
181       onlyIfV (all is_scalar_ty arg_tys
182                && is_scalar_ty res_ty
183                && is_scalar (extendVarSetList scalars args) body
184                && uses scalars body)
185         $ do
186             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
187             zipf    <- zipScalars arg_tys res_ty
188             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
189                                                 (zipf `App` Var fn_var)
190             clo_var <- hoistExpr (fsLit "clo") clo DontInline
191             lclo    <- liftPD (Var clo_var)
192             return (Var clo_var, lclo)
193   where
194     arg_tys = map idType args
195     res_ty  = exprType body
196
197     is_scalar_ty ty 
198         | Just (tycon, [])   <- splitTyConApp_maybe ty
199         =    tycon == intTyCon
200           || tycon == floatTyCon
201           || tycon == doubleTyCon
202
203         | otherwise = False
204
205     is_scalar vs (Var v)     = v `elemVarSet` vs
206     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
207     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
208     is_scalar _ _            = False
209
210     -- A scalar function has to actually compute something. Without the check,
211     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
212     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
213     -- (\n# x -> x) which is what we want.
214     uses funs (Var v)     = v `elemVarSet` funs 
215     uses funs (App e1 e2) = uses funs e1 || uses funs e2
216     uses _ _              = False
217
218
219 -- | Vectorise a lambda abstraction.
220 vectLam 
221         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
222         -> Bool                 -- ^ Whether the binding is a loop breaker.
223         -> VarSet               -- ^ The free variables in the body.
224         -> [Var]                -- ^ Binding variables.
225         -> CoreExprWithFVs      -- ^ Body of abstraction.
226         -> VM VExpr
227
228 vectLam inline loop_breaker fvs bs body
229  = do tyvars    <- localTyVars
230       (vs, vvs) <- readLEnv $ \env ->
231                    unzip [(var, vv) | var <- varSetElems fvs
232                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
233
234       arg_tys   <- mapM (vectType . idType) bs
235       res_ty    <- vectType (exprType $ deAnnotate body)
236
237       buildClosures tyvars vvs arg_tys res_ty
238         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
239         $ do
240             lc              <- builtin liftingContext
241             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
242
243             vbody' <- break_loop lc res_ty vbody
244             return $ vLams lc vbndrs vbody'
245   where
246     maybe_inline n | inline    = Inline n
247                    | otherwise = DontInline
248
249     break_loop lc ty (ve, le)
250       | loop_breaker
251       = do
252           empty <- emptyPD ty
253           lty <- mkPDataType ty
254           return (ve, mkWildCase (Var lc) intPrimTy lty
255                         [(DEFAULT, [], le),
256                          (LitAlt (mkMachInt 0), [], empty)])
257
258       | otherwise = return (ve, le)
259  
260
261 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
262 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
263 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
264                         (ppr $ deAnnotate e `mkTyApps` tys)
265
266
267 -- | Vectorise an algebraic case expression.
268 --   We convert
269 --
270 --   case e :: t of v { ... }
271 --
272 -- to
273 --
274 --   V:    let v' = e in case v' of _ { ... }
275 --   L:    let v' = e in case v' `cast` ... of _ { ... }
276 --
277 --   When lifting, we have to do it this way because v must have the type
278 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
279 --   have to handle the case where v is a wild var correctly.
280 --
281
282 -- FIXME: this is too lazy
283 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
284             -> [(AltCon, [Var], CoreExprWithFVs)]
285             -> VM VExpr
286 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
287   = do
288       vscrut         <- vectExpr scrut
289       (vty, lty)     <- vectAndLiftType ty
290       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
291       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
292
293 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
294   = do
295       vscrut         <- vectExpr scrut
296       (vty, lty)     <- vectAndLiftType ty
297       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
298       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
299
300 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
301   = do
302       (vty, lty) <- vectAndLiftType ty
303       vexpr      <- vectExpr scrut
304       (vbndr, (vbndrs, (vect_body, lift_body)))
305          <- vect_scrut_bndr
306           . vectBndrsIn bndrs
307           $ vectExpr body
308       let (vect_bndrs, lift_bndrs) = unzip vbndrs
309       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
310       vect_dc <- maybeV (lookupDataCon dc)
311       let [pdata_dc] = tyConDataCons pdata_tc
312
313       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
314           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
315
316       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
317   where
318     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
319                     | otherwise         = vectBndrIn bndr
320
321     mk_wild_case expr ty dc bndrs body
322       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
323
324 vectAlgCase tycon _ty_args scrut bndr ty alts
325   = do
326       vect_tc     <- maybeV (lookupTyCon tycon)
327       (vty, lty)  <- vectAndLiftType ty
328
329       let arity = length (tyConDataCons vect_tc)
330       sel_ty <- builtin (selTy arity)
331       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
332       let sel = Var sel_bndr
333
334       (vbndr, valts) <- vect_scrut_bndr
335                       $ mapM (proc_alt arity sel vty lty) alts'
336       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
337
338       vexpr <- vectExpr scrut
339       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
340       let [pdata_dc] = tyConDataCons pdata_tc
341
342       let (vect_bodies, lift_bodies) = unzip vbodies
343
344       vdummy <- newDummyVar (exprType vect_scrut)
345       ldummy <- newDummyVar (exprType lift_scrut)
346       let vect_case = Case vect_scrut vdummy vty
347                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
348
349       lc <- builtin liftingContext
350       lbody <- combinePD vty (Var lc) sel lift_bodies
351       let lift_case = Case lift_scrut ldummy lty
352                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
353                              lbody)]
354
355       return . vLet (vNonRec vbndr vexpr)
356              $ (vect_case, lift_case)
357   where
358     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
359                     | otherwise         = vectBndrIn bndr
360
361     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
362
363     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
364     cmp DEFAULT       DEFAULT       = EQ
365     cmp DEFAULT       _             = LT
366     cmp _             DEFAULT       = GT
367     cmp _             _             = panic "vectAlgCase/cmp"
368
369     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
370       = do
371           vect_dc <- maybeV (lookupDataCon dc)
372           let ntag = dataConTagZ vect_dc
373               tag  = mkDataConTag vect_dc
374               fvs  = freeVarsOf body `delVarSetList` bndrs
375
376           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
377           lc        <- builtin liftingContext
378           elems     <- builtin (selElements arity ntag)
379
380           (vbndrs, vbody)
381             <- vectBndrsIn bndrs
382              . localV
383              $ do
384                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
385                            . filter isLocalId
386                            $ varSetElems fvs
387                  (ve, le) <- vectExpr body
388                  return (ve, Case (elems `App` sel) lc lty
389                              [(DEFAULT, [], (mkLets (concat binds) le))])
390                  -- empty    <- emptyPD vty
391                  -- return (ve, Case (elems `App` sel) lc lty
392                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
393                  --                             $ mkLets (concat binds) le),
394                  --               (LitAlt (mkMachInt 0), [], empty)])
395           let (vect_bndrs, lift_bndrs) = unzip vbndrs
396           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
397
398     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
399
400     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
401
402     pack_var len tags t v
403       = do
404           r <- lookupVar v
405           case r of
406             Local (vv, lv) ->
407               do
408                 lv'  <- cloneVar lv
409                 expr <- packByTagPD (idType vv) (Var lv) len tags t
410                 updLEnv (\env -> env { local_vars = extendVarEnv
411                                                 (local_vars env) v (vv, lv') })
412                 return [(NonRec lv' expr)]
413
414             _ -> return []
415