Vectoriser: only treat a function as scalar if it actually computes something
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1
2 module Vectorise( vectorise )
3 where
4
5 import VectMonad
6 import VectUtils
7 import VectType
8 import VectCore
9
10 import HscTypes hiding      ( MonadThings(..) )
11
12 import Module               ( PackageId )
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold           ( mkInlineRule )
16 import MkCore               ( mkWildCase )
17 import CoreFVs
18 import CoreMonad            ( CoreM, getHscEnv )
19 import DataCon
20 import TyCon
21 import Type
22 import FamInstEnv           ( extendFamInstEnvList )
23 import Var
24 import VarEnv
25 import VarSet
26 import Id
27 import OccName
28 import BasicTypes           ( isLoopBreaker )
29
30 import Literal              ( Literal, mkMachInt )
31 import TysWiredIn
32 import TysPrim              ( intPrimTy )
33
34 import Outputable
35 import FastString
36 import Util                 ( zipLazy )
37 import Control.Monad
38 import Data.List            ( sortBy, unzip4 )
39
40 vectorise :: PackageId -> ModGuts -> CoreM ModGuts
41 vectorise backend guts = do
42     hsc_env <- getHscEnv
43     liftIO $ vectoriseIO backend hsc_env guts
44
45 vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
46 vectoriseIO backend hsc_env guts
47   = do
48       eps <- hscEPS hsc_env
49       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
50       Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
51       return (guts' { mg_vect_info = info' })
52
53 vectModule :: ModGuts -> VM ModGuts
54 vectModule guts
55   = do
56       (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
57
58       let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
59       updGEnv (setFamInstEnv fam_inst_env')
60
61       -- dicts   <- mapM buildPADict pa_insts
62       -- workers <- mapM vectDataConWorkers pa_insts
63       binds'  <- mapM vectTopBind (mg_binds guts)
64       return $ guts { mg_types        = types'
65                     , mg_binds        = Rec tc_binds : binds'
66                     , mg_fam_inst_env = fam_inst_env'
67                     , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
68                     }
69
70 vectTopBind :: CoreBind -> VM CoreBind
71 vectTopBind b@(NonRec var expr)
72   = do
73       (inline, expr') <- vectTopRhs var expr
74       var' <- vectTopBinder var inline expr'
75       hs    <- takeHoisted
76       cexpr <- tryConvert var var' expr
77       return . Rec $ (var, cexpr) : (var', expr') : hs
78   `orElseV`
79     return b
80
81 vectTopBind b@(Rec bs)
82   = do
83       (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) ->
84         do
85           vars' <- sequence [vectTopBinder var inline rhs
86                                | (var, ~(inline, rhs))
87                                  <- zipLazy vars (zip inlines rhss)]
88           (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs
89           return (vars', inlines', exprs')
90       hs     <- takeHoisted
91       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
92       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
93   `orElseV`
94     return b
95   where
96     (vars, exprs) = unzip bs
97
98 -- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
99 -- used inside of fixV in vectTopBind
100 vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var
101 vectTopBinder var inline expr
102   = do
103       vty  <- vectType (idType var)
104       var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
105       defGlobalVar var var'
106       return var'
107   where
108     unfolding = case inline of
109                   Inline arity -> mkInlineRule expr (Just arity)
110                   DontInline   -> noUnfolding
111
112 vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr)
113 vectTopRhs var expr
114   = closedV
115   $ do
116       (inline, vexpr) <- inBind var
117                        $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
118                                       (freeVars expr)
119       return (inline, vectorised vexpr)
120
121 tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
122 tryConvert var vect_var rhs
123   = fromVect (idType var) (Var vect_var) `orElseV` return rhs
124
125 -- ----------------------------------------------------------------------------
126 -- Bindings
127
128 vectBndr :: Var -> VM VVar
129 vectBndr v
130   = do
131       (vty, lty) <- vectAndLiftType (idType v)
132       let vv = v `Id.setIdType` vty
133           lv = v `Id.setIdType` lty
134       updLEnv (mapTo vv lv)
135       return (vv, lv)
136   where
137     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
138
139 vectBndrNew :: Var -> FastString -> VM VVar
140 vectBndrNew v fs
141   = do
142       vty <- vectType (idType v)
143       vv  <- newLocalVVar fs vty
144       updLEnv (upd vv)
145       return vv
146   where
147     upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }
148
149 vectBndrIn :: Var -> VM a -> VM (VVar, a)
150 vectBndrIn v p
151   = localV
152   $ do
153       vv <- vectBndr v
154       x <- p
155       return (vv, x)
156
157 vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
158 vectBndrNewIn v fs p
159   = localV
160   $ do
161       vv <- vectBndrNew v fs
162       x  <- p
163       return (vv, x)
164
165 vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
166 vectBndrsIn vs p
167   = localV
168   $ do
169       vvs <- mapM vectBndr vs
170       x <- p
171       return (vvs, x)
172
173 -- ----------------------------------------------------------------------------
174 -- Expressions
175
176 vectVar :: Var -> VM VExpr
177 vectVar v
178   = do
179       r <- lookupVar v
180       case r of
181         Local (vv,lv) -> return (Var vv, Var lv)
182         Global vv     -> do
183                            let vexpr = Var vv
184                            lexpr <- liftPD vexpr
185                            return (vexpr, lexpr)
186
187 vectPolyVar :: Var -> [Type] -> VM VExpr
188 vectPolyVar v tys
189   = do
190       vtys <- mapM vectType tys
191       r <- lookupVar v
192       case r of
193         Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
194                                      (polyApply (Var lv) vtys)
195         Global poly    -> do
196                             vexpr <- polyApply (Var poly) vtys
197                             lexpr <- liftPD vexpr
198                             return (vexpr, lexpr)
199
200 vectLiteral :: Literal -> VM VExpr
201 vectLiteral lit
202   = do
203       lexpr <- liftPD (Lit lit)
204       return (Lit lit, lexpr)
205
206 vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
207 vectPolyExpr loop_breaker (_, AnnNote note expr)
208   = do
209       (inline, expr') <- vectPolyExpr loop_breaker expr
210       return (inline, vNote note expr')
211 vectPolyExpr loop_breaker expr
212   = do
213       arity <- polyArity tvs
214       polyAbstract tvs $ \args ->
215         do
216           (inline, mono') <- vectFnExpr False loop_breaker mono
217           return (addInlineArity inline arity,
218                   mapVect (mkLams $ tvs ++ args) mono')
219   where
220     (tvs, mono) = collectAnnTypeBinders expr
221
222 vectExpr :: CoreExprWithFVs -> VM VExpr
223 vectExpr (_, AnnType ty)
224   = liftM vType (vectType ty)
225
226 vectExpr (_, AnnVar v) = vectVar v
227
228 vectExpr (_, AnnLit lit) = vectLiteral lit
229
230 vectExpr (_, AnnNote note expr)
231   = liftM (vNote note) (vectExpr expr)
232
233 vectExpr e@(_, AnnApp _ arg)
234   | isAnnTypeArg arg
235   = vectTyAppExpr fn tys
236   where
237     (fn, tys) = collectAnnTypeArgs e
238
239 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
240   | Just con <- isDataConId_maybe v
241   , is_special_con con
242   = do
243       let vexpr = App (Var v) (Lit lit)
244       lexpr <- liftPD vexpr
245       return (vexpr, lexpr)
246   where
247     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
248
249
250 vectExpr (_, AnnApp fn arg)
251   = do
252       arg_ty' <- vectType arg_ty
253       res_ty' <- vectType res_ty
254       fn'     <- vectExpr fn
255       arg'    <- vectExpr arg
256       mkClosureApp arg_ty' res_ty' fn' arg'
257   where
258     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
259
260 vectExpr (_, AnnCase scrut bndr ty alts)
261   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
262   , isAlgTyCon tycon
263   = vectAlgCase tycon ty_args scrut bndr ty alts
264   where
265     scrut_ty = exprType (deAnnotate scrut)
266
267 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
268   = do
269       vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
270       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
271       return $ vLet (vNonRec vbndr vrhs) vbody
272
273 vectExpr (_, AnnLet (AnnRec bs) body)
274   = do
275       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
276                                 $ liftM2 (,)
277                                   (zipWithM vect_rhs bndrs rhss)
278                                   (vectExpr body)
279       return $ vLet (vRec vbndrs vrhss) vbody
280   where
281     (bndrs, rhss) = unzip bs
282
283     vect_rhs bndr rhs = localV
284                       . inBind bndr
285                       . liftM snd
286                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
287
288 vectExpr e@(_, AnnLam bndr _)
289   | isId bndr = liftM snd $ vectFnExpr True False e
290 {-
291 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
292                 `orElseV` vectLam True fvs bs body
293   where
294     (bs,body) = collectAnnValBinders e
295 -}
296
297 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
298
299 vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
300 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
301   | isId bndr = onlyIfV (isEmptyVarSet fvs)
302                         (mark DontInline . vectScalarLam bs $ deAnnotate body)
303                 `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
304   where
305     (bs,body) = collectAnnValBinders e
306 vectFnExpr _ _ e = mark DontInline $ vectExpr e
307
308 mark :: Inline -> VM a -> VM (Inline, a)
309 mark b p = do { x <- p; return (b,x) }
310
311 vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
312 vectScalarLam args body
313   = do
314       scalars <- globalScalars
315       onlyIfV (all is_scalar_ty arg_tys
316                && is_scalar_ty res_ty
317                && is_scalar (extendVarSetList scalars args) body
318                && uses scalars body)
319         $ do
320             fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
321             zipf <- zipScalars arg_tys res_ty
322             clo <- scalarClosure arg_tys res_ty (Var fn_var)
323                                                 (zipf `App` Var fn_var)
324             clo_var <- hoistExpr (fsLit "clo") clo DontInline
325             lclo <- liftPD (Var clo_var)
326             return (Var clo_var, lclo)
327   where
328     arg_tys = map idType args
329     res_ty  = exprType body
330
331     is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
332                     = tycon == intTyCon
333                       || tycon == floatTyCon
334                       || tycon == doubleTyCon
335
336                     | otherwise = False
337
338     is_scalar vs (Var v)     = v `elemVarSet` vs
339     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
340     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
341     is_scalar _ _            = False
342
343     -- A scalar function has to actually compute something. Without the check,
344     -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
345     -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
346     -- (\n# x -> x) which is what we want.
347     uses funs (Var v)     = v `elemVarSet` funs 
348     uses funs (App e1 e2) = uses funs e1 || uses funs e2
349     uses _ _              = False
350
351 vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
352 vectLam inline loop_breaker fvs bs body
353   = do
354       tyvars <- localTyVars
355       (vs, vvs) <- readLEnv $ \env ->
356                    unzip [(var, vv) | var <- varSetElems fvs
357                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
358
359       arg_tys <- mapM (vectType . idType) bs
360       res_ty  <- vectType (exprType $ deAnnotate body)
361
362       buildClosures tyvars vvs arg_tys res_ty
363         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
364         $ do
365             lc <- builtin liftingContext
366             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
367                                            (vectExpr body)
368             vbody' <- break_loop lc res_ty vbody
369             return $ vLams lc vbndrs vbody'
370   where
371     maybe_inline n | inline    = Inline n
372                    | otherwise = DontInline
373
374     break_loop lc ty (ve, le)
375       | loop_breaker
376       = do
377           empty <- emptyPD ty
378           lty <- mkPDataType ty
379           return (ve, mkWildCase (Var lc) intPrimTy lty
380                         [(DEFAULT, [], le),
381                          (LitAlt (mkMachInt 0), [], empty)])
382
383       | otherwise = return (ve, le)
384  
385
386 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
387 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
388 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
389                         (ppr $ deAnnotate e `mkTyApps` tys)
390
391 -- We convert
392 --
393 --   case e :: t of v { ... }
394 --
395 -- to
396 --
397 --   V:    let v' = e in case v' of _ { ... }
398 --   L:    let v' = e in case v' `cast` ... of _ { ... }
399 --
400 -- When lifting, we have to do it this way because v must have the type
401 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
402 -- have to handle the case where v is a wild var correctly.
403 --
404
405 -- FIXME: this is too lazy
406 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
407             -> [(AltCon, [Var], CoreExprWithFVs)]
408             -> VM VExpr
409 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
410   = do
411       vscrut         <- vectExpr scrut
412       (vty, lty)     <- vectAndLiftType ty
413       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
414       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
415
416 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
417   = do
418       vscrut         <- vectExpr scrut
419       (vty, lty)     <- vectAndLiftType ty
420       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
421       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
422
423 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
424   = do
425       (vty, lty) <- vectAndLiftType ty
426       vexpr      <- vectExpr scrut
427       (vbndr, (vbndrs, (vect_body, lift_body)))
428          <- vect_scrut_bndr
429           . vectBndrsIn bndrs
430           $ vectExpr body
431       let (vect_bndrs, lift_bndrs) = unzip vbndrs
432       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
433       vect_dc <- maybeV (lookupDataCon dc)
434       let [pdata_dc] = tyConDataCons pdata_tc
435
436       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
437           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
438
439       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
440   where
441     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
442                     | otherwise         = vectBndrIn bndr
443
444     mk_wild_case expr ty dc bndrs body
445       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
446
447 vectAlgCase tycon _ty_args scrut bndr ty alts
448   = do
449       vect_tc     <- maybeV (lookupTyCon tycon)
450       (vty, lty)  <- vectAndLiftType ty
451
452       let arity = length (tyConDataCons vect_tc)
453       sel_ty <- builtin (selTy arity)
454       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
455       let sel = Var sel_bndr
456
457       (vbndr, valts) <- vect_scrut_bndr
458                       $ mapM (proc_alt arity sel vty lty) alts'
459       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
460
461       vexpr <- vectExpr scrut
462       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
463       let [pdata_dc] = tyConDataCons pdata_tc
464
465       let (vect_bodies, lift_bodies) = unzip vbodies
466
467       vdummy <- newDummyVar (exprType vect_scrut)
468       ldummy <- newDummyVar (exprType lift_scrut)
469       let vect_case = Case vect_scrut vdummy vty
470                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
471
472       lc <- builtin liftingContext
473       lbody <- combinePD vty (Var lc) sel lift_bodies
474       let lift_case = Case lift_scrut ldummy lty
475                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
476                              lbody)]
477
478       return . vLet (vNonRec vbndr vexpr)
479              $ (vect_case, lift_case)
480   where
481     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
482                     | otherwise         = vectBndrIn bndr
483
484     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
485
486     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
487     cmp DEFAULT       DEFAULT       = EQ
488     cmp DEFAULT       _             = LT
489     cmp _             DEFAULT       = GT
490     cmp _             _             = panic "vectAlgCase/cmp"
491
492     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
493       = do
494           vect_dc <- maybeV (lookupDataCon dc)
495           let ntag = dataConTagZ vect_dc
496               tag  = mkDataConTag vect_dc
497               fvs  = freeVarsOf body `delVarSetList` bndrs
498
499           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
500           lc        <- builtin liftingContext
501           elems     <- builtin (selElements arity ntag)
502
503           (vbndrs, vbody)
504             <- vectBndrsIn bndrs
505              . localV
506              $ do
507                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
508                            . filter isLocalId
509                            $ varSetElems fvs
510                  (ve, le) <- vectExpr body
511                  return (ve, Case (elems `App` sel) lc lty
512                              [(DEFAULT, [], (mkLets (concat binds) le))])
513                  -- empty    <- emptyPD vty
514                  -- return (ve, Case (elems `App` sel) lc lty
515                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
516                  --                             $ mkLets (concat binds) le),
517                  --               (LitAlt (mkMachInt 0), [], empty)])
518           let (vect_bndrs, lift_bndrs) = unzip vbndrs
519           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
520
521     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
522
523     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
524
525     pack_var len tags t v
526       = do
527           r <- lookupVar v
528           case r of
529             Local (vv, lv) ->
530               do
531                 lv'  <- cloneVar lv
532                 expr <- packByTagPD (idType vv) (Var lv) len tags t
533                 updLEnv (\env -> env { local_vars = extendVarEnv
534                                                 (local_vars env) v (vv, lv') })
535                 return [(NonRec lv' expr)]
536
537             _ -> return []
538