Added handling of non-recursive module global functions to isScalar check
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import Vectorise.Utils
7 import Vectorise.Type.Type
8 import Vectorise.Var
9 import Vectorise.Vect
10 import Vectorise.Env
11 import Vectorise.Monad
12 import Vectorise.Builtins
13
14 import CoreSyn
15 import CoreUtils
16 import MkCore
17 import CoreFVs
18 import DataCon
19 import TyCon
20 import Type
21 import Var
22 import VarEnv
23 import VarSet
24 import Id
25 import BasicTypes( isLoopBreaker )
26 import Literal
27 import TysWiredIn
28 import TysPrim
29 import Outputable
30 import FastString
31 import Control.Monad
32 import Data.List
33
34
35 -- | Vectorise a polymorphic expression.
36 vectPolyExpr 
37         :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
38                                 --   binding is a loop breaker.
39         -> CoreExprWithFVs
40         -> VM (Inline, Bool, VExpr)
41
42 vectPolyExpr loop_breaker (_, AnnNote note expr)
43  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr
44       return (inline, isScalarFn, vNote note expr')
45
46 vectPolyExpr loop_breaker expr
47  = do
48       arity <- polyArity tvs
49       polyAbstract tvs $ \args ->
50         do
51           (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono
52           return (addInlineArity inline arity, isScalarFn, 
53                   mapVect (mkLams $ tvs ++ args) mono')
54   where
55     (tvs, mono) = collectAnnTypeBinders expr
56
57
58 -- | Vectorise an expression.
59 vectExpr :: CoreExprWithFVs -> VM VExpr
60 vectExpr (_, AnnType ty)
61   = liftM vType (vectType ty)
62
63 vectExpr (_, AnnVar v) 
64   = vectVar v
65
66 vectExpr (_, AnnLit lit) 
67   = vectLiteral lit
68
69 vectExpr (_, AnnNote note expr)
70   = liftM (vNote note) (vectExpr expr)
71
72 vectExpr e@(_, AnnApp _ arg)
73   | isAnnTypeArg arg
74   = vectTyAppExpr fn tys
75   where
76     (fn, tys) = collectAnnTypeArgs e
77
78 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
79   | Just con <- isDataConId_maybe v
80   , is_special_con con
81   = do
82       let vexpr = App (Var v) (Lit lit)
83       lexpr <- liftPD vexpr
84       return (vexpr, lexpr)
85   where
86     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
87
88
89 -- TODO: Avoid using closure application for dictionaries.
90 -- vectExpr (_, AnnApp fn arg)
91 --  | if is application of dictionary 
92 --    just use regular app instead of closure app.
93
94 -- for lifted version. 
95 --      do liftPD (sub a dNumber)
96 --      lift the result of the selection, not sub and dNumber seprately. 
97
98 vectExpr (_, AnnApp fn arg)
99  = do
100       arg_ty' <- vectType arg_ty
101       res_ty' <- vectType res_ty
102
103       fn'     <- vectExpr fn
104       arg'    <- vectExpr arg
105
106       mkClosureApp arg_ty' res_ty' fn' arg'
107   where
108     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
109
110 vectExpr (_, AnnCase scrut bndr ty alts)
111   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
112   , isAlgTyCon tycon
113   = vectAlgCase tycon ty_args scrut bndr ty alts
114   | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
115   where
116     scrut_ty = exprType (deAnnotate scrut)
117
118 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
119   = do
120       vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs
121       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
122       return $ vLet (vNonRec vbndr vrhs) vbody
123
124 vectExpr (_, AnnLet (AnnRec bs) body)
125   = do
126       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
127                                 $ liftM2 (,)
128                                   (zipWithM vect_rhs bndrs rhss)
129                                   (vectExpr body)
130       return $ vLet (vRec vbndrs vrhss) vbody
131   where
132     (bndrs, rhss) = unzip bs
133
134     vect_rhs bndr rhs = localV
135                       . inBind bndr
136                       . liftM (\(_,_,z)->z)
137                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
138
139 vectExpr e@(_, AnnLam bndr _)
140   | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e
141 {-
142 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
143                 `orElseV` vectLam True fvs bs body
144   where
145     (bs,body) = collectAnnValBinders e
146 -}
147
148 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
149
150
151 -- | Vectorise an expression with an outer lambda abstraction.
152 vectFnExpr 
153         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
154         -> Bool                 -- ^ Whether the binding is a loop breaker.
155         -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
156         -> VM (Inline, Bool, VExpr)
157
158 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
159   | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$
160                  onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
161                         (mark DontInline True . vectScalarLam bs $ deAnnotate body)
162                 `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
163   where
164     (bs,body) = collectAnnValBinders e
165
166 vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
167
168 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
169 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
170
171
172 -- | Vectorise a function where are the args have scalar type,
173 --   that is Int, Float, Double etc.
174 vectScalarLam 
175         :: [Var]        -- ^ Bound variables of function.
176         -> CoreExpr     -- ^ Function body.
177         -> VM VExpr
178         
179 vectScalarLam args body
180  = do scalars <- globalScalars
181       pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $
182         onlyIfV (all is_prim_ty arg_tys
183                && is_prim_ty res_ty
184                && is_scalar (extendVarSetList scalars args) body
185                && uses scalars body)
186         $ do
187             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
188             zipf    <- zipScalars arg_tys res_ty
189             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
190                                                 (zipf `App` Var fn_var)
191             clo_var <- hoistExpr (fsLit "clo") clo DontInline
192             lclo    <- liftPD (Var clo_var)
193             pprTrace "  lam is scalar" (ppr "") $
194               return (Var clo_var, lclo)
195   where
196     arg_tys = map idType args
197     res_ty  = exprType body
198
199     is_prim_ty ty 
200         | Just (tycon, [])   <- splitTyConApp_maybe ty
201         =    tycon == intTyCon
202           || tycon == floatTyCon
203           || tycon == doubleTyCon
204
205         | otherwise = False
206     
207     cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
208          
209     maybe_parr_ty ty = maybe_parr_ty' [] ty
210       
211     maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
212     maybe_parr_ty' alreadySeen ty
213        | isPArrTyCon tycon     = True
214        | isPrimTyCon tycon     = False
215        | isAbstractTyCon tycon = True
216        | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
217        | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ 
218                              any (maybe_parr_ty' alreadySeen) args || 
219                              hasParrDataCon alreadySeen tycon
220        | otherwise = True
221        where
222          Just (tycon, args) = splitTyConApp_maybe ty 
223          
224          
225          hasParrDataCon alreadySeen tycon
226            | tycon `elem` alreadySeen = False  
227            | otherwise                =  
228                any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
229          
230     -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
231     -- an external (non data constructor) variable is used, or anonymous data constructor      
232     is_scalar vs e@(Var v) 
233       | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
234       | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
235     is_scalar _ e@(Lit _)    = -- pprTrace "is_scalar  Lit" (ppr e) $ 
236                                cantbe_parr_expr e  
237
238     is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar  App" (ppr e) $  
239                                cantbe_parr_expr e &&
240                                is_scalar vs e1 && is_scalar vs e2    
241     is_scalar vs e@(Let (NonRec b letExpr) body) 
242                              = -- pprTrace "is_scalar  Let" (ppr e) $  
243                                cantbe_parr_expr e &&
244                                is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
245     is_scalar vs e@(Let (Rec bnds) body) 
246                              =  let vs' = extendVarSetList vs (map fst bnds)
247                                 in -- pprTrace "is_scalar  Rec" (ppr e) $  
248                                    cantbe_parr_expr e &&  
249                                    all (is_scalar vs') (map snd bnds) && is_scalar vs' body
250     is_scalar vs e@(Case eC eId ty alts)  
251                              = let vs' = extendVarSet vs eId
252                                    in -- pprTrace "is_scalar  Case" (ppr e) $ 
253                                       cantbe_parr_expr e && 
254                                   is_prim_ty ty &&
255                                   is_scalar vs' eC   &&
256                                   (all (is_scalar_alt vs') alts)
257                                     
258     is_scalar _ e            =  -- pprTrace "is_scalar  other" (ppr e) $  
259                                 False
260
261     is_scalar_alt vs (_, bs, e) 
262                              = is_scalar (extendVarSetList vs bs) e
263
264     -- A scalar function has to actually compute something. Without the check,
265     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
266     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
267     -- (\n# x -> x) which is what we want.
268     uses funs (Var v)     = v `elemVarSet` funs 
269     uses funs (App e1 e2) = uses funs e1 || uses funs e2
270     uses funs (Let (NonRec _b letExpr) body) 
271                           = uses funs letExpr || uses funs  body
272     uses funs (Case e _eId _ty alts) 
273                           = uses funs e || any (uses_alt funs) alts
274     uses _ _              = False
275
276     uses_alt funs (_, _bs, e)   
277                           = uses funs e 
278
279 -- | Vectorise a lambda abstraction.
280 vectLam 
281         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
282         -> Bool                 -- ^ Whether the binding is a loop breaker.
283         -> VarSet               -- ^ The free variables in the body.
284         -> [Var]                -- ^ Binding variables.
285         -> CoreExprWithFVs      -- ^ Body of abstraction.
286         -> VM VExpr
287
288 vectLam inline loop_breaker fvs bs body
289  = do tyvars    <- localTyVars
290       (vs, vvs) <- readLEnv $ \env ->
291                    unzip [(var, vv) | var <- varSetElems fvs
292                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
293
294       arg_tys   <- mapM (vectType . idType) bs
295       res_ty    <- vectType (exprType $ deAnnotate body)
296
297       buildClosures tyvars vvs arg_tys res_ty
298         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
299         $ do
300             lc              <- builtin liftingContext
301             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
302
303             vbody' <- break_loop lc res_ty vbody
304             return $ vLams lc vbndrs vbody'
305   where
306     maybe_inline n | inline    = Inline n
307                    | otherwise = DontInline
308
309     break_loop lc ty (ve, le)
310       | loop_breaker
311       = do
312           empty <- emptyPD ty
313           lty <- mkPDataType ty
314           return (ve, mkWildCase (Var lc) intPrimTy lty
315                         [(DEFAULT, [], le),
316                          (LitAlt (mkMachInt 0), [], empty)])
317
318       | otherwise = return (ve, le)
319  
320
321 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
322 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
323 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
324                         (ppr $ deAnnotate e `mkTyApps` tys)
325
326
327 -- | Vectorise an algebraic case expression.
328 --   We convert
329 --
330 --   case e :: t of v { ... }
331 --
332 -- to
333 --
334 --   V:    let v' = e in case v' of _ { ... }
335 --   L:    let v' = e in case v' `cast` ... of _ { ... }
336 --
337 --   When lifting, we have to do it this way because v must have the type
338 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
339 --   have to handle the case where v is a wild var correctly.
340 --
341
342 -- FIXME: this is too lazy
343 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
344             -> [(AltCon, [Var], CoreExprWithFVs)]
345             -> VM VExpr
346 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
347   = do
348       vscrut         <- vectExpr scrut
349       (vty, lty)     <- vectAndLiftType ty
350       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
351       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
352
353 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
354   = do
355       vscrut         <- vectExpr scrut
356       (vty, lty)     <- vectAndLiftType ty
357       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
358       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
359
360 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
361   = do
362       (vty, lty) <- vectAndLiftType ty
363       vexpr      <- vectExpr scrut
364       (vbndr, (vbndrs, (vect_body, lift_body)))
365          <- vect_scrut_bndr
366           . vectBndrsIn bndrs
367           $ vectExpr body
368       let (vect_bndrs, lift_bndrs) = unzip vbndrs
369       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
370       vect_dc <- maybeV (lookupDataCon dc)
371       let [pdata_dc] = tyConDataCons pdata_tc
372
373       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
374           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
375
376       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
377   where
378     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
379                     | otherwise         = vectBndrIn bndr
380
381     mk_wild_case expr ty dc bndrs body
382       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
383
384 vectAlgCase tycon _ty_args scrut bndr ty alts
385   = do
386       vect_tc     <- maybeV (lookupTyCon tycon)
387       (vty, lty)  <- vectAndLiftType ty
388
389       let arity = length (tyConDataCons vect_tc)
390       sel_ty <- builtin (selTy arity)
391       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
392       let sel = Var sel_bndr
393
394       (vbndr, valts) <- vect_scrut_bndr
395                       $ mapM (proc_alt arity sel vty lty) alts'
396       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
397
398       vexpr <- vectExpr scrut
399       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
400       let [pdata_dc] = tyConDataCons pdata_tc
401
402       let (vect_bodies, lift_bodies) = unzip vbodies
403
404       vdummy <- newDummyVar (exprType vect_scrut)
405       ldummy <- newDummyVar (exprType lift_scrut)
406       let vect_case = Case vect_scrut vdummy vty
407                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
408
409       lc <- builtin liftingContext
410       lbody <- combinePD vty (Var lc) sel lift_bodies
411       let lift_case = Case lift_scrut ldummy lty
412                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
413                              lbody)]
414
415       return . vLet (vNonRec vbndr vexpr)
416              $ (vect_case, lift_case)
417   where
418     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
419                     | otherwise         = vectBndrIn bndr
420
421     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
422
423     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
424     cmp DEFAULT       DEFAULT       = EQ
425     cmp DEFAULT       _             = LT
426     cmp _             DEFAULT       = GT
427     cmp _             _             = panic "vectAlgCase/cmp"
428
429     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
430       = do
431           vect_dc <- maybeV (lookupDataCon dc)
432           let ntag = dataConTagZ vect_dc
433               tag  = mkDataConTag vect_dc
434               fvs  = freeVarsOf body `delVarSetList` bndrs
435
436           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
437           lc        <- builtin liftingContext
438           elems     <- builtin (selElements arity ntag)
439
440           (vbndrs, vbody)
441             <- vectBndrsIn bndrs
442              . localV
443              $ do
444                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
445                            . filter isLocalId
446                            $ varSetElems fvs
447                  (ve, le) <- vectExpr body
448                  return (ve, Case (elems `App` sel) lc lty
449                              [(DEFAULT, [], (mkLets (concat binds) le))])
450                  -- empty    <- emptyPD vty
451                  -- return (ve, Case (elems `App` sel) lc lty
452                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
453                  --                             $ mkLets (concat binds) le),
454                  --               (LitAlt (mkMachInt 0), [], empty)])
455           let (vect_bndrs, lift_bndrs) = unzip vbndrs
456           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
457
458     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
459
460     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
461
462     pack_var len tags t v
463       = do
464           r <- lookupVar v
465           case r of
466             Local (vv, lv) ->
467               do
468                 lv'  <- cloneVar lv
469                 expr <- packByTagPD (idType vv) (Var lv) len tags t
470                 updLEnv (\env -> env { local_vars = extendVarEnv
471                                                 (local_vars env) v (vv, lv') })
472                 return [(NonRec lv' expr)]
473
474             _ -> return []
475