vectScalarLam handles int, float, and double now
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import Vectorise.Utils
7 import Vectorise.Type.Type
8 import Vectorise.Var
9 import Vectorise.Vect
10 import Vectorise.Env
11 import Vectorise.Monad
12 import Vectorise.Builtins
13
14 import CoreSyn
15 import CoreUtils
16 import MkCore
17 import CoreFVs
18 import DataCon
19 import TyCon
20 import Type
21 import Var
22 import VarEnv
23 import VarSet
24 import Id
25 import BasicTypes( isLoopBreaker )
26 import Literal
27 import TysWiredIn
28 import TysPrim
29 import Outputable
30 import FastString
31 import Control.Monad
32 import Data.List
33
34
35 -- | Vectorise a polymorphic expression.
36 vectPolyExpr 
37         :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
38                                 --   binding is a loop breaker.
39         -> CoreExprWithFVs
40         -> VM (Inline, VExpr)
41
42 vectPolyExpr loop_breaker (_, AnnNote note expr)
43  = do (inline, expr') <- vectPolyExpr loop_breaker expr
44       return (inline, vNote note expr')
45
46 vectPolyExpr loop_breaker expr
47  = do
48       arity <- polyArity tvs
49       polyAbstract tvs $ \args ->
50         do
51           (inline, mono') <- vectFnExpr False loop_breaker mono
52           return (addInlineArity inline arity,
53                   mapVect (mkLams $ tvs ++ args) mono')
54   where
55     (tvs, mono) = collectAnnTypeBinders expr
56
57
58 -- | Vectorise an expression.
59 vectExpr :: CoreExprWithFVs -> VM VExpr
60 vectExpr (_, AnnType ty)
61   = liftM vType (vectType ty)
62
63 vectExpr (_, AnnVar v) 
64   = vectVar v
65
66 vectExpr (_, AnnLit lit) 
67   = vectLiteral lit
68
69 vectExpr (_, AnnNote note expr)
70   = liftM (vNote note) (vectExpr expr)
71
72 vectExpr e@(_, AnnApp _ arg)
73   | isAnnTypeArg arg
74   = vectTyAppExpr fn tys
75   where
76     (fn, tys) = collectAnnTypeArgs e
77
78 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
79   | Just con <- isDataConId_maybe v
80   , is_special_con con
81   = do
82       let vexpr = App (Var v) (Lit lit)
83       lexpr <- liftPD vexpr
84       return (vexpr, lexpr)
85   where
86     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
87
88
89 -- TODO: Avoid using closure application for dictionaries.
90 -- vectExpr (_, AnnApp fn arg)
91 --  | if is application of dictionary 
92 --    just use regular app instead of closure app.
93
94 -- for lifted version. 
95 --      do liftPD (sub a dNumber)
96 --      lift the result of the selection, not sub and dNumber seprately. 
97
98 vectExpr (_, AnnApp fn arg)
99  = do
100       arg_ty' <- vectType arg_ty
101       res_ty' <- vectType res_ty
102
103       fn'     <- vectExpr fn
104       arg'    <- vectExpr arg
105
106       mkClosureApp arg_ty' res_ty' fn' arg'
107   where
108     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
109
110 vectExpr (_, AnnCase scrut bndr ty alts)
111   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
112   , isAlgTyCon tycon
113   = vectAlgCase tycon ty_args scrut bndr ty alts
114   where
115     scrut_ty = exprType (deAnnotate scrut)
116
117 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
118   = do
119       vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
120       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
121       return $ vLet (vNonRec vbndr vrhs) vbody
122
123 vectExpr (_, AnnLet (AnnRec bs) body)
124   = do
125       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
126                                 $ liftM2 (,)
127                                   (zipWithM vect_rhs bndrs rhss)
128                                   (vectExpr body)
129       return $ vLet (vRec vbndrs vrhss) vbody
130   where
131     (bndrs, rhss) = unzip bs
132
133     vect_rhs bndr rhs = localV
134                       . inBind bndr
135                       . liftM snd
136                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
137
138 vectExpr e@(_, AnnLam bndr _)
139   | isId bndr = liftM snd $ vectFnExpr True False e
140 {-
141 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
142                 `orElseV` vectLam True fvs bs body
143   where
144     (bs,body) = collectAnnValBinders e
145 -}
146
147 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
148
149
150 -- | Vectorise an expression with an outer lambda abstraction.
151 vectFnExpr 
152         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
153         -> Bool                 -- ^ Whether the binding is a loop breaker.
154         -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
155         -> VM (Inline, VExpr)
156
157 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
158   | isId bndr = onlyIfV (isEmptyVarSet fvs)
159                         (mark DontInline . vectScalarLam bs $ deAnnotate body)
160                 `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
161   where
162     (bs,body) = collectAnnValBinders e
163
164 vectFnExpr _ _ e = mark DontInline $ vectExpr e
165
166 mark :: Inline -> VM a -> VM (Inline, a)
167 mark b p = do { x <- p; return (b,x) }
168
169
170 -- | Vectorise a function where are the args have scalar type,
171 --   that is Int, Float, Double etc.
172 vectScalarLam 
173         :: [Var]        -- ^ Bound variables of function.
174         -> CoreExpr     -- ^ Function body.
175         -> VM VExpr
176         
177 vectScalarLam args body
178  = do scalars <- globalScalars
179       onlyIfV (all is_scalar_ty arg_tys
180                && is_scalar_ty res_ty
181                && is_scalar (extendVarSetList scalars args) body
182                && uses scalars body)
183         $ do
184             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
185             zipf    <- zipScalars arg_tys res_ty
186             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
187                                                 (zipf `App` Var fn_var)
188             clo_var <- hoistExpr (fsLit "clo") clo DontInline
189             lclo    <- liftPD (Var clo_var)
190             return (Var clo_var, lclo)
191   where
192     arg_tys = map idType args
193     res_ty  = exprType body
194
195     is_scalar_ty ty 
196         | Just (tycon, [])   <- splitTyConApp_maybe ty
197         =    tycon == intTyCon
198           || tycon == floatTyCon
199           || tycon == doubleTyCon
200           || tycon == boolTyCon
201
202         | otherwise = False
203
204     is_scalar vs (Var v)     = v `elemVarSet` vs
205     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
206     
207     is_scalar _ (App (Var v) (Lit lit)) 
208        | Just con <- isDataConId_maybe v = con `elem` [intDataCon, floatDataCon, doubleDataCon]
209
210     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2    
211     is_scalar vs (Let (NonRec b letExpr) body) 
212                              = is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
213     is_scalar vs (Let (Rec bnds) body) 
214                              =  let vs' = extendVarSetList vs (map fst bnds)
215                                 in all (is_scalar vs') (map snd bnds) && is_scalar vs' body
216     is_scalar vs (Case e eId ty alts)  
217                              = let vs' = extendVarSet vs eId
218                                    in is_scalar_ty ty &&
219                                   is_scalar vs' e   &&
220                                   (all (is_scalar_alt vs') alts)
221                                     
222     is_scalar _ e            = False
223
224     is_scalar_alt vs (_, bs, e) 
225                              = is_scalar (extendVarSetList vs bs) e
226
227     -- A scalar function has to actually compute something. Without the check,
228     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
229     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
230     -- (\n# x -> x) which is what we want.
231     uses funs (Var v)     = v `elemVarSet` funs 
232     uses funs (App e1 e2) = uses funs e1 || uses funs e2
233     uses funs (Let (NonRec b letExpr) body) 
234                           = uses funs letExpr || uses funs  body
235     uses funs (Case e eId ty alts) 
236                           = uses funs e || any (uses_alt funs) alts
237     uses _ _              = False
238
239     uses_alt funs (_, bs, e)   
240                           = uses funs e 
241
242 -- | Vectorise a lambda abstraction.
243 vectLam 
244         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
245         -> Bool                 -- ^ Whether the binding is a loop breaker.
246         -> VarSet               -- ^ The free variables in the body.
247         -> [Var]                -- ^ Binding variables.
248         -> CoreExprWithFVs      -- ^ Body of abstraction.
249         -> VM VExpr
250
251 vectLam inline loop_breaker fvs bs body
252  = do tyvars    <- localTyVars
253       (vs, vvs) <- readLEnv $ \env ->
254                    unzip [(var, vv) | var <- varSetElems fvs
255                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
256
257       arg_tys   <- mapM (vectType . idType) bs
258       res_ty    <- vectType (exprType $ deAnnotate body)
259
260       buildClosures tyvars vvs arg_tys res_ty
261         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
262         $ do
263             lc              <- builtin liftingContext
264             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
265
266             vbody' <- break_loop lc res_ty vbody
267             return $ vLams lc vbndrs vbody'
268   where
269     maybe_inline n | inline    = Inline n
270                    | otherwise = DontInline
271
272     break_loop lc ty (ve, le)
273       | loop_breaker
274       = do
275           empty <- emptyPD ty
276           lty <- mkPDataType ty
277           return (ve, mkWildCase (Var lc) intPrimTy lty
278                         [(DEFAULT, [], le),
279                          (LitAlt (mkMachInt 0), [], empty)])
280
281       | otherwise = return (ve, le)
282  
283
284 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
285 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
286 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
287                         (ppr $ deAnnotate e `mkTyApps` tys)
288
289
290 -- | Vectorise an algebraic case expression.
291 --   We convert
292 --
293 --   case e :: t of v { ... }
294 --
295 -- to
296 --
297 --   V:    let v' = e in case v' of _ { ... }
298 --   L:    let v' = e in case v' `cast` ... of _ { ... }
299 --
300 --   When lifting, we have to do it this way because v must have the type
301 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
302 --   have to handle the case where v is a wild var correctly.
303 --
304
305 -- FIXME: this is too lazy
306 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
307             -> [(AltCon, [Var], CoreExprWithFVs)]
308             -> VM VExpr
309 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
310   = do
311       vscrut         <- vectExpr scrut
312       (vty, lty)     <- vectAndLiftType ty
313       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
314       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
315
316 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
317   = do
318       vscrut         <- vectExpr scrut
319       (vty, lty)     <- vectAndLiftType ty
320       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
321       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
322
323 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
324   = do
325       (vty, lty) <- vectAndLiftType ty
326       vexpr      <- vectExpr scrut
327       (vbndr, (vbndrs, (vect_body, lift_body)))
328          <- vect_scrut_bndr
329           . vectBndrsIn bndrs
330           $ vectExpr body
331       let (vect_bndrs, lift_bndrs) = unzip vbndrs
332       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
333       vect_dc <- maybeV (lookupDataCon dc)
334       let [pdata_dc] = tyConDataCons pdata_tc
335
336       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
337           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
338
339       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
340   where
341     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
342                     | otherwise         = vectBndrIn bndr
343
344     mk_wild_case expr ty dc bndrs body
345       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
346
347 vectAlgCase tycon _ty_args scrut bndr ty alts
348   = do
349       vect_tc     <- maybeV (lookupTyCon tycon)
350       (vty, lty)  <- vectAndLiftType ty
351
352       let arity = length (tyConDataCons vect_tc)
353       sel_ty <- builtin (selTy arity)
354       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
355       let sel = Var sel_bndr
356
357       (vbndr, valts) <- vect_scrut_bndr
358                       $ mapM (proc_alt arity sel vty lty) alts'
359       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
360
361       vexpr <- vectExpr scrut
362       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
363       let [pdata_dc] = tyConDataCons pdata_tc
364
365       let (vect_bodies, lift_bodies) = unzip vbodies
366
367       vdummy <- newDummyVar (exprType vect_scrut)
368       ldummy <- newDummyVar (exprType lift_scrut)
369       let vect_case = Case vect_scrut vdummy vty
370                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
371
372       lc <- builtin liftingContext
373       lbody <- combinePD vty (Var lc) sel lift_bodies
374       let lift_case = Case lift_scrut ldummy lty
375                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
376                              lbody)]
377
378       return . vLet (vNonRec vbndr vexpr)
379              $ (vect_case, lift_case)
380   where
381     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
382                     | otherwise         = vectBndrIn bndr
383
384     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
385
386     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
387     cmp DEFAULT       DEFAULT       = EQ
388     cmp DEFAULT       _             = LT
389     cmp _             DEFAULT       = GT
390     cmp _             _             = panic "vectAlgCase/cmp"
391
392     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
393       = do
394           vect_dc <- maybeV (lookupDataCon dc)
395           let ntag = dataConTagZ vect_dc
396               tag  = mkDataConTag vect_dc
397               fvs  = freeVarsOf body `delVarSetList` bndrs
398
399           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
400           lc        <- builtin liftingContext
401           elems     <- builtin (selElements arity ntag)
402
403           (vbndrs, vbody)
404             <- vectBndrsIn bndrs
405              . localV
406              $ do
407                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
408                            . filter isLocalId
409                            $ varSetElems fvs
410                  (ve, le) <- vectExpr body
411                  return (ve, Case (elems `App` sel) lc lty
412                              [(DEFAULT, [], (mkLets (concat binds) le))])
413                  -- empty    <- emptyPD vty
414                  -- return (ve, Case (elems `App` sel) lc lty
415                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
416                  --                             $ mkLets (concat binds) le),
417                  --               (LitAlt (mkMachInt 0), [], empty)])
418           let (vect_bndrs, lift_bndrs) = unzip vbndrs
419           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
420
421     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
422
423     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
424
425     pack_var len tags t v
426       = do
427           r <- lookupVar v
428           case r of
429             Local (vv, lv) ->
430               do
431                 lv'  <- cloneVar lv
432                 expr <- packByTagPD (idType vv) (Var lv) len tags t
433                 updLEnv (\env -> env { local_vars = extendVarEnv
434                                                 (local_vars env) v (vv, lv') })
435                 return [(NonRec lv' expr)]
436
437             _ -> return []
438