This BIG PATCH contains most of the work for the New Coercion Representation
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp (
4
5   -- Vectorise a polymorphic expression
6   vectPolyExpr, 
7   
8   -- Vectorise a scalar expression of functional type
9   vectScalarFun
10 ) where
11
12 #include "HsVersions.h"
13
14 import Vectorise.Type.Type
15 import Vectorise.Var
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19 import Vectorise.Builtins
20 import Vectorise.Utils
21
22 import CoreSyn
23 import CoreUtils
24 import MkCore
25 import CoreFVs
26 import DataCon
27 import TyCon
28 import Type
29 import Var
30 import VarEnv
31 import VarSet
32 import Id
33 import BasicTypes( isLoopBreaker )
34 import Literal
35 import TysWiredIn
36 import TysPrim
37 import Outputable
38 import FastString
39 import Control.Monad
40 import Data.List
41
42
43 -- | Vectorise a polymorphic expression.
44 --
45 vectPolyExpr :: Bool            -- ^ When vectorising the RHS of a binding, whether that
46                                               --   binding is a loop breaker.
47                    -> [Var]                     
48                    -> CoreExprWithFVs
49                    -> VM (Inline, Bool, VExpr)
50 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
51  = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
52       return (inline, isScalarFn, vNote note expr')
53 vectPolyExpr loop_breaker recFns expr
54  = do
55       arity <- polyArity tvs
56       polyAbstract tvs $ \args ->
57         do
58           (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
59           return (addInlineArity inline arity, isScalarFn, 
60                   mapVect (mkLams $ tvs ++ args) mono')
61   where
62     (tvs, mono) = collectAnnTypeBinders expr
63
64
65 -- | Vectorise an expression.
66 vectExpr :: CoreExprWithFVs -> VM VExpr
67 vectExpr (_, AnnType ty)
68   = liftM vType (vectType ty)
69
70 vectExpr (_, AnnVar v) 
71   = vectVar v
72
73 vectExpr (_, AnnLit lit) 
74   = vectLiteral lit
75
76 vectExpr (_, AnnNote note expr)
77   = liftM (vNote note) (vectExpr expr)
78
79 vectExpr e@(_, AnnApp _ arg)
80   | isAnnTypeArg arg
81   = vectTyAppExpr fn tys
82   where
83     (fn, tys) = collectAnnTypeArgs e
84
85 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
86   | Just con <- isDataConId_maybe v
87   , is_special_con con
88   = do
89       let vexpr = App (Var v) (Lit lit)
90       lexpr <- liftPD vexpr
91       return (vexpr, lexpr)
92   where
93     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
94
95
96 -- TODO: Avoid using closure application for dictionaries.
97 -- vectExpr (_, AnnApp fn arg)
98 --  | if is application of dictionary 
99 --    just use regular app instead of closure app.
100
101 -- for lifted version. 
102 --      do liftPD (sub a dNumber)
103 --      lift the result of the selection, not sub and dNumber seprately. 
104
105 vectExpr (_, AnnApp fn arg)
106  = do
107       arg_ty' <- vectType arg_ty
108       res_ty' <- vectType res_ty
109
110       fn'     <- vectExpr fn
111       arg'    <- vectExpr arg
112
113       mkClosureApp arg_ty' res_ty' fn' arg'
114   where
115     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
116
117 vectExpr (_, AnnCase scrut bndr ty alts)
118   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
119   , isAlgTyCon tycon
120   = vectAlgCase tycon ty_args scrut bndr ty alts
121   | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
122   where
123     scrut_ty = exprType (deAnnotate scrut)
124
125 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
126   = do
127       vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
128       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
129       return $ vLet (vNonRec vbndr vrhs) vbody
130
131 vectExpr (_, AnnLet (AnnRec bs) body)
132   = do
133       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
134                                 $ liftM2 (,)
135                                   (zipWithM vect_rhs bndrs rhss)
136                                   (vectExpr body)
137       return $ vLet (vRec vbndrs vrhss) vbody
138   where
139     (bndrs, rhss) = unzip bs
140
141     vect_rhs bndr rhs = localV
142                       . inBind bndr
143                       . liftM (\(_,_,z)->z)
144                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
145
146 vectExpr e@(_, AnnLam bndr _)
147   | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
148 {-
149 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
150                 `orElseV` vectLam True fvs bs body
151   where
152     (bs,body) = collectAnnValBinders e
153 -}
154
155 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
156
157 -- | Vectorise an expression with an outer lambda abstraction.
158 --
159 vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
160                                --   be inlined
161            -> Bool             -- ^ Whether the binding is a loop breaker
162            -> [Var]            -- ^ Names of function in same recursive binding group
163            -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
164            -> VM (Inline, Bool, VExpr)
165 vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
166   | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
167                 `orElseV` 
168                 mark inlineMe False (vectLam inline loop_breaker expr)
169 vectFnExpr _ _ _  e = mark DontInline False $ vectExpr e
170
171 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
172 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
173
174 -- |Vectorise an expression of functional type, where all arguments and the result are of scalar
175 -- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
176 -- involve parallel arrays.  Such functionals do not requires the full blown vectorisation
177 -- transformation; instead, they can be lifted by application of a member of the zipWith family
178 -- (i.e., 'map', 'zipWith', zipWith3', etc.)
179 --
180 vectScalarFun :: Bool       -- ^ Was the function marked as scalar by the user?
181               -> [Var]      -- ^ Functions names in same recursive binding group
182               -> CoreExpr   -- ^ Expression to be vectorised
183               -> VM VExpr
184 vectScalarFun forceScalar recFns expr
185  = do { gscalars <- globalScalars
186       ; let scalars = gscalars `extendVarSetList` recFns
187             (arg_tys, res_ty) = splitFunTys (exprType expr)
188       ; MASSERT( not $ null arg_tys )
189       ; onlyIfV (forceScalar                    -- user asserts the functions is scalar
190                  ||
191                  all is_prim_ty arg_tys         -- check whether the function is scalar
192                   && is_prim_ty res_ty
193                   && is_scalar scalars expr
194                   && uses scalars expr)
195         $ mkScalarFun arg_tys res_ty expr
196       }
197   where
198     -- FIXME: This is woefully insufficient!!!  We need a scalar pragma for types!!!
199     is_prim_ty ty 
200         | Just (tycon, [])   <- splitTyConApp_maybe ty
201         =    tycon == intTyCon
202           || tycon == floatTyCon
203           || tycon == doubleTyCon
204         | otherwise = False
205
206     -- Checks whether an expression contain a non-scalar subexpression. 
207     --
208     -- Precodition: The variables in the first argument are scalar.
209     --
210     -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
211     -- them to the list of scalar variables) and then check them.  If one of them turns out not to
212     -- be scalar, the entire group is regarded as not being scalar.
213     --
214     -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
215     --        data constructor as scalar.  Should be changed once scalar types are passed
216     --        through VectInfo.
217     --
218     is_scalar :: VarSet -> CoreExpr -> Bool
219     is_scalar scalars  (Var v)         = v `elemVarSet` scalars
220     is_scalar _scalars (Lit _)         = True
221     is_scalar scalars  e@(App e1 e2) 
222       | maybe_parr_ty  (exprType e)    = False
223       | otherwise                      = is_scalar scalars e1 && is_scalar scalars e2
224     is_scalar scalars  (Lam var body)  
225       | maybe_parr_ty  (varType var)   = False
226       | otherwise                      = is_scalar (scalars `extendVarSet` var) body
227     is_scalar scalars  (Let bind body) = bindsAreScalar && is_scalar scalars' body
228       where
229         (bindsAreScalar, scalars') = is_scalar_bind scalars bind
230     is_scalar scalars  (Case e var ty alts)
231       | is_prim_ty ty                  = is_scalar scalars' e && all (is_scalar_alt scalars') alts
232       | otherwise                      = False
233       where
234         scalars' = scalars `extendVarSet` var
235     is_scalar scalars  (Cast e _coe)   = is_scalar scalars e
236     is_scalar scalars  (Note _ e   )   = is_scalar scalars e
237     is_scalar _scalars (Type {})       = True
238     is_scalar _scalars (Coercion {})   = True
239
240     -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
241     is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
242     is_scalar_bind scalars (Rec bnds)     = (all (is_scalar scalars') es, scalars')
243       where
244         (vars, es) = unzip bnds
245         scalars'   = scalars `extendVarSetList` vars
246
247     is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
248
249     -- Checks whether the type might be a parallel array type.  In particular, if the outermost
250     -- constructor is a type family, we conservatively assume that it may be a parallel array type.
251     maybe_parr_ty :: Type -> Bool
252     maybe_parr_ty ty 
253       | Just ty'        <- coreView ty            = maybe_parr_ty ty'
254       | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon 
255     maybe_parr_ty _                               = False
256
257     -- FIXME: I'm not convinced that this reasoning is (always) sound.  If the identify functions
258     --        is called by some other function that is otherwise scalar, it would be very bad
259     --        that just this call to the identity makes it not be scalar.
260     -- A scalar function has to actually compute something. Without the check,
261     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
262     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
263     -- (\n# x -> x) which is what we want.
264     uses funs (Var v)       = v `elemVarSet` funs 
265     uses funs (App e1 e2)   = uses funs e1 || uses funs e2
266     uses funs (Lam b body)  = uses (funs `extendVarSet` b) body
267     uses funs (Let (NonRec _b letExpr) body) 
268                             = uses funs letExpr || uses funs  body
269     uses funs (Case e _eId _ty alts) 
270                             = uses funs e || any (uses_alt funs) alts
271     uses _ _                = False
272
273     uses_alt funs (_, _bs, e) = uses funs e 
274
275 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
276 mkScalarFun arg_tys res_ty expr
277   = do { fn_var  <- hoistExpr (fsLit "fn") expr DontInline
278        ; zipf    <- zipScalars arg_tys res_ty
279        ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
280        ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
281        ; lclo    <- liftPD (Var clo_var)
282        ; return (Var clo_var, lclo)
283        }
284
285 -- | Vectorise a lambda abstraction.
286 --
287 vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
288         -> Bool             -- ^ Whether the binding is a loop breaker.
289         -> CoreExprWithFVs  -- ^ Body of abstraction.
290         -> VM VExpr
291 vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
292  = do let (bs, body) = collectAnnValBinders expr
293  
294       tyvars    <- localTyVars
295       (vs, vvs) <- readLEnv $ \env ->
296                    unzip [(var, vv) | var <- varSetElems fvs
297                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
298
299       arg_tys   <- mapM (vectType . idType) bs
300       res_ty    <- vectType (exprType $ deAnnotate body)
301
302       buildClosures tyvars vvs arg_tys res_ty
303         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
304         $ do
305             lc              <- builtin liftingContext
306             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
307
308             vbody' <- break_loop lc res_ty vbody
309             return $ vLams lc vbndrs vbody'
310   where
311     maybe_inline n | inline    = Inline n
312                    | otherwise = DontInline
313
314     break_loop lc ty (ve, le)
315       | loop_breaker
316       = do
317           empty <- emptyPD ty
318           lty <- mkPDataType ty
319           return (ve, mkWildCase (Var lc) intPrimTy lty
320                         [(DEFAULT, [], le),
321                          (LitAlt (mkMachInt 0), [], empty)])
322
323       | otherwise = return (ve, le)
324 vectLam _ _ _ = panic "vectLam"
325  
326
327 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
328 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
329 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
330                         (ppr $ deAnnotate e `mkTyApps` tys)
331
332
333 -- | Vectorise an algebraic case expression.
334 --   We convert
335 --
336 --   case e :: t of v { ... }
337 --
338 -- to
339 --
340 --   V:    let v' = e in case v' of _ { ... }
341 --   L:    let v' = e in case v' `cast` ... of _ { ... }
342 --
343 --   When lifting, we have to do it this way because v must have the type
344 --   [:V(T):] but the scrutinee must be cast to the representation type. We also
345 --   have to handle the case where v is a wild var correctly.
346 --
347
348 -- FIXME: this is too lazy
349 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
350             -> [(AltCon, [Var], CoreExprWithFVs)]
351             -> VM VExpr
352 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
353   = do
354       vscrut         <- vectExpr scrut
355       (vty, lty)     <- vectAndLiftType ty
356       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
357       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
358
359 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
360   = do
361       vscrut         <- vectExpr scrut
362       (vty, lty)     <- vectAndLiftType ty
363       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
364       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
365
366 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
367   = do
368       (vty, lty) <- vectAndLiftType ty
369       vexpr      <- vectExpr scrut
370       (vbndr, (vbndrs, (vect_body, lift_body)))
371          <- vect_scrut_bndr
372           . vectBndrsIn bndrs
373           $ vectExpr body
374       let (vect_bndrs, lift_bndrs) = unzip vbndrs
375       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
376       vect_dc <- maybeV (lookupDataCon dc)
377       let [pdata_dc] = tyConDataCons pdata_tc
378
379       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
380           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
381
382       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
383   where
384     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
385                     | otherwise         = vectBndrIn bndr
386
387     mk_wild_case expr ty dc bndrs body
388       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
389
390 vectAlgCase tycon _ty_args scrut bndr ty alts
391   = do
392       vect_tc     <- maybeV (lookupTyCon tycon)
393       (vty, lty)  <- vectAndLiftType ty
394
395       let arity = length (tyConDataCons vect_tc)
396       sel_ty <- builtin (selTy arity)
397       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
398       let sel = Var sel_bndr
399
400       (vbndr, valts) <- vect_scrut_bndr
401                       $ mapM (proc_alt arity sel vty lty) alts'
402       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
403
404       vexpr <- vectExpr scrut
405       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
406       let [pdata_dc] = tyConDataCons pdata_tc
407
408       let (vect_bodies, lift_bodies) = unzip vbodies
409
410       vdummy <- newDummyVar (exprType vect_scrut)
411       ldummy <- newDummyVar (exprType lift_scrut)
412       let vect_case = Case vect_scrut vdummy vty
413                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
414
415       lc <- builtin liftingContext
416       lbody <- combinePD vty (Var lc) sel lift_bodies
417       let lift_case = Case lift_scrut ldummy lty
418                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
419                              lbody)]
420
421       return . vLet (vNonRec vbndr vexpr)
422              $ (vect_case, lift_case)
423   where
424     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
425                     | otherwise         = vectBndrIn bndr
426
427     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
428
429     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
430     cmp DEFAULT       DEFAULT       = EQ
431     cmp DEFAULT       _             = LT
432     cmp _             DEFAULT       = GT
433     cmp _             _             = panic "vectAlgCase/cmp"
434
435     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
436       = do
437           vect_dc <- maybeV (lookupDataCon dc)
438           let ntag = dataConTagZ vect_dc
439               tag  = mkDataConTag vect_dc
440               fvs  = freeVarsOf body `delVarSetList` bndrs
441
442           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
443           lc        <- builtin liftingContext
444           elems     <- builtin (selElements arity ntag)
445
446           (vbndrs, vbody)
447             <- vectBndrsIn bndrs
448              . localV
449              $ do
450                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
451                            . filter isLocalId
452                            $ varSetElems fvs
453                  (ve, le) <- vectExpr body
454                  return (ve, Case (elems `App` sel) lc lty
455                              [(DEFAULT, [], (mkLets (concat binds) le))])
456                  -- empty    <- emptyPD vty
457                  -- return (ve, Case (elems `App` sel) lc lty
458                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
459                  --                             $ mkLets (concat binds) le),
460                  --               (LitAlt (mkMachInt 0), [], empty)])
461           let (vect_bndrs, lift_bndrs) = unzip vbndrs
462           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
463
464     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
465
466     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
467
468     pack_var len tags t v
469       = do
470           r <- lookupVar v
471           case r of
472             Local (vv, lv) ->
473               do
474                 lv'  <- cloneVar lv
475                 expr <- packByTagPD (idType vv) (Var lv) len tags t
476                 updLEnv (\env -> env { local_vars = extendVarEnv
477                                                 (local_vars env) v (vv, lv') })
478                 return [(NonRec lv' expr)]
479
480             _ -> return []
481