Use new closure generation code in vectorisation
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 1 Aug 2007 01:37:28 +0000 (01:37 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 1 Aug 2007 01:37:28 +0000 (01:37 +0000)
compiler/vectorise/VectCore.hs
compiler/vectorise/VectUtils.hs
compiler/vectorise/Vectorise.hs

index 9118214..23fe0e4 100644 (file)
@@ -7,6 +7,7 @@ module VectCore (
   vNonRec, vRec,
 
   vVar, vType, vNote, vLet,
+  vLams,
   mkVLams, mkVVarApps
 ) where
 
@@ -54,6 +55,11 @@ vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
 vLet :: VBind -> VExpr -> VExpr
 vLet = zipWithVect Let
 
+vLams :: Var -> [VVar] -> VExpr -> VExpr
+vLams lc vs (ve, le) = (mkLams vvs ve, mkLams (lc:lvs) le)
+  where
+    (vvs,lvs) = unzip vs
+
 mkVLams :: [VVar] -> VExpr -> VExpr
 mkVLams vvs (ve,le) = (mkLams vs ve, mkLams ls le)
   where
index 199af1a..73c986b 100644 (file)
@@ -4,9 +4,10 @@ module VectUtils (
   mkPADictType, mkPArrayType,
   paDictArgType, paDictOfType,
   paMethod, lengthPA, replicatePA, emptyPA,
-  polyAbstract, polyApply,
+  polyAbstract, polyApply, polyVApply,
   lookupPArrayFamInst,
-  hoistExpr, takeHoisted
+  hoistExpr, hoistPolyVExpr, takeHoisted,
+  buildClosure
 ) where
 
 #include "HsVersions.h"
@@ -177,6 +178,12 @@ polyApply expr tys
       dicts <- mapM paDictOfType tys
       return $ expr `mkTyApps` tys `mkApps` dicts
 
+polyVApply :: VExpr -> [Type] -> VM VExpr
+polyVApply expr tys
+  = do
+      dicts <- mapM paDictOfType tys
+      return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
+
 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
 
@@ -188,19 +195,20 @@ hoistExpr fs expr
         env { global_bindings = (var, expr) : global_bindings env }
       return var
 
-hoistPolyExpr :: FastString -> [TyVar] -> CoreExpr -> VM CoreExpr
-hoistPolyExpr fs tvs expr
+hoistVExpr :: FastString -> VExpr -> VM VVar
+hoistVExpr fs (ve, le)
   = do
-      poly_expr <- closedV . polyAbstract tvs $ \abstract -> return (abstract expr)
-      fn        <- hoistExpr fs poly_expr
-      polyApply (Var fn) (mkTyVarTys tvs)
+      vv <- hoistExpr ('v' `consFS` fs) ve
+      lv <- hoistExpr ('l' `consFS` fs) le
+      return (vv, lv)
 
-hoistPolyVExpr :: FastString -> [TyVar] -> VExpr -> VM VExpr
-hoistPolyVExpr fs tvs (ve, le)
+hoistPolyVExpr :: FastString -> [TyVar] -> VM VExpr -> VM VExpr
+hoistPolyVExpr fs tvs p
   = do
-      ve' <- hoistPolyExpr ('v' `consFS` fs) tvs ve
-      le' <- hoistPolyExpr ('l' `consFS` fs) tvs le
-      return (ve',le')
+      expr <- closedV . polyAbstract tvs $ \abstract ->
+              liftM (mapVect abstract) p
+      fn   <- hoistVExpr fs expr
+      polyVApply (vVar fn) (mkTyVarTys tvs)
 
 takeHoisted :: VM [(Var, CoreExpr)]
 takeHoisted
@@ -224,31 +232,33 @@ mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
 
-buildClosure :: [TyVar] -> Var -> [VVar] -> VVar -> VExpr -> VM VExpr
-buildClosure tvs lv vars arg body
+buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
+buildClosure tvs lv vars arg_ty res_ty mk_body
   = do
       (env_ty, env, bind) <- buildEnv lv vars
-      env_bndr            <- newLocalVVar FSLIT("env") env_ty
+      env_bndr <- newLocalVVar FSLIT("env") env_ty
+      arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
 
       fn <- hoistPolyVExpr FSLIT("fn") tvs
-          . mkVLams [env_bndr, arg]
-          . bind (vVar env_bndr)
-          $ mkVVarApps lv body (vars ++ [arg])
+          $ do
+              body  <- mk_body
+              body' <- bind (vVar env_bndr)
+                            (mkVVarApps lv body (vars ++ [arg_bndr]))
+              return (mkVLams [env_bndr, arg_bndr] body')
 
       mkClosure arg_ty res_ty env_ty fn env
 
-  where
-    arg_ty = idType (vectorised arg)
-    res_ty = exprType (vectorised body)
-
-
-buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
+buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
 buildEnv lv vvs
   = do
       let (ty, venv, vbind) = mkVectEnv tys vs
       (lenv, lbind) <- mkLiftEnv lv tys ls
       return (ty, (venv, lenv),
-              \(venv,lenv) (vbody,lbody) -> (vbind venv vbody, lbind lenv lbody))
+              \(venv,lenv) (vbody,lbody) ->
+              do
+                let vbody' = vbind venv vbody
+                lbody' <- lbind lenv lbody
+                return (vbody', lbody'))
   where
     (vs,ls) = unzip vvs
     tys     = map idType vs
@@ -262,12 +272,13 @@ mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
   where
     ty = mkCoreTupTy tys
 
-mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
+mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
 mkLiftEnv lv [ty] [v]
-  = do
-      len <- lengthPA (Var v)
-      return (Var v, \env body -> Let (NonRec v env)
-                                $ Case len lv (exprType body) [(DEFAULT, [], body)])
+  = return (Var v, \env body ->
+                   do
+                     len <- lengthPA (Var v)
+                     return . Let (NonRec v env)
+                            $ Case len lv (exprType body) [(DEFAULT, [], body)])
 
 -- NOTE: this transparently deals with empty environments
 mkLiftEnv lv tys vs
@@ -281,9 +292,13 @@ mkLiftEnv lv tys vs
 
           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
                           in
-                          Case scrut (mkWildId (exprType scrut)) (exprType body)
-                            [(DataAlt env_con, lv : vs, body)]
+                          return $ Case scrut (mkWildId (exprType scrut))
+                                        (exprType body)
+                                        [(DataAlt env_con, lv : bndrs, body)]
       return (env, bind)
   where
     vty = mkCoreTupTy tys
 
+    bndrs | null vs   = [mkWildId unitTy]
+          | otherwise = vs
+
index 489d6cc..055137a 100644 (file)
@@ -109,7 +109,11 @@ vectTopBinder var
       return var'
     
 vectTopRhs :: CoreExpr -> VM CoreExpr
-vectTopRhs = liftM fst . closedV . vectPolyExpr (panic "Empty lifting context") . freeVars
+vectTopRhs expr
+  = do
+      lc <- newLocalVar FSLIT("lc") intPrimTy
+      closedV . liftM vectorised
+              $ vectPolyExpr lc (freeVars expr)
 
 -- ----------------------------------------------------------------------------
 -- Bindings
@@ -244,160 +248,20 @@ vectExpr lc e@(_, AnnLam bndr body)
 vectExpr lc (fvs, AnnLam bndr body)
   = do
       tyvars <- localTyVars
-      info <- mkCEnvInfo fvs bndr body
-      (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
-
-      vfn_var <- hoistExpr FSLIT("vfn") poly_vfn
-      lfn_var <- hoistExpr FSLIT("lfn") poly_lfn
-
-      let (venv, lenv) = mkClosureEnvs info (Var lc)
-
-      let env_ty = cenv_vty info
-
-      pa_dict <- paDictOfType env_ty
+      (vs, vvs) <- readLEnv $ \env ->
+                   unzip [(var, vv) | var <- varSetElems fvs
+                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
 
-      arg_ty <- vectType (varType bndr)
+      arg_ty <- vectType (idType bndr)
       res_ty <- vectType (exprType $ deAnnotate body)
+      buildClosure tyvars lc vvs arg_ty res_ty
+        . hoistPolyVExpr FSLIT("fn") tyvars
+        $ do
+            new_lc <- newLocalVar FSLIT("lc") intPrimTy
+            (vbndrs, vbody) <- vectBndrsIn (vs ++ [bndr])
+                                           (vectExpr new_lc body)
+            return $ vLams new_lc vbndrs vbody
 
-      -- FIXME: move the functions to the top level
-      mono_vfn <- polyApply (Var vfn_var) (mkTyVarTys tyvars)
-      mono_lfn <- polyApply (Var lfn_var) (mkTyVarTys tyvars)
-
-      mk_clo <- builtin mkClosureVar
-      mk_cloP <- builtin mkClosurePVar
-
-      let vclo = Var mk_clo  `mkTyApps` [arg_ty, res_ty, env_ty]
-                             `mkApps`   [pa_dict, mono_vfn, mono_lfn, venv]
-          
-          lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
-                             `mkApps`   [pa_dict, mono_vfn, mono_lfn, lenv]
-
-      return (vclo, lclo)
-
-data CEnvInfo = CEnvInfo {
-               cenv_vars         :: [Var]
-             , cenv_values       :: [(CoreExpr, CoreExpr)]
-             , cenv_vty          :: Type
-             , cenv_lty          :: Type
-             , cenv_repr_tycon   :: TyCon
-             , cenv_repr_tyargs  :: [Type]
-             , cenv_repr_datacon :: DataCon
-             }
-
-mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
-mkCEnvInfo fvs arg body
-  = do
-      locals <- readLEnv local_vars
-      let
-          (vars, vals) = unzip
-                 [(var, (Var v, Var v')) | var      <- varSetElems fvs
-                                         , Just (v,v') <- [lookupVarEnv locals var]]
-      vtys <- mapM (vectType . varType) vars
-
-      (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
-      lty <- mkPArrayType vty
-      
-      return $ CEnvInfo {
-                 cenv_vars         = vars
-               , cenv_values       = vals
-               , cenv_vty          = vty
-               , cenv_lty          = lty
-               , cenv_repr_tycon   = repr_tycon
-               , cenv_repr_tyargs  = repr_tyargs
-               , cenv_repr_datacon = repr_datacon
-               }
-  where
-    mk_env_ty [vty]
-      = return (vty, error "absent cinfo_repr_tycon"
-                   , error "absent cinfo_repr_tyargs"
-                   , error "absent cinfo_repr_datacon")
-
-    mk_env_ty vtys
-      = do
-          let ty = mkCoreTupTy vtys
-          (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
-          let [repr_con] = tyConDataCons repr_tc
-          return (ty, repr_tc, repr_tyargs, repr_con)
-
-    
-
-mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
-mkClosureEnvs info lc
-  | [] <- vals
-  = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
-                               [lc, Var unitDataConId])
-
-  | [(vval, lval)] <- vals
-  = (vval, lval)
-
-  | otherwise
-  = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
-                      `mkTyApps` cenv_repr_tyargs info
-                      `mkApps`   (lc : lvals))
-
-  where
-    vals = cenv_values info
-    (vvals, lvals) = unzip vals
-
-mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
-             -> VM (CoreExpr, CoreExpr)
-mkClosureFns info tyvars arg body
-  = closedV
-  . polyAbstract tyvars
-  $ \mk_tlams ->
-  do
-    (vfn, lfn) <- mkClosureMonoFns info arg body
-    return (mk_tlams vfn, mk_tlams lfn)
-
-mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
-mkClosureMonoFns info arg body
-  = do
-      lc_bndr <- newLocalVar FSLIT("lc") intPrimTy
-      (bndrs, (vbody, lbody))
-        <- vectBndrsIn (arg : cenv_vars info)
-                       (vectExpr lc_bndr body)
-      let (varg : vbndrs, larg : lbndrs) = unzip bndrs
-
-      venv_bndr <- newLocalVar FSLIT("env") vty
-      lenv_bndr <- newLocalVar FSLIT("env") lty
-
-      let vcase = bind_venv (Var venv_bndr) vbody vbndrs
-      lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
-      return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
-  where
-    vty = cenv_vty info
-    lty = cenv_lty info
-
-    arity = length (cenv_vars info)
-
-    bind_venv venv vbody []      = vbody
-    bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
-    bind_venv venv vbody vbndrs
-      = Case venv (mkWildId vty) (exprType vbody)
-             [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
-
-    bind_lenv lenv lbody lc_bndr [lbndr]
-      = do
-          len <- lengthPA (Var lbndr)
-          return . Let (NonRec lbndr lenv)
-                 $ Case len
-                        lc_bndr
-                        (exprType lbody)
-                        [(DEFAULT, [], lbody)]
-
-    bind_lenv lenv lbody lc_bndr lbndrs
-      = let scrut = unwrapFamInstScrut (cenv_repr_tycon info)
-                                       (cenv_repr_tyargs info)
-                                       lenv
-            lbndrs' | null lbndrs = [mkWildId unitTy]
-                    | otherwise   = lbndrs
-        in
-        return
-      $ Case scrut
-             (mkWildId (exprType scrut))
-             (exprType lbody)
-             [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs', lbody)]
-          
 vectTyAppExpr :: Var -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)