[project @ 1996-04-05 08:26:04 by partain]
[ghc-hetmet.git] / ghc / compiler / simplStg / StgSATMonad.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
3 %
4 %************************************************************************
5 %*                                                                      *
6 \section[SATMonad]{The Static Argument Transformation pass Monad}
7 %*                                                                      *
8 %************************************************************************
9
10 \begin{code}
11 #include "HsVersions.h"
12
13 module StgSATMonad ( getArgLists, saTransform ) where
14
15 import Ubiq{-uitous-}
16
17 import Util             ( panic )
18
19 getArgLists = panic "StgSATMonad.getArgLists"
20 saTransform = panic "StgSATMonad.saTransform"
21 \end{code}
22
23 %************************************************************************
24 %*                                                                      *
25 \subsection{Utility Functions}
26 %*                                                                      *
27 %************************************************************************
28
29 \begin{code}
30 {- LATER: to end of file:
31
32 newSATNames :: [Id] -> SatM [Id]
33 newSATNames [] = returnSAT []
34 newSATNames (id:ids) = newSATName id (idType id)        `thenSAT` \ id' ->
35                        newSATNames ids                  `thenSAT` \ ids' ->
36                        returnSAT (id:ids)
37
38 getArgLists :: StgRhs -> ([Arg Type],[Arg Id])
39 getArgLists (StgRhsCon _ _ _)
40   = ([],[])
41 getArgLists (StgRhsClosure _ _ _ _ args _)
42   = ([], [Static v | v <- args])
43
44 \end{code}
45
46 \begin{code}
47 saTransform :: Id -> StgRhs -> SatM StgBinding
48 saTransform binder rhs
49   = getSATInfo binder `thenSAT` \ r ->
50     case r of
51       Just (_,args) | any isStatic args
52       -- [Andre] test: do it only if we have more than one static argument.
53       --Just (_,args) | length (filter isStatic args) > 1
54         -> newSATName binder (new_ty args)      `thenSAT` \ binder' ->
55            let non_static_args = get_nsa args (snd (getArgLists rhs))
56            in
57            newSATNames non_static_args          `thenSAT` \ non_static_args' ->
58            mkNewRhs binder binder' args rhs non_static_args' non_static_args
59                                                 `thenSAT` \ new_rhs ->
60            trace ("SAT(STG) "++ show (length (filter isStatic args))) (
61            returnSAT (StgNonRec binder new_rhs)
62            )
63       _ -> returnSAT (StgRec [(binder, rhs)])
64
65   where
66     get_nsa :: [Arg a] -> [Arg a] -> [a]
67     get_nsa []                  _               = []
68     get_nsa _                   []              = []
69     get_nsa (NotStatic:args)    (Static v:as)   = v:get_nsa args as
70     get_nsa (_:args)            (_:as)          =   get_nsa args as
71
72     mkNewRhs binder binder' args rhs@(StgRhsClosure cc bi fvs upd rhsargs body) non_static_args' non_static_args
73       = let
74           local_body = StgApp (StgVarArg binder')
75                          [StgVarArg a | a <- non_static_args] emptyUniqSet
76
77           rec_body = StgRhsClosure cc bi fvs upd non_static_args'
78                        (doStgSubst binder args subst_env body)
79
80           subst_env = mkIdEnv
81                         ((binder,binder'):zip non_static_args non_static_args')
82         in
83         returnSAT (
84             StgRhsClosure cc bi fvs upd rhsargs
85               (StgLet (StgRec [(binder',rec_body)]) {-in-} local_body)
86         )
87
88     new_ty args
89       = instantiateTy [] (mkSigmaTy [] dict_tys' tau_ty')
90       where
91         -- get type info for the local function:
92         (tv_tmpl, dict_tys, tau_ty) = (splitSigmaTy . idType) binder
93         (reg_arg_tys, res_type)     = splitTyArgs tau_ty
94
95         -- now, we drop the ones that are
96         -- static, that is, the ones we will not pass to the local function
97         l            = length dict_tys
98         dict_tys'    = dropStatics (take l args) dict_tys
99         reg_arg_tys' = dropStatics (drop l args) reg_arg_tys
100         tau_ty'      = glueTyArgs reg_arg_tys' res_type
101 \end{code}
102
103 NOTE: This does not keep live variable/free variable information!!
104
105 \begin{code}
106 doStgSubst binder orig_args subst_env body
107   = substExpr body
108   where
109     substExpr (StgCon con args lvs)
110       = StgCon con (map substAtom args) emptyUniqSet
111     substExpr (StgPrim op args lvs)
112       = StgPrim op (map substAtom args) emptyUniqSet
113     substExpr expr@(StgApp (StgLitArg _) [] _)
114       = expr
115     substExpr (StgApp atom@(StgVarArg v)  args lvs)
116       | v `eqId` binder
117       = StgApp (StgVarArg (lookupNoFailIdEnv subst_env v))
118                (remove_static_args orig_args args) emptyUniqSet
119       | otherwise
120       = StgApp (substAtom atom) (map substAtom args) lvs
121     substExpr (StgCase scrut lv1 lv2 uniq alts)
122       = StgCase (substExpr scrut) emptyUniqSet emptyUniqSet uniq (subst_alts alts)
123       where
124         subst_alts (StgAlgAlts ty alg_alts deflt)
125           = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
126         subst_alts (StgPrimAlts ty prim_alts deflt)
127           = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
128         subst_alg_alt (con, args, use_mask, rhs)
129           = (con, args, use_mask, substExpr rhs)
130         subst_prim_alt (lit, rhs)
131           = (lit, substExpr rhs)
132         subst_deflt StgNoDefault
133           = StgNoDefault
134         subst_deflt (StgBindDefault var used rhs)
135           = StgBindDefault var used (substExpr rhs)
136     substExpr (StgLetNoEscape fv1 fv2 b body)
137       = StgLetNoEscape emptyUniqSet emptyUniqSet (substBinding b) (substExpr body)
138     substExpr (StgLet b body)
139       = StgLet (substBinding b) (substExpr body)
140     substExpr (StgSCC ty cc expr)
141       = StgSCC ty cc (substExpr expr)
142     substRhs (StgRhsCon cc v args)
143       = StgRhsCon cc v (map substAtom args)
144     substRhs (StgRhsClosure cc bi fvs upd args body)
145       = StgRhsClosure cc bi [] upd args (substExpr body)
146
147     substBinding (StgNonRec binder rhs)
148       = StgNonRec binder (substRhs rhs)
149     substBinding (StgRec pairs)
150       = StgRec (zip binders (map substRhs rhss))
151       where
152         (binders,rhss) = unzip pairs
153
154     substAtom atom@(StgLitArg lit) = atom
155     substAtom atom@(StgVarArg v)
156       = case lookupIdEnv subst_env v of
157           Just v' -> StgVarArg v'
158           Nothing -> atom
159
160     remove_static_args _ []
161       = []
162     remove_static_args (Static _:origs) (_:as)
163       = remove_static_args origs as
164     remove_static_args (NotStatic:origs) (a:as)
165       = substAtom a:remove_static_args origs as
166 -}
167 \end{code}