Add built-ins to vectorisation monad
[ghc-hetmet.git] / compiler / vectorise / VectMonad.hs
1 module VectMonad (
2   Scope(..),
3   VM,
4
5   noV, tryV, maybeV, traceMaybeV, orElseV, fixV, localV, closedV, initV,
6   liftDs,
7   cloneName, cloneId,
8   newExportedVar, newLocalVar, newDummyVar, newTyVar,
9   
10   Builtins(..),
11   builtin,
12
13   GlobalEnv(..),
14   setFamInstEnv,
15   readGEnv, setGEnv, updGEnv,
16
17   LocalEnv(..),
18   readLEnv, setLEnv, updLEnv,
19
20   getBindName, inBind,
21
22   lookupVar, defGlobalVar,
23   lookupTyCon, defTyCon,
24   lookupDataCon, defDataCon,
25   lookupTyConPA, defTyConPA, defTyConPAs, defTyConBuiltinPAs,
26   lookupTyVarPA, defLocalTyVar, defLocalTyVarWithPA, localTyVars,
27
28   {-lookupInst,-} lookupFamInst
29 ) where
30
31 #include "HsVersions.h"
32
33 import HscTypes
34 import CoreSyn
35 import TyCon
36 import DataCon
37 import Type
38 import Var
39 import VarEnv
40 import Id
41 import OccName
42 import Name
43 import NameEnv
44 import TysPrim       ( intPrimTy )
45 import Module
46 import IfaceEnv
47
48 import DsMonad
49 import PrelNames
50
51 import InstEnv
52 import FamInstEnv
53
54 import Panic
55 import Outputable
56 import FastString
57 import SrcLoc        ( noSrcSpan )
58
59 import Control.Monad ( liftM, zipWithM )
60
61 data Scope a b = Global a | Local b
62
63 -- ----------------------------------------------------------------------------
64 -- Vectorisation monad
65
66 data Builtins = Builtins {
67                   parrayTyCon      :: TyCon
68                 , paTyCon          :: TyCon
69                 , paDataCon        :: DataCon
70                 , preprTyCon       :: TyCon
71                 , embedTyCon       :: TyCon
72                 , embedDataCon     :: DataCon
73                 , crossTyCon       :: TyCon
74                 , crossDataCon     :: DataCon
75                 , plusTyCon        :: TyCon
76                 , leftDataCon      :: DataCon
77                 , rightDataCon     :: DataCon
78                 , closureTyCon     :: TyCon
79                 , mkClosureVar     :: Var
80                 , applyClosureVar  :: Var
81                 , mkClosurePVar    :: Var
82                 , applyClosurePVar :: Var
83                 , lengthPAVar      :: Var
84                 , replicatePAVar   :: Var
85                 , emptyPAVar       :: Var
86                 -- , packPAVar        :: Var
87                 -- , combinePAVar     :: Var
88                 , intEqPAVar       :: Var
89                 , liftingContext   :: Var
90                 }
91
92 initBuiltins :: DsM Builtins
93 initBuiltins
94   = do
95       parrayTyCon  <- dsLookupTyCon parrayTyConName
96       paTyCon      <- dsLookupTyCon paTyConName
97       let [paDataCon] = tyConDataCons paTyCon
98       preprTyCon   <- dsLookupTyCon preprTyConName
99       embedTyCon   <- dsLookupTyCon embedTyConName
100       let [embedDataCon] = tyConDataCons embedTyCon
101       crossTyCon   <- dsLookupTyCon crossTyConName
102       let [crossDataCon] = tyConDataCons crossTyCon
103       plusTyCon    <- dsLookupTyCon plusTyConName
104       let [leftDataCon, rightDataCon] = tyConDataCons plusTyCon
105       closureTyCon <- dsLookupTyCon closureTyConName
106
107       mkClosureVar     <- dsLookupGlobalId mkClosureName
108       applyClosureVar  <- dsLookupGlobalId applyClosureName
109       mkClosurePVar    <- dsLookupGlobalId mkClosurePName
110       applyClosurePVar <- dsLookupGlobalId applyClosurePName
111       lengthPAVar      <- dsLookupGlobalId lengthPAName
112       replicatePAVar   <- dsLookupGlobalId replicatePAName
113       emptyPAVar       <- dsLookupGlobalId emptyPAName
114       -- packPAVar        <- dsLookupGlobalId packPAName
115       -- combinePAVar     <- dsLookupGlobalId combinePAName
116       intEqPAVar       <- dsLookupGlobalId intEqPAName
117
118       liftingContext <- liftM (\u -> mkSysLocal FSLIT("lc") u intPrimTy)
119                               newUnique
120
121       return $ Builtins {
122                  parrayTyCon      = parrayTyCon
123                , paTyCon          = paTyCon
124                , paDataCon        = paDataCon
125                , preprTyCon       = preprTyCon
126                , embedTyCon       = embedTyCon
127                , embedDataCon     = embedDataCon
128                , crossTyCon       = crossTyCon
129                , crossDataCon     = crossDataCon
130                , plusTyCon        = plusTyCon
131                , leftDataCon      = leftDataCon
132                , rightDataCon     = rightDataCon
133                , closureTyCon     = closureTyCon
134                , mkClosureVar     = mkClosureVar
135                , applyClosureVar  = applyClosureVar
136                , mkClosurePVar    = mkClosurePVar
137                , applyClosurePVar = applyClosurePVar
138                , lengthPAVar      = lengthPAVar
139                , replicatePAVar   = replicatePAVar
140                , emptyPAVar       = emptyPAVar
141                -- , packPAVar        = packPAVar
142                -- , combinePAVar     = combinePAVar
143                , intEqPAVar       = intEqPAVar
144                , liftingContext   = liftingContext
145                }
146
147 data GlobalEnv = GlobalEnv {
148                   -- Mapping from global variables to their vectorised versions.
149                   -- 
150                   global_vars :: VarEnv Var
151
152                   -- Exported variables which have a vectorised version
153                   --
154                 , global_exported_vars :: VarEnv (Var, Var)
155
156                   -- Mapping from TyCons to their vectorised versions.
157                   -- TyCons which do not have to be vectorised are mapped to
158                   -- themselves.
159                   --
160                 , global_tycons :: NameEnv TyCon
161
162                   -- Mapping from DataCons to their vectorised versions
163                   --
164                 , global_datacons :: NameEnv DataCon
165
166                   -- Mapping from TyCons to their PA dfuns
167                   --
168                 , global_pa_funs :: NameEnv Var
169
170                 -- External package inst-env & home-package inst-env for class
171                 -- instances
172                 --
173                 , global_inst_env :: (InstEnv, InstEnv)
174
175                 -- External package inst-env & home-package inst-env for family
176                 -- instances
177                 --
178                 , global_fam_inst_env :: FamInstEnvs
179
180                 -- Hoisted bindings
181                 , global_bindings :: [(Var, CoreExpr)]
182                 }
183
184 data LocalEnv = LocalEnv {
185                  -- Mapping from local variables to their vectorised and
186                  -- lifted versions
187                  --
188                  local_vars :: VarEnv (Var, Var)
189
190                  -- In-scope type variables
191                  --
192                , local_tyvars :: [TyVar]
193
194                  -- Mapping from tyvars to their PA dictionaries
195                , local_tyvar_pa :: VarEnv CoreExpr
196
197                  -- Local binding name
198                , local_bind_name :: FastString
199                }
200               
201
202 initGlobalEnv :: VectInfo -> (InstEnv, InstEnv) -> FamInstEnvs -> Builtins
203               -> GlobalEnv
204 initGlobalEnv info instEnvs famInstEnvs bi
205   = GlobalEnv {
206       global_vars          = mapVarEnv snd $ vectInfoVar info
207     , global_exported_vars = emptyVarEnv
208     , global_tycons        = extendNameEnv (mapNameEnv snd (vectInfoTyCon info))
209                                            (tyConName funTyCon) (closureTyCon bi)
210                               
211     , global_datacons      = mapNameEnv snd $ vectInfoDataCon info
212     , global_pa_funs       = mapNameEnv snd $ vectInfoPADFun info
213     , global_inst_env      = instEnvs
214     , global_fam_inst_env  = famInstEnvs
215     , global_bindings      = []
216     }
217
218 setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
219 setFamInstEnv l_fam_inst genv
220   = genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
221   where
222     (g_fam_inst, _) = global_fam_inst_env genv
223
224 emptyLocalEnv = LocalEnv {
225                    local_vars     = emptyVarEnv
226                  , local_tyvars   = []
227                  , local_tyvar_pa = emptyVarEnv
228                  , local_bind_name  = FSLIT("fn")
229                  }
230
231 -- FIXME
232 updVectInfo :: GlobalEnv -> TypeEnv -> VectInfo -> VectInfo
233 updVectInfo env tyenv info
234   = info {
235       vectInfoVar     = global_exported_vars env
236     , vectInfoTyCon   = mk_env typeEnvTyCons global_tycons
237     , vectInfoDataCon = mk_env typeEnvDataCons global_datacons
238     , vectInfoPADFun  = mk_env typeEnvTyCons global_pa_funs
239     }
240   where
241     mk_env from_tyenv from_env = mkNameEnv [(name, (from,to))
242                                    | from <- from_tyenv tyenv
243                                    , let name = getName from
244                                    , Just to <- [lookupNameEnv (from_env env) name]]
245
246 data VResult a = Yes GlobalEnv LocalEnv a | No
247
248 newtype VM a = VM { runVM :: Builtins -> GlobalEnv -> LocalEnv -> DsM (VResult a) }
249
250 instance Monad VM where
251   return x   = VM $ \bi genv lenv -> return (Yes genv lenv x)
252   VM p >>= f = VM $ \bi genv lenv -> do
253                                       r <- p bi genv lenv
254                                       case r of
255                                         Yes genv' lenv' x -> runVM (f x) bi genv' lenv'
256                                         No                -> return No
257
258 noV :: VM a
259 noV = VM $ \_ _ _ -> return No
260
261 traceNoV :: String -> SDoc -> VM a
262 traceNoV s d = pprTrace s d noV
263
264 tryV :: VM a -> VM (Maybe a)
265 tryV (VM p) = VM $ \bi genv lenv ->
266   do
267     r <- p bi genv lenv
268     case r of
269       Yes genv' lenv' x -> return (Yes genv' lenv' (Just x))
270       No                -> return (Yes genv  lenv  Nothing)
271
272 maybeV :: VM (Maybe a) -> VM a
273 maybeV p = maybe noV return =<< p
274
275 traceMaybeV :: String -> SDoc -> VM (Maybe a) -> VM a
276 traceMaybeV s d p = maybe (traceNoV s d) return =<< p
277
278 orElseV :: VM a -> VM a -> VM a
279 orElseV p q = maybe q return =<< tryV p
280
281 fixV :: (a -> VM a) -> VM a
282 fixV f = VM (\bi genv lenv -> fixDs $ \r -> runVM (f (unYes r)) bi genv lenv )
283   where
284     unYes (Yes _ _ x) = x
285
286 localV :: VM a -> VM a
287 localV p = do
288              env <- readLEnv id
289              x <- p
290              setLEnv env
291              return x
292
293 closedV :: VM a -> VM a
294 closedV p = do
295               env <- readLEnv id
296               setLEnv (emptyLocalEnv { local_bind_name = local_bind_name env })
297               x <- p
298               setLEnv env
299               return x
300
301 liftDs :: DsM a -> VM a
302 liftDs p = VM $ \bi genv lenv -> do { x <- p; return (Yes genv lenv x) }
303
304 builtin :: (Builtins -> a) -> VM a
305 builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))
306
307 readGEnv :: (GlobalEnv -> a) -> VM a
308 readGEnv f = VM $ \bi genv lenv -> return (Yes genv lenv (f genv))
309
310 setGEnv :: GlobalEnv -> VM ()
311 setGEnv genv = VM $ \_ _ lenv -> return (Yes genv lenv ())
312
313 updGEnv :: (GlobalEnv -> GlobalEnv) -> VM ()
314 updGEnv f = VM $ \_ genv lenv -> return (Yes (f genv) lenv ())
315
316 readLEnv :: (LocalEnv -> a) -> VM a
317 readLEnv f = VM $ \bi genv lenv -> return (Yes genv lenv (f lenv))
318
319 setLEnv :: LocalEnv -> VM ()
320 setLEnv lenv = VM $ \_ genv _ -> return (Yes genv lenv ())
321
322 updLEnv :: (LocalEnv -> LocalEnv) -> VM ()
323 updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ())
324
325 getInstEnv :: VM (InstEnv, InstEnv)
326 getInstEnv = readGEnv global_inst_env
327
328 getFamInstEnv :: VM FamInstEnvs
329 getFamInstEnv = readGEnv global_fam_inst_env
330
331 getBindName :: VM FastString
332 getBindName = readLEnv local_bind_name
333
334 inBind :: Id -> VM a -> VM a
335 inBind id p
336   = do updLEnv $ \env -> env { local_bind_name = occNameFS (getOccName id) }
337        p
338
339 lookupExternalVar :: Module -> FastString -> VM Var
340 lookupExternalVar mod fs
341   = liftDs
342   $ dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
343
344 cloneName :: (OccName -> OccName) -> Name -> VM Name
345 cloneName mk_occ name = liftM make (liftDs newUnique)
346   where
347     occ_name = mk_occ (nameOccName name)
348
349     make u | isExternalName name = mkExternalName u (nameModule name)
350                                                     occ_name
351                                                     (nameSrcSpan name)
352            | otherwise           = mkSystemName u occ_name
353
354 cloneId :: (OccName -> OccName) -> Id -> Type -> VM Id
355 cloneId mk_occ id ty
356   = do
357       name <- cloneName mk_occ (getName id)
358       let id' | isExportedId id = Id.mkExportedLocalId name ty
359               | otherwise       = Id.mkLocalId         name ty
360       return id'
361
362 newExportedVar :: OccName -> Type -> VM Var
363 newExportedVar occ_name ty 
364   = do
365       mod <- liftDs getModuleDs
366       u   <- liftDs newUnique
367
368       let name = mkExternalName u mod occ_name noSrcSpan
369       
370       return $ Id.mkExportedLocalId name ty
371
372 newLocalVar :: FastString -> Type -> VM Var
373 newLocalVar fs ty
374   = do
375       u <- liftDs newUnique
376       return $ mkSysLocal fs u ty
377
378 newDummyVar :: Type -> VM Var
379 newDummyVar = newLocalVar FSLIT("ds")
380
381 newTyVar :: FastString -> Kind -> VM Var
382 newTyVar fs k
383   = do
384       u <- liftDs newUnique
385       return $ mkTyVar (mkSysTvName u fs) k
386
387 defGlobalVar :: Var -> Var -> VM ()
388 defGlobalVar v v' = updGEnv $ \env ->
389   env { global_vars = extendVarEnv (global_vars env) v v'
390       , global_exported_vars = upd (global_exported_vars env)
391       }
392   where
393     upd env | isExportedId v = extendVarEnv env v (v, v')
394             | otherwise      = env
395
396 lookupVar :: Var -> VM (Scope Var (Var, Var))
397 lookupVar v
398   = do
399       r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
400       case r of
401         Just e  -> return (Local e)
402         Nothing -> liftM Global
403                  $  traceMaybeV "lookupVar" (ppr v)
404                                 (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
405
406 lookupTyCon :: TyCon -> VM (Maybe TyCon)
407 lookupTyCon tc
408   | isUnLiftedTyCon tc || isTupleTyCon tc = return (Just tc)
409
410   | otherwise = readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc)
411
412 defTyCon :: TyCon -> TyCon -> VM ()
413 defTyCon tc tc' = updGEnv $ \env ->
414   env { global_tycons = extendNameEnv (global_tycons env) (tyConName tc) tc' }
415
416 lookupDataCon :: DataCon -> VM (Maybe DataCon)
417 lookupDataCon dc = readGEnv $ \env -> lookupNameEnv (global_datacons env) (dataConName dc)
418
419 defDataCon :: DataCon -> DataCon -> VM ()
420 defDataCon dc dc' = updGEnv $ \env ->
421   env { global_datacons = extendNameEnv (global_datacons env) (dataConName dc) dc' }
422
423 lookupTyConPA :: TyCon -> VM (Maybe Var)
424 lookupTyConPA tc = readGEnv $ \env -> lookupNameEnv (global_pa_funs env) (tyConName tc)
425
426 defTyConPA :: TyCon -> Var -> VM ()
427 defTyConPA tc pa = updGEnv $ \env ->
428   env { global_pa_funs = extendNameEnv (global_pa_funs env) (tyConName tc) pa }
429
430 defTyConPAs :: [(TyCon, Var)] -> VM ()
431 defTyConPAs ps = updGEnv $ \env ->
432   env { global_pa_funs = extendNameEnvList (global_pa_funs env)
433                                            [(tyConName tc, pa) | (tc, pa) <- ps] }
434
435 defTyConBuiltinPAs :: [(Name, Module, FastString)] -> VM ()
436 defTyConBuiltinPAs ps
437   = do
438       pas <- zipWithM lookupExternalVar mods fss
439       updGEnv $ \env ->
440         env { global_pa_funs = extendNameEnvList (global_pa_funs env)
441                                                  (zip tcs pas) }
442   where
443     (tcs, mods, fss) = unzip3 ps
444
445 lookupTyVarPA :: Var -> VM (Maybe CoreExpr)
446 lookupTyVarPA tv = readLEnv $ \env -> lookupVarEnv (local_tyvar_pa env) tv 
447
448 defLocalTyVar :: TyVar -> VM ()
449 defLocalTyVar tv = updLEnv $ \env ->
450   env { local_tyvars   = tv : local_tyvars env
451       , local_tyvar_pa = local_tyvar_pa env `delVarEnv` tv
452       }
453
454 defLocalTyVarWithPA :: TyVar -> CoreExpr -> VM ()
455 defLocalTyVarWithPA tv pa = updLEnv $ \env ->
456   env { local_tyvars   = tv : local_tyvars env
457       , local_tyvar_pa = extendVarEnv (local_tyvar_pa env) tv pa
458       }
459
460 localTyVars :: VM [TyVar]
461 localTyVars = readLEnv (reverse . local_tyvars)
462
463 -- Look up the dfun of a class instance.
464 --
465 -- The match must be unique - ie, match exactly one instance - but the 
466 -- type arguments used for matching may be more specific than those of 
467 -- the class instance declaration.  The found class instances must not have
468 -- any type variables in the instance context that do not appear in the
469 -- instances head (i.e., no flexi vars); for details for what this means,
470 -- see the docs at InstEnv.lookupInstEnv.
471 --
472 {-
473 lookupInst :: Class -> [Type] -> VM (DFunId, [Type])
474 lookupInst cls tys
475   = do { instEnv <- getInstEnv
476        ; case lookupInstEnv instEnv cls tys of
477            ([(inst, inst_tys)], _) 
478              | noFlexiVar -> return (instanceDFunId inst, inst_tys')
479              | otherwise  -> pprPanic "VectMonad.lookupInst: flexi var: " 
480                                       (ppr $ mkTyConApp (classTyCon cls) tys)
481              where
482                inst_tys'  = [ty | Right ty <- inst_tys]
483                noFlexiVar = all isRight inst_tys
484            _other         -> traceNoV "lookupInst" (ppr cls <+> ppr tys)
485        }
486   where
487     isRight (Left  _) = False
488     isRight (Right _) = True
489 -}
490
491 -- Look up the representation tycon of a family instance.
492 --
493 -- The match must be unique - ie, match exactly one instance - but the 
494 -- type arguments used for matching may be more specific than those of 
495 -- the family instance declaration.
496 --
497 -- Return the instance tycon and its type instance.  For example, if we have
498 --
499 --  lookupFamInst 'T' '[Int]' yields (':R42T', 'Int')
500 --
501 -- then we have a coercion (ie, type instance of family instance coercion)
502 --
503 --  :Co:R42T Int :: T [Int] ~ :R42T Int
504 --
505 -- which implies that :R42T was declared as 'data instance T [a]'.
506 --
507 lookupFamInst :: TyCon -> [Type] -> VM (TyCon, [Type])
508 lookupFamInst tycon tys
509   = ASSERT( isOpenTyCon tycon )
510     do { instEnv <- getFamInstEnv
511        ; case lookupFamInstEnv instEnv tycon tys of
512            [(fam_inst, rep_tys)] -> return (famInstTyCon fam_inst, rep_tys)
513            _other                -> 
514              pprPanic "VectMonad.lookupFamInst: not found: " 
515                       (ppr $ mkTyConApp tycon tys)
516        }
517
518 initV :: HscEnv -> ModGuts -> VectInfo -> VM a -> IO (Maybe (VectInfo, a))
519 initV hsc_env guts info p
520   = do
521       eps <- hscEPS hsc_env
522       let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
523       let instEnvs    = (eps_inst_env     eps, mg_inst_env     guts)
524
525       Just r <- initDs hsc_env (mg_module guts)
526                                (mg_rdr_env guts)
527                                (mg_types guts)
528                                (go instEnvs famInstEnvs)
529       return r
530   where
531
532     go instEnvs famInstEnvs = 
533       do
534         builtins <- initBuiltins
535         r <- runVM p builtins (initGlobalEnv info
536                                              instEnvs
537                                              famInstEnvs
538                                              builtins)
539                    emptyLocalEnv
540         case r of
541           Yes genv _ x -> return $ Just (new_info genv, x)
542           No           -> return Nothing
543
544     new_info genv = updVectInfo genv (mg_types guts) info
545