[project @ 1996-01-08 20:28:12 by partain]
[ghc-hetmet.git] / ghc / compiler / simplStg / StgSATMonad.lhs
diff --git a/ghc/compiler/simplStg/StgSATMonad.lhs b/ghc/compiler/simplStg/StgSATMonad.lhs
new file mode 100644 (file)
index 0000000..f0cb84d
--- /dev/null
@@ -0,0 +1,182 @@
+%
+% (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
+%
+%************************************************************************
+%*                                                                     *
+\section[SATMonad]{The Static Argument Transformation pass Monad}
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+#include "HsVersions.h"
+
+module StgSATMonad (
+       getArgLists, saTransform, 
+
+       Id, UniType, SplitUniqSupply, PlainStgExpr(..)
+    ) where
+
+import AbsUniType      ( mkTyVarTy, mkSigmaTy, TyVarTemplate,
+                         extractTyVarsFromTy, splitType, splitTyArgs,
+                         glueTyArgs, instantiateTy, TauType(..),
+                         Class, ThetaType(..), SigmaType(..),
+                         InstTyEnv(..)
+                       )
+import IdEnv
+import Id              ( mkSysLocal, getIdUniType, eqId )
+import Maybes          ( Maybe(..) )
+import StgSyn
+import SATMonad         ( SATEnv(..), SATInfo(..), Arg(..), updSAEnv, insSAEnv,
+                          SatM(..), initSAT, thenSAT, thenSAT_,
+                          emptyEnvSAT, returnSAT, mapSAT, isStatic, dropStatics,
+                          getSATInfo, newSATName )
+import SrcLoc          ( SrcLoc, mkUnknownSrcLoc )
+import SplitUniq
+import Unique
+import UniqSet         ( UniqSet(..), emptyUniqSet )
+import Util
+
+\end{code}
+
+%************************************************************************
+%*                                                                     *
+\subsection{Utility Functions}
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+newSATNames :: [Id] -> SatM [Id]
+newSATNames [] = returnSAT []
+newSATNames (id:ids) = newSATName id (getIdUniType id) `thenSAT` \ id' ->
+                       newSATNames ids                 `thenSAT` \ ids' ->
+                       returnSAT (id:ids)
+
+getArgLists :: PlainStgRhs -> ([Arg UniType],[Arg Id])
+getArgLists (StgRhsCon _ _ _) 
+  = ([],[])
+getArgLists (StgRhsClosure _ _ _ _ args _)
+  = ([], [Static v | v <- args])
+
+\end{code}
+
+\begin{code}
+saTransform :: Id -> PlainStgRhs -> SatM PlainStgBinding
+saTransform binder rhs
+  = getSATInfo binder `thenSAT` \ r ->
+    case r of
+      Just (_,args) | any isStatic args 
+      -- [Andre] test: do it only if we have more than one static argument.
+      --Just (_,args) | length (filter isStatic args) > 1
+       -> newSATName binder (new_ty args)      `thenSAT` \ binder' ->
+           let non_static_args = get_nsa args (snd (getArgLists rhs))
+           in
+          newSATNames non_static_args          `thenSAT` \ non_static_args' ->
+          mkNewRhs binder binder' args rhs non_static_args' non_static_args
+                                               `thenSAT` \ new_rhs ->
+          trace ("SAT(STG) "++ show (length (filter isStatic args))) (
+           returnSAT (StgNonRec binder new_rhs)
+           )
+      _ -> returnSAT (StgRec [(binder, rhs)])
+
+  where
+    get_nsa :: [Arg a] -> [Arg a] -> [a]
+    get_nsa []                 _               = []
+    get_nsa _                  []              = []
+    get_nsa (NotStatic:args)   (Static v:as)   = v:get_nsa args as
+    get_nsa (_:args)           (_:as)          =   get_nsa args as
+
+    mkNewRhs binder binder' args rhs@(StgRhsClosure cc bi fvs upd rhsargs body) non_static_args' non_static_args
+      = let
+         local_body = StgApp (StgVarAtom binder')
+                        [StgVarAtom a | a <- non_static_args] emptyUniqSet
+
+         rec_body = StgRhsClosure cc bi fvs upd non_static_args'
+                      (doStgSubst binder args subst_env body)
+
+         subst_env = mkIdEnv 
+                        ((binder,binder'):zip non_static_args non_static_args')
+       in
+       returnSAT (
+           StgRhsClosure cc bi fvs upd rhsargs 
+             (StgLet (StgRec [(binder',rec_body)]) {-in-} local_body)
+       )
+
+    new_ty args
+      = instantiateTy [] (mkSigmaTy [] dict_tys' tau_ty')
+      where
+       -- get type info for the local function:
+       (tv_tmpl, dict_tys, tau_ty) = (splitType . getIdUniType) binder
+       (reg_arg_tys, res_type)     = splitTyArgs tau_ty
+
+       -- now, we drop the ones that are
+       -- static, that is, the ones we will not pass to the local function
+       l            = length dict_tys
+       dict_tys'    = dropStatics (take l args) dict_tys
+       reg_arg_tys' = dropStatics (drop l args) reg_arg_tys
+       tau_ty'      = glueTyArgs reg_arg_tys' res_type
+\end{code}
+
+NOTE: This does not keep live variable/free variable information!!
+
+\begin{code}
+doStgSubst binder orig_args subst_env body
+  = substExpr body
+  where 
+    substExpr (StgConApp con args lvs) 
+      = StgConApp con (map substAtom args) emptyUniqSet
+    substExpr (StgPrimApp op args lvs)
+      = StgPrimApp op (map substAtom args) emptyUniqSet
+    substExpr expr@(StgApp (StgLitAtom _) [] _) 
+      = expr
+    substExpr (StgApp atom@(StgVarAtom v)  args lvs)
+      | v `eqId` binder
+      = StgApp (StgVarAtom (lookupNoFailIdEnv subst_env v))
+               (remove_static_args orig_args args) emptyUniqSet
+      | otherwise
+      = StgApp (substAtom atom) (map substAtom args) lvs
+    substExpr (StgCase scrut lv1 lv2 uniq alts)
+      = StgCase (substExpr scrut) emptyUniqSet emptyUniqSet uniq (subst_alts alts)
+      where
+        subst_alts (StgAlgAlts ty alg_alts deflt)
+          = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
+        subst_alts (StgPrimAlts ty prim_alts deflt)
+          = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
+        subst_alg_alt (con, args, use_mask, rhs)
+          = (con, args, use_mask, substExpr rhs)
+        subst_prim_alt (lit, rhs)
+          = (lit, substExpr rhs)
+        subst_deflt StgNoDefault 
+          = StgNoDefault
+        subst_deflt (StgBindDefault var used rhs)
+          = StgBindDefault var used (substExpr rhs)
+    substExpr (StgLetNoEscape fv1 fv2 b body)
+      = StgLetNoEscape emptyUniqSet emptyUniqSet (substBinding b) (substExpr body)
+    substExpr (StgLet b body)
+      = StgLet (substBinding b) (substExpr body)
+    substExpr (StgSCC ty cc expr)
+      = StgSCC ty cc (substExpr expr)
+    substRhs (StgRhsCon cc v args) 
+      = StgRhsCon cc v (map substAtom args)
+    substRhs (StgRhsClosure cc bi fvs upd args body)
+      = StgRhsClosure cc bi [] upd args (substExpr body)
+    
+    substBinding (StgNonRec binder rhs)
+      = StgNonRec binder (substRhs rhs)
+    substBinding (StgRec pairs)
+      = StgRec (zip binders (map substRhs rhss))
+      where
+        (binders,rhss) = unzip pairs
+    
+    substAtom atom@(StgLitAtom lit) = atom
+    substAtom atom@(StgVarAtom v) 
+      = case lookupIdEnv subst_env v of
+          Just v' -> StgVarAtom v'
+          Nothing -> atom
+    
+    remove_static_args _ [] 
+      = []
+    remove_static_args (Static _:origs) (_:as) 
+      = remove_static_args origs as
+    remove_static_args (NotStatic:origs) (a:as) 
+      = substAtom a:remove_static_args origs as
+\end{code}