Add built-ins to vectorisation monad
[ghc-hetmet.git] / compiler / vectorise / VectMonad.hs
index d4fa8f8..22b776e 100644 (file)
@@ -1,36 +1,49 @@
 module VectMonad (
+  Scope(..),
   VM,
 
-  noV, tryV, maybeV, orElseV, localV, closedV, initV,
-  newLocalVar, newTyVar,
+  noV, tryV, maybeV, traceMaybeV, orElseV, fixV, localV, closedV, initV,
+  liftDs,
+  cloneName, cloneId,
+  newExportedVar, newLocalVar, newDummyVar, newTyVar,
   
-  Builtins(..), paDictTyCon,
+  Builtins(..),
   builtin,
 
   GlobalEnv(..),
+  setFamInstEnv,
   readGEnv, setGEnv, updGEnv,
 
   LocalEnv(..),
   readLEnv, setLEnv, updLEnv,
 
-  lookupTyCon,
-  lookupTyVarPA, extendTyVarPA, deleteTyVarPA,
+  getBindName, inBind,
 
-  lookupInst, lookupFamInst
+  lookupVar, defGlobalVar,
+  lookupTyCon, defTyCon,
+  lookupDataCon, defDataCon,
+  lookupTyConPA, defTyConPA, defTyConPAs, defTyConBuiltinPAs,
+  lookupTyVarPA, defLocalTyVar, defLocalTyVarWithPA, localTyVars,
+
+  {-lookupInst,-} lookupFamInst
 ) where
 
 #include "HsVersions.h"
 
 import HscTypes
 import CoreSyn
-import Class
 import TyCon
+import DataCon
 import Type
 import Var
 import VarEnv
 import Id
+import OccName
 import Name
 import NameEnv
+import TysPrim       ( intPrimTy )
+import Module
+import IfaceEnv
 
 import DsMonad
 import PrelNames
@@ -41,13 +54,27 @@ import FamInstEnv
 import Panic
 import Outputable
 import FastString
+import SrcLoc        ( noSrcSpan )
+
+import Control.Monad ( liftM, zipWithM )
+
+data Scope a b = Global a | Local b
 
 -- ----------------------------------------------------------------------------
 -- Vectorisation monad
 
 data Builtins = Builtins {
                   parrayTyCon      :: TyCon
-                , paClass          :: Class
+                , paTyCon          :: TyCon
+                , paDataCon        :: DataCon
+                , preprTyCon       :: TyCon
+                , embedTyCon       :: TyCon
+                , embedDataCon     :: DataCon
+                , crossTyCon       :: TyCon
+                , crossDataCon     :: DataCon
+                , plusTyCon        :: TyCon
+                , leftDataCon      :: DataCon
+                , rightDataCon     :: DataCon
                 , closureTyCon     :: TyCon
                 , mkClosureVar     :: Var
                 , applyClosureVar  :: Var
@@ -55,16 +82,26 @@ data Builtins = Builtins {
                 , applyClosurePVar :: Var
                 , lengthPAVar      :: Var
                 , replicatePAVar   :: Var
+                , emptyPAVar       :: Var
+                -- , packPAVar        :: Var
+                -- , combinePAVar     :: Var
+                , intEqPAVar       :: Var
+                , liftingContext   :: Var
                 }
 
-paDictTyCon :: Builtins -> TyCon
-paDictTyCon = classTyCon . paClass
-
 initBuiltins :: DsM Builtins
 initBuiltins
   = do
       parrayTyCon  <- dsLookupTyCon parrayTyConName
-      paClass      <- dsLookupClass paClassName
+      paTyCon      <- dsLookupTyCon paTyConName
+      let [paDataCon] = tyConDataCons paTyCon
+      preprTyCon   <- dsLookupTyCon preprTyConName
+      embedTyCon   <- dsLookupTyCon embedTyConName
+      let [embedDataCon] = tyConDataCons embedTyCon
+      crossTyCon   <- dsLookupTyCon crossTyConName
+      let [crossDataCon] = tyConDataCons crossTyCon
+      plusTyCon    <- dsLookupTyCon plusTyConName
+      let [leftDataCon, rightDataCon] = tyConDataCons plusTyCon
       closureTyCon <- dsLookupTyCon closureTyConName
 
       mkClosureVar     <- dsLookupGlobalId mkClosureName
@@ -73,10 +110,26 @@ initBuiltins
       applyClosurePVar <- dsLookupGlobalId applyClosurePName
       lengthPAVar      <- dsLookupGlobalId lengthPAName
       replicatePAVar   <- dsLookupGlobalId replicatePAName
+      emptyPAVar       <- dsLookupGlobalId emptyPAName
+      -- packPAVar        <- dsLookupGlobalId packPAName
+      -- combinePAVar     <- dsLookupGlobalId combinePAName
+      intEqPAVar       <- dsLookupGlobalId intEqPAName
+
+      liftingContext <- liftM (\u -> mkSysLocal FSLIT("lc") u intPrimTy)
+                              newUnique
 
       return $ Builtins {
                  parrayTyCon      = parrayTyCon
-               , paClass          = paClass
+               , paTyCon          = paTyCon
+               , paDataCon        = paDataCon
+               , preprTyCon       = preprTyCon
+               , embedTyCon       = embedTyCon
+               , embedDataCon     = embedDataCon
+               , crossTyCon       = crossTyCon
+               , crossDataCon     = crossDataCon
+               , plusTyCon        = plusTyCon
+               , leftDataCon      = leftDataCon
+               , rightDataCon     = rightDataCon
                , closureTyCon     = closureTyCon
                , mkClosureVar     = mkClosureVar
                , applyClosureVar  = applyClosureVar
@@ -84,12 +137,17 @@ initBuiltins
                , applyClosurePVar = applyClosurePVar
                , lengthPAVar      = lengthPAVar
                , replicatePAVar   = replicatePAVar
+               , emptyPAVar       = emptyPAVar
+               -- , packPAVar        = packPAVar
+               -- , combinePAVar     = combinePAVar
+               , intEqPAVar       = intEqPAVar
+               , liftingContext   = liftingContext
                }
 
 data GlobalEnv = GlobalEnv {
                   -- Mapping from global variables to their vectorised versions.
                   -- 
-                  global_vars :: VarEnv CoreExpr
+                  global_vars :: VarEnv Var
 
                   -- Exported variables which have a vectorised version
                   --
@@ -101,9 +159,13 @@ data GlobalEnv = GlobalEnv {
                   --
                 , global_tycons :: NameEnv TyCon
 
-                  -- Mapping from TyCons to their PA dictionaries
+                  -- Mapping from DataCons to their vectorised versions
                   --
-                , global_tycon_pa :: NameEnv CoreExpr
+                , global_datacons :: NameEnv DataCon
+
+                  -- Mapping from TyCons to their PA dfuns
+                  --
+                , global_pa_funs :: NameEnv Var
 
                 -- External package inst-env & home-package inst-env for class
                 -- instances
@@ -114,51 +176,72 @@ data GlobalEnv = GlobalEnv {
                 -- instances
                 --
                 , global_fam_inst_env :: FamInstEnvs
+
+                -- Hoisted bindings
+                , global_bindings :: [(Var, CoreExpr)]
                 }
 
 data LocalEnv = LocalEnv {
                  -- Mapping from local variables to their vectorised and
                  -- lifted versions
                  --
-                 local_vars :: VarEnv (CoreExpr, CoreExpr)
+                 local_vars :: VarEnv (Var, Var)
+
+                 -- In-scope type variables
+                 --
+               , local_tyvars :: [TyVar]
 
                  -- Mapping from tyvars to their PA dictionaries
                , local_tyvar_pa :: VarEnv CoreExpr
 
-                 -- Hoisted bindings
-               , local_bindings :: [(Var, CoreExpr)]
+                 -- Local binding name
+               , local_bind_name :: FastString
                }
               
 
-initGlobalEnv :: VectInfo -> (InstEnv, InstEnv) -> FamInstEnvs -> GlobalEnv
-initGlobalEnv info instEnvs famInstEnvs
+initGlobalEnv :: VectInfo -> (InstEnv, InstEnv) -> FamInstEnvs -> Builtins
+              -> GlobalEnv
+initGlobalEnv info instEnvs famInstEnvs bi
   = GlobalEnv {
-      global_vars          = mapVarEnv  (Var . snd) $ vectInfoVar   info
+      global_vars          = mapVarEnv snd $ vectInfoVar info
     , global_exported_vars = emptyVarEnv
-    , global_tycons        = mapNameEnv snd $ vectInfoTyCon info
-    , global_tycon_pa      = emptyNameEnv
+    , global_tycons        = extendNameEnv (mapNameEnv snd (vectInfoTyCon info))
+                                           (tyConName funTyCon) (closureTyCon bi)
+                              
+    , global_datacons      = mapNameEnv snd $ vectInfoDataCon info
+    , global_pa_funs       = mapNameEnv snd $ vectInfoPADFun info
     , global_inst_env      = instEnvs
     , global_fam_inst_env  = famInstEnvs
+    , global_bindings      = []
     }
 
+setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
+setFamInstEnv l_fam_inst genv
+  = genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
+  where
+    (g_fam_inst, _) = global_fam_inst_env genv
+
 emptyLocalEnv = LocalEnv {
                    local_vars     = emptyVarEnv
+                 , local_tyvars   = []
                  , local_tyvar_pa = emptyVarEnv
-                 , local_bindings = []
+                 , local_bind_name  = FSLIT("fn")
                  }
 
 -- FIXME
 updVectInfo :: GlobalEnv -> TypeEnv -> VectInfo -> VectInfo
 updVectInfo env tyenv info
   = info {
-      vectInfoVar   = global_exported_vars env
-    , vectInfoTyCon = tc_env
+      vectInfoVar     = global_exported_vars env
+    , vectInfoTyCon   = mk_env typeEnvTyCons global_tycons
+    , vectInfoDataCon = mk_env typeEnvDataCons global_datacons
+    , vectInfoPADFun  = mk_env typeEnvTyCons global_pa_funs
     }
   where
-    tc_env = mkNameEnv [(tc_name, (tc,tc'))
-               | tc <- typeEnvTyCons tyenv
-               , let tc_name = tyConName tc
-               , Just tc' <- [lookupNameEnv (global_tycons env) tc_name]]
+    mk_env from_tyenv from_env = mkNameEnv [(name, (from,to))
+                                   | from <- from_tyenv tyenv
+                                   , let name = getName from
+                                   , Just to <- [lookupNameEnv (from_env env) name]]
 
 data VResult a = Yes GlobalEnv LocalEnv a | No
 
@@ -175,6 +258,9 @@ instance Monad VM where
 noV :: VM a
 noV = VM $ \_ _ _ -> return No
 
+traceNoV :: String -> SDoc -> VM a
+traceNoV s d = pprTrace s d noV
+
 tryV :: VM a -> VM (Maybe a)
 tryV (VM p) = VM $ \bi genv lenv ->
   do
@@ -186,9 +272,17 @@ tryV (VM p) = VM $ \bi genv lenv ->
 maybeV :: VM (Maybe a) -> VM a
 maybeV p = maybe noV return =<< p
 
+traceMaybeV :: String -> SDoc -> VM (Maybe a) -> VM a
+traceMaybeV s d p = maybe (traceNoV s d) return =<< p
+
 orElseV :: VM a -> VM a -> VM a
 orElseV p q = maybe q return =<< tryV p
 
+fixV :: (a -> VM a) -> VM a
+fixV f = VM (\bi genv lenv -> fixDs $ \r -> runVM (f (unYes r)) bi genv lenv )
+  where
+    unYes (Yes _ _ x) = x
+
 localV :: VM a -> VM a
 localV p = do
              env <- readLEnv id
@@ -199,7 +293,7 @@ localV p = do
 closedV :: VM a -> VM a
 closedV p = do
               env <- readLEnv id
-              setLEnv emptyLocalEnv
+              setLEnv (emptyLocalEnv { local_bind_name = local_bind_name env })
               x <- p
               setLEnv env
               return x
@@ -234,29 +328,137 @@ getInstEnv = readGEnv global_inst_env
 getFamInstEnv :: VM FamInstEnvs
 getFamInstEnv = readGEnv global_fam_inst_env
 
+getBindName :: VM FastString
+getBindName = readLEnv local_bind_name
+
+inBind :: Id -> VM a -> VM a
+inBind id p
+  = do updLEnv $ \env -> env { local_bind_name = occNameFS (getOccName id) }
+       p
+
+lookupExternalVar :: Module -> FastString -> VM Var
+lookupExternalVar mod fs
+  = liftDs
+  $ dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
+
+cloneName :: (OccName -> OccName) -> Name -> VM Name
+cloneName mk_occ name = liftM make (liftDs newUnique)
+  where
+    occ_name = mk_occ (nameOccName name)
+
+    make u | isExternalName name = mkExternalName u (nameModule name)
+                                                    occ_name
+                                                    (nameSrcSpan name)
+           | otherwise           = mkSystemName u occ_name
+
+cloneId :: (OccName -> OccName) -> Id -> Type -> VM Id
+cloneId mk_occ id ty
+  = do
+      name <- cloneName mk_occ (getName id)
+      let id' | isExportedId id = Id.mkExportedLocalId name ty
+              | otherwise       = Id.mkLocalId         name ty
+      return id'
+
+newExportedVar :: OccName -> Type -> VM Var
+newExportedVar occ_name ty 
+  = do
+      mod <- liftDs getModuleDs
+      u   <- liftDs newUnique
+
+      let name = mkExternalName u mod occ_name noSrcSpan
+      
+      return $ Id.mkExportedLocalId name ty
+
 newLocalVar :: FastString -> Type -> VM Var
 newLocalVar fs ty
   = do
       u <- liftDs newUnique
       return $ mkSysLocal fs u ty
 
+newDummyVar :: Type -> VM Var
+newDummyVar = newLocalVar FSLIT("ds")
+
 newTyVar :: FastString -> Kind -> VM Var
 newTyVar fs k
   = do
       u <- liftDs newUnique
       return $ mkTyVar (mkSysTvName u fs) k
 
+defGlobalVar :: Var -> Var -> VM ()
+defGlobalVar v v' = updGEnv $ \env ->
+  env { global_vars = extendVarEnv (global_vars env) v v'
+      , global_exported_vars = upd (global_exported_vars env)
+      }
+  where
+    upd env | isExportedId v = extendVarEnv env v (v, v')
+            | otherwise      = env
+
+lookupVar :: Var -> VM (Scope Var (Var, Var))
+lookupVar v
+  = do
+      r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
+      case r of
+        Just e  -> return (Local e)
+        Nothing -> liftM Global
+                 $  traceMaybeV "lookupVar" (ppr v)
+                                (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
+
 lookupTyCon :: TyCon -> VM (Maybe TyCon)
-lookupTyCon tc = readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc)
+lookupTyCon tc
+  | isUnLiftedTyCon tc || isTupleTyCon tc = return (Just tc)
+
+  | otherwise = readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc)
+
+defTyCon :: TyCon -> TyCon -> VM ()
+defTyCon tc tc' = updGEnv $ \env ->
+  env { global_tycons = extendNameEnv (global_tycons env) (tyConName tc) tc' }
+
+lookupDataCon :: DataCon -> VM (Maybe DataCon)
+lookupDataCon dc = readGEnv $ \env -> lookupNameEnv (global_datacons env) (dataConName dc)
+
+defDataCon :: DataCon -> DataCon -> VM ()
+defDataCon dc dc' = updGEnv $ \env ->
+  env { global_datacons = extendNameEnv (global_datacons env) (dataConName dc) dc' }
+
+lookupTyConPA :: TyCon -> VM (Maybe Var)
+lookupTyConPA tc = readGEnv $ \env -> lookupNameEnv (global_pa_funs env) (tyConName tc)
+
+defTyConPA :: TyCon -> Var -> VM ()
+defTyConPA tc pa = updGEnv $ \env ->
+  env { global_pa_funs = extendNameEnv (global_pa_funs env) (tyConName tc) pa }
+
+defTyConPAs :: [(TyCon, Var)] -> VM ()
+defTyConPAs ps = updGEnv $ \env ->
+  env { global_pa_funs = extendNameEnvList (global_pa_funs env)
+                                           [(tyConName tc, pa) | (tc, pa) <- ps] }
+
+defTyConBuiltinPAs :: [(Name, Module, FastString)] -> VM ()
+defTyConBuiltinPAs ps
+  = do
+      pas <- zipWithM lookupExternalVar mods fss
+      updGEnv $ \env ->
+        env { global_pa_funs = extendNameEnvList (global_pa_funs env)
+                                                 (zip tcs pas) }
+  where
+    (tcs, mods, fss) = unzip3 ps
 
 lookupTyVarPA :: Var -> VM (Maybe CoreExpr)
 lookupTyVarPA tv = readLEnv $ \env -> lookupVarEnv (local_tyvar_pa env) tv 
 
-extendTyVarPA :: Var -> CoreExpr -> VM ()
-extendTyVarPA tv pa = updLEnv $ \env -> env { local_tyvar_pa = extendVarEnv (local_tyvar_pa env) tv pa }
+defLocalTyVar :: TyVar -> VM ()
+defLocalTyVar tv = updLEnv $ \env ->
+  env { local_tyvars   = tv : local_tyvars env
+      , local_tyvar_pa = local_tyvar_pa env `delVarEnv` tv
+      }
+
+defLocalTyVarWithPA :: TyVar -> CoreExpr -> VM ()
+defLocalTyVarWithPA tv pa = updLEnv $ \env ->
+  env { local_tyvars   = tv : local_tyvars env
+      , local_tyvar_pa = extendVarEnv (local_tyvar_pa env) tv pa
+      }
 
-deleteTyVarPA :: Var -> VM ()
-deleteTyVarPA tv = updLEnv $ \env -> env { local_tyvar_pa = delVarEnv (local_tyvar_pa env) tv }
+localTyVars :: VM [TyVar]
+localTyVars = readLEnv (reverse . local_tyvars)
 
 -- Look up the dfun of a class instance.
 --
@@ -267,6 +469,7 @@ deleteTyVarPA tv = updLEnv $ \env -> env { local_tyvar_pa = delVarEnv (local_tyv
 -- instances head (i.e., no flexi vars); for details for what this means,
 -- see the docs at InstEnv.lookupInstEnv.
 --
+{-
 lookupInst :: Class -> [Type] -> VM (DFunId, [Type])
 lookupInst cls tys
   = do { instEnv <- getInstEnv
@@ -278,11 +481,12 @@ lookupInst cls tys
              where
                inst_tys'  = [ty | Right ty <- inst_tys]
                noFlexiVar = all isRight inst_tys
-          _other         -> noV
+          _other         -> traceNoV "lookupInst" (ppr cls <+> ppr tys)
        }
   where
     isRight (Left  _) = False
     isRight (Right _) = True
+-}
 
 -- Look up the representation tycon of a family instance.
 --
@@ -328,7 +532,10 @@ initV hsc_env guts info p
     go instEnvs famInstEnvs = 
       do
         builtins <- initBuiltins
-        r <- runVM p builtins (initGlobalEnv info instEnvs famInstEnvs) 
+        r <- runVM p builtins (initGlobalEnv info
+                                             instEnvs
+                                             famInstEnvs
+                                             builtins)
                    emptyLocalEnv
         case r of
           Yes genv _ x -> return $ Just (new_info genv, x)