Use new closure generation code in vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
index 199af1a..73c986b 100644 (file)
@@ -4,9 +4,10 @@ module VectUtils (
   mkPADictType, mkPArrayType,
   paDictArgType, paDictOfType,
   paMethod, lengthPA, replicatePA, emptyPA,
-  polyAbstract, polyApply,
+  polyAbstract, polyApply, polyVApply,
   lookupPArrayFamInst,
-  hoistExpr, takeHoisted
+  hoistExpr, hoistPolyVExpr, takeHoisted,
+  buildClosure
 ) where
 
 #include "HsVersions.h"
@@ -177,6 +178,12 @@ polyApply expr tys
       dicts <- mapM paDictOfType tys
       return $ expr `mkTyApps` tys `mkApps` dicts
 
+polyVApply :: VExpr -> [Type] -> VM VExpr
+polyVApply expr tys
+  = do
+      dicts <- mapM paDictOfType tys
+      return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
+
 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
 
@@ -188,19 +195,20 @@ hoistExpr fs expr
         env { global_bindings = (var, expr) : global_bindings env }
       return var
 
-hoistPolyExpr :: FastString -> [TyVar] -> CoreExpr -> VM CoreExpr
-hoistPolyExpr fs tvs expr
+hoistVExpr :: FastString -> VExpr -> VM VVar
+hoistVExpr fs (ve, le)
   = do
-      poly_expr <- closedV . polyAbstract tvs $ \abstract -> return (abstract expr)
-      fn        <- hoistExpr fs poly_expr
-      polyApply (Var fn) (mkTyVarTys tvs)
+      vv <- hoistExpr ('v' `consFS` fs) ve
+      lv <- hoistExpr ('l' `consFS` fs) le
+      return (vv, lv)
 
-hoistPolyVExpr :: FastString -> [TyVar] -> VExpr -> VM VExpr
-hoistPolyVExpr fs tvs (ve, le)
+hoistPolyVExpr :: FastString -> [TyVar] -> VM VExpr -> VM VExpr
+hoistPolyVExpr fs tvs p
   = do
-      ve' <- hoistPolyExpr ('v' `consFS` fs) tvs ve
-      le' <- hoistPolyExpr ('l' `consFS` fs) tvs le
-      return (ve',le')
+      expr <- closedV . polyAbstract tvs $ \abstract ->
+              liftM (mapVect abstract) p
+      fn   <- hoistVExpr fs expr
+      polyVApply (vVar fn) (mkTyVarTys tvs)
 
 takeHoisted :: VM [(Var, CoreExpr)]
 takeHoisted
@@ -224,31 +232,33 @@ mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
 
-buildClosure :: [TyVar] -> Var -> [VVar] -> VVar -> VExpr -> VM VExpr
-buildClosure tvs lv vars arg body
+buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
+buildClosure tvs lv vars arg_ty res_ty mk_body
   = do
       (env_ty, env, bind) <- buildEnv lv vars
-      env_bndr            <- newLocalVVar FSLIT("env") env_ty
+      env_bndr <- newLocalVVar FSLIT("env") env_ty
+      arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
 
       fn <- hoistPolyVExpr FSLIT("fn") tvs
-          . mkVLams [env_bndr, arg]
-          . bind (vVar env_bndr)
-          $ mkVVarApps lv body (vars ++ [arg])
+          $ do
+              body  <- mk_body
+              body' <- bind (vVar env_bndr)
+                            (mkVVarApps lv body (vars ++ [arg_bndr]))
+              return (mkVLams [env_bndr, arg_bndr] body')
 
       mkClosure arg_ty res_ty env_ty fn env
 
-  where
-    arg_ty = idType (vectorised arg)
-    res_ty = exprType (vectorised body)
-
-
-buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
+buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
 buildEnv lv vvs
   = do
       let (ty, venv, vbind) = mkVectEnv tys vs
       (lenv, lbind) <- mkLiftEnv lv tys ls
       return (ty, (venv, lenv),
-              \(venv,lenv) (vbody,lbody) -> (vbind venv vbody, lbind lenv lbody))
+              \(venv,lenv) (vbody,lbody) ->
+              do
+                let vbody' = vbind venv vbody
+                lbody' <- lbind lenv lbody
+                return (vbody', lbody'))
   where
     (vs,ls) = unzip vvs
     tys     = map idType vs
@@ -262,12 +272,13 @@ mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
   where
     ty = mkCoreTupTy tys
 
-mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
+mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
 mkLiftEnv lv [ty] [v]
-  = do
-      len <- lengthPA (Var v)
-      return (Var v, \env body -> Let (NonRec v env)
-                                $ Case len lv (exprType body) [(DEFAULT, [], body)])
+  = return (Var v, \env body ->
+                   do
+                     len <- lengthPA (Var v)
+                     return . Let (NonRec v env)
+                            $ Case len lv (exprType body) [(DEFAULT, [], body)])
 
 -- NOTE: this transparently deals with empty environments
 mkLiftEnv lv tys vs
@@ -281,9 +292,13 @@ mkLiftEnv lv tys vs
 
           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
                           in
-                          Case scrut (mkWildId (exprType scrut)) (exprType body)
-                            [(DataAlt env_con, lv : vs, body)]
+                          return $ Case scrut (mkWildId (exprType scrut))
+                                        (exprType body)
+                                        [(DataAlt env_con, lv : bndrs, body)]
       return (env, bind)
   where
     vty = mkCoreTupTy tys
 
+    bndrs | null vs   = [mkWildId unitTy]
+          | otherwise = vs
+