993ed3017d90d8ce92eae54fce316b60ecbbe163
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 module Vectorise( vectorise )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import DynFlags
10 import HscTypes
11
12 import CoreLint             ( showPass, endPass )
13 import CoreSyn
14 import CoreUtils
15 import CoreFVs
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Var
21 import VarEnv
22 import VarSet
23 import Name                 ( mkSysTvName )
24 import NameEnv
25 import Id
26 import MkId                 ( unwrapFamInstScrut )
27
28 import DsMonad hiding (mapAndUnzipM)
29 import DsUtils              ( mkCoreTup, mkCoreTupTy )
30
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes           ( Boxity(..) )
34
35 import Outputable
36 import FastString
37 import Control.Monad        ( liftM, liftM2, mapAndUnzipM, zipWithM_ )
38 import Data.Maybe           ( maybeToList )
39
40 vectorise :: HscEnv -> ModGuts -> IO ModGuts
41 vectorise hsc_env guts
42   | not (Opt_Vectorise `dopt` dflags) = return guts
43   | otherwise
44   = do
45       showPass dflags "Vectorisation"
46       eps <- hscEPS hsc_env
47       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
48       Just (info', guts') <- initV hsc_env guts info (vectModule guts)
49       endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
50       return $ guts' { mg_vect_info = info' }
51   where
52     dflags = hsc_dflags hsc_env
53
54 vectModule :: ModGuts -> VM ModGuts
55 vectModule guts = return guts
56
57 -- ----------------------------------------------------------------------------
58 -- Bindings
59
60 vectBndr :: Var -> VM (Var, Var)
61 vectBndr v
62   = do
63       vty <- vectType (idType v)
64       lty <- mkPArrayType vty
65       let vv = v `Id.setIdType` vty
66           lv = v `Id.setIdType` lty
67       updLEnv (mapTo vv lv)
68       return (vv, lv)
69   where
70     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
71
72 vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
73 vectBndrIn v p
74   = localV
75   $ do
76       (vv, lv) <- vectBndr v
77       x <- p
78       return (vv, lv, x)
79
80 vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
81 vectBndrsIn vs p
82   = localV
83   $ do
84       (vvs, lvs) <- mapAndUnzipM vectBndr vs
85       x <- p
86       return (vvs, lvs, x)
87
88 -- ----------------------------------------------------------------------------
89 -- Expressions
90
91 replicateP :: CoreExpr -> CoreExpr -> VM CoreExpr
92 replicateP expr len
93   = do
94       dict <- paDictOfType ty
95       rep  <- builtin replicatePAVar
96       return $ mkApps (Var rep) [Type ty, dict, expr, len]
97   where
98     ty = exprType expr
99
100 capply :: (CoreExpr, CoreExpr) -> (CoreExpr, CoreExpr) -> VM (CoreExpr, CoreExpr)
101 capply (vfn, lfn) (varg, larg)
102   = do
103       apply  <- builtin applyClosureVar
104       applyP <- builtin applyClosurePVar
105       return (mkApps (Var apply)  [Type arg_ty, Type res_ty, vfn, varg],
106               mkApps (Var applyP) [Type arg_ty, Type res_ty, lfn, larg])
107   where
108     fn_ty            = exprType vfn
109     (arg_ty, res_ty) = splitClosureTy fn_ty
110
111 vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
112 vectVar lc v = local v `orElseV` global v
113   where
114     local  v = maybeV (readLEnv $ \env -> lookupVarEnv (local_vars env) v)
115     global v = do
116                  vexpr <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
117                  lexpr <- replicateP vexpr lc
118                  return (vexpr, lexpr)
119
120 vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
121 vectPolyVar lc v tys
122   = do
123       r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
124       case r of
125         Just (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
126         Nothing ->
127           do
128             poly  <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
129             vexpr <- mk_app poly
130             lexpr <- replicateP vexpr lc
131             return (vexpr, lexpr)
132   where
133     mk_app e = applyToTypes e =<< mapM vectType tys
134
135 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
136 abstractOverTyVars tvs p
137   = do
138       mdicts <- mapM mk_dict_var tvs
139       zipWithM_ (\tv -> maybe (deleteTyVarPA tv) (extendTyVarPA tv . Var)) tvs mdicts
140       p (mk_lams mdicts)
141   where
142     mk_dict_var tv = do
143                        r <- paDictArgType tv
144                        case r of
145                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
146                          Nothing -> return Nothing
147
148     mk_lams mdicts = mkLams [arg | (tv, mdict) <- zip tvs mdicts
149                                  , arg <- tv : maybeToList mdict]
150
151 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
152 applyToTypes expr tys
153   = do
154       dicts <- mapM paDictOfType tys
155       return $ mkApps expr [arg | (ty, dict) <- zip tys dicts
156                                 , arg <- [Type ty, dict]]
157     
158
159 vectPolyExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
160 vectPolyExpr lc expr
161   = localV
162   . abstractOverTyVars tvs $ \mk_lams ->
163     -- FIXME: shadowing (tvs in lc)
164     do
165       (vmono, lmono) <- vectExpr lc mono
166       return $ (mk_lams vmono, mk_lams lmono)
167   where
168     (tvs, mono) = collectAnnTypeBinders expr  
169                 
170 vectExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
171 vectExpr lc (_, AnnType ty)
172   = do
173       vty <- vectType ty
174       return (Type vty, Type vty)
175
176 vectExpr lc (_, AnnVar v)   = vectVar lc v
177
178 vectExpr lc (_, AnnLit lit)
179   = do
180       let vexpr = Lit lit
181       lexpr <- replicateP vexpr lc
182       return (vexpr, lexpr)
183
184 vectExpr lc (_, AnnNote note expr)
185   = do
186       (vexpr, lexpr) <- vectExpr lc expr
187       return (Note note vexpr, Note note lexpr)
188
189 vectExpr lc e@(_, AnnApp _ arg)
190   | isAnnTypeArg arg
191   = vectTyAppExpr lc fn tys
192   where
193     (fn, tys) = collectAnnTypeArgs e
194
195 vectExpr lc (_, AnnApp fn arg)
196   = do
197       fn'  <- vectExpr lc fn
198       arg' <- vectExpr lc arg
199       capply fn' arg'
200
201 vectExpr lc (_, AnnCase expr bndr ty alts)
202   = panic "vectExpr: case"
203
204 vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
205   = do
206       (vrhs, lrhs) <- vectPolyExpr lc rhs
207       (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
208       return (Let (NonRec vbndr vrhs) vbody,
209               Let (NonRec lbndr lrhs) lbody)
210
211 vectExpr lc (_, AnnLet (AnnRec prs) body)
212   = do
213       (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
214       return (Let (Rec (zip vbndrs vrhss)) vbody,
215               Let (Rec (zip lbndrs lrhss)) lbody)
216   where
217     (bndrs, rhss) = unzip prs
218     
219     vect = do
220              (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
221              (vbody, lbody) <- vectPolyExpr lc body
222              return (vrhss, vbody, lrhss, lbody)
223
224 vectExpr lc e@(_, AnnLam bndr body)
225   | isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
226
227 vectExpr lc (fvs, AnnLam bndr body)
228   = do
229       let tyvars = filter isTyVar (varSetElems fvs)
230       info <- mkCEnvInfo fvs bndr body
231       (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
232       let (venv, lenv) = mkClosureEnvs info lc
233
234       let env_ty = cenv_vty info
235
236       pa_dict <- paDictOfType env_ty
237
238       arg_ty <- vectType (varType bndr)
239       res_ty <- vectType (exprType $ deAnnotate body)
240
241       -- FIXME: move the functions to the top level
242       mono_vfn <- applyToTypes poly_vfn (map TyVarTy tyvars)
243       mono_lfn <- applyToTypes poly_lfn (map TyVarTy tyvars)
244
245       mk_clo <- builtin mkClosureVar
246       mk_cloP <- builtin mkClosurePVar
247
248       let vclo = Var mk_clo  `mkTyApps` [arg_ty, res_ty, env_ty]
249                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, venv]
250           
251           lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
252                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, lenv]
253
254       return (vclo, lclo)
255        
256
257 data CEnvInfo = CEnvInfo {
258                cenv_vars         :: [Var]
259              , cenv_values       :: [(CoreExpr, CoreExpr)]
260              , cenv_vty          :: Type
261              , cenv_lty          :: Type
262              , cenv_repr_tycon   :: TyCon
263              , cenv_repr_tyargs  :: [Type]
264              , cenv_repr_datacon :: DataCon
265              }
266
267 mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
268 mkCEnvInfo fvs arg body
269   = do
270       locals <- readLEnv local_vars
271       let
272           (vars, vals) = unzip
273                  [(var, val) | var      <- varSetElems fvs
274                              , Just val <- [lookupVarEnv locals var]]
275       vtys <- mapM (vectType . varType) vars
276
277       (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
278       lty <- mkPArrayType vty
279       
280       return $ CEnvInfo {
281                  cenv_vars         = vars
282                , cenv_values       = vals
283                , cenv_vty          = vty
284                , cenv_lty          = lty
285                , cenv_repr_tycon   = repr_tycon
286                , cenv_repr_tyargs  = repr_tyargs
287                , cenv_repr_datacon = repr_datacon
288                }
289   where
290     mk_env_ty [vty]
291       = return (vty, error "absent cinfo_repr_tycon"
292                    , error "absent cinfo_repr_tyargs"
293                    , error "absent cinfo_repr_datacon")
294
295     mk_env_ty vtys
296       = do
297           let ty = mkCoreTupTy vtys
298           (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
299           let [repr_con] = tyConDataCons repr_tc
300           return (ty, repr_tc, repr_tyargs, repr_con)
301
302     
303
304 mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
305 mkClosureEnvs info lc
306   | [] <- vals
307   = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
308                                [lc, Var unitDataConId])
309
310   | [(vval, lval)] <- vals
311   = (vval, lval)
312
313   | otherwise
314   = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
315                       `mkTyApps` cenv_repr_tyargs info
316                       `mkApps`   (lc : lvals))
317
318   where
319     vals = cenv_values info
320     (vvals, lvals) = unzip vals
321
322 mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
323              -> VM (CoreExpr, CoreExpr)
324 mkClosureFns info tyvars arg body
325   = closedV
326   . abstractOverTyVars tyvars
327   $ \mk_tlams ->
328   do
329     (vfn, lfn) <- mkClosureMonoFns info arg body
330     return (mk_tlams vfn, mk_tlams lfn)
331
332 mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
333 mkClosureMonoFns info arg body
334   = do
335       lc_bndr <- newLocalVar FSLIT("lc") intTy
336       (varg : vbndrs, larg : lbndrs, (vbody, lbody))
337         <- vectBndrsIn (arg : cenv_vars info)
338                        (vectExpr (Var lc_bndr) body)
339
340       venv_bndr <- newLocalVar FSLIT("env") vty
341       lenv_bndr <- newLocalVar FSLIT("env") lty
342
343       let vcase = bind_venv (Var venv_bndr) vbody vbndrs
344       lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
345       return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
346   where
347     vty = cenv_vty info
348     lty = cenv_lty info
349
350     arity = length (cenv_vars info)
351
352     bind_venv venv vbody []      = vbody
353     bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
354     bind_venv venv vbody vbndrs
355       = Case venv (mkWildId vty) (exprType vbody)
356              [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
357
358     bind_lenv lenv lbody lc_bndr [lbndr]
359       = do
360           lengthPA <- builtin lengthPAVar
361           return . Let (NonRec lbndr lenv)
362                  $ Case (mkApps (Var lengthPA) [Type vty, (Var lbndr)])
363                         lc_bndr
364                         intTy
365                         [(DEFAULT, [], lbody)]
366
367     bind_lenv lenv lbody lc_bndr lbndrs
368       = return
369       $ Case (unwrapFamInstScrut (cenv_repr_tycon info)
370                                  (cenv_repr_tyargs info)
371                                  lenv)
372              (mkWildId lty)
373              (exprType lbody)
374              [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs, lbody)]
375           
376 vectTyAppExpr :: CoreExpr -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
377 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
378 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
379
380 -- ----------------------------------------------------------------------------
381 -- Types
382
383 vectTyCon :: TyCon -> VM TyCon
384 vectTyCon tc
385   | isFunTyCon tc        = builtin closureTyCon
386   | isBoxedTupleTyCon tc = return tc
387   | isUnLiftedTyCon tc   = return tc
388   | otherwise = do
389                   r <- lookupTyCon tc
390                   case r of
391                     Just tc' -> return tc'
392
393                     -- FIXME: just for now
394                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
395
396 vectType :: Type -> VM Type
397 vectType ty | Just ty' <- coreView ty = vectType ty
398 vectType (TyVarTy tv) = return $ TyVarTy tv
399 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
400 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
401 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
402                                              (mapM vectType [ty1,ty2])
403 vectType (ForAllTy tv ty)
404   = do
405       r   <- paDictArgType tv
406       ty' <- vectType ty
407       return $ ForAllTy tv (wrap r ty')
408   where
409     wrap Nothing      = id
410     wrap (Just pa_ty) = FunTy pa_ty
411
412 vectType ty = pprPanic "vectType:" (ppr ty)
413