[project @ 1997-03-14 07:52:06 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplStg / LambdaLift.lhs
index 40d180a..1abccae 100644 (file)
@@ -1,5 +1,5 @@
 %
-% (c) The AQUA Project, Glasgow University, 1994-1995
+% (c) The AQUA Project, Glasgow University, 1994-1996
 %
 \section[LambdaLift]{A STG-code lambda lifter}
 
@@ -8,18 +8,21 @@
 
 module LambdaLift ( liftProgram ) where
 
+IMP_Ubiq(){-uitous-}
+
 import StgSyn
 
-import Type            ( mkForallTy, splitForalls, glueTyArgs,
-                         Type, RhoType(..), TauType(..)
+import Bag             ( emptyBag, unionBags, unitBag, snocBag, bagToList )
+import Id              ( idType, mkSysLocal, addIdArity, 
+                         mkIdSet, unitIdSet, minusIdSet, setIdVisibility,
+                         unionManyIdSets, idSetToList, SYN_IE(IdSet),
+                         nullIdEnv, growIdEnvList, lookupIdEnv, SYN_IE(IdEnv)
                        )
-import Bag
-import Id              ( mkSysLocal, idType, addIdArity, Id )
-import Maybes
-import UniqSupply
-import SrcLoc          ( mkUnknownSrcLoc, SrcLoc )
-import UniqSet
-import Util
+import IdInfo          ( ArityInfo, exactArity )
+import SrcLoc          ( noSrcLoc )
+import Type            ( splitForAllTy, mkForAllTys, mkFunTys )
+import UniqSupply      ( getUnique, splitUniqSupply )
+import Util            ( zipEqual, panic, assertPanic )
 \end{code}
 
 This is the lambda lifter.  It turns lambda abstractions into
@@ -84,11 +87,13 @@ supercombinators on a selective basis:
   recursive calls, which may now have lots of free vars.
 
 Recent Observations:
+
 * 2 might be already ``too many'' variables to abstract.
   The problem is that the increase in the number of free variables
   of closures refering to the lifted function (which is always # of
   abstracted args - 1) may increase heap allocation a lot.
   Expeiments are being done to check this...
+
 * We do not lambda lift if the function has at least one occurrence
   without any arguments. This caused lots of problems. Ex:
   h = \ x -> ... let y = ...
@@ -117,8 +122,8 @@ Recent Observations:
 %************************************************************************
 
 \begin{code}
-liftProgram :: UniqSupply -> [StgBinding] -> [StgBinding]
-liftProgram us prog = concat (runLM Nothing us (mapLM liftTopBind prog))
+liftProgram :: Module -> UniqSupply -> [StgBinding] -> [StgBinding]
+liftProgram mod us prog = concat (runLM mod Nothing us (mapLM liftTopBind prog))
 
 
 liftTopBind :: StgBinding -> LiftM [StgBinding]
@@ -145,8 +150,9 @@ liftExpr expr@(StgCon con args lvs) = returnLM (expr, emptyLiftInfo)
 liftExpr expr@(StgPrim op args lvs) = returnLM (expr, emptyLiftInfo)
 
 liftExpr expr@(StgApp (StgLitArg lit) args lvs) = returnLM (expr, emptyLiftInfo)
+liftExpr expr@(StgApp (StgConArg con) args lvs) = returnLM (expr, emptyLiftInfo)
 liftExpr expr@(StgApp (StgVarArg v)  args lvs)
-  = lookup v           `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
+  = lookUp v           `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
                                                        -- poke these bindings too early!
     returnLM (StgApp (StgVarArg sc) (map StgVarArg sc_args ++ args) lvs,
              emptyLiftInfo)
@@ -196,7 +202,7 @@ liftExpr (StgLetNoEscape _ _ (StgNonRec binder rhs) body)
 liftExpr (StgLetNoEscape _ _ (StgRec pairs) body)
   = liftExpr body                      `thenLM` \ (body', body_info) ->
     mapAndUnzipLM dontLiftRhs rhss     `thenLM` \ (rhss', rhs_infos) ->
-    returnLM (StgLet (StgRec (binders `zipEqual` rhss')) body',
+    returnLM (StgLet (StgRec (zipEqual "liftExpr" binders rhss')) body',
              foldr unionLiftInfo body_info rhs_infos)
   where
    (binders,rhss) = unzip pairs
@@ -238,7 +244,7 @@ liftExpr (StgLet (StgRec pairs) body)
   | not (all isLiftableRec rhss)
   = liftExpr body                      `thenLM` \ (body', body_info) ->
     mapAndUnzipLM dontLiftRhs rhss     `thenLM` \ (rhss', rhs_infos) ->
-    returnLM (StgLet (StgRec (binders `zipEqual` rhss')) body',
+    returnLM (StgLet (StgRec (zipEqual "liftExpr2" binders rhss')) body',
              foldr unionLiftInfo body_info rhs_infos)
 
   | otherwise  -- All rhss are liftable
@@ -251,9 +257,9 @@ liftExpr (StgLet (StgRec pairs) body)
       let
        -- Find the free vars of all the rhss,
        -- excluding the binders themselves.
-       rhs_free_vars = unionManyUniqSets (map rhsFreeVars rhss)
-                       `minusUniqSet`
-                       mkUniqSet binders
+       rhs_free_vars = unionManyIdSets (map rhsFreeVars rhss)
+                       `minusIdSet`
+                       mkIdSet binders
 
        rhs_info      = unionLiftInfos rhs_infos
       in
@@ -335,7 +341,7 @@ isLiftableRec (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _
 isLiftableRec other_rhs = False
 
 rhsFreeVars :: StgRhs -> IdSet
-rhsFreeVars (StgRhsClosure _ _ fvs _ _ _) = mkUniqSet fvs
+rhsFreeVars (StgRhsClosure _ _ fvs _ _ _) = mkIdSet fvs
 rhsFreeVars other                        = panic "rhsFreeVars"
 \end{code}
 
@@ -364,22 +370,18 @@ mkScPieces :: IdSet               -- Extra args for the supercombinator
 mkScPieces extra_arg_set (id, StgRhsClosure cc bi _ upd args body)
   = ASSERT( n_args > 0 )
        -- Construct the rhs of the supercombinator, and its Id
-    -- this trace blackholes sometimes, don't use it
-    -- trace ("LL " ++ show (length (uniqSetToList extra_arg_set))) (
     newSupercombinator sc_ty arity  `thenLM` \ sc_id ->
-
     returnLM ((sc_id, extra_args), (sc_id, sc_rhs))
-    --)
   where
     n_args     = length args
-    extra_args = uniqSetToList extra_arg_set
+    extra_args = idSetToList extra_arg_set
     arity      = n_args + length extra_args
 
        -- Construct the supercombinator type
     type_of_original_id = idType id
     extra_arg_tys       = map idType extra_args
-    (tyvars, rest)      = splitForalls type_of_original_id
-    sc_ty              = mkForallTy tyvars (glueTyArgs extra_arg_tys rest)
+    (tyvars, rest)      = splitForAllTy type_of_original_id
+    sc_ty              = mkForAllTys tyvars (mkFunTys extra_arg_tys rest)
 
     sc_rhs = StgRhsClosure cc bi [] upd (extra_args ++ args) body
 \end{code}
@@ -394,7 +396,8 @@ mkScPieces extra_arg_set (id, StgRhsClosure cc bi _ upd args body)
 The monad is used only to distribute global stuff, and the unique supply.
 
 \begin{code}
-type LiftM a =  LiftFlags
+type LiftM a =  Module 
+            -> LiftFlags
             -> UniqSupply
             -> (IdEnv                          -- Domain = candidates for lifting
                       (Id,                     -- The supercombinator
@@ -407,22 +410,22 @@ type LiftFlags = Maybe Int        -- No of fvs reqd to float recursive
                                -- binding; Nothing == infinity
 
 
-runLM :: LiftFlags -> UniqSupply -> LiftM a -> a
-runLM flags us m = m flags us nullIdEnv
+runLM :: Module -> LiftFlags -> UniqSupply -> LiftM a -> a
+runLM mod flags us m = m mod flags us nullIdEnv
 
 thenLM :: LiftM a -> (a -> LiftM b) -> LiftM b
-thenLM m k ci us idenv
-  = k (m ci us1 idenv) ci us2 idenv
+thenLM m k mod ci us idenv
+  = k (m mod ci us1 idenv) mod ci us2 idenv
   where
     (us1, us2) = splitUniqSupply us
 
 returnLM :: a -> LiftM a
-returnLM a ci us idenv = a
+returnLM a mod ci us idenv = a
 
 fixLM :: (a -> LiftM a) -> LiftM a
-fixLM k ci us idenv = r
+fixLM k mod ci us idenv = r
                       where
-                        r = k r ci us idenv
+                        r = k r mod ci us idenv
 
 mapLM :: (a -> LiftM b) -> [a] -> LiftM [b]
 mapLM f [] = returnLM []
@@ -442,22 +445,22 @@ newSupercombinator :: Type
                   -> Int               -- Arity
                   -> LiftM Id
 
-newSupercombinator ty arity ci us idenv
-  = (mkSysLocal SLIT("sc") uniq ty mkUnknownSrcLoc)    -- ToDo: improve location
-    `addIdArity` arity
+newSupercombinator ty arity mod ci us idenv
+  = setIdVisibility mod (mkSysLocal SLIT("sc") uniq ty noSrcLoc)
+    `addIdArity` exactArity arity
        -- ToDo: rm the addIdArity?  Just let subsequent stg-saturation pass do it?
   where
     uniq = getUnique us
 
-lookup :: Id -> LiftM (Id,[Id])
-lookup v ci us idenv
-  = case lookupIdEnv idenv v of
-       Just result -> result
-       Nothing     -> (v, [])
+lookUp :: Id -> LiftM (Id,[Id])
+lookUp v mod ci us idenv
+  = case (lookupIdEnv idenv v) of
+      Just result -> result
+      Nothing     -> (v, [])
 
 addScInlines :: [Id] -> [(Id,[Id])] -> LiftM a -> LiftM a
-addScInlines ids values m ci us idenv
-  = m ci us idenv'
+addScInlines ids values m mod ci us idenv
+  = m mod ci us idenv'
   where
     idenv' = growIdEnvList idenv (ids `zip_lazy` values)
 
@@ -487,15 +490,14 @@ addScInlines ids values m ci us idenv
 
 getFinalFreeVars :: IdSet -> LiftM IdSet
 
-getFinalFreeVars free_vars ci us idenv
-  = unionManyUniqSets (map munge_it (uniqSetToList free_vars))
+getFinalFreeVars free_vars mod ci us idenv
+  = unionManyIdSets (map munge_it (idSetToList free_vars))
   where
     munge_it :: Id -> IdSet    -- Takes a free var and maps it to the "real"
                                -- free var
-    munge_it id = case lookupIdEnv idenv id of
-                       Just (_, args) -> mkUniqSet args
-                       Nothing        -> singletonUniqSet id
-
+    munge_it id = case (lookupIdEnv idenv id) of
+                   Just (_, args) -> mkIdSet args
+                   Nothing        -> unitIdSet id
 \end{code}