module VectMonad (
+ Scope(..),
VM,
noV, tryV, maybeV, orElseV, localV, closedV, initV,
- newLocalVar, newTyVar,
+ cloneName, newLocalVar, newTyVar,
Builtins(..), paDictTyCon,
builtin,
LocalEnv(..),
readLEnv, setLEnv, updLEnv,
+ defGlobalVar, lookupVar,
lookupTyCon,
lookupTyVarPA, extendTyVarPA, deleteTyVarPA,
import Outputable
import FastString
+import Control.Monad ( liftM )
+
+data Scope a b = Global a | Local b
+
-- ----------------------------------------------------------------------------
-- Vectorisation monad
u <- liftDs newUnique
return $ mkTyVar (mkSysTvName u fs) k
+defGlobalVar :: Var -> CoreExpr -> VM ()
+defGlobalVar v e = updGEnv $ \env -> env { global_vars = extendVarEnv (global_vars env) v e }
+
+lookupVar :: Var -> VM (Scope CoreExpr (CoreExpr, CoreExpr))
+lookupVar v
+ = do
+ r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
+ case r of
+ Just e -> return (Local e)
+ Nothing -> liftM Global
+ $ maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
+
lookupTyCon :: TyCon -> VM (Maybe TyCon)
lookupTyCon tc = readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc)
(arg_ty, res_ty) = splitClosureTy fn_ty
vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
-vectVar lc v = local v `orElseV` global v
- where
- local v = maybeV (readLEnv $ \env -> lookupVarEnv (local_vars env) v)
- global v = do
- vexpr <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
- lexpr <- replicateP vexpr lc
- return (vexpr, lexpr)
+vectVar lc v
+ = do
+ r <- lookupVar v
+ case r of
+ Local es -> return es
+ Global vexpr -> do
+ lexpr <- replicateP vexpr lc
+ return (vexpr, lexpr)
vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
vectPolyVar lc v tys
= do
- r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
+ r <- lookupVar v
case r of
- Just (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
- Nothing ->
- do
- poly <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
- vexpr <- mk_app poly
- lexpr <- replicateP vexpr lc
- return (vexpr, lexpr)
+ Local (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
+ Global poly -> do
+ vexpr <- mk_app poly
+ lexpr <- replicateP vexpr lc
+ return (vexpr, lexpr)
where
mk_app e = applyToTypes e =<< mapM vectType tys