079e8265c4a52eeb4347af3ddc7b61a703d9b7dc
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp
4         (vectPolyExpr)
5 where
6 import Vectorise.Utils
7 import Vectorise.Type.Type
8 import Vectorise.Var
9 import Vectorise.Vect
10 import Vectorise.Env
11 import Vectorise.Monad
12 import Vectorise.Builtins
13
14 import CoreSyn
15 import CoreUtils
16 import MkCore
17 import CoreFVs
18 import DataCon
19 import TyCon
20 import Type
21 import Var
22 import VarEnv
23 import VarSet
24 import Id
25 import BasicTypes( isLoopBreaker )
26 import Literal
27 import TysWiredIn
28 import TysPrim
29 import Outputable
30 import FastString
31 import Control.Monad
32 import Data.List
33
34
35 -- | Vectorise a polymorphic expression.
36 vectPolyExpr 
37         :: Bool                 -- ^ When vectorising the RHS of a binding, whether that
38                                     --   binding is a loop breaker.
39         -> [Var]                        
40         -> CoreExprWithFVs
41         -> VM (Inline, Bool, VExpr)
42
43 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
44  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
45       return (inline, isScalarFn, vNote note expr')
46
47 vectPolyExpr loop_breaker recFns expr
48  = do
49       arity <- polyArity tvs
50       polyAbstract tvs $ \args ->
51         do
52           (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
53           return (addInlineArity inline arity, isScalarFn, 
54                   mapVect (mkLams $ tvs ++ args) mono')
55   where
56     (tvs, mono) = collectAnnTypeBinders expr
57
58
59 -- | Vectorise an expression.
60 vectExpr :: CoreExprWithFVs -> VM VExpr
61 vectExpr (_, AnnType ty)
62   = liftM vType (vectType ty)
63
64 vectExpr (_, AnnVar v) 
65   = vectVar v
66
67 vectExpr (_, AnnLit lit) 
68   = vectLiteral lit
69
70 vectExpr (_, AnnNote note expr)
71   = liftM (vNote note) (vectExpr expr)
72
73 vectExpr e@(_, AnnApp _ arg)
74   | isAnnTypeArg arg
75   = vectTyAppExpr fn tys
76   where
77     (fn, tys) = collectAnnTypeArgs e
78
79 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
80   | Just con <- isDataConId_maybe v
81   , is_special_con con
82   = do
83       let vexpr = App (Var v) (Lit lit)
84       lexpr <- liftPD vexpr
85       return (vexpr, lexpr)
86   where
87     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
88
89
90 -- TODO: Avoid using closure application for dictionaries.
91 -- vectExpr (_, AnnApp fn arg)
92 --  | if is application of dictionary 
93 --    just use regular app instead of closure app.
94
95 -- for lifted version. 
96 --      do liftPD (sub a dNumber)
97 --      lift the result of the selection, not sub and dNumber seprately. 
98
99 vectExpr (_, AnnApp fn arg)
100  = do
101       arg_ty' <- vectType arg_ty
102       res_ty' <- vectType res_ty
103
104       fn'     <- vectExpr fn
105       arg'    <- vectExpr arg
106
107       mkClosureApp arg_ty' res_ty' fn' arg'
108   where
109     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
110
111 vectExpr (_, AnnCase scrut bndr ty alts)
112   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
113   , isAlgTyCon tycon
114   = vectAlgCase tycon ty_args scrut bndr ty alts
115   | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
116   where
117     scrut_ty = exprType (deAnnotate scrut)
118
119 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
120   = do
121       vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
122       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
123       return $ vLet (vNonRec vbndr vrhs) vbody
124
125 vectExpr (_, AnnLet (AnnRec bs) body)
126   = do
127       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
128                                 $ liftM2 (,)
129                                   (zipWithM vect_rhs bndrs rhss)
130                                   (vectExpr body)
131       return $ vLet (vRec vbndrs vrhss) vbody
132   where
133     (bndrs, rhss) = unzip bs
134
135     vect_rhs bndr rhs = localV
136                       . inBind bndr
137                       . liftM (\(_,_,z)->z)
138                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
139
140 vectExpr e@(_, AnnLam bndr _)
141   | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
142 {-
143 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
144                 `orElseV` vectLam True fvs bs body
145   where
146     (bs,body) = collectAnnValBinders e
147 -}
148
149 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
150
151
152 -- | Vectorise an expression with an outer lambda abstraction.
153 vectFnExpr 
154         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
155         -> Bool                 -- ^ Whether the binding is a loop breaker.
156         -> [Var]
157         -> CoreExprWithFVs      -- ^ Expression to vectorise. Must have an outer `AnnLam`.
158         -> VM (Inline, Bool, VExpr)
159
160 vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
161   | isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$
162                  onlyIfV True -- (isEmptyVarSet fvs)  -- we check for free variables later. TODO: clean up
163                         (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
164                 `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
165   where
166     (bs,body) = collectAnnValBinders e
167
168 vectFnExpr _ _ _  e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
169
170 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
171 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
172
173
174 -- | Vectorise a function where are the args have scalar type,
175 --   that is Int, Float, Double etc.
176 vectScalarLam 
177         :: [Var]        -- ^ Bound variables of function
178         -> [Var]
179         -> CoreExpr     -- ^ Function body.
180         -> VM VExpr
181         
182 vectScalarLam args recFns body
183  = do scalars' <- globalScalars
184       let scalars = unionVarSet (mkVarSet recFns) scalars'
185 {-      pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $
186         pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $
187         pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $
188         pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -}
189       onlyIfV (all is_prim_ty arg_tys
190                && is_prim_ty res_ty
191                && is_scalar (extendVarSetList scalars args) body
192                && uses scalars body)
193         $ do
194             fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
195             zipf    <- zipScalars arg_tys res_ty
196             clo     <- scalarClosure arg_tys res_ty (Var fn_var)
197                                                 (zipf `App` Var fn_var)
198             clo_var <- hoistExpr (fsLit "clo") clo DontInline
199             lclo    <- liftPD (Var clo_var)
200             {- pprTrace "  lam is scalar" (ppr "") $ -}
201             return (Var clo_var, lclo)
202   where
203     arg_tys = map idType args
204     res_ty  = exprType body
205
206     is_prim_ty ty 
207         | Just (tycon, [])   <- splitTyConApp_maybe ty
208         =    tycon == intTyCon
209           || tycon == floatTyCon
210           || tycon == doubleTyCon
211
212         | otherwise = False
213     
214     cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
215          
216     maybe_parr_ty ty = maybe_parr_ty' [] ty
217       
218     maybe_parr_ty' _           ty | Nothing <- splitTyConApp_maybe ty = False   -- TODO: is this really what we want to do with polym. types?
219     maybe_parr_ty' alreadySeen ty
220        | isPArrTyCon tycon     = True
221        | isPrimTyCon tycon     = False
222        | isAbstractTyCon tycon = True
223        | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon  = any (maybe_parr_ty' alreadySeen) args     
224        | isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $ 
225                              any (maybe_parr_ty' alreadySeen) args || 
226                              hasParrDataCon alreadySeen tycon
227        | otherwise = True
228        where
229          Just (tycon, args) = splitTyConApp_maybe ty 
230          
231          
232          hasParrDataCon alreadySeen tycon
233            | tycon `elem` alreadySeen = False  
234            | otherwise                =  
235                any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $  map dataConOrigArgTys $ tyConDataCons tycon 
236          
237     -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
238     -- an external (non data constructor) variable is used, or anonymous data constructor      
239     is_scalar vs e@(Var v) 
240       | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
241       | otherwise                     = cantbe_parr_expr e &&  (v `elemVarSet` vs)
242     is_scalar _ e@(Lit _)    = -- pprTrace "is_scalar  Lit" (ppr e) $ 
243                                cantbe_parr_expr e  
244
245     is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar  App" (ppr e) $  
246                                cantbe_parr_expr e &&
247                                is_scalar vs e1 && is_scalar vs e2    
248     is_scalar vs e@(Let (NonRec b letExpr) body) 
249                              = -- pprTrace "is_scalar  Let" (ppr e) $  
250                                cantbe_parr_expr e &&
251                                is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
252     is_scalar vs e@(Let (Rec bnds) body) 
253                              =  let vs' = extendVarSetList vs (map fst bnds)
254                                 in -- pprTrace "is_scalar  Rec" (ppr e) $  
255                                    cantbe_parr_expr e &&  
256                                    all (is_scalar vs') (map snd bnds) && is_scalar vs' body
257     is_scalar vs e@(Case eC eId ty alts)  
258                              = let vs' = extendVarSet vs eId
259                                    in -- pprTrace "is_scalar  Case" (ppr e) $ 
260                                       cantbe_parr_expr e && 
261                                   is_prim_ty ty &&
262                                   is_scalar vs' eC   &&
263                                   (all (is_scalar_alt vs') alts)
264                                     
265     is_scalar _ e            =  -- pprTrace "is_scalar  other" (ppr e) $  
266                                 False
267
268     is_scalar_alt vs (_, bs, e) 
269                              = is_scalar (extendVarSetList vs bs) e
270
271     -- A scalar function has to actually compute something. Without the check,
272     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
273     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
274     -- (\n# x -> x) which is what we want.
275     uses funs (Var v)     = v `elemVarSet` funs 
276     uses funs (App e1 e2) = uses funs e1 || uses funs e2
277     uses funs (Let (NonRec _b letExpr) body) 
278                           = uses funs letExpr || uses funs  body
279     uses funs (Case e _eId _ty alts) 
280                           = uses funs e || any (uses_alt funs) alts
281     uses _ _              = False
282
283     uses_alt funs (_, _bs, e)   
284                           = uses funs e 
285
286 -- | Vectorise a lambda abstraction.
287 vectLam 
288         :: Bool                 -- ^ When the RHS of a binding, whether that binding should be inlined.
289         -> Bool                 -- ^ Whether the binding is a loop breaker.
290         -> VarSet               -- ^ The free variables in the body.
291         -> [Var]                -- ^ Binding variables.
292         -> CoreExprWithFVs      -- ^ Body of abstraction.
293         -> VM VExpr
294
295 vectLam inline loop_breaker fvs bs body
296  = do tyvars    <- localTyVars
297       (vs, vvs) <- readLEnv $ \env ->
298                    unzip [(var, vv) | var <- varSetElems fvs
299                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
300
301       arg_tys   <- mapM (vectType . idType) bs
302       res_ty    <- vectType (exprType $ deAnnotate body)
303
304       buildClosures tyvars vvs arg_tys res_ty
305         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
306         $ do
307             lc              <- builtin liftingContext
308             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
309
310             vbody' <- break_loop lc res_ty vbody
311             return $ vLams lc vbndrs vbody'
312   where
313     maybe_inline n | inline    = Inline n
314                    | otherwise = DontInline
315
316     break_loop lc ty (ve, le)
317       | loop_breaker
318       = do
319           empty <- emptyPD ty
320           lty <- mkPDataType ty
321           return (ve, mkWildCase (Var lc) intPrimTy lty
322                         [(DEFAULT, [], le),
323                          (LitAlt (mkMachInt 0), [], empty)])
324
325       | otherwise = return (ve, le)
326  
327
328 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
329 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
330 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
331                         (ppr $ deAnnotate e `mkTyApps` tys)
332
333
334 -- | Vectorise an algebraic case expression.
335 --   We convert
336 --
337 --   case e :: t of v { ... }
338 --
339 -- to
340 --
341 --   V:    let v' = e in case v' of _ { ... }
342 --   L:    let v' = e in case v' `cast` ... of _ { ... }
343 --
344 --   When lifting, we have to do it this way because v must have the type
345 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
346 --   have to handle the case where v is a wild var correctly.
347 --
348
349 -- FIXME: this is too lazy
350 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
351             -> [(AltCon, [Var], CoreExprWithFVs)]
352             -> VM VExpr
353 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
354   = do
355       vscrut         <- vectExpr scrut
356       (vty, lty)     <- vectAndLiftType ty
357       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
358       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
359
360 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
361   = do
362       vscrut         <- vectExpr scrut
363       (vty, lty)     <- vectAndLiftType ty
364       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
365       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
366
367 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
368   = do
369       (vty, lty) <- vectAndLiftType ty
370       vexpr      <- vectExpr scrut
371       (vbndr, (vbndrs, (vect_body, lift_body)))
372          <- vect_scrut_bndr
373           . vectBndrsIn bndrs
374           $ vectExpr body
375       let (vect_bndrs, lift_bndrs) = unzip vbndrs
376       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
377       vect_dc <- maybeV (lookupDataCon dc)
378       let [pdata_dc] = tyConDataCons pdata_tc
379
380       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
381           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
382
383       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
384   where
385     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
386                     | otherwise         = vectBndrIn bndr
387
388     mk_wild_case expr ty dc bndrs body
389       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
390
391 vectAlgCase tycon _ty_args scrut bndr ty alts
392   = do
393       vect_tc     <- maybeV (lookupTyCon tycon)
394       (vty, lty)  <- vectAndLiftType ty
395
396       let arity = length (tyConDataCons vect_tc)
397       sel_ty <- builtin (selTy arity)
398       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
399       let sel = Var sel_bndr
400
401       (vbndr, valts) <- vect_scrut_bndr
402                       $ mapM (proc_alt arity sel vty lty) alts'
403       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
404
405       vexpr <- vectExpr scrut
406       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
407       let [pdata_dc] = tyConDataCons pdata_tc
408
409       let (vect_bodies, lift_bodies) = unzip vbodies
410
411       vdummy <- newDummyVar (exprType vect_scrut)
412       ldummy <- newDummyVar (exprType lift_scrut)
413       let vect_case = Case vect_scrut vdummy vty
414                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
415
416       lc <- builtin liftingContext
417       lbody <- combinePD vty (Var lc) sel lift_bodies
418       let lift_case = Case lift_scrut ldummy lty
419                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
420                              lbody)]
421
422       return . vLet (vNonRec vbndr vexpr)
423              $ (vect_case, lift_case)
424   where
425     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
426                     | otherwise         = vectBndrIn bndr
427
428     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
429
430     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
431     cmp DEFAULT       DEFAULT       = EQ
432     cmp DEFAULT       _             = LT
433     cmp _             DEFAULT       = GT
434     cmp _             _             = panic "vectAlgCase/cmp"
435
436     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
437       = do
438           vect_dc <- maybeV (lookupDataCon dc)
439           let ntag = dataConTagZ vect_dc
440               tag  = mkDataConTag vect_dc
441               fvs  = freeVarsOf body `delVarSetList` bndrs
442
443           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
444           lc        <- builtin liftingContext
445           elems     <- builtin (selElements arity ntag)
446
447           (vbndrs, vbody)
448             <- vectBndrsIn bndrs
449              . localV
450              $ do
451                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
452                            . filter isLocalId
453                            $ varSetElems fvs
454                  (ve, le) <- vectExpr body
455                  return (ve, Case (elems `App` sel) lc lty
456                              [(DEFAULT, [], (mkLets (concat binds) le))])
457                  -- empty    <- emptyPD vty
458                  -- return (ve, Case (elems `App` sel) lc lty
459                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
460                  --                             $ mkLets (concat binds) le),
461                  --               (LitAlt (mkMachInt 0), [], empty)])
462           let (vect_bndrs, lift_bndrs) = unzip vbndrs
463           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
464
465     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
466
467     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
468
469     pack_var len tags t v
470       = do
471           r <- lookupVar v
472           case r of
473             Local (vv, lv) ->
474               do
475                 lv'  <- cloneVar lv
476                 expr <- packByTagPD (idType vv) (Var lv) len tags t
477                 updLEnv (\env -> env { local_vars = extendVarEnv
478                                                 (local_vars env) v (vv, lv') })
479                 return [(NonRec lv' expr)]
480
481             _ -> return []
482