9cd34e3ac3f54d0a2e0403b5f025a048660ea8da
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import Vectorise.Utils
7 import Vectorise.Type.Type
8 import Vectorise.Var
9 import Vectorise.Vect
10 import Vectorise.Env
11 import Vectorise.Monad
12 import Vectorise.Builtins
13
14 import CoreSyn
15 import CoreUtils
16 import MkCore
17 import CoreFVs
18 import DataCon
19 import TyCon
20 import Type
21 import Var
22 import VarEnv
23 import VarSet
24 import Id
25 import BasicTypes( isLoopBreaker )
26 import Literal
27 import TysWiredIn
28 import TysPrim
29 import Outputable
30 import FastString
31 import Control.Monad
32 import Data.List
33
34
35 -- | Vectorise a polymorphic expression.
36 vectPolyExpr 
37         :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
38                                     --   binding is a loop breaker.
39         -> [Var]                        
40         -> CoreExprWithFVs
41         -> VM (Inline, Bool, VExpr)
42
43 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
44  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
45       return (inline, isScalarFn, vNote note expr')
46
47 vectPolyExpr loop_breaker recFns expr
48  = do
49       arity <- polyArity tvs
50       polyAbstract tvs $ \args ->
51         do
52           (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
53           return (addInlineArity inline arity, isScalarFn, 
54                   mapVect (mkLams $ tvs ++ args) mono')
55   where
56     (tvs, mono) = collectAnnTypeBinders expr
57
58
59 -- | Vectorise an expression.
60 vectExpr :: CoreExprWithFVs -> VM VExpr
61 vectExpr (_, AnnType ty)
62   = liftM vType (vectType ty)
63
64 vectExpr (_, AnnVar v) 
65   = vectVar v
66
67 vectExpr (_, AnnLit lit) 
68   = vectLiteral lit
69
70 vectExpr (_, AnnNote note expr)
71   = liftM (vNote note) (vectExpr expr)
72
73 vectExpr e@(_, AnnApp _ arg)
74   | isAnnTypeArg arg
75   = vectTyAppExpr fn tys
76   where
77     (fn, tys) = collectAnnTypeArgs e
78
79 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
80   | Just con <- isDataConId_maybe v
81   , is_special_con con
82   = do
83       let vexpr = App (Var v) (Lit lit)
84       lexpr <- liftPD vexpr
85       return (vexpr, lexpr)
86   where
87     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
88
89
90 -- TODO: Avoid using closure application for dictionaries.
91 -- vectExpr (_, AnnApp fn arg)
92 --  | if is application of dictionary 
93 --    just use regular app instead of closure app.
94
95 -- for lifted version. 
96 --      do liftPD (sub a dNumber)
97 --      lift the result of the selection, not sub and dNumber seprately. 
98
99 vectExpr (_, AnnApp fn arg)
100  = do
101       arg_ty' <- vectType arg_ty
102       res_ty' <- vectType res_ty
103
104       fn'     <- vectExpr fn
105       arg'    <- vectExpr arg
106
107       mkClosureApp arg_ty' res_ty' fn' arg'
108   where
109     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
110
111 vectExpr (_, AnnCase scrut bndr ty alts)
112   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
113   , isAlgTyCon tycon
114   = vectAlgCase tycon ty_args scrut bndr ty alts
115   | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
116   where
117     scrut_ty = exprType (deAnnotate scrut)
118
119 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
120   = do
121       vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
122       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
123       return $ vLet (vNonRec vbndr vrhs) vbody
124
125 vectExpr (_, AnnLet (AnnRec bs) body)
126   = do
127       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
128                                 $ liftM2 (,)
129                                   (zipWithM vect_rhs bndrs rhss)
130                                   (vectExpr body)
131       return $ vLet (vRec vbndrs vrhss) vbody
132   where
133     (bndrs, rhss) = unzip bs
134
135     vect_rhs bndr rhs = localV
136                       . inBind bndr
137                       . liftM (\(_,_,z)->z)
138                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
139
140 vectExpr e@(_, AnnLam bndr _)
141   | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
142 {-
143 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
144                 `orElseV` vectLam True fvs bs body
145   where
146     (bs,body) = collectAnnValBinders e
147 -}
148
149 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
150
151
152 -- | Vectorise an expression with an outer lambda abstraction.
153 vectFnExpr 
154         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
155         -> Bool                 -- ^ Whether the binding is a loop breaker.
156         -> [Var]
157         -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
158         -> VM (Inline, Bool, VExpr)
159
160 vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
161   | isId bndr = onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
162                         (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
163                 `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
164   where
165     (bs,body) = collectAnnValBinders e
166
167 vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
168
169 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
170 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
171
172
173 -- | Vectorise a function where are the args have scalar type,
174 --   that is Int, Float, Double etc.
175 vectScalarLam 
176         :: [Var]        -- ^ Bound variables of function
177         -> [Var]
178         -> CoreExpr     -- ^ Function body.
179         -> VM VExpr
180         
181 vectScalarLam args recFns body
182  = do scalars' <- globalScalars
183       let scalars = unionVarSet (mkVarSet recFns) scalars'
184       onlyIfV (all is_prim_ty arg_tys
185                && is_prim_ty res_ty
186                && is_scalar (extendVarSetList scalars args) body
187                && uses scalars body)
188         $ do
189             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
190             zipf    <- zipScalars arg_tys res_ty
191             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
192                                                 (zipf `App` Var fn_var)
193             clo_var <- hoistExpr (fsLit "clo") clo DontInline
194             lclo    <- liftPD (Var clo_var)
195             return (Var clo_var, lclo)
196   where
197     arg_tys = map idType args
198     res_ty  = exprType body
199
200     is_prim_ty ty 
201         | Just (tycon, [])   <- splitTyConApp_maybe ty
202         =    tycon == intTyCon
203           || tycon == floatTyCon
204           || tycon == doubleTyCon
205
206         | otherwise = False
207     
208     cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
209          
210     maybe_parr_ty ty = maybe_parr_ty' [] ty
211       
212     maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
213     maybe_parr_ty' alreadySeen ty
214        | isPArrTyCon tycon     = True
215        | isPrimTyCon tycon     = False
216        | isAbstractTyCon tycon = True
217        | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
218        | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || 
219                              hasParrDataCon alreadySeen tycon
220        | otherwise = True
221        where
222          Just (tycon, args) = splitTyConApp_maybe ty 
223          
224          
225          hasParrDataCon alreadySeen tycon
226            | tycon `elem` alreadySeen = False  
227            | otherwise                =  
228                any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
229          
230     -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
231     -- an external (non data constructor) variable is used, or anonymous data constructor      
232     is_scalar vs e@(Var v) 
233       | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
234       | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
235     is_scalar _ e@(Lit _)    = cantbe_parr_expr e  
236
237     is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
238                                is_scalar vs e1 && is_scalar vs e2    
239     is_scalar vs e@(Let (NonRec b letExpr) body) 
240                              = cantbe_parr_expr e &&
241                                is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
242     is_scalar vs  e@(Let (Rec bnds) body) 
243                              =  let vs' = extendVarSetList vs (map fst bnds)
244                                 in cantbe_parr_expr e &&  
245                                    all (is_scalar vs') (map snd bnds) && is_scalar vs' body
246     is_scalar vs e@(Case eC eId ty alts)  
247                              = let vs' = extendVarSet vs eId
248                                    in cantbe_parr_expr e && 
249                                   is_prim_ty ty &&
250                                   is_scalar vs' eC   &&
251                                   (all (is_scalar_alt vs') alts)
252                                     
253     is_scalar _ _            =  False
254
255     is_scalar_alt vs (_, bs, e) 
256                              = is_scalar (extendVarSetList vs bs) e
257
258     -- A scalar function has to actually compute something. Without the check,
259     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
260     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
261     -- (\n# x -> x) which is what we want.
262     uses funs (Var v)     = v `elemVarSet` funs 
263     uses funs (App e1 e2) = uses funs e1 || uses funs e2
264     uses funs (Let (NonRec _b letExpr) body) 
265                           = uses funs letExpr || uses funs  body
266     uses funs (Case e _eId _ty alts) 
267                           = uses funs e || any (uses_alt funs) alts
268     uses _ _              = False
269
270     uses_alt funs (_, _bs, e)   
271                           = uses funs e 
272
273 -- | Vectorise a lambda abstraction.
274 vectLam 
275         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
276         -> Bool                 -- ^ Whether the binding is a loop breaker.
277         -> VarSet               -- ^ The free variables in the body.
278         -> [Var]                -- ^ Binding variables.
279         -> CoreExprWithFVs      -- ^ Body of abstraction.
280         -> VM VExpr
281
282 vectLam inline loop_breaker fvs bs body
283  = do tyvars    <- localTyVars
284       (vs, vvs) <- readLEnv $ \env ->
285                    unzip [(var, vv) | var <- varSetElems fvs
286                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
287
288       arg_tys   <- mapM (vectType . idType) bs
289       res_ty    <- vectType (exprType $ deAnnotate body)
290
291       buildClosures tyvars vvs arg_tys res_ty
292         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
293         $ do
294             lc              <- builtin liftingContext
295             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
296
297             vbody' <- break_loop lc res_ty vbody
298             return $ vLams lc vbndrs vbody'
299   where
300     maybe_inline n | inline    = Inline n
301                    | otherwise = DontInline
302
303     break_loop lc ty (ve, le)
304       | loop_breaker
305       = do
306           empty <- emptyPD ty
307           lty <- mkPDataType ty
308           return (ve, mkWildCase (Var lc) intPrimTy lty
309                         [(DEFAULT, [], le),
310                          (LitAlt (mkMachInt 0), [], empty)])
311
312       | otherwise = return (ve, le)
313  
314
315 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
316 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
317 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
318                         (ppr $ deAnnotate e `mkTyApps` tys)
319
320
321 -- | Vectorise an algebraic case expression.
322 --   We convert
323 --
324 --   case e :: t of v { ... }
325 --
326 -- to
327 --
328 --   V:    let v' = e in case v' of _ { ... }
329 --   L:    let v' = e in case v' `cast` ... of _ { ... }
330 --
331 --   When lifting, we have to do it this way because v must have the type
332 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
333 --   have to handle the case where v is a wild var correctly.
334 --
335
336 -- FIXME: this is too lazy
337 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
338             -> [(AltCon, [Var], CoreExprWithFVs)]
339             -> VM VExpr
340 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
341   = do
342       vscrut         <- vectExpr scrut
343       (vty, lty)     <- vectAndLiftType ty
344       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
345       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
346
347 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
348   = do
349       vscrut         <- vectExpr scrut
350       (vty, lty)     <- vectAndLiftType ty
351       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
352       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
353
354 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
355   = do
356       (vty, lty) <- vectAndLiftType ty
357       vexpr      <- vectExpr scrut
358       (vbndr, (vbndrs, (vect_body, lift_body)))
359          <- vect_scrut_bndr
360           . vectBndrsIn bndrs
361           $ vectExpr body
362       let (vect_bndrs, lift_bndrs) = unzip vbndrs
363       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
364       vect_dc <- maybeV (lookupDataCon dc)
365       let [pdata_dc] = tyConDataCons pdata_tc
366
367       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
368           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
369
370       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
371   where
372     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
373                     | otherwise         = vectBndrIn bndr
374
375     mk_wild_case expr ty dc bndrs body
376       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
377
378 vectAlgCase tycon _ty_args scrut bndr ty alts
379   = do
380       vect_tc     <- maybeV (lookupTyCon tycon)
381       (vty, lty)  <- vectAndLiftType ty
382
383       let arity = length (tyConDataCons vect_tc)
384       sel_ty <- builtin (selTy arity)
385       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
386       let sel = Var sel_bndr
387
388       (vbndr, valts) <- vect_scrut_bndr
389                       $ mapM (proc_alt arity sel vty lty) alts'
390       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
391
392       vexpr <- vectExpr scrut
393       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
394       let [pdata_dc] = tyConDataCons pdata_tc
395
396       let (vect_bodies, lift_bodies) = unzip vbodies
397
398       vdummy <- newDummyVar (exprType vect_scrut)
399       ldummy <- newDummyVar (exprType lift_scrut)
400       let vect_case = Case vect_scrut vdummy vty
401                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
402
403       lc <- builtin liftingContext
404       lbody <- combinePD vty (Var lc) sel lift_bodies
405       let lift_case = Case lift_scrut ldummy lty
406                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
407                              lbody)]
408
409       return . vLet (vNonRec vbndr vexpr)
410              $ (vect_case, lift_case)
411   where
412     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
413                     | otherwise         = vectBndrIn bndr
414
415     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
416
417     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
418     cmp DEFAULT       DEFAULT       = EQ
419     cmp DEFAULT       _             = LT
420     cmp _             DEFAULT       = GT
421     cmp _             _             = panic "vectAlgCase/cmp"
422
423     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
424       = do
425           vect_dc <- maybeV (lookupDataCon dc)
426           let ntag = dataConTagZ vect_dc
427               tag  = mkDataConTag vect_dc
428               fvs  = freeVarsOf body `delVarSetList` bndrs
429
430           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
431           lc        <- builtin liftingContext
432           elems     <- builtin (selElements arity ntag)
433
434           (vbndrs, vbody)
435             <- vectBndrsIn bndrs
436              . localV
437              $ do
438                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
439                            . filter isLocalId
440                            $ varSetElems fvs
441                  (ve, le) <- vectExpr body
442                  return (ve, Case (elems `App` sel) lc lty
443                              [(DEFAULT, [], (mkLets (concat binds) le))])
444                  -- empty    <- emptyPD vty
445                  -- return (ve, Case (elems `App` sel) lc lty
446                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
447                  --                             $ mkLets (concat binds) le),
448                  --               (LitAlt (mkMachInt 0), [], empty)])
449           let (vect_bndrs, lift_bndrs) = unzip vbndrs
450           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
451
452     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
453
454     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
455
456     pack_var len tags t v
457       = do
458           r <- lookupVar v
459           case r of
460             Local (vv, lv) ->
461               do
462                 lv'  <- cloneVar lv
463                 expr <- packByTagPD (idType vv) (Var lv) len tags t
464                 updLEnv (\env -> env { local_vars = extendVarEnv
465                                                 (local_vars env) v (vv, lv') })
466                 return [(NonRec lv' expr)]
467
468             _ -> return []
469