Fix initialisation of strictness in the demand analyser
authorsimonpj@microsoft.com <unknown>
Tue, 26 Oct 2010 08:17:57 +0000 (08:17 +0000)
committersimonpj@microsoft.com <unknown>
Tue, 26 Oct 2010 08:17:57 +0000 (08:17 +0000)
Previously, the demand analyser assumed that every binder
starts off with no strictness info.  But now that we are
preserving strictness on nesting bindings in interface files,
that assumption is no longer correct, because an inlined function
might have a nested binding with strictness set.

So we need to know when we are in the initial sweep, so that we can
set the strictness to 'bottom'.

See Note [Initialising strictness]

compiler/stranal/DmdAnal.lhs

index 32986e5..7c9ddd5 100644 (file)
@@ -24,17 +24,16 @@ import DataCon              ( dataConTyCon, dataConRepStrictness )
 import TyCon           ( isProductTyCon, isRecursiveTyCon )
 import Id              ( Id, idType, idInlineActivation,
                          isDataConWorkId, isGlobalId, idArity,
-                         idStrictness, idStrictness_maybe,
+                         idStrictness, 
                          setIdStrictness, idDemandInfo, idUnfolding,
-                         idDemandInfo_maybe,
-                         setIdDemandInfo
+                         idDemandInfo_maybe, setIdDemandInfo
                        )
 import Var             ( Var )
 import VarEnv
 import TysWiredIn      ( unboxedPairDataCon )
 import TysPrim         ( realWorldStatePrimTy )
 import UniqFM          ( addToUFM_Directly, lookupUFM_Directly,
-                         minusUFM, ufmToList, filterUFM )
+                         minusUFM, filterUFM )
 import Type            ( isUnLiftedType, coreEqType, splitTyConApp_maybe )
 import Coercion         ( coercionKind )
 import Util            ( mapAndUnzip, lengthIs, zipEqual )
@@ -43,6 +42,7 @@ import BasicTypes     ( Arity, TopLevelFlag(..), isTopLevel, isNeverActive,
 import Maybes          ( orElse, expectJust )
 import Outputable
 import Data.List
+import FastString
 \end{code}
 
 To think about
@@ -380,6 +380,85 @@ if X is monomorphic, and has an UNPACK pragma, then this optimisation
 is even more important.  We don't want the wrapper to rebox an unboxed
 argument, and pass an Int to $wfoo!
 
+
+%************************************************************************
+%*                                                                     *
+                    Demand transformer
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+dmdTransform :: SigEnv         -- The strictness environment
+            -> Id              -- The function
+            -> Demand          -- The demand on the function
+            -> DmdType         -- The demand type of the function in this context
+       -- Returned DmdEnv includes the demand on 
+       -- this function plus demand on its free variables
+
+dmdTransform sigs var dmd
+
+------         DATA CONSTRUCTOR
+  | isDataConWorkId var                -- Data constructor
+  = let 
+       StrictSig dmd_ty    = idStrictness var  -- It must have a strictness sig
+       DmdType _ _ con_res = dmd_ty
+       arity               = idArity var
+    in
+    if arity == call_depth then                -- Saturated, so unleash the demand
+       let 
+               -- Important!  If we Keep the constructor application, then
+               -- we need the demands the constructor places (always lazy)
+               -- If not, we don't need to.  For example:
+               --      f p@(x,y) = (p,y)       -- S(AL)
+               --      g a b     = f (a,b)
+               -- It's vital that we don't calculate Absent for a!
+          dmd_ds = case res_dmd of
+                       Box (Eval ds) -> mapDmds box ds
+                       Eval ds       -> ds
+                       _             -> Poly Top
+
+               -- ds can be empty, when we are just seq'ing the thing
+               -- If so we must make up a suitable bunch of demands
+          arg_ds = case dmd_ds of
+                     Poly d  -> replicate arity d
+                     Prod ds -> ASSERT( ds `lengthIs` arity ) ds
+
+       in
+       mkDmdType emptyDmdEnv arg_ds con_res
+               -- Must remember whether it's a product, hence con_res, not TopRes
+    else
+       topDmdType
+
+------         IMPORTED FUNCTION
+  | isGlobalId var,            -- Imported function
+    let StrictSig dmd_ty = idStrictness var
+  = -- pprTrace "strict-sig" (ppr var $$ ppr dmd_ty) $
+    if dmdTypeDepth dmd_ty <= call_depth then  -- Saturated, so unleash the demand
+       dmd_ty
+    else
+       topDmdType
+
+------         LOCAL LET/REC BOUND THING
+  | Just (StrictSig dmd_ty, top_lvl) <- lookupSigEnv sigs var
+  = let
+       fn_ty | dmdTypeDepth dmd_ty <= call_depth = dmd_ty 
+             | otherwise                         = deferType dmd_ty
+       -- NB: it's important to use deferType, and not just return topDmdType
+       -- Consider     let { f x y = p + x } in f 1
+       -- The application isn't saturated, but we must nevertheless propagate 
+       --      a lazy demand for p!  
+    in
+    if isTopLevel top_lvl then fn_ty   -- Don't record top level things
+    else addVarDmd fn_ty var dmd
+
+------         LOCAL NON-LET/REC BOUND THING
+  | otherwise                  -- Default case
+  = unitVarDmd var dmd
+
+  where
+    (call_depth, res_dmd) = splitCallDmd dmd
+\end{code}
+
 %************************************************************************
 %*                                                                     *
 \subsection{Bindings}
@@ -397,7 +476,7 @@ dmdFix top_lvl sigs orig_pairs
   = loop 1 initial_sigs orig_pairs
   where
     bndrs        = map fst orig_pairs
-    initial_sigs = extendSigEnvList sigs [(id, (initialSig id, top_lvl)) | id <- bndrs]
+    initial_sigs = addInitialSigs top_lvl sigs bndrs
     
     loop :: Int
         -> SigEnv                      -- Already contains the current sigs
@@ -406,55 +485,43 @@ dmdFix top_lvl sigs orig_pairs
     loop n sigs pairs
       | found_fixpoint
       = (sigs', lazy_fv, pairs')
-               -- Note: use pairs', not pairs.   pairs' is the result of 
+               -- Note: return pairs', not pairs.   pairs' is the result of 
                -- processing the RHSs with sigs (= sigs'), whereas pairs 
                -- is the result of processing the RHSs with the *previous* 
                -- iteration of sigs.
 
-      | n >= 10  = pprTrace "dmdFix loop" (ppr n <+> (vcat 
-                               [ text "Sigs:" <+> ppr [(id,lookup sigs id, lookup sigs' id) | (id,_) <- pairs],
-                                 text "env:" <+> ppr (ufmToList sigs),
-                                 text "binds:" <+> pprCoreBinding (Rec pairs)]))
-                             (emptySigEnv, lazy_fv, orig_pairs)        -- Safe output
-                       -- The lazy_fv part is really important!  orig_pairs has no strictness
-                       -- info, including nothing about free vars.  But if we have
-                       --      letrec f = ....y..... in ...f...
-                       -- where 'y' is free in f, we must record that y is mentioned, 
-                       -- otherwise y will get recorded as absent altogether
-
-      | otherwise    = loop (n+1) sigs' pairs'
+      | n >= 10  
+      = pprTrace "dmdFix loop" (ppr n <+> (vcat 
+                       [ text "Sigs:" <+> ppr [ (id,lookupSigEnv sigs id, lookupSigEnv sigs' id) 
+                                               | (id,_) <- pairs],
+                         text "env:" <+> ppr sigs,
+                         text "binds:" <+> pprCoreBinding (Rec pairs)]))
+       (emptySigEnv, lazy_fv, orig_pairs)      -- Safe output
+               -- The lazy_fv part is really important!  orig_pairs has no strictness
+               -- info, including nothing about free vars.  But if we have
+               --      letrec f = ....y..... in ...f...
+               -- where 'y' is free in f, we must record that y is mentioned, 
+               -- otherwise y will get recorded as absent altogether
+
+      | otherwise
+      = loop (n+1) (setNonVirgin sigs') pairs'
       where
        found_fixpoint = all (same_sig sigs sigs') bndrs 
                -- Use the new signature to do the next pair
                -- The occurrence analyser has arranged them in a good order
                -- so this can significantly reduce the number of iterations needed
-       ((sigs',lazy_fv), pairs') = mapAccumL (my_downRhs top_lvl) (sigs, emptyDmdEnv) pairs
+       ((sigs',lazy_fv), pairs') = mapAccumL my_downRhs (sigs, emptyDmdEnv) pairs
        
-    my_downRhs top_lvl (sigs,lazy_fv) (id,rhs)
-       = -- pprTrace "downRhs {" (ppr id <+> (ppr old_sig))
-         -- (new_sig `seq` 
-         --    pprTrace "downRhsEnd" (ppr id <+> ppr new_sig <+> char '}' ) 
-         ((sigs', lazy_fv'), pair')
-         --     )
+    my_downRhs (sigs,lazy_fv) (id,rhs) = ((sigs', lazy_fv'), pair')
        where
          (sigs', lazy_fv1, pair') = dmdAnalRhs top_lvl Recursive sigs (id,rhs)
          lazy_fv'                 = plusVarEnv_C both lazy_fv lazy_fv1   
-         -- old_sig               = lookup sigs id
-         -- new_sig               = lookup sigs' id
           
     same_sig sigs sigs' var = lookup sigs var == lookup sigs' var
-    lookup sigs var = case lookupVarEnv sigs var of
+    lookup sigs var = case lookupSigEnv sigs var of
                        Just (sig,_) -> sig
                         Nothing      -> pprPanic "dmdFix" (ppr var)
 
-       -- Get an initial strictness signature from the Id
-       -- itself.  That way we make use of earlier iterations
-       -- of the fixpoint algorithm.  (Cunning plan.)
-       -- Note that the cunning plan extends to the DmdEnv too,
-       -- since it is part of the strictness signature
-initialSig :: Id -> StrictSig
-initialSig id = idStrictness_maybe id `orElse` botSig
-
 dmdAnalRhs :: TopLevelFlag -> RecFlag
        -> SigEnv -> (Id, CoreExpr)
        -> (SigEnv,  DmdEnv, (Id, CoreExpr))
@@ -475,6 +542,7 @@ dmdAnalRhs top_lvl rec_flag sigs (id, rhs)
   sigs'                     = extendSigEnv top_lvl sigs id sig_ty
 \end{code}
 
+
 %************************************************************************
 %*                                                                     *
 \subsection{Strictness signatures and types}
@@ -838,22 +906,44 @@ forget that fact, otherwise we might make 'x' absent when it isn't.
 %************************************************************************
 
 \begin{code}
-type SigEnv  = VarEnv (StrictSig, TopLevelFlag)
-       -- We use the SigEnv to tell us whether to
+data SigEnv  
+  = SE { se_env    :: VarEnv (StrictSig, TopLevelFlag)
+       , se_virgin :: Bool }  -- True on first iteration only
+                             -- See Note [Initialising strictness]
+       -- We use the se_env to tell us whether to
        -- record info about a variable in the DmdEnv
        -- We do so if it's a LocalId, but not top-level
        --
        -- The DmdEnv gives the demand on the free vars of the function
        -- when it is given enough args to satisfy the strictness signature
 
+instance Outputable SigEnv where
+  ppr (SE { se_env = env, se_virgin = virgin })
+    = ptext (sLit "SE") <+> braces (vcat 
+         [ ptext (sLit "se_virgin =") <+> ppr virgin
+         , ptext (sLit "se_env =") <+> ppr env ])
+
 emptySigEnv :: SigEnv
-emptySigEnv  = emptyVarEnv
+emptySigEnv  = SE { se_env = emptyVarEnv, se_virgin = True }
 
 extendSigEnv :: TopLevelFlag -> SigEnv -> Id -> StrictSig -> SigEnv
-extendSigEnv top_lvl env var sig = extendVarEnv env var (sig, top_lvl)
+extendSigEnv top_lvl sigs var sig 
+  = sigs { se_env = extendVarEnv (se_env sigs) var (sig, top_lvl) }
+
+lookupSigEnv :: SigEnv -> Id -> Maybe (StrictSig, TopLevelFlag)
+lookupSigEnv sigs id = lookupVarEnv (se_env sigs) id
+
+addInitialSigs :: TopLevelFlag -> SigEnv -> [Id] -> SigEnv
+-- See Note [Initialising strictness]
+addInitialSigs top_lvl sigs@(SE { se_env = env, se_virgin = virgin }) ids
+  = sigs { se_env = extendVarEnvList env [ (id, (init_sig id, top_lvl)) 
+                                         | id <- ids ] }
+  where
+    init_sig | virgin    = \_ -> botSig
+             | otherwise = idStrictness
 
-extendSigEnvList :: SigEnv -> [(Id, (StrictSig, TopLevelFlag))] -> SigEnv
-extendSigEnvList = extendVarEnvList
+setNonVirgin :: SigEnv -> SigEnv
+setNonVirgin sigs = sigs { se_virgin = False }
 
 extendSigsWithLam :: SigEnv -> Id -> SigEnv
 -- Extend the SigEnv when we meet a lambda binder
@@ -873,87 +963,36 @@ extendSigsWithLam :: SigEnv -> Id -> SigEnv
 
 extendSigsWithLam sigs id
   = case idDemandInfo_maybe id of
-       Nothing              -> extendVarEnv sigs id (cprSig, NotTopLevel)
+       Nothing              -> extendSigEnv NotTopLevel sigs id cprSig
                -- Optimistic in the Nothing case;
                -- See notes [CPR-AND-STRICTNESS]
-       Just (Eval (Prod _)) -> extendVarEnv sigs id (cprSig, NotTopLevel)
+       Just (Eval (Prod _)) -> extendSigEnv NotTopLevel sigs id cprSig
        _                    -> sigs
+\end{code}
 
+Note [Initialising strictness]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Our basic plan is to initialise the strictness of each Id in 
+a recursive group to "bottom", and find a fixpoint from there.
+However, this group A might be inside an *enclosing* recursive
+group B, in which case we'll do the entire fixpoint shebang on A
+for each iteration of B.
 
-dmdTransform :: SigEnv         -- The strictness environment
-            -> Id              -- The function
-            -> Demand          -- The demand on the function
-            -> DmdType         -- The demand type of the function in this context
-       -- Returned DmdEnv includes the demand on 
-       -- this function plus demand on its free variables
-
-dmdTransform sigs var dmd
-
-------         DATA CONSTRUCTOR
-  | isDataConWorkId var                -- Data constructor
-  = let 
-       StrictSig dmd_ty    = idStrictness var  -- It must have a strictness sig
-       DmdType _ _ con_res = dmd_ty
-       arity               = idArity var
-    in
-    if arity == call_depth then                -- Saturated, so unleash the demand
-       let 
-               -- Important!  If we Keep the constructor application, then
-               -- we need the demands the constructor places (always lazy)
-               -- If not, we don't need to.  For example:
-               --      f p@(x,y) = (p,y)       -- S(AL)
-               --      g a b     = f (a,b)
-               -- It's vital that we don't calculate Absent for a!
-          dmd_ds = case res_dmd of
-                       Box (Eval ds) -> mapDmds box ds
-                       Eval ds       -> ds
-                       _             -> Poly Top
-
-               -- ds can be empty, when we are just seq'ing the thing
-               -- If so we must make up a suitable bunch of demands
-          arg_ds = case dmd_ds of
-                     Poly d  -> replicate arity d
-                     Prod ds -> ASSERT( ds `lengthIs` arity ) ds
-
-       in
-       mkDmdType emptyDmdEnv arg_ds con_res
-               -- Must remember whether it's a product, hence con_res, not TopRes
-    else
-       topDmdType
-
-------         IMPORTED FUNCTION
-  | isGlobalId var,            -- Imported function
-    let StrictSig dmd_ty = idStrictness var
-  = if dmdTypeDepth dmd_ty <= call_depth then  -- Saturated, so unleash the demand
-       dmd_ty
-    else
-       topDmdType
-
-------         LOCAL LET/REC BOUND THING
-  | Just (StrictSig dmd_ty, top_lvl) <- lookupVarEnv sigs var
-  = let
-       fn_ty | dmdTypeDepth dmd_ty <= call_depth = dmd_ty 
-             | otherwise                         = deferType dmd_ty
-       -- NB: it's important to use deferType, and not just return topDmdType
-       -- Consider     let { f x y = p + x } in f 1
-       -- The application isn't saturated, but we must nevertheless propagate 
-       --      a lazy demand for p!  
-    in
-    if isTopLevel top_lvl then fn_ty   -- Don't record top level things
-    else addVarDmd fn_ty var dmd
-
-------         LOCAL NON-LET/REC BOUND THING
-  | otherwise                  -- Default case
-  = unitVarDmd var dmd
+To speed things up, we initialise each iteration of B from the result
+of the last one, which is neatly recorded in each binder.  That way we
+make use of earlier iterations of the fixpoint algorithm.  (Cunning
+plan.)  
 
-  where
-    (call_depth, res_dmd) = splitCallDmd dmd
-\end{code}
+But on the *first* iteration we want to *ignore* the current strictness
+of the Id, and start from "bottom".  Nowadays the Id can have a current
+strictness, because interface files record strictness for nested bindings.
+To know when we are in the first iteration, we look at the se_virgin
+field of the SigEnv.
 
 
 %************************************************************************
 %*                                                                     *
-\subsection{Demands}
+                   Demands
 %*                                                                     *
 %************************************************************************