[project @ 1996-03-19 08:58:34 by partain]
[ghc-hetmet.git] / ghc / compiler / simplStg / StgSATMonad.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
3 %
4 %************************************************************************
5 %*                                                                      *
6 \section[SATMonad]{The Static Argument Transformation pass Monad}
7 %*                                                                      *
8 %************************************************************************
9
10 \begin{code}
11 #include "HsVersions.h"
12
13 module StgSATMonad (
14         getArgLists, saTransform
15     ) where
16
17 import Type             ( mkTyVarTy, mkSigmaTy, TyVarTemplate,
18                           extractTyVarsFromTy, splitSigmaTy, splitTyArgs,
19                           glueTyArgs, instantiateTy, TauType(..),
20                           Class, ThetaType(..), SigmaType(..),
21                           InstTyEnv(..)
22                         )
23 import Id               ( mkSysLocal, idType, eqId )
24 import Maybes           ( Maybe(..) )
25 import StgSyn
26 import SATMonad         ( SATEnv(..), SATInfo(..), Arg(..), updSAEnv, insSAEnv,
27                           SatM(..), initSAT, thenSAT, thenSAT_,
28                           emptyEnvSAT, returnSAT, mapSAT, isStatic, dropStatics,
29                           getSATInfo, newSATName )
30 import SrcLoc           ( SrcLoc, mkUnknownSrcLoc )
31 import UniqSupply
32 import UniqSet          ( UniqSet(..), emptyUniqSet )
33 import Util
34
35 \end{code}
36
37 %************************************************************************
38 %*                                                                      *
39 \subsection{Utility Functions}
40 %*                                                                      *
41 %************************************************************************
42
43 \begin{code}
44 newSATNames :: [Id] -> SatM [Id]
45 newSATNames [] = returnSAT []
46 newSATNames (id:ids) = newSATName id (idType id)        `thenSAT` \ id' ->
47                        newSATNames ids                  `thenSAT` \ ids' ->
48                        returnSAT (id:ids)
49
50 getArgLists :: StgRhs -> ([Arg Type],[Arg Id])
51 getArgLists (StgRhsCon _ _ _)
52   = ([],[])
53 getArgLists (StgRhsClosure _ _ _ _ args _)
54   = ([], [Static v | v <- args])
55
56 \end{code}
57
58 \begin{code}
59 saTransform :: Id -> StgRhs -> SatM StgBinding
60 saTransform binder rhs
61   = getSATInfo binder `thenSAT` \ r ->
62     case r of
63       Just (_,args) | any isStatic args
64       -- [Andre] test: do it only if we have more than one static argument.
65       --Just (_,args) | length (filter isStatic args) > 1
66         -> newSATName binder (new_ty args)      `thenSAT` \ binder' ->
67            let non_static_args = get_nsa args (snd (getArgLists rhs))
68            in
69            newSATNames non_static_args          `thenSAT` \ non_static_args' ->
70            mkNewRhs binder binder' args rhs non_static_args' non_static_args
71                                                 `thenSAT` \ new_rhs ->
72            trace ("SAT(STG) "++ show (length (filter isStatic args))) (
73            returnSAT (StgNonRec binder new_rhs)
74            )
75       _ -> returnSAT (StgRec [(binder, rhs)])
76
77   where
78     get_nsa :: [Arg a] -> [Arg a] -> [a]
79     get_nsa []                  _               = []
80     get_nsa _                   []              = []
81     get_nsa (NotStatic:args)    (Static v:as)   = v:get_nsa args as
82     get_nsa (_:args)            (_:as)          =   get_nsa args as
83
84     mkNewRhs binder binder' args rhs@(StgRhsClosure cc bi fvs upd rhsargs body) non_static_args' non_static_args
85       = let
86           local_body = StgApp (StgVarArg binder')
87                          [StgVarArg a | a <- non_static_args] emptyUniqSet
88
89           rec_body = StgRhsClosure cc bi fvs upd non_static_args'
90                        (doStgSubst binder args subst_env body)
91
92           subst_env = mkIdEnv
93                         ((binder,binder'):zip non_static_args non_static_args')
94         in
95         returnSAT (
96             StgRhsClosure cc bi fvs upd rhsargs
97               (StgLet (StgRec [(binder',rec_body)]) {-in-} local_body)
98         )
99
100     new_ty args
101       = instantiateTy [] (mkSigmaTy [] dict_tys' tau_ty')
102       where
103         -- get type info for the local function:
104         (tv_tmpl, dict_tys, tau_ty) = (splitSigmaTy . idType) binder
105         (reg_arg_tys, res_type)     = splitTyArgs tau_ty
106
107         -- now, we drop the ones that are
108         -- static, that is, the ones we will not pass to the local function
109         l            = length dict_tys
110         dict_tys'    = dropStatics (take l args) dict_tys
111         reg_arg_tys' = dropStatics (drop l args) reg_arg_tys
112         tau_ty'      = glueTyArgs reg_arg_tys' res_type
113 \end{code}
114
115 NOTE: This does not keep live variable/free variable information!!
116
117 \begin{code}
118 doStgSubst binder orig_args subst_env body
119   = substExpr body
120   where
121     substExpr (StgCon con args lvs)
122       = StgCon con (map substAtom args) emptyUniqSet
123     substExpr (StgPrim op args lvs)
124       = StgPrim op (map substAtom args) emptyUniqSet
125     substExpr expr@(StgApp (StgLitArg _) [] _)
126       = expr
127     substExpr (StgApp atom@(StgVarArg v)  args lvs)
128       | v `eqId` binder
129       = StgApp (StgVarArg (lookupNoFailIdEnv subst_env v))
130                (remove_static_args orig_args args) emptyUniqSet
131       | otherwise
132       = StgApp (substAtom atom) (map substAtom args) lvs
133     substExpr (StgCase scrut lv1 lv2 uniq alts)
134       = StgCase (substExpr scrut) emptyUniqSet emptyUniqSet uniq (subst_alts alts)
135       where
136         subst_alts (StgAlgAlts ty alg_alts deflt)
137           = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
138         subst_alts (StgPrimAlts ty prim_alts deflt)
139           = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
140         subst_alg_alt (con, args, use_mask, rhs)
141           = (con, args, use_mask, substExpr rhs)
142         subst_prim_alt (lit, rhs)
143           = (lit, substExpr rhs)
144         subst_deflt StgNoDefault
145           = StgNoDefault
146         subst_deflt (StgBindDefault var used rhs)
147           = StgBindDefault var used (substExpr rhs)
148     substExpr (StgLetNoEscape fv1 fv2 b body)
149       = StgLetNoEscape emptyUniqSet emptyUniqSet (substBinding b) (substExpr body)
150     substExpr (StgLet b body)
151       = StgLet (substBinding b) (substExpr body)
152     substExpr (StgSCC ty cc expr)
153       = StgSCC ty cc (substExpr expr)
154     substRhs (StgRhsCon cc v args)
155       = StgRhsCon cc v (map substAtom args)
156     substRhs (StgRhsClosure cc bi fvs upd args body)
157       = StgRhsClosure cc bi [] upd args (substExpr body)
158
159     substBinding (StgNonRec binder rhs)
160       = StgNonRec binder (substRhs rhs)
161     substBinding (StgRec pairs)
162       = StgRec (zip binders (map substRhs rhss))
163       where
164         (binders,rhss) = unzip pairs
165
166     substAtom atom@(StgLitArg lit) = atom
167     substAtom atom@(StgVarArg v)
168       = case lookupIdEnv subst_env v of
169           Just v' -> StgVarArg v'
170           Nothing -> atom
171
172     remove_static_args _ []
173       = []
174     remove_static_args (Static _:origs) (_:as)
175       = remove_static_args origs as
176     remove_static_args (NotStatic:origs) (a:as)
177       = substAtom a:remove_static_args origs as
178 \end{code}