Fixed calling convention for unboxed tuples
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1
2 module Vectorise( vectorise )
3 where
4
5 import VectMonad
6 import VectUtils
7 import VectType
8 import VectCore
9
10 import HscTypes hiding      ( MonadThings(..) )
11
12 import Module               ( PackageId )
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold           ( mkInlineRule )
16 import MkCore               ( mkWildCase )
17 import CoreFVs
18 import CoreMonad            ( CoreM, getHscEnv )
19 import DataCon
20 import TyCon
21 import Type
22 import FamInstEnv           ( extendFamInstEnvList )
23 import Var
24 import VarEnv
25 import VarSet
26 import Id
27 import OccName
28 import BasicTypes           ( isLoopBreaker )
29
30 import Literal              ( Literal, mkMachInt )
31 import TysWiredIn
32 import TysPrim              ( intPrimTy )
33
34 import Outputable
35 import FastString
36 import Util                 ( zipLazy )
37 import Control.Monad
38 import Data.List            ( sortBy, unzip4 )
39
40 vectorise :: PackageId -> ModGuts -> CoreM ModGuts
41 vectorise backend guts = do
42     hsc_env <- getHscEnv
43     liftIO $ vectoriseIO backend hsc_env guts
44
45 vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
46 vectoriseIO backend hsc_env guts
47   = do
48       eps <- hscEPS hsc_env
49       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
50       Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
51       return (guts' { mg_vect_info = info' })
52
53 vectModule :: ModGuts -> VM ModGuts
54 vectModule guts
55   = do
56       (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
57
58       let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
59       updGEnv (setFamInstEnv fam_inst_env')
60
61       -- dicts   <- mapM buildPADict pa_insts
62       -- workers <- mapM vectDataConWorkers pa_insts
63       binds'  <- mapM vectTopBind (mg_binds guts)
64       return $ guts { mg_types        = types'
65                     , mg_binds        = Rec tc_binds : binds'
66                     , mg_fam_inst_env = fam_inst_env'
67                     , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
68                     }
69
70 vectTopBind :: CoreBind -> VM CoreBind
71 vectTopBind b@(NonRec var expr)
72   = do
73       (inline, expr') <- vectTopRhs var expr
74       var' <- vectTopBinder var inline expr'
75       hs    <- takeHoisted
76       cexpr <- tryConvert var var' expr
77       return . Rec $ (var, cexpr) : (var', expr') : hs
78   `orElseV`
79     return b
80
81 vectTopBind b@(Rec bs)
82   = do
83       (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) ->
84         do
85           vars' <- sequence [vectTopBinder var inline rhs
86                                | (var, ~(inline, rhs))
87                                  <- zipLazy vars (zip inlines rhss)]
88           (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs
89           return (vars', inlines', exprs')
90       hs     <- takeHoisted
91       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
92       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
93   `orElseV`
94     return b
95   where
96     (vars, exprs) = unzip bs
97
98 -- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
99 -- used inside of fixV in vectTopBind
100 vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var
101 vectTopBinder var inline expr
102   = do
103       vty  <- vectType (idType var)
104       var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
105       defGlobalVar var var'
106       return var'
107   where
108     unfolding = case inline of
109                   Inline arity -> mkInlineRule InlSat expr arity
110                   DontInline   -> noUnfolding
111
112 vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr)
113 vectTopRhs var expr
114   = closedV
115   $ do
116       (inline, vexpr) <- inBind var
117                        $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
118                                       (freeVars expr)
119       return (inline, vectorised vexpr)
120
121 tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
122 tryConvert var vect_var rhs
123   = fromVect (idType var) (Var vect_var) `orElseV` return rhs
124
125 -- ----------------------------------------------------------------------------
126 -- Bindings
127
128 vectBndr :: Var -> VM VVar
129 vectBndr v
130   = do
131       (vty, lty) <- vectAndLiftType (idType v)
132       let vv = v `Id.setIdType` vty
133           lv = v `Id.setIdType` lty
134       updLEnv (mapTo vv lv)
135       return (vv, lv)
136   where
137     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
138
139 vectBndrNew :: Var -> FastString -> VM VVar
140 vectBndrNew v fs
141   = do
142       vty <- vectType (idType v)
143       vv  <- newLocalVVar fs vty
144       updLEnv (upd vv)
145       return vv
146   where
147     upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }
148
149 vectBndrIn :: Var -> VM a -> VM (VVar, a)
150 vectBndrIn v p
151   = localV
152   $ do
153       vv <- vectBndr v
154       x <- p
155       return (vv, x)
156
157 vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
158 vectBndrNewIn v fs p
159   = localV
160   $ do
161       vv <- vectBndrNew v fs
162       x  <- p
163       return (vv, x)
164
165 vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
166 vectBndrsIn vs p
167   = localV
168   $ do
169       vvs <- mapM vectBndr vs
170       x <- p
171       return (vvs, x)
172
173 -- ----------------------------------------------------------------------------
174 -- Expressions
175
176 vectVar :: Var -> VM VExpr
177 vectVar v
178   = do
179       r <- lookupVar v
180       case r of
181         Local (vv,lv) -> return (Var vv, Var lv)
182         Global vv     -> do
183                            let vexpr = Var vv
184                            lexpr <- liftPD vexpr
185                            return (vexpr, lexpr)
186
187 vectPolyVar :: Var -> [Type] -> VM VExpr
188 vectPolyVar v tys
189   = do
190       vtys <- mapM vectType tys
191       r <- lookupVar v
192       case r of
193         Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
194                                      (polyApply (Var lv) vtys)
195         Global poly    -> do
196                             vexpr <- polyApply (Var poly) vtys
197                             lexpr <- liftPD vexpr
198                             return (vexpr, lexpr)
199
200 vectLiteral :: Literal -> VM VExpr
201 vectLiteral lit
202   = do
203       lexpr <- liftPD (Lit lit)
204       return (Lit lit, lexpr)
205
206 vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
207 vectPolyExpr loop_breaker (_, AnnNote note expr)
208   = do
209       (inline, expr') <- vectPolyExpr loop_breaker expr
210       return (inline, vNote note expr')
211 vectPolyExpr loop_breaker expr
212   = do
213       arity <- polyArity tvs
214       polyAbstract tvs $ \args ->
215         do
216           (inline, mono') <- vectFnExpr False loop_breaker mono
217           return (addInlineArity inline arity,
218                   mapVect (mkLams $ tvs ++ args) mono')
219   where
220     (tvs, mono) = collectAnnTypeBinders expr
221
222 vectExpr :: CoreExprWithFVs -> VM VExpr
223 vectExpr (_, AnnType ty)
224   = liftM vType (vectType ty)
225
226 vectExpr (_, AnnVar v) = vectVar v
227
228 vectExpr (_, AnnLit lit) = vectLiteral lit
229
230 vectExpr (_, AnnNote note expr)
231   = liftM (vNote note) (vectExpr expr)
232
233 vectExpr e@(_, AnnApp _ arg)
234   | isAnnTypeArg arg
235   = vectTyAppExpr fn tys
236   where
237     (fn, tys) = collectAnnTypeArgs e
238
239 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
240   | Just con <- isDataConId_maybe v
241   , is_special_con con
242   = do
243       let vexpr = App (Var v) (Lit lit)
244       lexpr <- liftPD vexpr
245       return (vexpr, lexpr)
246   where
247     is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
248
249
250 vectExpr (_, AnnApp fn arg)
251   = do
252       arg_ty' <- vectType arg_ty
253       res_ty' <- vectType res_ty
254       fn'     <- vectExpr fn
255       arg'    <- vectExpr arg
256       mkClosureApp arg_ty' res_ty' fn' arg'
257   where
258     (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
259
260 vectExpr (_, AnnCase scrut bndr ty alts)
261   | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
262   , isAlgTyCon tycon
263   = vectAlgCase tycon ty_args scrut bndr ty alts
264   where
265     scrut_ty = exprType (deAnnotate scrut)
266
267 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
268   = do
269       vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
270       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
271       return $ vLet (vNonRec vbndr vrhs) vbody
272
273 vectExpr (_, AnnLet (AnnRec bs) body)
274   = do
275       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
276                                 $ liftM2 (,)
277                                   (zipWithM vect_rhs bndrs rhss)
278                                   (vectExpr body)
279       return $ vLet (vRec vbndrs vrhss) vbody
280   where
281     (bndrs, rhss) = unzip bs
282
283     vect_rhs bndr rhs = localV
284                       . inBind bndr
285                       . liftM snd
286                       $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
287
288 vectExpr e@(_, AnnLam bndr _)
289   | isId bndr = liftM snd $ vectFnExpr True False e
290 {-
291 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
292                 `orElseV` vectLam True fvs bs body
293   where
294     (bs,body) = collectAnnValBinders e
295 -}
296
297 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
298
299 vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
300 vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
301   | isId bndr = onlyIfV (isEmptyVarSet fvs)
302                         (mark DontInline . vectScalarLam bs $ deAnnotate body)
303                 `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
304   where
305     (bs,body) = collectAnnValBinders e
306 vectFnExpr _ _ e = mark DontInline $ vectExpr e
307
308 mark :: Inline -> VM a -> VM (Inline, a)
309 mark b p = do { x <- p; return (b,x) }
310
311 vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
312 vectScalarLam args body
313   = do
314       scalars <- globalScalars
315       onlyIfV (all is_scalar_ty arg_tys
316                && is_scalar_ty res_ty
317                && is_scalar (extendVarSetList scalars args) body)
318         $ do
319             fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
320             zipf <- zipScalars arg_tys res_ty
321             clo <- scalarClosure arg_tys res_ty (Var fn_var)
322                                                 (zipf `App` Var fn_var)
323             clo_var <- hoistExpr (fsLit "clo") clo DontInline
324             lclo <- liftPD (Var clo_var)
325             return (Var clo_var, lclo)
326   where
327     arg_tys = map idType args
328     res_ty  = exprType body
329
330     is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
331                     = tycon == intTyCon
332                       || tycon == floatTyCon
333                       || tycon == doubleTyCon
334
335                     | otherwise = False
336
337     is_scalar vs (Var v)     = v `elemVarSet` vs
338     is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
339     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
340     is_scalar _ _            = False
341
342 vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
343 vectLam inline loop_breaker fvs bs body
344   = do
345       tyvars <- localTyVars
346       (vs, vvs) <- readLEnv $ \env ->
347                    unzip [(var, vv) | var <- varSetElems fvs
348                                     , Just vv <- [lookupVarEnv (local_vars env) var]]
349
350       arg_tys <- mapM (vectType . idType) bs
351       res_ty  <- vectType (exprType $ deAnnotate body)
352
353       buildClosures tyvars vvs arg_tys res_ty
354         . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
355         $ do
356             lc <- builtin liftingContext
357             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
358                                            (vectExpr body)
359             vbody' <- break_loop lc res_ty vbody
360             return $ vLams lc vbndrs vbody'
361   where
362     maybe_inline n | inline    = Inline n
363                    | otherwise = DontInline
364
365     break_loop lc ty (ve, le)
366       | loop_breaker
367       = do
368           empty <- emptyPD ty
369           lty <- mkPDataType ty
370           return (ve, mkWildCase (Var lc) intPrimTy lty
371                         [(DEFAULT, [], le),
372                          (LitAlt (mkMachInt 0), [], empty)])
373
374       | otherwise = return (ve, le)
375  
376
377 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
378 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
379 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
380                         (ppr $ deAnnotate e `mkTyApps` tys)
381
382 -- We convert
383 --
384 --   case e :: t of v { ... }
385 --
386 -- to
387 --
388 --   V:    let v' = e in case v' of _ { ... }
389 --   L:    let v' = e in case v' `cast` ... of _ { ... }
390 --
391 -- When lifting, we have to do it this way because v must have the type
392 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
393 -- have to handle the case where v is a wild var correctly.
394 --
395
396 -- FIXME: this is too lazy
397 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
398             -> [(AltCon, [Var], CoreExprWithFVs)]
399             -> VM VExpr
400 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
401   = do
402       vscrut         <- vectExpr scrut
403       (vty, lty)     <- vectAndLiftType ty
404       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
405       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
406
407 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
408   = do
409       vscrut         <- vectExpr scrut
410       (vty, lty)     <- vectAndLiftType ty
411       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
412       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
413
414 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
415   = do
416       (vty, lty) <- vectAndLiftType ty
417       vexpr      <- vectExpr scrut
418       (vbndr, (vbndrs, (vect_body, lift_body)))
419          <- vect_scrut_bndr
420           . vectBndrsIn bndrs
421           $ vectExpr body
422       let (vect_bndrs, lift_bndrs) = unzip vbndrs
423       (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
424       vect_dc <- maybeV (lookupDataCon dc)
425       let [pdata_dc] = tyConDataCons pdata_tc
426
427       let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
428           lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
429
430       return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
431   where
432     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
433                     | otherwise         = vectBndrIn bndr
434
435     mk_wild_case expr ty dc bndrs body
436       = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
437
438 vectAlgCase tycon _ty_args scrut bndr ty alts
439   = do
440       vect_tc     <- maybeV (lookupTyCon tycon)
441       (vty, lty)  <- vectAndLiftType ty
442
443       let arity = length (tyConDataCons vect_tc)
444       sel_ty <- builtin (selTy arity)
445       sel_bndr <- newLocalVar (fsLit "sel") sel_ty
446       let sel = Var sel_bndr
447
448       (vbndr, valts) <- vect_scrut_bndr
449                       $ mapM (proc_alt arity sel vty lty) alts'
450       let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
451
452       vexpr <- vectExpr scrut
453       (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
454       let [pdata_dc] = tyConDataCons pdata_tc
455
456       let (vect_bodies, lift_bodies) = unzip vbodies
457
458       vdummy <- newDummyVar (exprType vect_scrut)
459       ldummy <- newDummyVar (exprType lift_scrut)
460       let vect_case = Case vect_scrut vdummy vty
461                            (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
462
463       lc <- builtin liftingContext
464       lbody <- combinePD vty (Var lc) sel lift_bodies
465       let lift_case = Case lift_scrut ldummy lty
466                            [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
467                              lbody)]
468
469       return . vLet (vNonRec vbndr vexpr)
470              $ (vect_case, lift_case)
471   where
472     vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
473                     | otherwise         = vectBndrIn bndr
474
475     alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
476
477     cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
478     cmp DEFAULT       DEFAULT       = EQ
479     cmp DEFAULT       _             = LT
480     cmp _             DEFAULT       = GT
481     cmp _             _             = panic "vectAlgCase/cmp"
482
483     proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
484       = do
485           vect_dc <- maybeV (lookupDataCon dc)
486           let ntag = dataConTagZ vect_dc
487               tag  = mkDataConTag vect_dc
488               fvs  = freeVarsOf body `delVarSetList` bndrs
489
490           sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
491           lc        <- builtin liftingContext
492           elems     <- builtin (selElements arity ntag)
493
494           (vbndrs, vbody)
495             <- vectBndrsIn bndrs
496              . localV
497              $ do
498                  binds    <- mapM (pack_var (Var lc) sel_tags tag)
499                            . filter isLocalId
500                            $ varSetElems fvs
501                  (ve, le) <- vectExpr body
502                  return (ve, Case (elems `App` sel) lc lty
503                              [(DEFAULT, [], (mkLets (concat binds) le))])
504                  -- empty    <- emptyPD vty
505                  -- return (ve, Case (elems `App` sel) lc lty
506                  --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
507                  --                             $ mkLets (concat binds) le),
508                  --               (LitAlt (mkMachInt 0), [], empty)])
509           let (vect_bndrs, lift_bndrs) = unzip vbndrs
510           return (vect_dc, vect_bndrs, lift_bndrs, vbody)
511
512     proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
513
514     mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
515
516     pack_var len tags t v
517       = do
518           r <- lookupVar v
519           case r of
520             Local (vv, lv) ->
521               do
522                 lv'  <- cloneVar lv
523                 expr <- packByTagPD (idType vv) (Var lv) len tags t
524                 updLEnv (\env -> env { local_vars = extendVarEnv
525                                                 (local_vars env) v (vv, lv') })
526                 return [(NonRec lv' expr)]
527
528             _ -> return []
529