Refactoring
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
1 module Vectorise( vectorise )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import DynFlags
10 import HscTypes
11
12 import CoreLint             ( showPass, endPass )
13 import CoreSyn
14 import CoreUtils
15 import CoreFVs
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Var
21 import VarEnv
22 import VarSet
23 import Name                 ( mkSysTvName )
24 import NameEnv
25 import Id
26 import MkId                 ( unwrapFamInstScrut )
27
28 import DsMonad hiding (mapAndUnzipM)
29 import DsUtils              ( mkCoreTup, mkCoreTupTy )
30
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes           ( Boxity(..) )
34
35 import Outputable
36 import FastString
37 import Control.Monad        ( liftM, liftM2, mapAndUnzipM, zipWithM_ )
38 import Data.Maybe           ( maybeToList )
39
40 vectorise :: HscEnv -> ModGuts -> IO ModGuts
41 vectorise hsc_env guts
42   | not (Opt_Vectorise `dopt` dflags) = return guts
43   | otherwise
44   = do
45       showPass dflags "Vectorisation"
46       eps <- hscEPS hsc_env
47       let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
48       Just (info', guts') <- initV hsc_env guts info (vectModule guts)
49       endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
50       return $ guts' { mg_vect_info = info' }
51   where
52     dflags = hsc_dflags hsc_env
53
54 vectModule :: ModGuts -> VM ModGuts
55 vectModule guts = return guts
56
57 -- ----------------------------------------------------------------------------
58 -- Bindings
59
60 vectBndr :: Var -> VM (Var, Var)
61 vectBndr v
62   = do
63       vty <- vectType (idType v)
64       lty <- mkPArrayType vty
65       let vv = v `Id.setIdType` vty
66           lv = v `Id.setIdType` lty
67       updLEnv (mapTo vv lv)
68       return (vv, lv)
69   where
70     mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
71
72 vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
73 vectBndrIn v p
74   = localV
75   $ do
76       (vv, lv) <- vectBndr v
77       x <- p
78       return (vv, lv, x)
79
80 vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
81 vectBndrsIn vs p
82   = localV
83   $ do
84       (vvs, lvs) <- mapAndUnzipM vectBndr vs
85       x <- p
86       return (vvs, lvs, x)
87
88 -- ----------------------------------------------------------------------------
89 -- Expressions
90
91 replicateP :: CoreExpr -> CoreExpr -> VM CoreExpr
92 replicateP expr len
93   = do
94       dict <- paDictOfType ty
95       rep  <- builtin replicatePAVar
96       return $ mkApps (Var rep) [Type ty, dict, expr, len]
97   where
98     ty = exprType expr
99
100 capply :: (CoreExpr, CoreExpr) -> (CoreExpr, CoreExpr) -> VM (CoreExpr, CoreExpr)
101 capply (vfn, lfn) (varg, larg)
102   = do
103       apply  <- builtin applyClosureVar
104       applyP <- builtin applyClosurePVar
105       return (mkApps (Var apply)  [Type arg_ty, Type res_ty, vfn, varg],
106               mkApps (Var applyP) [Type arg_ty, Type res_ty, lfn, larg])
107   where
108     fn_ty            = exprType vfn
109     (arg_ty, res_ty) = splitClosureTy fn_ty
110
111 vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
112 vectVar lc v
113   = do
114       r <- lookupVar v
115       case r of
116         Local es     -> return es
117         Global vexpr -> do
118                           lexpr <- replicateP vexpr lc
119                           return (vexpr, lexpr)
120
121 vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
122 vectPolyVar lc v tys
123   = do
124       r <- lookupVar v
125       case r of
126         Local (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
127         Global poly          -> do
128                                   vexpr <- mk_app poly
129                                   lexpr <- replicateP vexpr lc
130                                   return (vexpr, lexpr)
131   where
132     mk_app e = applyToTypes e =<< mapM vectType tys
133
134 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
135 abstractOverTyVars tvs p
136   = do
137       mdicts <- mapM mk_dict_var tvs
138       zipWithM_ (\tv -> maybe (deleteTyVarPA tv) (extendTyVarPA tv . Var)) tvs mdicts
139       p (mk_lams mdicts)
140   where
141     mk_dict_var tv = do
142                        r <- paDictArgType tv
143                        case r of
144                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
145                          Nothing -> return Nothing
146
147     mk_lams mdicts = mkLams [arg | (tv, mdict) <- zip tvs mdicts
148                                  , arg <- tv : maybeToList mdict]
149
150 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
151 applyToTypes expr tys
152   = do
153       dicts <- mapM paDictOfType tys
154       return $ mkApps expr [arg | (ty, dict) <- zip tys dicts
155                                 , arg <- [Type ty, dict]]
156     
157
158 vectPolyExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
159 vectPolyExpr lc expr
160   = localV
161   . abstractOverTyVars tvs $ \mk_lams ->
162     -- FIXME: shadowing (tvs in lc)
163     do
164       (vmono, lmono) <- vectExpr lc mono
165       return $ (mk_lams vmono, mk_lams lmono)
166   where
167     (tvs, mono) = collectAnnTypeBinders expr  
168                 
169 vectExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
170 vectExpr lc (_, AnnType ty)
171   = do
172       vty <- vectType ty
173       return (Type vty, Type vty)
174
175 vectExpr lc (_, AnnVar v)   = vectVar lc v
176
177 vectExpr lc (_, AnnLit lit)
178   = do
179       let vexpr = Lit lit
180       lexpr <- replicateP vexpr lc
181       return (vexpr, lexpr)
182
183 vectExpr lc (_, AnnNote note expr)
184   = do
185       (vexpr, lexpr) <- vectExpr lc expr
186       return (Note note vexpr, Note note lexpr)
187
188 vectExpr lc e@(_, AnnApp _ arg)
189   | isAnnTypeArg arg
190   = vectTyAppExpr lc fn tys
191   where
192     (fn, tys) = collectAnnTypeArgs e
193
194 vectExpr lc (_, AnnApp fn arg)
195   = do
196       fn'  <- vectExpr lc fn
197       arg' <- vectExpr lc arg
198       capply fn' arg'
199
200 vectExpr lc (_, AnnCase expr bndr ty alts)
201   = panic "vectExpr: case"
202
203 vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
204   = do
205       (vrhs, lrhs) <- vectPolyExpr lc rhs
206       (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
207       return (Let (NonRec vbndr vrhs) vbody,
208               Let (NonRec lbndr lrhs) lbody)
209
210 vectExpr lc (_, AnnLet (AnnRec prs) body)
211   = do
212       (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
213       return (Let (Rec (zip vbndrs vrhss)) vbody,
214               Let (Rec (zip lbndrs lrhss)) lbody)
215   where
216     (bndrs, rhss) = unzip prs
217     
218     vect = do
219              (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
220              (vbody, lbody) <- vectPolyExpr lc body
221              return (vrhss, vbody, lrhss, lbody)
222
223 vectExpr lc e@(_, AnnLam bndr body)
224   | isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
225
226 vectExpr lc (fvs, AnnLam bndr body)
227   = do
228       let tyvars = filter isTyVar (varSetElems fvs)
229       info <- mkCEnvInfo fvs bndr body
230       (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
231
232       vfn_var <- hoistExpr FSLIT("vfn") poly_vfn
233       lfn_var <- hoistExpr FSLIT("lfn") poly_lfn
234
235       let (venv, lenv) = mkClosureEnvs info lc
236
237       let env_ty = cenv_vty info
238
239       pa_dict <- paDictOfType env_ty
240
241       arg_ty <- vectType (varType bndr)
242       res_ty <- vectType (exprType $ deAnnotate body)
243
244       -- FIXME: move the functions to the top level
245       mono_vfn <- applyToTypes (Var vfn_var) (map TyVarTy tyvars)
246       mono_lfn <- applyToTypes (Var lfn_var) (map TyVarTy tyvars)
247
248       mk_clo <- builtin mkClosureVar
249       mk_cloP <- builtin mkClosurePVar
250
251       let vclo = Var mk_clo  `mkTyApps` [arg_ty, res_ty, env_ty]
252                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, venv]
253           
254           lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
255                              `mkApps`   [pa_dict, mono_vfn, mono_lfn, lenv]
256
257       return (vclo, lclo)
258        
259
260 data CEnvInfo = CEnvInfo {
261                cenv_vars         :: [Var]
262              , cenv_values       :: [(CoreExpr, CoreExpr)]
263              , cenv_vty          :: Type
264              , cenv_lty          :: Type
265              , cenv_repr_tycon   :: TyCon
266              , cenv_repr_tyargs  :: [Type]
267              , cenv_repr_datacon :: DataCon
268              }
269
270 mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
271 mkCEnvInfo fvs arg body
272   = do
273       locals <- readLEnv local_vars
274       let
275           (vars, vals) = unzip
276                  [(var, val) | var      <- varSetElems fvs
277                              , Just val <- [lookupVarEnv locals var]]
278       vtys <- mapM (vectType . varType) vars
279
280       (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
281       lty <- mkPArrayType vty
282       
283       return $ CEnvInfo {
284                  cenv_vars         = vars
285                , cenv_values       = vals
286                , cenv_vty          = vty
287                , cenv_lty          = lty
288                , cenv_repr_tycon   = repr_tycon
289                , cenv_repr_tyargs  = repr_tyargs
290                , cenv_repr_datacon = repr_datacon
291                }
292   where
293     mk_env_ty [vty]
294       = return (vty, error "absent cinfo_repr_tycon"
295                    , error "absent cinfo_repr_tyargs"
296                    , error "absent cinfo_repr_datacon")
297
298     mk_env_ty vtys
299       = do
300           let ty = mkCoreTupTy vtys
301           (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
302           let [repr_con] = tyConDataCons repr_tc
303           return (ty, repr_tc, repr_tyargs, repr_con)
304
305     
306
307 mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
308 mkClosureEnvs info lc
309   | [] <- vals
310   = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
311                                [lc, Var unitDataConId])
312
313   | [(vval, lval)] <- vals
314   = (vval, lval)
315
316   | otherwise
317   = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
318                       `mkTyApps` cenv_repr_tyargs info
319                       `mkApps`   (lc : lvals))
320
321   where
322     vals = cenv_values info
323     (vvals, lvals) = unzip vals
324
325 mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
326              -> VM (CoreExpr, CoreExpr)
327 mkClosureFns info tyvars arg body
328   = closedV
329   . abstractOverTyVars tyvars
330   $ \mk_tlams ->
331   do
332     (vfn, lfn) <- mkClosureMonoFns info arg body
333     return (mk_tlams vfn, mk_tlams lfn)
334
335 mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
336 mkClosureMonoFns info arg body
337   = do
338       lc_bndr <- newLocalVar FSLIT("lc") intTy
339       (varg : vbndrs, larg : lbndrs, (vbody, lbody))
340         <- vectBndrsIn (arg : cenv_vars info)
341                        (vectExpr (Var lc_bndr) body)
342
343       venv_bndr <- newLocalVar FSLIT("env") vty
344       lenv_bndr <- newLocalVar FSLIT("env") lty
345
346       let vcase = bind_venv (Var venv_bndr) vbody vbndrs
347       lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
348       return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
349   where
350     vty = cenv_vty info
351     lty = cenv_lty info
352
353     arity = length (cenv_vars info)
354
355     bind_venv venv vbody []      = vbody
356     bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
357     bind_venv venv vbody vbndrs
358       = Case venv (mkWildId vty) (exprType vbody)
359              [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
360
361     bind_lenv lenv lbody lc_bndr [lbndr]
362       = do
363           lengthPA <- builtin lengthPAVar
364           return . Let (NonRec lbndr lenv)
365                  $ Case (mkApps (Var lengthPA) [Type vty, (Var lbndr)])
366                         lc_bndr
367                         intTy
368                         [(DEFAULT, [], lbody)]
369
370     bind_lenv lenv lbody lc_bndr lbndrs
371       = return
372       $ Case (unwrapFamInstScrut (cenv_repr_tycon info)
373                                  (cenv_repr_tyargs info)
374                                  lenv)
375              (mkWildId lty)
376              (exprType lbody)
377              [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs, lbody)]
378           
379 vectTyAppExpr :: CoreExpr -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
380 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
381 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
382
383 -- ----------------------------------------------------------------------------
384 -- Types
385
386 vectTyCon :: TyCon -> VM TyCon
387 vectTyCon tc
388   | isFunTyCon tc        = builtin closureTyCon
389   | isBoxedTupleTyCon tc = return tc
390   | isUnLiftedTyCon tc   = return tc
391   | otherwise = do
392                   r <- lookupTyCon tc
393                   case r of
394                     Just tc' -> return tc'
395
396                     -- FIXME: just for now
397                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
398
399 vectType :: Type -> VM Type
400 vectType ty | Just ty' <- coreView ty = vectType ty
401 vectType (TyVarTy tv) = return $ TyVarTy tv
402 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
403 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
404 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
405                                              (mapM vectType [ty1,ty2])
406 vectType (ForAllTy tv ty)
407   = do
408       r   <- paDictArgType tv
409       ty' <- vectType ty
410       return $ ForAllTy tv (wrap r ty')
411   where
412     wrap Nothing      = id
413     wrap (Just pa_ty) = FunTy pa_ty
414
415 vectType ty = pprPanic "vectType:" (ppr ty)
416