Added a VECTORISE pragma
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import Vectorise.Utils
7 import Vectorise.Type.Type
8 import Vectorise.Var
9 import Vectorise.Vect
10 import Vectorise.Env
11 import Vectorise.Monad
12 import Vectorise.Builtins
13
14 import CoreSyn
15 import CoreUtils
16 import MkCore
17 import CoreFVs
18 import DataCon
19 import TyCon
20 import Type
21 import Var
22 import VarEnv
23 import VarSet
24 import Id
25 import BasicTypes( isLoopBreaker )
26 import Literal
27 import TysWiredIn
28 import TysPrim
29 import Outputable
30 import FastString
31 import Control.Monad
32 import Data.List
33
34
35 -- | Vectorise a polymorphic expression.
36 --
37 vectPolyExpr :: Bool            -- ^ When vectorising the RHS of a binding, whether that
38                                               --   binding is a loop breaker.
39                    -> [Var]                     
40                    -> CoreExprWithFVs
41                    -> VM (Inline, Bool, VExpr)
42 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
43  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
44       return (inline, isScalarFn, vNote note expr')
45 vectPolyExpr loop_breaker recFns expr
46  = do
47       arity <- polyArity tvs
48       polyAbstract tvs $ \args ->
49         do
50           (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
51           return (addInlineArity inline arity, isScalarFn, 
52                   mapVect (mkLams $ tvs ++ args) mono')
53   where
54     (tvs, mono) = collectAnnTypeBinders expr
55
56
57 -- | Vectorise an expression.
58 vectExpr :: CoreExprWithFVs -> VM VExpr
59 vectExpr (_, AnnType ty)
60   = liftM vType (vectType ty)
61
62 vectExpr (_, AnnVar v) 
63   = vectVar v
64
65 vectExpr (_, AnnLit lit) 
66   = vectLiteral lit
67
68 vectExpr (_, AnnNote note expr)
69   = liftM (vNote note) (vectExpr expr)
70
71 vectExpr e@(_, AnnApp _ arg)
72   | isAnnTypeArg arg
73   = vectTyAppExpr fn tys
74   where
75     (fn, tys) = collectAnnTypeArgs e
76
77 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
78   | Just con <- isDataConId_maybe v
79   , is_special_con con
80   = do
81       let vexpr = App (Var v) (Lit lit)
82       lexpr <- liftPD vexpr
83       return (vexpr, lexpr)
84   where
85     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
86
87
88 -- TODO: Avoid using closure application for dictionaries.
89 -- vectExpr (_, AnnApp fn arg)
90 --  | if is application of dictionary 
91 --    just use regular app instead of closure app.
92
93 -- for lifted version. 
94 --      do liftPD (sub a dNumber)
95 --      lift the result of the selection, not sub and dNumber seprately. 
96
97 vectExpr (_, AnnApp fn arg)
98  = do
99       arg_ty' <- vectType arg_ty
100       res_ty' <- vectType res_ty
101
102       fn'     <- vectExpr fn
103       arg'    <- vectExpr arg
104
105       mkClosureApp arg_ty' res_ty' fn' arg'
106   where
107     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
108
109 vectExpr (_, AnnCase scrut bndr ty alts)
110   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
111   , isAlgTyCon tycon
112   = vectAlgCase tycon ty_args scrut bndr ty alts
113   | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
114   where
115     scrut_ty = exprType (deAnnotate scrut)
116
117 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
118   = do
119       vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
120       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
121       return $ vLet (vNonRec vbndr vrhs) vbody
122
123 vectExpr (_, AnnLet (AnnRec bs) body)
124   = do
125       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
126                                 $ liftM2 (,)
127                                   (zipWithM vect_rhs bndrs rhss)
128                                   (vectExpr body)
129       return $ vLet (vRec vbndrs vrhss) vbody
130   where
131     (bndrs, rhss) = unzip bs
132
133     vect_rhs bndr rhs = localV
134                       . inBind bndr
135                       . liftM (\(_,_,z)->z)
136                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
137
138 vectExpr e@(_, AnnLam bndr _)
139   | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
140 {-
141 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
142                 `orElseV` vectLam True fvs bs body
143   where
144     (bs,body) = collectAnnValBinders e
145 -}
146
147 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
148
149 -- | Vectorise an expression with an outer lambda abstraction.
150 --
151 vectFnExpr :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
152            -> Bool             -- ^ Whether the binding is a loop breaker.
153            -> [Var]
154            -> CoreExprWithFVs  -- ^ Expression to vectorise. Must have an outer `AnnLam`.
155            -> VM (Inline, Bool, VExpr)
156 vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
157   | isId bndr = onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
158                         (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
159                 `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
160   where
161     (bs,body) = collectAnnValBinders e
162 vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
163
164 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
165 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
166
167
168 -- | Vectorise a function where are the args have scalar type,
169 --   that is Int, Float, Double etc.
170 vectScalarLam 
171         :: [Var]        -- ^ Bound variables of function
172         -> [Var]
173         -> CoreExpr     -- ^ Function body.
174         -> VM VExpr
175         
176 vectScalarLam args recFns body
177  = do scalars' <- globalScalars
178       let scalars = unionVarSet (mkVarSet recFns) scalars'
179       onlyIfV (all is_prim_ty arg_tys
180                && is_prim_ty res_ty
181                && is_scalar (extendVarSetList scalars args) body
182                && uses scalars body)
183         $ do
184             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
185             zipf    <- zipScalars arg_tys res_ty
186             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
187                                                 (zipf `App` Var fn_var)
188             clo_var <- hoistExpr (fsLit "clo") clo DontInline
189             lclo    <- liftPD (Var clo_var)
190             return (Var clo_var, lclo)
191   where
192     arg_tys = map idType args
193     res_ty  = exprType body
194
195     is_prim_ty ty 
196         | Just (tycon, [])   <- splitTyConApp_maybe ty
197         =    tycon == intTyCon
198           || tycon == floatTyCon
199           || tycon == doubleTyCon
200
201         | otherwise = False
202     
203     cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
204          
205     maybe_parr_ty ty = maybe_parr_ty' [] ty
206       
207     maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
208     maybe_parr_ty' alreadySeen ty
209        | isPArrTyCon tycon     = True
210        | isPrimTyCon tycon     = False
211        | isAbstractTyCon tycon = True
212        | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
213        | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || 
214                              hasParrDataCon alreadySeen tycon
215        | otherwise = True
216        where
217          Just (tycon, args) = splitTyConApp_maybe ty 
218          
219          
220          hasParrDataCon alreadySeen tycon
221            | tycon `elem` alreadySeen = False  
222            | otherwise                =  
223                any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
224          
225     -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
226     -- an external (non data constructor) variable is used, or anonymous data constructor      
227     is_scalar vs e@(Var v) 
228       | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
229       | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
230     is_scalar _ e@(Lit _)    = cantbe_parr_expr e  
231
232     is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
233                                is_scalar vs e1 && is_scalar vs e2    
234     is_scalar vs e@(Let (NonRec b letExpr) body) 
235                              = cantbe_parr_expr e &&
236                                is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
237     is_scalar vs  e@(Let (Rec bnds) body) 
238                              =  let vs' = extendVarSetList vs (map fst bnds)
239                                 in cantbe_parr_expr e &&  
240                                    all (is_scalar vs') (map snd bnds) && is_scalar vs' body
241     is_scalar vs e@(Case eC eId ty alts)  
242                              = let vs' = extendVarSet vs eId
243                                    in cantbe_parr_expr e && 
244                                   is_prim_ty ty &&
245                                   is_scalar vs' eC   &&
246                                   (all (is_scalar_alt vs') alts)
247                                     
248     is_scalar _ _            =  False
249
250     is_scalar_alt vs (_, bs, e) 
251                              = is_scalar (extendVarSetList vs bs) e
252
253     -- A scalar function has to actually compute something. Without the check,
254     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
255     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
256     -- (\n# x -> x) which is what we want.
257     uses funs (Var v)     = v `elemVarSet` funs 
258     uses funs (App e1 e2) = uses funs e1 || uses funs e2
259     uses funs (Let (NonRec _b letExpr) body) 
260                           = uses funs letExpr || uses funs  body
261     uses funs (Case e _eId _ty alts) 
262                           = uses funs e || any (uses_alt funs) alts
263     uses _ _              = False
264
265     uses_alt funs (_, _bs, e)   
266                           = uses funs e 
267
268 -- | Vectorise a lambda abstraction.
269 vectLam 
270         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
271         -> Bool                 -- ^ Whether the binding is a loop breaker.
272         -> VarSet               -- ^ The free variables in the body.
273         -> [Var]                -- ^ Binding variables.
274         -> CoreExprWithFVs      -- ^ Body of abstraction.
275         -> VM VExpr
276
277 vectLam inline loop_breaker fvs bs body
278  = do tyvars    <- localTyVars
279       (vs, vvs) <- readLEnv $ \env ->
280                    unzip [(var, vv) | var <- varSetElems fvs
281                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
282
283       arg_tys   <- mapM (vectType . idType) bs
284       res_ty    <- vectType (exprType $ deAnnotate body)
285
286       buildClosures tyvars vvs arg_tys res_ty
287         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
288         $ do
289             lc              <- builtin liftingContext
290             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
291
292             vbody' <- break_loop lc res_ty vbody
293             return $ vLams lc vbndrs vbody'
294   where
295     maybe_inline n | inline    = Inline n
296                    | otherwise = DontInline
297
298     break_loop lc ty (ve, le)
299       | loop_breaker
300       = do
301           empty <- emptyPD ty
302           lty <- mkPDataType ty
303           return (ve, mkWildCase (Var lc) intPrimTy lty
304                         [(DEFAULT, [], le),
305                          (LitAlt (mkMachInt 0), [], empty)])
306
307       | otherwise = return (ve, le)
308  
309
310 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
311 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
312 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
313                         (ppr $ deAnnotate e `mkTyApps` tys)
314
315
316 -- | Vectorise an algebraic case expression.
317 --   We convert
318 --
319 --   case e :: t of v { ... }
320 --
321 -- to
322 --
323 --   V:    let v' = e in case v' of _ { ... }
324 --   L:    let v' = e in case v' `cast` ... of _ { ... }
325 --
326 --   When lifting, we have to do it this way because v must have the type
327 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
328 --   have to handle the case where v is a wild var correctly.
329 --
330
331 -- FIXME: this is too lazy
332 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
333             -> [(AltCon, [Var], CoreExprWithFVs)]
334             -> VM VExpr
335 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
336   = do
337       vscrut         <- vectExpr scrut
338       (vty, lty)     <- vectAndLiftType ty
339       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
340       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
341
342 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
343   = do
344       vscrut         <- vectExpr scrut
345       (vty, lty)     <- vectAndLiftType ty
346       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
347       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
348
349 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
350   = do
351       (vty, lty) <- vectAndLiftType ty
352       vexpr      <- vectExpr scrut
353       (vbndr, (vbndrs, (vect_body, lift_body)))
354          <- vect_scrut_bndr
355           . vectBndrsIn bndrs
356           $ vectExpr body
357       let (vect_bndrs, lift_bndrs) = unzip vbndrs
358       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
359       vect_dc <- maybeV (lookupDataCon dc)
360       let [pdata_dc] = tyConDataCons pdata_tc
361
362       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
363           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
364
365       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
366   where
367     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
368                     | otherwise         = vectBndrIn bndr
369
370     mk_wild_case expr ty dc bndrs body
371       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
372
373 vectAlgCase tycon _ty_args scrut bndr ty alts
374   = do
375       vect_tc     <- maybeV (lookupTyCon tycon)
376       (vty, lty)  <- vectAndLiftType ty
377
378       let arity = length (tyConDataCons vect_tc)
379       sel_ty <- builtin (selTy arity)
380       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
381       let sel = Var sel_bndr
382
383       (vbndr, valts) <- vect_scrut_bndr
384                       $ mapM (proc_alt arity sel vty lty) alts'
385       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
386
387       vexpr <- vectExpr scrut
388       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
389       let [pdata_dc] = tyConDataCons pdata_tc
390
391       let (vect_bodies, lift_bodies) = unzip vbodies
392
393       vdummy <- newDummyVar (exprType vect_scrut)
394       ldummy <- newDummyVar (exprType lift_scrut)
395       let vect_case = Case vect_scrut vdummy vty
396                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
397
398       lc <- builtin liftingContext
399       lbody <- combinePD vty (Var lc) sel lift_bodies
400       let lift_case = Case lift_scrut ldummy lty
401                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
402                              lbody)]
403
404       return . vLet (vNonRec vbndr vexpr)
405              $ (vect_case, lift_case)
406   where
407     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
408                     | otherwise         = vectBndrIn bndr
409
410     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
411
412     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
413     cmp DEFAULT       DEFAULT       = EQ
414     cmp DEFAULT       _             = LT
415     cmp _             DEFAULT       = GT
416     cmp _             _             = panic "vectAlgCase/cmp"
417
418     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
419       = do
420           vect_dc <- maybeV (lookupDataCon dc)
421           let ntag = dataConTagZ vect_dc
422               tag  = mkDataConTag vect_dc
423               fvs  = freeVarsOf body `delVarSetList` bndrs
424
425           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
426           lc        <- builtin liftingContext
427           elems     <- builtin (selElements arity ntag)
428
429           (vbndrs, vbody)
430             <- vectBndrsIn bndrs
431              . localV
432              $ do
433                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
434                            . filter isLocalId
435                            $ varSetElems fvs
436                  (ve, le) <- vectExpr body
437                  return (ve, Case (elems `App` sel) lc lty
438                              [(DEFAULT, [], (mkLets (concat binds) le))])
439                  -- empty    <- emptyPD vty
440                  -- return (ve, Case (elems `App` sel) lc lty
441                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
442                  --                             $ mkLets (concat binds) le),
443                  --               (LitAlt (mkMachInt 0), [], empty)])
444           let (vect_bndrs, lift_bndrs) = unzip vbndrs
445           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
446
447     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
448
449     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
450
451     pack_var len tags t v
452       = do
453           r <- lookupVar v
454           case r of
455             Local (vv, lv) ->
456               do
457                 lv'  <- cloneVar lv
458                 expr <- packByTagPD (idType vv) (Var lv) len tags t
459                 updLEnv (\env -> env { local_vars = extendVarEnv
460                                                 (local_vars env) v (vv, lv') })
461                 return [(NonRec lv' expr)]
462
463             _ -> return []
464