f0cb84d4d10836614b29e774451454af05506365
[ghc-hetmet.git] / ghc / compiler / simplStg / StgSATMonad.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
3 %
4 %************************************************************************
5 %*                                                                      *
6 \section[SATMonad]{The Static Argument Transformation pass Monad}
7 %*                                                                      *
8 %************************************************************************
9
10 \begin{code}
11 #include "HsVersions.h"
12
13 module StgSATMonad (
14         getArgLists, saTransform, 
15
16         Id, UniType, SplitUniqSupply, PlainStgExpr(..)
17     ) where
18
19 import AbsUniType       ( mkTyVarTy, mkSigmaTy, TyVarTemplate,
20                           extractTyVarsFromTy, splitType, splitTyArgs,
21                           glueTyArgs, instantiateTy, TauType(..),
22                           Class, ThetaType(..), SigmaType(..),
23                           InstTyEnv(..)
24                         )
25 import IdEnv
26 import Id               ( mkSysLocal, getIdUniType, eqId )
27 import Maybes           ( Maybe(..) )
28 import StgSyn
29 import SATMonad         ( SATEnv(..), SATInfo(..), Arg(..), updSAEnv, insSAEnv,
30                           SatM(..), initSAT, thenSAT, thenSAT_,
31                           emptyEnvSAT, returnSAT, mapSAT, isStatic, dropStatics,
32                           getSATInfo, newSATName )
33 import SrcLoc           ( SrcLoc, mkUnknownSrcLoc )
34 import SplitUniq
35 import Unique
36 import UniqSet          ( UniqSet(..), emptyUniqSet )
37 import Util
38
39 \end{code}
40
41 %************************************************************************
42 %*                                                                      *
43 \subsection{Utility Functions}
44 %*                                                                      *
45 %************************************************************************
46
47 \begin{code}
48 newSATNames :: [Id] -> SatM [Id]
49 newSATNames [] = returnSAT []
50 newSATNames (id:ids) = newSATName id (getIdUniType id)  `thenSAT` \ id' ->
51                        newSATNames ids                  `thenSAT` \ ids' ->
52                        returnSAT (id:ids)
53
54 getArgLists :: PlainStgRhs -> ([Arg UniType],[Arg Id])
55 getArgLists (StgRhsCon _ _ _) 
56   = ([],[])
57 getArgLists (StgRhsClosure _ _ _ _ args _)
58   = ([], [Static v | v <- args])
59
60 \end{code}
61
62 \begin{code}
63 saTransform :: Id -> PlainStgRhs -> SatM PlainStgBinding
64 saTransform binder rhs
65   = getSATInfo binder `thenSAT` \ r ->
66     case r of
67       Just (_,args) | any isStatic args 
68       -- [Andre] test: do it only if we have more than one static argument.
69       --Just (_,args) | length (filter isStatic args) > 1
70         -> newSATName binder (new_ty args)      `thenSAT` \ binder' ->
71            let non_static_args = get_nsa args (snd (getArgLists rhs))
72            in
73            newSATNames non_static_args          `thenSAT` \ non_static_args' ->
74            mkNewRhs binder binder' args rhs non_static_args' non_static_args
75                                                 `thenSAT` \ new_rhs ->
76            trace ("SAT(STG) "++ show (length (filter isStatic args))) (
77            returnSAT (StgNonRec binder new_rhs)
78            )
79       _ -> returnSAT (StgRec [(binder, rhs)])
80
81   where
82     get_nsa :: [Arg a] -> [Arg a] -> [a]
83     get_nsa []                  _               = []
84     get_nsa _                   []              = []
85     get_nsa (NotStatic:args)    (Static v:as)   = v:get_nsa args as
86     get_nsa (_:args)            (_:as)          =   get_nsa args as
87
88     mkNewRhs binder binder' args rhs@(StgRhsClosure cc bi fvs upd rhsargs body) non_static_args' non_static_args
89       = let
90           local_body = StgApp (StgVarAtom binder')
91                          [StgVarAtom a | a <- non_static_args] emptyUniqSet
92
93           rec_body = StgRhsClosure cc bi fvs upd non_static_args'
94                        (doStgSubst binder args subst_env body)
95
96           subst_env = mkIdEnv 
97                         ((binder,binder'):zip non_static_args non_static_args')
98         in
99         returnSAT (
100             StgRhsClosure cc bi fvs upd rhsargs 
101               (StgLet (StgRec [(binder',rec_body)]) {-in-} local_body)
102         )
103
104     new_ty args
105       = instantiateTy [] (mkSigmaTy [] dict_tys' tau_ty')
106       where
107         -- get type info for the local function:
108         (tv_tmpl, dict_tys, tau_ty) = (splitType . getIdUniType) binder
109         (reg_arg_tys, res_type)     = splitTyArgs tau_ty
110
111         -- now, we drop the ones that are
112         -- static, that is, the ones we will not pass to the local function
113         l            = length dict_tys
114         dict_tys'    = dropStatics (take l args) dict_tys
115         reg_arg_tys' = dropStatics (drop l args) reg_arg_tys
116         tau_ty'      = glueTyArgs reg_arg_tys' res_type
117 \end{code}
118
119 NOTE: This does not keep live variable/free variable information!!
120
121 \begin{code}
122 doStgSubst binder orig_args subst_env body
123   = substExpr body
124   where 
125     substExpr (StgConApp con args lvs) 
126       = StgConApp con (map substAtom args) emptyUniqSet
127     substExpr (StgPrimApp op args lvs)
128       = StgPrimApp op (map substAtom args) emptyUniqSet
129     substExpr expr@(StgApp (StgLitAtom _) [] _) 
130       = expr
131     substExpr (StgApp atom@(StgVarAtom v)  args lvs)
132       | v `eqId` binder
133       = StgApp (StgVarAtom (lookupNoFailIdEnv subst_env v))
134                (remove_static_args orig_args args) emptyUniqSet
135       | otherwise
136       = StgApp (substAtom atom) (map substAtom args) lvs
137     substExpr (StgCase scrut lv1 lv2 uniq alts)
138       = StgCase (substExpr scrut) emptyUniqSet emptyUniqSet uniq (subst_alts alts)
139       where
140         subst_alts (StgAlgAlts ty alg_alts deflt)
141           = StgAlgAlts ty (map subst_alg_alt alg_alts) (subst_deflt deflt)
142         subst_alts (StgPrimAlts ty prim_alts deflt)
143           = StgPrimAlts ty (map subst_prim_alt prim_alts) (subst_deflt deflt)
144         subst_alg_alt (con, args, use_mask, rhs)
145           = (con, args, use_mask, substExpr rhs)
146         subst_prim_alt (lit, rhs)
147           = (lit, substExpr rhs)
148         subst_deflt StgNoDefault 
149           = StgNoDefault
150         subst_deflt (StgBindDefault var used rhs)
151           = StgBindDefault var used (substExpr rhs)
152     substExpr (StgLetNoEscape fv1 fv2 b body)
153       = StgLetNoEscape emptyUniqSet emptyUniqSet (substBinding b) (substExpr body)
154     substExpr (StgLet b body)
155       = StgLet (substBinding b) (substExpr body)
156     substExpr (StgSCC ty cc expr)
157       = StgSCC ty cc (substExpr expr)
158     substRhs (StgRhsCon cc v args) 
159       = StgRhsCon cc v (map substAtom args)
160     substRhs (StgRhsClosure cc bi fvs upd args body)
161       = StgRhsClosure cc bi [] upd args (substExpr body)
162     
163     substBinding (StgNonRec binder rhs)
164       = StgNonRec binder (substRhs rhs)
165     substBinding (StgRec pairs)
166       = StgRec (zip binders (map substRhs rhss))
167       where
168         (binders,rhss) = unzip pairs
169     
170     substAtom atom@(StgLitAtom lit) = atom
171     substAtom atom@(StgVarAtom v) 
172       = case lookupIdEnv subst_env v of
173           Just v' -> StgVarAtom v'
174           Nothing -> atom
175     
176     remove_static_args _ [] 
177       = []
178     remove_static_args (Static _:origs) (_:as) 
179       = remove_static_args origs as
180     remove_static_args (NotStatic:origs) (a:as) 
181       = substAtom a:remove_static_args origs as
182 \end{code}