Utility functions for vectorisation
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
index 67bacc7..71514e1 100644 (file)
@@ -17,14 +17,15 @@ import Var
 import VarEnv
 import Name                 ( mkSysTvName )
 import NameEnv
+import Id
 
-import DsMonad
+import DsMonad hiding (mapAndUnzipM)
 
 import PrelNames
 
 import Outputable
 import FastString
-import Control.Monad        ( liftM2 )
+import Control.Monad        ( liftM, liftM2, mapAndUnzipM )
 
 vectorise :: HscEnv -> ModGuts -> IO ModGuts
 vectorise hsc_env guts
@@ -177,6 +178,13 @@ maybeV p = maybe noV return =<< p
 orElseV :: VM a -> VM a -> VM a
 orElseV p q = maybe q return =<< tryV p
 
+localV :: VM a -> VM a
+localV p = do
+             env <- readLEnv id
+             x <- p
+             setLEnv env
+             return x
+
 liftDs :: DsM a -> VM a
 liftDs p = VM $ \bi genv lenv -> do { x <- p; return (Yes genv lenv x) }
 
@@ -201,6 +209,12 @@ setLEnv lenv = VM $ \_ genv _ -> return (Yes genv lenv ())
 updLEnv :: (LocalEnv -> LocalEnv) -> VM ()
 updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ())
 
+newLocalVar :: FastString -> Type -> VM Var
+newLocalVar fs ty
+  = do
+      u <- liftDs newUnique
+      return $ mkSysLocal fs u ty
+
 newTyVar :: FastString -> Kind -> VM Var
 newTyVar fs k
   = do
@@ -210,6 +224,10 @@ newTyVar fs k
 lookupTyCon :: TyCon -> VM (Maybe TyCon)
 lookupTyCon tc = readGEnv $ \env -> lookupNameEnv (global_tycons env) (tyConName tc)
 
+
+extendTyVarPA :: Var -> CoreExpr -> VM ()
+extendTyVarPA tv pa = updLEnv $ \env -> env { local_tyvar_pa = extendVarEnv (local_tyvar_pa env) tv pa }
+
 -- ----------------------------------------------------------------------------
 -- Bindings
 
@@ -225,6 +243,36 @@ vectoriseModule info guts
 vectModule :: ModGuts -> VM ModGuts
 vectModule guts = return guts
 
+
+
+vectBndr :: Var -> VM (Var, Var)
+vectBndr v
+  = do
+      vty <- vectType (idType v)
+      lty <- mkPArrayTy vty
+      let vv = v `Id.setIdType` vty
+          lv = v `Id.setIdType` lty
+      updLEnv (mapTo vv lv)
+      return (vv, lv)
+  where
+    mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
+
+vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
+vectBndrIn v p
+  = localV
+  $ do
+      (vv, lv) <- vectBndr v
+      x <- p
+      return (vv, lv, x)
+
+vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
+vectBndrsIn vs p
+  = localV
+  $ do
+      (vvs, lvs) <- mapAndUnzipM vectBndr vs
+      x <- p
+      return (vvs, lvs, x)
+
 -- ----------------------------------------------------------------------------
 -- Expressions
 
@@ -277,6 +325,26 @@ vectExpr lc (_, AnnApp fn arg)
       fn'  <- vectExpr lc fn
       arg' <- vectExpr lc arg
       capply fn' arg'
+vectExpr lc (_, AnnCase expr bndr ty alts)
+  = panic "vectExpr: case"
+vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
+  = do
+      (vrhs, lrhs) <- vectExpr lc rhs
+      (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
+      return (Let (NonRec vbndr vrhs) vbody,
+              Let (NonRec lbndr lrhs) lbody)
+vectExpr lc (_, AnnLet (AnnRec prs) body)
+  = do
+      (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
+      return (Let (Rec (zip vbndrs vrhss)) vbody,
+              Let (Rec (zip lbndrs lrhss)) lbody)
+  where
+    (bndrs, rhss) = unzip prs
+    
+    vect = do
+             (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
+             (vbody, lbody) <- vectExpr lc body
+             return (vrhss, vbody, lrhss, lbody)
 
 -- ----------------------------------------------------------------------------
 -- PA dictionaries
@@ -377,3 +445,8 @@ splitClosureTy ty
 
   | otherwise = pprPanic "splitClosureTy" (ppr ty)
 
+mkPArrayTy :: Type -> VM Type
+mkPArrayTy ty = do
+                  tc <- builtin parrayTyCon
+                  return $ TyConApp tc [ty]
+