#include "HsVersions.h"
module StgSATMonad (
- getArgLists, saTransform,
-
- Id, UniType, SplitUniqSupply, PlainStgExpr(..)
+ getArgLists, saTransform
) where
-import AbsUniType ( mkTyVarTy, mkSigmaTy, TyVarTemplate,
- extractTyVarsFromTy, splitType, splitTyArgs,
+import Type ( mkTyVarTy, mkSigmaTy, TyVarTemplate,
+ extractTyVarsFromTy, splitSigmaTy, splitTyArgs,
glueTyArgs, instantiateTy, TauType(..),
Class, ThetaType(..), SigmaType(..),
InstTyEnv(..)
)
-import IdEnv
-import Id ( mkSysLocal, getIdUniType, eqId )
+import Id ( mkSysLocal, idType, eqId )
import Maybes ( Maybe(..) )
import StgSyn
import SATMonad ( SATEnv(..), SATInfo(..), Arg(..), updSAEnv, insSAEnv,
- SatM(..), initSAT, thenSAT, thenSAT_,
- emptyEnvSAT, returnSAT, mapSAT, isStatic, dropStatics,
- getSATInfo, newSATName )
+ SatM(..), initSAT, thenSAT, thenSAT_,
+ emptyEnvSAT, returnSAT, mapSAT, isStatic, dropStatics,
+ getSATInfo, newSATName )
import SrcLoc ( SrcLoc, mkUnknownSrcLoc )
-import SplitUniq
-import Unique
+import UniqSupply
import UniqSet ( UniqSet(..), emptyUniqSet )
import Util
\begin{code}
newSATNames :: [Id] -> SatM [Id]
newSATNames [] = returnSAT []
-newSATNames (id:ids) = newSATName id (getIdUniType id) `thenSAT` \ id' ->
- newSATNames ids `thenSAT` \ ids' ->
- returnSAT (id:ids)
+newSATNames (id:ids) = newSATName id (idType id) `thenSAT` \ id' ->
+ newSATNames ids `thenSAT` \ ids' ->
+ returnSAT (id:ids)
-getArgLists :: PlainStgRhs -> ([Arg UniType],[Arg Id])
-getArgLists (StgRhsCon _ _ _)
+getArgLists :: StgRhs -> ([Arg Type],[Arg Id])
+getArgLists (StgRhsCon _ _ _)
= ([],[])
getArgLists (StgRhsClosure _ _ _ _ args _)
= ([], [Static v | v <- args])
\end{code}
\begin{code}
-saTransform :: Id -> PlainStgRhs -> SatM PlainStgBinding
+saTransform :: Id -> StgRhs -> SatM StgBinding
saTransform binder rhs
= getSATInfo binder `thenSAT` \ r ->
case r of
- Just (_,args) | any isStatic args
+ Just (_,args) | any isStatic args
-- [Andre] test: do it only if we have more than one static argument.
--Just (_,args) | length (filter isStatic args) > 1
-> newSATName binder (new_ty args) `thenSAT` \ binder' ->
- let non_static_args = get_nsa args (snd (getArgLists rhs))
- in
+ let non_static_args = get_nsa args (snd (getArgLists rhs))
+ in
newSATNames non_static_args `thenSAT` \ non_static_args' ->
mkNewRhs binder binder' args rhs non_static_args' non_static_args
`thenSAT` \ new_rhs ->
trace ("SAT(STG) "++ show (length (filter isStatic args))) (
- returnSAT (StgNonRec binder new_rhs)
- )
+ returnSAT (StgNonRec binder new_rhs)
+ )
_ -> returnSAT (StgRec [(binder, rhs)])
where
mkNewRhs binder binder' args rhs@(StgRhsClosure cc bi fvs upd rhsargs body) non_static_args' non_static_args
= let
- local_body = StgApp (StgVarAtom binder')
- [StgVarAtom a | a <- non_static_args] emptyUniqSet
+ local_body = StgApp (StgVarArg binder')
+ [StgVarArg a | a <- non_static_args] emptyUniqSet
rec_body = StgRhsClosure cc bi fvs upd non_static_args'
- (doStgSubst binder args subst_env body)
+ (doStgSubst binder args subst_env body)
- subst_env = mkIdEnv
- ((binder,binder'):zip non_static_args non_static_args')
+ subst_env = mkIdEnv
+ ((binder,binder'):zip non_static_args non_static_args')
in
returnSAT (
- StgRhsClosure cc bi fvs upd rhsargs
+ StgRhsClosure cc bi fvs upd rhsargs
(StgLet (StgRec [(binder',rec_body)]) {-in-} local_body)
)
= instantiateTy [] (mkSigmaTy [] dict_tys' tau_ty')
where
-- get type info for the local function:
- (tv_tmpl, dict_tys, tau_ty) = (splitType . getIdUniType) binder
+ (tv_tmpl, dict_tys, tau_ty) = (splitSigmaTy . idType) binder
(reg_arg_tys, res_type) = splitTyArgs tau_ty
-- now, we drop the ones that are
\begin{code}
doStgSubst binder orig_args subst_env body
= substExpr body
- where
- substExpr (StgConApp con args lvs)
- = StgConApp con (map substAtom args) emptyUniqSet
- substExpr (StgPrimApp op args lvs)
- = StgPrimApp op (map substAtom args) emptyUniqSet
- substExpr expr@(StgApp (StgLitAtom _) [] _)
+ where
+ substExpr (StgCon con args lvs)
+ = StgCon con (map substAtom args) emptyUniqSet
+ substExpr (StgPrim op args lvs)
+ = StgPrim op (map substAtom args) emptyUniqSet
+ substExpr expr@(StgApp (StgLitArg _) [] _)
= expr
- substExpr (StgApp atom@(StgVarAtom v) args lvs)
+ substExpr (StgApp atom@(StgVarArg v) args lvs)
| v `eqId` binder
- = StgApp (StgVarAtom (lookupNoFailIdEnv subst_env v))
- (remove_static_args orig_args args) emptyUniqSet
+ = StgApp (StgVarArg (lookupNoFailIdEnv subst_env v))
+ (remove_static_args orig_args args) emptyUniqSet
| otherwise
= StgApp (substAtom atom) (map substAtom args) lvs
substExpr (StgCase scrut lv1 lv2 uniq alts)
= StgCase (substExpr scrut) emptyUniqSet emptyUniqSet uniq (subst_alts alts)
where
- subst_alts (StgAlgAlts ty alg_alts deflt)
- = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
- subst_alts (StgPrimAlts ty prim_alts deflt)
- = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
- subst_alg_alt (con, args, use_mask, rhs)
- = (con, args, use_mask, substExpr rhs)
- subst_prim_alt (lit, rhs)
- = (lit, substExpr rhs)
- subst_deflt StgNoDefault
- = StgNoDefault
- subst_deflt (StgBindDefault var used rhs)
- = StgBindDefault var used (substExpr rhs)
+ subst_alts (StgAlgAlts ty alg_alts deflt)
+ = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
+ subst_alts (StgPrimAlts ty prim_alts deflt)
+ = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
+ subst_alg_alt (con, args, use_mask, rhs)
+ = (con, args, use_mask, substExpr rhs)
+ subst_prim_alt (lit, rhs)
+ = (lit, substExpr rhs)
+ subst_deflt StgNoDefault
+ = StgNoDefault
+ subst_deflt (StgBindDefault var used rhs)
+ = StgBindDefault var used (substExpr rhs)
substExpr (StgLetNoEscape fv1 fv2 b body)
= StgLetNoEscape emptyUniqSet emptyUniqSet (substBinding b) (substExpr body)
substExpr (StgLet b body)
= StgLet (substBinding b) (substExpr body)
substExpr (StgSCC ty cc expr)
= StgSCC ty cc (substExpr expr)
- substRhs (StgRhsCon cc v args)
+ substRhs (StgRhsCon cc v args)
= StgRhsCon cc v (map substAtom args)
substRhs (StgRhsClosure cc bi fvs upd args body)
= StgRhsClosure cc bi [] upd args (substExpr body)
-
+
substBinding (StgNonRec binder rhs)
= StgNonRec binder (substRhs rhs)
substBinding (StgRec pairs)
= StgRec (zip binders (map substRhs rhss))
where
- (binders,rhss) = unzip pairs
-
- substAtom atom@(StgLitAtom lit) = atom
- substAtom atom@(StgVarAtom v)
+ (binders,rhss) = unzip pairs
+
+ substAtom atom@(StgLitArg lit) = atom
+ substAtom atom@(StgVarArg v)
= case lookupIdEnv subst_env v of
- Just v' -> StgVarAtom v'
- Nothing -> atom
-
- remove_static_args _ []
+ Just v' -> StgVarArg v'
+ Nothing -> atom
+
+ remove_static_args _ []
= []
- remove_static_args (Static _:origs) (_:as)
+ remove_static_args (Static _:origs) (_:as)
= remove_static_args origs as
- remove_static_args (NotStatic:origs) (a:as)
+ remove_static_args (NotStatic:origs) (a:as)
= substAtom a:remove_static_args origs as
\end{code}