Vectorise top-level bindings of a module
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 module Vectorise( vectorise )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import DynFlags
10 import HscTypes
11
12 import CoreLint             ( showPass, endPass )
13 import CoreSyn
14 import CoreUtils
15 import CoreFVs
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Var
21 import VarEnv
22 import VarSet
23 import Name                 ( mkSysTvName, getName )
24 import NameEnv
25 import Id
26 import MkId                 ( unwrapFamInstScrut )
27 import OccName
28
29 import DsMonad hiding (mapAndUnzipM)
30 import DsUtils              ( mkCoreTup, mkCoreTupTy )
31
32 import PrelNames
33 import TysWiredIn
34 import BasicTypes           ( Boxity(..) )
35
36 import Outputable
37 import FastString
38 import Control.Monad        ( liftM, liftM2, mapAndUnzipM, zipWithM_ )
39 import Data.Maybe           ( maybeToList )
40
41 vectorise :: HscEnv -> ModGuts -> IO ModGuts
42 vectorise hsc_env guts
43   | not (Opt_Vectorise `dopt` dflags) = return guts
44   | otherwise
45   = do
46       showPass dflags "Vectorisation"
47       eps <- hscEPS hsc_env
48       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
49       Just (info', guts') <- initV hsc_env guts info (vectModule guts)
50       endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
51       return $ guts' { mg_vect_info = info' }
52   where
53     dflags = hsc_dflags hsc_env
54
55 vectModule :: ModGuts -> VM ModGuts
56 vectModule guts
57   = do
58       binds' <- mapM vectTopBind (mg_binds guts)
59       return $ guts { mg_binds = binds' }
60
61 vectTopBind :: CoreBind -> VM CoreBind
62 vectTopBind b@(NonRec var expr)
63   = do
64       var'  <- vectTopBinder var
65       expr' <- vectTopRhs expr
66       hs    <- takeHoisted
67       return . Rec $ (var, expr) : (var', expr') : hs
68   `orElseV`
69     return b
70
71 vectTopBind b@(Rec bs)
72   = do
73       vars'  <- mapM vectTopBinder vars
74       exprs' <- mapM vectTopRhs exprs
75       hs     <- takeHoisted
76       return . Rec $ bs ++ zip vars' exprs' ++ hs
77   `orElseV`
78     return b
79   where
80     (vars, exprs) = unzip bs
81
82 vectTopBinder :: Var -> VM Var
83 vectTopBinder var
84   = do
85       vty <- liftM (mkForAllTys tyvars) $ vectType mono_ty
86       name <- cloneName mkVectOcc (getName var)
87       let var' | isExportedId var = Id.mkExportedLocalId name vty
88                | otherwise        = Id.mkLocalId         name vty
89       defGlobalVar var var'
90       return var'
91   where
92     (tyvars, mono_ty) = splitForAllTys (idType var)
93     
94 vectTopRhs :: CoreExpr -> VM CoreExpr
95 vectTopRhs = liftM fst . closedV . vectPolyExpr (panic "Empty lifting context") . freeVars
96
97 -- ----------------------------------------------------------------------------
98 -- Bindings
99
100 vectBndr :: Var -> VM (Var, Var)
101 vectBndr v
102   = do
103       vty <- vectType (idType v)
104       lty <- mkPArrayType vty
105       let vv = v `Id.setIdType` vty
106           lv = v `Id.setIdType` lty
107       updLEnv (mapTo vv lv)
108       return (vv, lv)
109   where
110     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
111
112 vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
113 vectBndrIn v p
114   = localV
115   $ do
116       (vv, lv) <- vectBndr v
117       x <- p
118       return (vv, lv, x)
119
120 vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
121 vectBndrsIn vs p
122   = localV
123   $ do
124       (vvs, lvs) <- mapAndUnzipM vectBndr vs
125       x <- p
126       return (vvs, lvs, x)
127
128 -- ----------------------------------------------------------------------------
129 -- Expressions
130
131 replicateP :: CoreExpr -> CoreExpr -> VM CoreExpr
132 replicateP expr len
133   = do
134       dict <- paDictOfType ty
135       rep  <- builtin replicatePAVar
136       return $ mkApps (Var rep) [Type ty, dict, expr, len]
137   where
138     ty = exprType expr
139
140 capply :: (CoreExpr, CoreExpr) -> (CoreExpr, CoreExpr) -> VM (CoreExpr, CoreExpr)
141 capply (vfn, lfn) (varg, larg)
142   = do
143       apply  <- builtin applyClosureVar
144       applyP <- builtin applyClosurePVar
145       return (mkApps (Var apply)  [Type arg_ty, Type res_ty, vfn, varg],
146               mkApps (Var applyP) [Type arg_ty, Type res_ty, lfn, larg])
147   where
148     fn_ty            = exprType vfn
149     (arg_ty, res_ty) = splitClosureTy fn_ty
150
151 vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
152 vectVar lc v
153   = do
154       r <- lookupVar v
155       case r of
156         Local es     -> return es
157         Global vexpr -> do
158                           lexpr <- replicateP vexpr lc
159                           return (vexpr, lexpr)
160
161 vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
162 vectPolyVar lc v tys
163   = do
164       r <- lookupVar v
165       case r of
166         Local (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
167         Global poly          -> do
168                                   vexpr <- mk_app poly
169                                   lexpr <- replicateP vexpr lc
170                                   return (vexpr, lexpr)
171   where
172     mk_app e = applyToTypes e =<< mapM vectType tys
173
174 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
175 abstractOverTyVars tvs p
176   = do
177       mdicts <- mapM mk_dict_var tvs
178       zipWithM_ (\tv -> maybe (deleteTyVarPA tv) (extendTyVarPA tv . Var)) tvs mdicts
179       p (mk_lams mdicts)
180   where
181     mk_dict_var tv = do
182                        r <- paDictArgType tv
183                        case r of
184                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
185                          Nothing -> return Nothing
186
187     mk_lams mdicts = mkLams [arg | (tv, mdict) <- zip tvs mdicts
188                                  , arg <- tv : maybeToList mdict]
189
190 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
191 applyToTypes expr tys
192   = do
193       dicts <- mapM paDictOfType tys
194       return $ mkApps expr [arg | (ty, dict) <- zip tys dicts
195                                 , arg <- [Type ty, dict]]
196     
197
198 vectPolyExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
199 vectPolyExpr lc expr
200   = localV
201   . abstractOverTyVars tvs $ \mk_lams ->
202     -- FIXME: shadowing (tvs in lc)
203     do
204       (vmono, lmono) <- vectExpr lc mono
205       return $ (mk_lams vmono, mk_lams lmono)
206   where
207     (tvs, mono) = collectAnnTypeBinders expr  
208                 
209 vectExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
210 vectExpr lc (_, AnnType ty)
211   = do
212       vty <- vectType ty
213       return (Type vty, Type vty)
214
215 vectExpr lc (_, AnnVar v)   = vectVar lc v
216
217 vectExpr lc (_, AnnLit lit)
218   = do
219       let vexpr = Lit lit
220       lexpr <- replicateP vexpr lc
221       return (vexpr, lexpr)
222
223 vectExpr lc (_, AnnNote note expr)
224   = do
225       (vexpr, lexpr) <- vectExpr lc expr
226       return (Note note vexpr, Note note lexpr)
227
228 vectExpr lc e@(_, AnnApp _ arg)
229   | isAnnTypeArg arg
230   = vectTyAppExpr lc fn tys
231   where
232     (fn, tys) = collectAnnTypeArgs e
233
234 vectExpr lc (_, AnnApp fn arg)
235   = do
236       fn'  <- vectExpr lc fn
237       arg' <- vectExpr lc arg
238       capply fn' arg'
239
240 vectExpr lc (_, AnnCase expr bndr ty alts)
241   = panic "vectExpr: case"
242
243 vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
244   = do
245       (vrhs, lrhs) <- vectPolyExpr lc rhs
246       (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
247       return (Let (NonRec vbndr vrhs) vbody,
248               Let (NonRec lbndr lrhs) lbody)
249
250 vectExpr lc (_, AnnLet (AnnRec prs) body)
251   = do
252       (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
253       return (Let (Rec (zip vbndrs vrhss)) vbody,
254               Let (Rec (zip lbndrs lrhss)) lbody)
255   where
256     (bndrs, rhss) = unzip prs
257     
258     vect = do
259              (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
260              (vbody, lbody) <- vectPolyExpr lc body
261              return (vrhss, vbody, lrhss, lbody)
262
263 vectExpr lc e@(_, AnnLam bndr body)
264   | isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
265
266 vectExpr lc (fvs, AnnLam bndr body)
267   = do
268       let tyvars = filter isTyVar (varSetElems fvs)
269       info <- mkCEnvInfo fvs bndr body
270       (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
271
272       vfn_var <- hoistExpr FSLIT("vfn") poly_vfn
273       lfn_var <- hoistExpr FSLIT("lfn") poly_lfn
274
275       let (venv, lenv) = mkClosureEnvs info lc
276
277       let env_ty = cenv_vty info
278
279       pa_dict <- paDictOfType env_ty
280
281       arg_ty <- vectType (varType bndr)
282       res_ty <- vectType (exprType $ deAnnotate body)
283
284       -- FIXME: move the functions to the top level
285       mono_vfn <- applyToTypes (Var vfn_var) (map TyVarTy tyvars)
286       mono_lfn <- applyToTypes (Var lfn_var) (map TyVarTy tyvars)
287
288       mk_clo <- builtin mkClosureVar
289       mk_cloP <- builtin mkClosurePVar
290
291       let vclo = Var mk_clo  `mkTyApps` [arg_ty, res_ty, env_ty]
292                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, venv]
293           
294           lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
295                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, lenv]
296
297       return (vclo, lclo)
298        
299
300 data CEnvInfo = CEnvInfo {
301                cenv_vars         :: [Var]
302              , cenv_values       :: [(CoreExpr, CoreExpr)]
303              , cenv_vty          :: Type
304              , cenv_lty          :: Type
305              , cenv_repr_tycon   :: TyCon
306              , cenv_repr_tyargs  :: [Type]
307              , cenv_repr_datacon :: DataCon
308              }
309
310 mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
311 mkCEnvInfo fvs arg body
312   = do
313       locals <- readLEnv local_vars
314       let
315           (vars, vals) = unzip
316                  [(var, val) | var      <- varSetElems fvs
317                              , Just val <- [lookupVarEnv locals var]]
318       vtys <- mapM (vectType . varType) vars
319
320       (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
321       lty <- mkPArrayType vty
322       
323       return $ CEnvInfo {
324                  cenv_vars         = vars
325                , cenv_values       = vals
326                , cenv_vty          = vty
327                , cenv_lty          = lty
328                , cenv_repr_tycon   = repr_tycon
329                , cenv_repr_tyargs  = repr_tyargs
330                , cenv_repr_datacon = repr_datacon
331                }
332   where
333     mk_env_ty [vty]
334       = return (vty, error "absent cinfo_repr_tycon"
335                    , error "absent cinfo_repr_tyargs"
336                    , error "absent cinfo_repr_datacon")
337
338     mk_env_ty vtys
339       = do
340           let ty = mkCoreTupTy vtys
341           (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
342           let [repr_con] = tyConDataCons repr_tc
343           return (ty, repr_tc, repr_tyargs, repr_con)
344
345     
346
347 mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
348 mkClosureEnvs info lc
349   | [] <- vals
350   = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
351                                [lc, Var unitDataConId])
352
353   | [(vval, lval)] <- vals
354   = (vval, lval)
355
356   | otherwise
357   = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
358                       `mkTyApps` cenv_repr_tyargs info
359                       `mkApps`   (lc : lvals))
360
361   where
362     vals = cenv_values info
363     (vvals, lvals) = unzip vals
364
365 mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
366              -> VM (CoreExpr, CoreExpr)
367 mkClosureFns info tyvars arg body
368   = closedV
369   . abstractOverTyVars tyvars
370   $ \mk_tlams ->
371   do
372     (vfn, lfn) <- mkClosureMonoFns info arg body
373     return (mk_tlams vfn, mk_tlams lfn)
374
375 mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
376 mkClosureMonoFns info arg body
377   = do
378       lc_bndr <- newLocalVar FSLIT("lc") intTy
379       (varg : vbndrs, larg : lbndrs, (vbody, lbody))
380         <- vectBndrsIn (arg : cenv_vars info)
381                        (vectExpr (Var lc_bndr) body)
382
383       venv_bndr <- newLocalVar FSLIT("env") vty
384       lenv_bndr <- newLocalVar FSLIT("env") lty
385
386       let vcase = bind_venv (Var venv_bndr) vbody vbndrs
387       lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
388       return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
389   where
390     vty = cenv_vty info
391     lty = cenv_lty info
392
393     arity = length (cenv_vars info)
394
395     bind_venv venv vbody []      = vbody
396     bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
397     bind_venv venv vbody vbndrs
398       = Case venv (mkWildId vty) (exprType vbody)
399              [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
400
401     bind_lenv lenv lbody lc_bndr [lbndr]
402       = do
403           lengthPA <- builtin lengthPAVar
404           return . Let (NonRec lbndr lenv)
405                  $ Case (mkApps (Var lengthPA) [Type vty, (Var lbndr)])
406                         lc_bndr
407                         intTy
408                         [(DEFAULT, [], lbody)]
409
410     bind_lenv lenv lbody lc_bndr lbndrs
411       = return
412       $ Case (unwrapFamInstScrut (cenv_repr_tycon info)
413                                  (cenv_repr_tyargs info)
414                                  lenv)
415              (mkWildId lty)
416              (exprType lbody)
417              [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs, lbody)]
418           
419 vectTyAppExpr :: CoreExpr -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
420 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
421 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
422
423 -- ----------------------------------------------------------------------------
424 -- Types
425
426 vectTyCon :: TyCon -> VM TyCon
427 vectTyCon tc
428   | isFunTyCon tc        = builtin closureTyCon
429   | isBoxedTupleTyCon tc = return tc
430   | isUnLiftedTyCon tc   = return tc
431   | otherwise = do
432                   r <- lookupTyCon tc
433                   case r of
434                     Just tc' -> return tc'
435
436                     -- FIXME: just for now
437                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
438
439 vectType :: Type -> VM Type
440 vectType ty | Just ty' <- coreView ty = vectType ty
441 vectType (TyVarTy tv) = return $ TyVarTy tv
442 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
443 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
444 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
445                                              (mapM vectType [ty1,ty2])
446 vectType (ForAllTy tv ty)
447   = do
448       r   <- paDictArgType tv
449       ty' <- vectType ty
450       return $ ForAllTy tv (wrap r ty')
451   where
452     wrap Nothing      = id
453     wrap (Just pa_ty) = FunTy pa_ty
454
455 vectType ty = pprPanic "vectType:" (ppr ty)
456