[project @ 1999-05-28 13:32:50 by simonmar]
authorsimonmar <unknown>
Fri, 28 May 1999 13:32:50 +0000 (13:32 +0000)
committersimonmar <unknown>
Fri, 28 May 1999 13:32:50 +0000 (13:32 +0000)
Fixes for case-of-case and let-no-escape.

ghc/compiler/simplStg/SRT.lhs

index 00c31f5..356d84b 100644 (file)
@@ -40,6 +40,8 @@ Our functions have type
 
        :: UniqFM CafInfo       -- which top-level ids don't refer to any CAfs
        -> SrtOffset            -- next free offset within the SRT
+       -> (UniqSet Id,         -- global refs in the continuation
+           UniqFM (UniqSet Id))-- global refs in let-no-escaped variables
 {- * -}        -> StgExpr              -- expression to analyse
 
        -> (StgExpr,            -- (e) newly annotated expression
@@ -111,7 +113,8 @@ srtTopBind
 srtTopBind rho (StgNonRec binder rhs) =
 
    -- no need to use circularity for non-recursive bindings
-   srtRhs rho 0{-initial offset-} rhs          =: \(rhs, g, srt, off) ->
+   srtRhs rho (emptyUniqSet,emptyUFM) 0{-initial offset-} rhs
+                                       =: \(rhs, g, srt, off) ->
    let
        filtered_g = filter (mayHaveCafRefs rho) (uniqSetToList g)
         caf_info   = mk_caf_info rhs filtered_g
@@ -156,7 +159,8 @@ srtTopBind rho (StgRec bs) =
 
     doBinds [] new_binds g srt off = (reverse new_binds, g, srt, off)
     doBinds ((binder,rhs):binds) new_binds g srt off =
-       srtRhs rho' off rhs =: \(rhs, rhs_g, rhs_srt, off) ->
+       srtRhs rho' (emptyUniqSet,emptyUFM) off rhs 
+                               =: \(rhs, rhs_g, rhs_srt, off) ->
        let 
            g'   = unionUniqSets rhs_g g
            srt' = rhs_srt ++ srt
@@ -171,14 +175,14 @@ caf_rhs _ = False
 Non-top-level bindings
 
 \begin{code}
-srtBind :: UniqFM CafInfo -> Int -> StgBinding
-       -> (StgBinding, UniqSet Id, [Id], Int)
+srtBind :: UniqFM CafInfo -> (UniqSet Id, UniqFM (UniqSet Id))
+       -> Int -> StgBinding -> (StgBinding, UniqSet Id, [Id], Int)
 
-srtBind rho off (StgNonRec binder rhs) =
-  srtRhs rho off rhs   =: \(rhs, g, srt, off) ->
+srtBind rho cont_refs off (StgNonRec binder rhs) =
+  srtRhs rho cont_refs off rhs   =: \(rhs, g, srt, off) ->
   (StgNonRec binder rhs, g, srt, off)
 
-srtBind rho off (StgRec binds) =
+srtBind rho cont_refs off (StgRec binds) =
     (StgRec new_binds, g, srt, new_off)
   where
     -- process each binding
@@ -186,7 +190,7 @@ srtBind rho off (StgRec binds) =
 
     doBinds [] g srt off new_binds = (reverse new_binds, g, srt, off)
     doBinds ((binder,rhs):binds) g srt off new_binds =
-        srtRhs rho off rhs   =: \(rhs, g', srt', off) ->
+        srtRhs rho cont_refs off rhs   =: \(rhs, g', srt', off) ->
        doBinds binds (unionUniqSets g g') (srt'++srt) off
                ((binder,rhs):new_binds)
 \end{code}
@@ -195,14 +199,14 @@ srtBind rho off (StgRec binds) =
 Right Hand Sides
 
 \begin{code}
-srtRhs :: UniqFM CafInfo -> Int -> StgRhs
-       -> (StgRhs, UniqSet Id, [Id], Int)
+srtRhs         :: UniqFM CafInfo -> (UniqSet Id, UniqFM (UniqSet Id))
+       -> Int -> StgRhs -> (StgRhs, UniqSet Id, [Id], Int)
 
-srtRhs rho off (StgRhsClosure cc bi old_srt free_vars u args body) =
-    srtExpr rho off body       =: \(body, g, srt, off) ->
+srtRhs rho cont off (StgRhsClosure cc bi old_srt free_vars u args body) =
+    srtExpr rho cont off body  =: \(body, g, srt, off) ->
     (StgRhsClosure cc bi old_srt free_vars u args body, g, srt, off)
 
-srtRhs rho off e@(StgRhsCon cc con args) =
+srtRhs rho cont off e@(StgRhsCon cc con args) =
     (e, getGlobalRefs rho args, [], off)
 \end{code}
 
@@ -210,24 +214,27 @@ srtRhs rho off e@(StgRhsCon cc con args) =
 Expressions
 
 \begin{code}
-srtExpr :: UniqFM CafInfo -> Int -> StgExpr 
-       -> (StgExpr, UniqSet Id, [Id], Int)
+srtExpr :: UniqFM CafInfo -> (UniqSet Id, UniqFM (UniqSet Id))
+       -> Int -> StgExpr -> (StgExpr, UniqSet Id, [Id], Int)
 
-srtExpr rho off e@(StgApp f args) =
-   (e, getGlobalRefs rho (StgVarArg f:args), [], off)
+srtExpr rho (cont,lne) off e@(StgApp f args) = (e, global_refs, [], off)
+  where global_refs = 
+               cont `unionUniqSets`
+               getGlobalRefs rho (StgVarArg f:args) `unionUniqSets`
+               lookupPossibleLNE lne f
 
-srtExpr rho off e@(StgCon con args ty) =
-   (e, getGlobalRefs rho args, [], off)
+srtExpr rho (cont,lne) off e@(StgCon con args ty) =
+   (e, cont `unionUniqSets` getGlobalRefs rho args, [], off)
 
-srtExpr rho off (StgCase scrut live1 live2 uniq _{-srt-} alts) =
-   srtCaseAlts rho off alts    =: \(alts, alts_g, alts_srt, alts_off) ->
-   let
-       extra_refs = filter (`notElem` alts_srt)
-                       (filter (mayHaveCafRefs rho) (uniqSetToList alts_g))
-       this_srt = extra_refs ++ alts_srt
-       scrut_off = alts_off + length extra_refs
-   in
-   srtExpr rho scrut_off scrut         =: \(scrut, scrut_g, scrut_srt, case_off) ->
+srtExpr rho c@(cont,lne) off (StgCase scrut live1 live2 uniq _{-srt-} alts) =
+   srtCaseAlts rho c off alts =: \(alts, alts_g, alts_srt, alts_off) ->
+
+       -- construct the SRT for this case
+   let (this_srt, scrut_off) = construct_srt rho alts_g alts_srt alts_off in
+
+       -- global refs in the continuation is alts_g.
+   srtExpr rho (alts_g,lne) scrut_off scrut
+                               =: \(scrut, scrut_g, scrut_srt, case_off) ->
    let
        g = unionUniqSets alts_g scrut_g
        srt = scrut_srt ++ this_srt
@@ -237,23 +244,29 @@ srtExpr rho off (StgCase scrut live1 live2 uniq _{-srt-} alts) =
    in
    (StgCase scrut live1 live2 uniq srt_info alts, g, srt, case_off)
 
-srtExpr rho off (StgLet bind body) =
-   srtLet rho off bind body StgLet
-
-   -- let-no-escapes are delicate, see below
-srtExpr rho off (StgLetNoEscape live1 live2 bind body) =
-   srtLet rho off bind body (StgLetNoEscape live1 live2) 
-               =: \(expr, g, srt, off') ->
-   let
-       -- find the SRT for the *whole* expression
-       length = off' - off
-       all_srt | length == 0 = NoSRT
-               | otherwise   = SRT off length
-   in
-   (fixLNE_srt all_srt expr, g, srt, off')
-
-srtExpr rho off (StgSCC cc expr) =
-   srtExpr rho off expr                =: \(expr, g, srt, off) ->
+srtExpr rho cont off (StgLet bind body) =
+   srtLet rho cont off bind body StgLet (\_ cont -> cont)
+
+srtExpr rho cont off (StgLetNoEscape live1 live2 b@(StgNonRec bndr rhs) body)
+  = srtLet rho cont off b body (StgLetNoEscape live1 live2) calc_cont
+  where calc_cont g (cont,lne) = (cont,addToUFM lne bndr g)
+
+-- for recursive let-no-escapes, we do *two* passes, the first time
+-- just to extract the list of global refs, and the second time we actually
+-- construct the SRT now that we know what global refs should be in
+-- the various let-no-escape continuations.
+srtExpr rho conts@(cont,lne) off 
+       (StgLetNoEscape live1 live2 bind@(StgRec pairs) body)
+  = srtBind rho conts off bind =: \(_, g, _, _) ->
+    let 
+       lne' = addListToUFM lne [ (bndr,g) | (bndr,_) <- pairs ]
+       calc_cont _ conts = conts
+    in
+    srtLet rho (cont,lne') off bind body (StgLetNoEscape live1 live2) calc_cont
+
+
+srtExpr rho cont off (StgSCC cc expr) =
+   srtExpr rho cont off expr   =: \(expr, g, srt, off) ->
    (StgSCC cc expr, g, srt, off)
 \end{code}
 
@@ -263,13 +276,13 @@ Let-expressions
 This is quite complicated stuff...
 
 \begin{code}
-srtLet rho off bind body let_constr
+srtLet rho cont off bind body let_constr calc_cont
 
  -- If the bindings are all constructors, then we don't need to
  -- buid an SRT at all...
  | all_con_binds bind =
-   srtBind rho off bind                =: \(bind, bind_g, bind_srt, off) ->
-   srtExpr rho off body                =: \(body, body_g, body_srt, off) ->
+   srtBind rho cont off bind   =: \(bind, bind_g, bind_srt, off) ->
+   srtExpr rho cont off body   =: \(body, body_g, body_srt, off) ->
    let
        g   = unionUniqSets bind_g body_g
        srt = body_srt ++ bind_srt
@@ -280,23 +293,16 @@ srtLet rho off bind body let_constr
  | otherwise =
 
     -- first, find the sub-SRTs in the binding
-   srtBind rho off bind                =: \(bind, bind_g, bind_srt, bind_off) ->
+   srtBind rho cont off bind   =: \(bind, bind_g, bind_srt, bind_off) ->
 
-   -- Construct the SRT for this binding from its sub-SRTs and any new global
-   -- references which aren't already contained in one of the sub-SRTs (and
-   -- which are "live").  
-   let
-       extra_refs = filter (`notElem` bind_srt) 
-                       (filter (mayHaveCafRefs rho) (uniqSetToList bind_g))
-       this_srt = extra_refs ++ bind_srt
+    -- construct the SRT for this binding
+   let (this_srt, body_off) = construct_srt rho bind_g bind_srt bind_off in
 
-       -- Add the length of the new entries to the     
-        -- current offset to get the next free offset in the global SRT.
-       body_off = bind_off + length extra_refs
-   in
+    -- get the new continuation information (if a let-no-escape)
+   let new_cont = calc_cont bind_g cont in
 
-   -- now find the SRTs in the body
-   srtExpr rho body_off body   =: \(body, body_g, body_srt, let_off) ->
+    -- now find the SRTs in the body
+   srtExpr rho cont body_off body  =: \(body, body_g, body_srt, let_off) ->
 
    let
        -- union all the global references together
@@ -312,53 +318,73 @@ srtLet rho off bind body let_constr
 \end{code}
 
 -----------------------------------------------------------------------------
+Construct an SRT.
+
+Construct the SRT at this point from its sub-SRTs and any new global
+references which aren't already contained in one of the sub-SRTs (and
+which are "live").
+
+\begin{code}
+construct_srt rho global_refs sub_srt current_offset
+   = let
+       extra_refs = filter (`notElem` sub_srt) 
+                     (filter (mayHaveCafRefs rho) (uniqSetToList global_refs))
+       this_srt = extra_refs ++ sub_srt
+
+       -- Add the length of the new entries to the     
+        -- current offset to get the next free offset in the global SRT.
+       new_offset = current_offset + length extra_refs
+   in (this_srt, new_offset)
+\end{code}
+
+-----------------------------------------------------------------------------
 Case Alternatives
 
 \begin{code}
-srtCaseAlts :: UniqFM CafInfo -> Int -> StgCaseAlts ->
-       (StgCaseAlts, UniqSet Id, [Id], Int)
+srtCaseAlts :: UniqFM CafInfo -> (UniqSet Id, UniqFM (UniqSet Id))
+       -> Int -> StgCaseAlts -> (StgCaseAlts, UniqSet Id, [Id], Int)
 
-srtCaseAlts rho off (StgAlgAlts  t alts dflt) =
-   srtAlgAlts rho off alts [] emptyUniqSet []  
+srtCaseAlts rho cont off (StgAlgAlts  t alts dflt) =
+   srtAlgAlts rho cont off alts [] emptyUniqSet []  
                                  =: \(alts, alts_g, alts_srt, off) ->
-   srtDefault rho off dflt               =: \(dflt, dflt_g, dflt_srt, off) ->
+   srtDefault rho cont off dflt          =: \(dflt, dflt_g, dflt_srt, off) ->
    let
        g   = unionUniqSets alts_g dflt_g
        srt = dflt_srt ++ alts_srt
    in
    (StgAlgAlts t alts dflt, g, srt, off)
 
-srtCaseAlts rho off (StgPrimAlts t alts dflt) =
-   srtPrimAlts rho off alts [] emptyUniqSet []  
+srtCaseAlts rho cont off (StgPrimAlts t alts dflt) =
+   srtPrimAlts rho cont off alts [] emptyUniqSet []  
                                   =: \(alts, alts_g, alts_srt, off) ->
-   srtDefault rho off dflt                =: \(dflt, dflt_g, dflt_srt, off) ->
+   srtDefault rho cont off dflt           =: \(dflt, dflt_g, dflt_srt, off) ->
    let
        g   = unionUniqSets alts_g dflt_g
        srt = dflt_srt ++ alts_srt
    in
    (StgPrimAlts t alts dflt, g, srt, off)
 
-srtAlgAlts rho off [] new_alts g srt = (reverse new_alts, g, srt, off)
-srtAlgAlts rho off ((con,args,used,rhs):alts) new_alts g srt =
-   srtExpr rho off rhs                 =: \(rhs, rhs_g, rhs_srt, off) ->
+srtAlgAlts rho cont off [] new_alts g srt = (reverse new_alts, g, srt, off)
+srtAlgAlts rho cont off ((con,args,used,rhs):alts) new_alts g srt =
+   srtExpr rho cont off rhs    =: \(rhs, rhs_g, rhs_srt, off) ->
    let
        g'   = unionUniqSets rhs_g g
        srt' = rhs_srt ++ srt
    in
-   srtAlgAlts rho off alts ((con,args,used,rhs) : new_alts) g' srt'
+   srtAlgAlts rho cont off alts ((con,args,used,rhs) : new_alts) g' srt'
 
-srtPrimAlts rho off [] new_alts g srt = (reverse new_alts, g, srt, off)
-srtPrimAlts rho off ((lit,rhs):alts) new_alts g srt =
-   srtExpr rho off rhs                 =: \(rhs, rhs_g, rhs_srt, off) ->
+srtPrimAlts rho cont off [] new_alts g srt = (reverse new_alts, g, srt, off)
+srtPrimAlts rho cont off ((lit,rhs):alts) new_alts g srt =
+   srtExpr rho cont off rhs    =: \(rhs, rhs_g, rhs_srt, off) ->
    let
        g'   = unionUniqSets rhs_g g
        srt' = rhs_srt ++ srt
    in
-   srtPrimAlts rho off alts ((lit,rhs) : new_alts) g' srt'
+   srtPrimAlts rho cont off alts ((lit,rhs) : new_alts) g' srt'
 
-srtDefault rho off StgNoDefault = (StgNoDefault,emptyUniqSet,[],off)
-srtDefault rho off (StgBindDefault rhs) =
-   srtExpr rho off rhs                 =: \(rhs, g, srt, off) ->
+srtDefault rho cont off StgNoDefault = (StgNoDefault,emptyUniqSet,[],off)
+srtDefault rho cont off (StgBindDefault rhs) =
+   srtExpr rho cont off rhs    =: \(rhs, g, srt, off) ->
    (StgBindDefault rhs, g, srt, off)
 \end{code}
 
@@ -521,60 +547,8 @@ binding have their SRTs replaced with the SRT for the binding group
 (*not* the SRT of the whole let-no-escape expression).
 
 \begin{code}
-fixLNE_srt :: SRT -> StgExpr -> StgExpr
-fixLNE_srt all_srt (StgLetNoEscape live1 live2 (StgNonRec id rhs) body)
-  = StgLetNoEscape live1 live2 (StgNonRec id rhs) (fixLNE [id] all_srt body)
-  
-fixLNE_srt all_srt (StgLetNoEscape live1 live2 (StgRec pairs) body)
-  = StgLetNoEscape live1 live2
-        (StgRec (map fixLNE_rec pairs)) (fixLNE binders all_srt body)
-  where
-       binders = map fst pairs
-       fixLNE_rec (id,StgRhsClosure cc bi srt fvs uf args e) = 
-          (id, StgRhsClosure cc bi srt fvs uf args (fixLNE binders srt e))
-        fixLNE_rec (id,con) = (id,con)
-
-fixLNE :: [Id] -> SRT -> StgExpr -> StgExpr
-
-fixLNE ids srt expr@(StgCase scrut live rhs_live bndr old_srt alts)
-  | any (`elementOfUniqSet` rhs_live) ids
-    = StgCase scrut live rhs_live bndr srt (fixLNE_alts ids srt alts)
-  | otherwise = expr
-  -- can't be in the scrutinee, because it's a let-no-escape!
-
-fixLNE ids srt expr@(StgLetNoEscape live rhs_live bind body)
-  | any (`elementOfUniqSet` rhs_live) ids =
-       StgLetNoEscape live rhs_live (fixLNE_bind ids srt bind)
-                                    (fixLNE      ids srt body)
-  | any (`elementOfUniqSet` live) ids = 
-       StgLetNoEscape live rhs_live bind (fixLNE ids srt body)
-  | otherwise = expr
-
-fixLNE ids srt (StgLet bind body)  = StgLet bind (fixLNE ids srt body)
-fixLNE ids srt (StgSCC cc expr)    = StgSCC cc (fixLNE ids srt expr)
-fixLNE ids srt expr               = expr
-
-fixLNE_alts ids srt (StgAlgAlts t alts dflt)
-  = StgAlgAlts  t (map (fixLNE_algalt  ids srt) alts) (fixLNE_dflt ids srt dflt)
-
-fixLNE_alts ids srt (StgPrimAlts t alts dflt)
-  = StgPrimAlts t (map (fixLNE_primalt ids srt) alts) (fixLNE_dflt ids srt dflt)
-
-fixLNE_algalt  ids srt (con,args,used,rhs) = (con,args,used, fixLNE ids srt rhs)
-fixLNE_primalt ids srt (lit,rhs)           = (lit,           fixLNE ids srt rhs)
-
-fixLNE_dflt    ids srt (StgNoDefault)     = StgNoDefault
-fixLNE_dflt    ids srt (StgBindDefault rhs) = StgBindDefault (fixLNE ids srt rhs)
-
-fixLNE_bind ids srt (StgNonRec bndr rhs) 
-  = StgNonRec bndr (fixLNE_rhs ids srt rhs)
-fixLNE_bind ids srt (StgRec pairs) 
-  = StgRec [ (bndr, fixLNE_rhs ids srt rhs) | (bndr,rhs) <- pairs ]
-
-fixLNE_rhs ids srt rhs@(StgRhsClosure cc bi old_srt fvs uf args expr)
-  | any (`elem` fvs) ids 
-      = StgRhsClosure cc bi srt fvs uf args (fixLNE ids srt expr)
-  | otherwise     = rhs
-fixLNE_rhs ids srt rhs@(StgRhsCon cc con args) = rhs
-
+lookupPossibleLNE lne_env f = 
+  case lookupUFM lne_env f of
+       Nothing   -> emptyUniqSet
+       Just refs -> refs
 \end{code}