Collect hoisted vectorised functions
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 module Vectorise( vectorise )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import DynFlags
10 import HscTypes
11
12 import CoreLint             ( showPass, endPass )
13 import CoreSyn
14 import CoreUtils
15 import CoreFVs
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Var
21 import VarEnv
22 import VarSet
23 import Name                 ( mkSysTvName )
24 import NameEnv
25 import Id
26 import MkId                 ( unwrapFamInstScrut )
27
28 import DsMonad hiding (mapAndUnzipM)
29 import DsUtils              ( mkCoreTup, mkCoreTupTy )
30
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes           ( Boxity(..) )
34
35 import Outputable
36 import FastString
37 import Control.Monad        ( liftM, liftM2, mapAndUnzipM, zipWithM_ )
38 import Data.Maybe           ( maybeToList )
39
40 vectorise :: HscEnv -> ModGuts -> IO ModGuts
41 vectorise hsc_env guts
42   | not (Opt_Vectorise `dopt` dflags) = return guts
43   | otherwise
44   = do
45       showPass dflags "Vectorisation"
46       eps <- hscEPS hsc_env
47       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
48       Just (info', guts') <- initV hsc_env guts info (vectModule guts)
49       endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
50       return $ guts' { mg_vect_info = info' }
51   where
52     dflags = hsc_dflags hsc_env
53
54 vectModule :: ModGuts -> VM ModGuts
55 vectModule guts = return guts
56
57 -- ----------------------------------------------------------------------------
58 -- Bindings
59
60 vectBndr :: Var -> VM (Var, Var)
61 vectBndr v
62   = do
63       vty <- vectType (idType v)
64       lty <- mkPArrayType vty
65       let vv = v `Id.setIdType` vty
66           lv = v `Id.setIdType` lty
67       updLEnv (mapTo vv lv)
68       return (vv, lv)
69   where
70     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
71
72 vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
73 vectBndrIn v p
74   = localV
75   $ do
76       (vv, lv) <- vectBndr v
77       x <- p
78       return (vv, lv, x)
79
80 vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
81 vectBndrsIn vs p
82   = localV
83   $ do
84       (vvs, lvs) <- mapAndUnzipM vectBndr vs
85       x <- p
86       return (vvs, lvs, x)
87
88 -- ----------------------------------------------------------------------------
89 -- Expressions
90
91 replicateP :: CoreExpr -> CoreExpr -> VM CoreExpr
92 replicateP expr len
93   = do
94       dict <- paDictOfType ty
95       rep  <- builtin replicatePAVar
96       return $ mkApps (Var rep) [Type ty, dict, expr, len]
97   where
98     ty = exprType expr
99
100 capply :: (CoreExpr, CoreExpr) -> (CoreExpr, CoreExpr) -> VM (CoreExpr, CoreExpr)
101 capply (vfn, lfn) (varg, larg)
102   = do
103       apply  <- builtin applyClosureVar
104       applyP <- builtin applyClosurePVar
105       return (mkApps (Var apply)  [Type arg_ty, Type res_ty, vfn, varg],
106               mkApps (Var applyP) [Type arg_ty, Type res_ty, lfn, larg])
107   where
108     fn_ty            = exprType vfn
109     (arg_ty, res_ty) = splitClosureTy fn_ty
110
111 vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
112 vectVar lc v = local v `orElseV` global v
113   where
114     local  v = maybeV (readLEnv $ \env -> lookupVarEnv (local_vars env) v)
115     global v = do
116                  vexpr <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
117                  lexpr <- replicateP vexpr lc
118                  return (vexpr, lexpr)
119
120 vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
121 vectPolyVar lc v tys
122   = do
123       r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
124       case r of
125         Just (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
126         Nothing ->
127           do
128             poly  <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
129             vexpr <- mk_app poly
130             lexpr <- replicateP vexpr lc
131             return (vexpr, lexpr)
132   where
133     mk_app e = applyToTypes e =<< mapM vectType tys
134
135 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
136 abstractOverTyVars tvs p
137   = do
138       mdicts <- mapM mk_dict_var tvs
139       zipWithM_ (\tv -> maybe (deleteTyVarPA tv) (extendTyVarPA tv . Var)) tvs mdicts
140       p (mk_lams mdicts)
141   where
142     mk_dict_var tv = do
143                        r <- paDictArgType tv
144                        case r of
145                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
146                          Nothing -> return Nothing
147
148     mk_lams mdicts = mkLams [arg | (tv, mdict) <- zip tvs mdicts
149                                  , arg <- tv : maybeToList mdict]
150
151 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
152 applyToTypes expr tys
153   = do
154       dicts <- mapM paDictOfType tys
155       return $ mkApps expr [arg | (ty, dict) <- zip tys dicts
156                                 , arg <- [Type ty, dict]]
157     
158
159 vectPolyExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
160 vectPolyExpr lc expr
161   = localV
162   . abstractOverTyVars tvs $ \mk_lams ->
163     -- FIXME: shadowing (tvs in lc)
164     do
165       (vmono, lmono) <- vectExpr lc mono
166       return $ (mk_lams vmono, mk_lams lmono)
167   where
168     (tvs, mono) = collectAnnTypeBinders expr  
169                 
170 vectExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
171 vectExpr lc (_, AnnType ty)
172   = do
173       vty <- vectType ty
174       return (Type vty, Type vty)
175
176 vectExpr lc (_, AnnVar v)   = vectVar lc v
177
178 vectExpr lc (_, AnnLit lit)
179   = do
180       let vexpr = Lit lit
181       lexpr <- replicateP vexpr lc
182       return (vexpr, lexpr)
183
184 vectExpr lc (_, AnnNote note expr)
185   = do
186       (vexpr, lexpr) <- vectExpr lc expr
187       return (Note note vexpr, Note note lexpr)
188
189 vectExpr lc e@(_, AnnApp _ arg)
190   | isAnnTypeArg arg
191   = vectTyAppExpr lc fn tys
192   where
193     (fn, tys) = collectAnnTypeArgs e
194
195 vectExpr lc (_, AnnApp fn arg)
196   = do
197       fn'  <- vectExpr lc fn
198       arg' <- vectExpr lc arg
199       capply fn' arg'
200
201 vectExpr lc (_, AnnCase expr bndr ty alts)
202   = panic "vectExpr: case"
203
204 vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
205   = do
206       (vrhs, lrhs) <- vectPolyExpr lc rhs
207       (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
208       return (Let (NonRec vbndr vrhs) vbody,
209               Let (NonRec lbndr lrhs) lbody)
210
211 vectExpr lc (_, AnnLet (AnnRec prs) body)
212   = do
213       (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
214       return (Let (Rec (zip vbndrs vrhss)) vbody,
215               Let (Rec (zip lbndrs lrhss)) lbody)
216   where
217     (bndrs, rhss) = unzip prs
218     
219     vect = do
220              (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
221              (vbody, lbody) <- vectPolyExpr lc body
222              return (vrhss, vbody, lrhss, lbody)
223
224 vectExpr lc e@(_, AnnLam bndr body)
225   | isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
226
227 vectExpr lc (fvs, AnnLam bndr body)
228   = do
229       let tyvars = filter isTyVar (varSetElems fvs)
230       info <- mkCEnvInfo fvs bndr body
231       (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
232
233       vfn_var <- hoistExpr FSLIT("vfn") poly_vfn
234       lfn_var <- hoistExpr FSLIT("lfn") poly_lfn
235
236       let (venv, lenv) = mkClosureEnvs info lc
237
238       let env_ty = cenv_vty info
239
240       pa_dict <- paDictOfType env_ty
241
242       arg_ty <- vectType (varType bndr)
243       res_ty <- vectType (exprType $ deAnnotate body)
244
245       -- FIXME: move the functions to the top level
246       mono_vfn <- applyToTypes (Var vfn_var) (map TyVarTy tyvars)
247       mono_lfn <- applyToTypes (Var lfn_var) (map TyVarTy tyvars)
248
249       mk_clo <- builtin mkClosureVar
250       mk_cloP <- builtin mkClosurePVar
251
252       let vclo = Var mk_clo  `mkTyApps` [arg_ty, res_ty, env_ty]
253                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, venv]
254           
255           lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
256                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, lenv]
257
258       return (vclo, lclo)
259        
260
261 data CEnvInfo = CEnvInfo {
262                cenv_vars         :: [Var]
263              , cenv_values       :: [(CoreExpr, CoreExpr)]
264              , cenv_vty          :: Type
265              , cenv_lty          :: Type
266              , cenv_repr_tycon   :: TyCon
267              , cenv_repr_tyargs  :: [Type]
268              , cenv_repr_datacon :: DataCon
269              }
270
271 mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
272 mkCEnvInfo fvs arg body
273   = do
274       locals <- readLEnv local_vars
275       let
276           (vars, vals) = unzip
277                  [(var, val) | var      <- varSetElems fvs
278                              , Just val <- [lookupVarEnv locals var]]
279       vtys <- mapM (vectType . varType) vars
280
281       (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
282       lty <- mkPArrayType vty
283       
284       return $ CEnvInfo {
285                  cenv_vars         = vars
286                , cenv_values       = vals
287                , cenv_vty          = vty
288                , cenv_lty          = lty
289                , cenv_repr_tycon   = repr_tycon
290                , cenv_repr_tyargs  = repr_tyargs
291                , cenv_repr_datacon = repr_datacon
292                }
293   where
294     mk_env_ty [vty]
295       = return (vty, error "absent cinfo_repr_tycon"
296                    , error "absent cinfo_repr_tyargs"
297                    , error "absent cinfo_repr_datacon")
298
299     mk_env_ty vtys
300       = do
301           let ty = mkCoreTupTy vtys
302           (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
303           let [repr_con] = tyConDataCons repr_tc
304           return (ty, repr_tc, repr_tyargs, repr_con)
305
306     
307
308 mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
309 mkClosureEnvs info lc
310   | [] <- vals
311   = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
312                                [lc, Var unitDataConId])
313
314   | [(vval, lval)] <- vals
315   = (vval, lval)
316
317   | otherwise
318   = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
319                       `mkTyApps` cenv_repr_tyargs info
320                       `mkApps`   (lc : lvals))
321
322   where
323     vals = cenv_values info
324     (vvals, lvals) = unzip vals
325
326 mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
327              -> VM (CoreExpr, CoreExpr)
328 mkClosureFns info tyvars arg body
329   = closedV
330   . abstractOverTyVars tyvars
331   $ \mk_tlams ->
332   do
333     (vfn, lfn) <- mkClosureMonoFns info arg body
334     return (mk_tlams vfn, mk_tlams lfn)
335
336 mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
337 mkClosureMonoFns info arg body
338   = do
339       lc_bndr <- newLocalVar FSLIT("lc") intTy
340       (varg : vbndrs, larg : lbndrs, (vbody, lbody))
341         <- vectBndrsIn (arg : cenv_vars info)
342                        (vectExpr (Var lc_bndr) body)
343
344       venv_bndr <- newLocalVar FSLIT("env") vty
345       lenv_bndr <- newLocalVar FSLIT("env") lty
346
347       let vcase = bind_venv (Var venv_bndr) vbody vbndrs
348       lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
349       return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
350   where
351     vty = cenv_vty info
352     lty = cenv_lty info
353
354     arity = length (cenv_vars info)
355
356     bind_venv venv vbody []      = vbody
357     bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
358     bind_venv venv vbody vbndrs
359       = Case venv (mkWildId vty) (exprType vbody)
360              [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
361
362     bind_lenv lenv lbody lc_bndr [lbndr]
363       = do
364           lengthPA <- builtin lengthPAVar
365           return . Let (NonRec lbndr lenv)
366                  $ Case (mkApps (Var lengthPA) [Type vty, (Var lbndr)])
367                         lc_bndr
368                         intTy
369                         [(DEFAULT, [], lbody)]
370
371     bind_lenv lenv lbody lc_bndr lbndrs
372       = return
373       $ Case (unwrapFamInstScrut (cenv_repr_tycon info)
374                                  (cenv_repr_tyargs info)
375                                  lenv)
376              (mkWildId lty)
377              (exprType lbody)
378              [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs, lbody)]
379           
380 vectTyAppExpr :: CoreExpr -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
381 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
382 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
383
384 -- ----------------------------------------------------------------------------
385 -- Types
386
387 vectTyCon :: TyCon -> VM TyCon
388 vectTyCon tc
389   | isFunTyCon tc        = builtin closureTyCon
390   | isBoxedTupleTyCon tc = return tc
391   | isUnLiftedTyCon tc   = return tc
392   | otherwise = do
393                   r <- lookupTyCon tc
394                   case r of
395                     Just tc' -> return tc'
396
397                     -- FIXME: just for now
398                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
399
400 vectType :: Type -> VM Type
401 vectType ty | Just ty' <- coreView ty = vectType ty
402 vectType (TyVarTy tv) = return $ TyVarTy tv
403 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
404 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
405 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
406                                              (mapM vectType [ty1,ty2])
407 vectType (ForAllTy tv ty)
408   = do
409       r   <- paDictArgType tv
410       ty' <- vectType ty
411       return $ ForAllTy tv (wrap r ty')
412   where
413     wrap Nothing      = id
414     wrap (Just pa_ty) = FunTy pa_ty
415
416 vectType ty = pprPanic "vectType:" (ppr ty)
417