Vectorisation of top-level bindings
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 module Vectorise( vectorise )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import DynFlags
10 import HscTypes
11
12 import CoreLint             ( showPass, endPass )
13 import CoreSyn
14 import CoreUtils
15 import CoreFVs
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Var
21 import VarEnv
22 import VarSet
23 import Name                 ( mkSysTvName, getName )
24 import NameEnv
25 import Id
26 import MkId                 ( unwrapFamInstScrut )
27 import OccName
28
29 import DsMonad hiding (mapAndUnzipM)
30 import DsUtils              ( mkCoreTup, mkCoreTupTy )
31
32 import PrelNames
33 import TysWiredIn
34 import BasicTypes           ( Boxity(..) )
35
36 import Outputable
37 import FastString
38 import Control.Monad        ( liftM, liftM2, mapAndUnzipM, zipWithM_ )
39 import Data.Maybe           ( maybeToList )
40
41 vectorise :: HscEnv -> ModGuts -> IO ModGuts
42 vectorise hsc_env guts
43   | not (Opt_Vectorise `dopt` dflags) = return guts
44   | otherwise
45   = do
46       showPass dflags "Vectorisation"
47       eps <- hscEPS hsc_env
48       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
49       Just (info', guts') <- initV hsc_env guts info (vectModule guts)
50       endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
51       return $ guts' { mg_vect_info = info' }
52   where
53     dflags = hsc_dflags hsc_env
54
55 vectModule :: ModGuts -> VM ModGuts
56 vectModule guts = return guts
57
58 vectTopBind b@(NonRec var expr)
59   = do
60       var'  <- vectTopBinder var
61       expr' <- vectTopRhs expr
62       hs    <- takeHoisted
63       return . Rec $ (var, expr) : (var', expr') : hs
64   `orElseV`
65     return b
66
67 vectTopBind b@(Rec bs)
68   = do
69       vars'  <- mapM vectTopBinder vars
70       exprs' <- mapM vectTopRhs exprs
71       hs     <- takeHoisted
72       return . Rec $ bs ++ zip vars' exprs' ++ hs
73   `orElseV`
74     return b
75   where
76     (vars, exprs) = unzip bs
77
78 vectTopBinder :: Var -> VM Var
79 vectTopBinder var
80   = do
81       vty <- liftM (mkForAllTys tyvars) $ vectType mono_ty
82       name <- cloneName mkVectOcc (getName var)
83       let var' | isExportedId var = Id.mkExportedLocalId name vty
84                | otherwise        = Id.mkLocalId         name vty
85       defGlobalVar var var'
86       return var'
87   where
88     (tyvars, mono_ty) = splitForAllTys (idType var)
89     
90 vectTopRhs :: CoreExpr -> VM CoreExpr
91 vectTopRhs = liftM fst . closedV . vectPolyExpr (panic "Empty lifting context") . freeVars
92
93 -- ----------------------------------------------------------------------------
94 -- Bindings
95
96 vectBndr :: Var -> VM (Var, Var)
97 vectBndr v
98   = do
99       vty <- vectType (idType v)
100       lty <- mkPArrayType vty
101       let vv = v `Id.setIdType` vty
102           lv = v `Id.setIdType` lty
103       updLEnv (mapTo vv lv)
104       return (vv, lv)
105   where
106     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
107
108 vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
109 vectBndrIn v p
110   = localV
111   $ do
112       (vv, lv) <- vectBndr v
113       x <- p
114       return (vv, lv, x)
115
116 vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
117 vectBndrsIn vs p
118   = localV
119   $ do
120       (vvs, lvs) <- mapAndUnzipM vectBndr vs
121       x <- p
122       return (vvs, lvs, x)
123
124 -- ----------------------------------------------------------------------------
125 -- Expressions
126
127 replicateP :: CoreExpr -> CoreExpr -> VM CoreExpr
128 replicateP expr len
129   = do
130       dict <- paDictOfType ty
131       rep  <- builtin replicatePAVar
132       return $ mkApps (Var rep) [Type ty, dict, expr, len]
133   where
134     ty = exprType expr
135
136 capply :: (CoreExpr, CoreExpr) -> (CoreExpr, CoreExpr) -> VM (CoreExpr, CoreExpr)
137 capply (vfn, lfn) (varg, larg)
138   = do
139       apply  <- builtin applyClosureVar
140       applyP <- builtin applyClosurePVar
141       return (mkApps (Var apply)  [Type arg_ty, Type res_ty, vfn, varg],
142               mkApps (Var applyP) [Type arg_ty, Type res_ty, lfn, larg])
143   where
144     fn_ty            = exprType vfn
145     (arg_ty, res_ty) = splitClosureTy fn_ty
146
147 vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
148 vectVar lc v
149   = do
150       r <- lookupVar v
151       case r of
152         Local es     -> return es
153         Global vexpr -> do
154                           lexpr <- replicateP vexpr lc
155                           return (vexpr, lexpr)
156
157 vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
158 vectPolyVar lc v tys
159   = do
160       r <- lookupVar v
161       case r of
162         Local (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
163         Global poly          -> do
164                                   vexpr <- mk_app poly
165                                   lexpr <- replicateP vexpr lc
166                                   return (vexpr, lexpr)
167   where
168     mk_app e = applyToTypes e =<< mapM vectType tys
169
170 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
171 abstractOverTyVars tvs p
172   = do
173       mdicts <- mapM mk_dict_var tvs
174       zipWithM_ (\tv -> maybe (deleteTyVarPA tv) (extendTyVarPA tv . Var)) tvs mdicts
175       p (mk_lams mdicts)
176   where
177     mk_dict_var tv = do
178                        r <- paDictArgType tv
179                        case r of
180                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
181                          Nothing -> return Nothing
182
183     mk_lams mdicts = mkLams [arg | (tv, mdict) <- zip tvs mdicts
184                                  , arg <- tv : maybeToList mdict]
185
186 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
187 applyToTypes expr tys
188   = do
189       dicts <- mapM paDictOfType tys
190       return $ mkApps expr [arg | (ty, dict) <- zip tys dicts
191                                 , arg <- [Type ty, dict]]
192     
193
194 vectPolyExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
195 vectPolyExpr lc expr
196   = localV
197   . abstractOverTyVars tvs $ \mk_lams ->
198     -- FIXME: shadowing (tvs in lc)
199     do
200       (vmono, lmono) <- vectExpr lc mono
201       return $ (mk_lams vmono, mk_lams lmono)
202   where
203     (tvs, mono) = collectAnnTypeBinders expr  
204                 
205 vectExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
206 vectExpr lc (_, AnnType ty)
207   = do
208       vty <- vectType ty
209       return (Type vty, Type vty)
210
211 vectExpr lc (_, AnnVar v)   = vectVar lc v
212
213 vectExpr lc (_, AnnLit lit)
214   = do
215       let vexpr = Lit lit
216       lexpr <- replicateP vexpr lc
217       return (vexpr, lexpr)
218
219 vectExpr lc (_, AnnNote note expr)
220   = do
221       (vexpr, lexpr) <- vectExpr lc expr
222       return (Note note vexpr, Note note lexpr)
223
224 vectExpr lc e@(_, AnnApp _ arg)
225   | isAnnTypeArg arg
226   = vectTyAppExpr lc fn tys
227   where
228     (fn, tys) = collectAnnTypeArgs e
229
230 vectExpr lc (_, AnnApp fn arg)
231   = do
232       fn'  <- vectExpr lc fn
233       arg' <- vectExpr lc arg
234       capply fn' arg'
235
236 vectExpr lc (_, AnnCase expr bndr ty alts)
237   = panic "vectExpr: case"
238
239 vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
240   = do
241       (vrhs, lrhs) <- vectPolyExpr lc rhs
242       (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
243       return (Let (NonRec vbndr vrhs) vbody,
244               Let (NonRec lbndr lrhs) lbody)
245
246 vectExpr lc (_, AnnLet (AnnRec prs) body)
247   = do
248       (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
249       return (Let (Rec (zip vbndrs vrhss)) vbody,
250               Let (Rec (zip lbndrs lrhss)) lbody)
251   where
252     (bndrs, rhss) = unzip prs
253     
254     vect = do
255              (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
256              (vbody, lbody) <- vectPolyExpr lc body
257              return (vrhss, vbody, lrhss, lbody)
258
259 vectExpr lc e@(_, AnnLam bndr body)
260   | isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
261
262 vectExpr lc (fvs, AnnLam bndr body)
263   = do
264       let tyvars = filter isTyVar (varSetElems fvs)
265       info <- mkCEnvInfo fvs bndr body
266       (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
267
268       vfn_var <- hoistExpr FSLIT("vfn") poly_vfn
269       lfn_var <- hoistExpr FSLIT("lfn") poly_lfn
270
271       let (venv, lenv) = mkClosureEnvs info lc
272
273       let env_ty = cenv_vty info
274
275       pa_dict <- paDictOfType env_ty
276
277       arg_ty <- vectType (varType bndr)
278       res_ty <- vectType (exprType $ deAnnotate body)
279
280       -- FIXME: move the functions to the top level
281       mono_vfn <- applyToTypes (Var vfn_var) (map TyVarTy tyvars)
282       mono_lfn <- applyToTypes (Var lfn_var) (map TyVarTy tyvars)
283
284       mk_clo <- builtin mkClosureVar
285       mk_cloP <- builtin mkClosurePVar
286
287       let vclo = Var mk_clo  `mkTyApps` [arg_ty, res_ty, env_ty]
288                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, venv]
289           
290           lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
291                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, lenv]
292
293       return (vclo, lclo)
294        
295
296 data CEnvInfo = CEnvInfo {
297                cenv_vars         :: [Var]
298              , cenv_values       :: [(CoreExpr, CoreExpr)]
299              , cenv_vty          :: Type
300              , cenv_lty          :: Type
301              , cenv_repr_tycon   :: TyCon
302              , cenv_repr_tyargs  :: [Type]
303              , cenv_repr_datacon :: DataCon
304              }
305
306 mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
307 mkCEnvInfo fvs arg body
308   = do
309       locals <- readLEnv local_vars
310       let
311           (vars, vals) = unzip
312                  [(var, val) | var      <- varSetElems fvs
313                              , Just val <- [lookupVarEnv locals var]]
314       vtys <- mapM (vectType . varType) vars
315
316       (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
317       lty <- mkPArrayType vty
318       
319       return $ CEnvInfo {
320                  cenv_vars         = vars
321                , cenv_values       = vals
322                , cenv_vty          = vty
323                , cenv_lty          = lty
324                , cenv_repr_tycon   = repr_tycon
325                , cenv_repr_tyargs  = repr_tyargs
326                , cenv_repr_datacon = repr_datacon
327                }
328   where
329     mk_env_ty [vty]
330       = return (vty, error "absent cinfo_repr_tycon"
331                    , error "absent cinfo_repr_tyargs"
332                    , error "absent cinfo_repr_datacon")
333
334     mk_env_ty vtys
335       = do
336           let ty = mkCoreTupTy vtys
337           (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
338           let [repr_con] = tyConDataCons repr_tc
339           return (ty, repr_tc, repr_tyargs, repr_con)
340
341     
342
343 mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
344 mkClosureEnvs info lc
345   | [] <- vals
346   = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
347                                [lc, Var unitDataConId])
348
349   | [(vval, lval)] <- vals
350   = (vval, lval)
351
352   | otherwise
353   = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
354                       `mkTyApps` cenv_repr_tyargs info
355                       `mkApps`   (lc : lvals))
356
357   where
358     vals = cenv_values info
359     (vvals, lvals) = unzip vals
360
361 mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
362              -> VM (CoreExpr, CoreExpr)
363 mkClosureFns info tyvars arg body
364   = closedV
365   . abstractOverTyVars tyvars
366   $ \mk_tlams ->
367   do
368     (vfn, lfn) <- mkClosureMonoFns info arg body
369     return (mk_tlams vfn, mk_tlams lfn)
370
371 mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
372 mkClosureMonoFns info arg body
373   = do
374       lc_bndr <- newLocalVar FSLIT("lc") intTy
375       (varg : vbndrs, larg : lbndrs, (vbody, lbody))
376         <- vectBndrsIn (arg : cenv_vars info)
377                        (vectExpr (Var lc_bndr) body)
378
379       venv_bndr <- newLocalVar FSLIT("env") vty
380       lenv_bndr <- newLocalVar FSLIT("env") lty
381
382       let vcase = bind_venv (Var venv_bndr) vbody vbndrs
383       lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
384       return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
385   where
386     vty = cenv_vty info
387     lty = cenv_lty info
388
389     arity = length (cenv_vars info)
390
391     bind_venv venv vbody []      = vbody
392     bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
393     bind_venv venv vbody vbndrs
394       = Case venv (mkWildId vty) (exprType vbody)
395              [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
396
397     bind_lenv lenv lbody lc_bndr [lbndr]
398       = do
399           lengthPA <- builtin lengthPAVar
400           return . Let (NonRec lbndr lenv)
401                  $ Case (mkApps (Var lengthPA) [Type vty, (Var lbndr)])
402                         lc_bndr
403                         intTy
404                         [(DEFAULT, [], lbody)]
405
406     bind_lenv lenv lbody lc_bndr lbndrs
407       = return
408       $ Case (unwrapFamInstScrut (cenv_repr_tycon info)
409                                  (cenv_repr_tyargs info)
410                                  lenv)
411              (mkWildId lty)
412              (exprType lbody)
413              [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs, lbody)]
414           
415 vectTyAppExpr :: CoreExpr -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
416 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
417 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
418
419 -- ----------------------------------------------------------------------------
420 -- Types
421
422 vectTyCon :: TyCon -> VM TyCon
423 vectTyCon tc
424   | isFunTyCon tc        = builtin closureTyCon
425   | isBoxedTupleTyCon tc = return tc
426   | isUnLiftedTyCon tc   = return tc
427   | otherwise = do
428                   r <- lookupTyCon tc
429                   case r of
430                     Just tc' -> return tc'
431
432                     -- FIXME: just for now
433                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
434
435 vectType :: Type -> VM Type
436 vectType ty | Just ty' <- coreView ty = vectType ty
437 vectType (TyVarTy tv) = return $ TyVarTy tv
438 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
439 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
440 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
441                                              (mapM vectType [ty1,ty2])
442 vectType (ForAllTy tv ty)
443   = do
444       r   <- paDictArgType tv
445       ty' <- vectType ty
446       return $ ForAllTy tv (wrap r ty')
447   where
448     wrap Nothing      = id
449     wrap (Just pa_ty) = FunTy pa_ty
450
451 vectType ty = pprPanic "vectType:" (ppr ty)
452