Add (a) CoreM monad, (b) new Annotations feature
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
index eebb11c..4b1055b 100644 (file)
@@ -8,18 +8,16 @@ module LiberateCase ( liberateCase ) where
 
 #include "HsVersions.h"
 
 
 #include "HsVersions.h"
 
-import DynFlags                ( DynFlags, DynFlag(..) )
-import StaticFlags     ( opt_LiberateCaseThreshold )
-import CoreLint                ( showPass, endPass )
+import DynFlags
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
-import Id              ( Id, setIdName, idName, setIdNotExported )
+import Id
 import VarEnv
 import VarEnv
-import Name            ( localiseName )
-import Outputable
 import Util             ( notNull )
 \end{code}
 
 import Util             ( notNull )
 \end{code}
 
+The liberate-case transformation
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
@@ -27,64 +25,66 @@ The criterion is:
 
 Example
 
 
 Example
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : f t
-\end{verbatim}
+   f = \ t -> case v of
+                V a b -> a : f t
 
 => the inner f is replaced.
 
 
 => the inner f is replaced.
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+   f = \ t -> case v of
+                V a b -> a : (letrec
                                f =  \ t -> case v of
                                               V a b -> a : f t
                                f =  \ t -> case v of
                                               V a b -> a : f t
-                            in f) t
-\end{verbatim}
+                              in f) t
 (note the NEED for shadowing)
 
 => Simplify
 
 (note the NEED for shadowing)
 
 => Simplify
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+  f = \ t -> case v of
+                V a b -> a : (letrec
                                f = \ t -> a : f t
                                f = \ t -> a : f t
-                            in f t)
-\begin{verbatim}
+                              in f t)
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
 
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
 
-Other examples we'd like to catch with this kind of transformation
+Note that this deals with *free variables*.  SpecConstr deals with
+*arguments* that are of known form.  E.g.
 
        last []     = error 
        last (x:[]) = x
        last (x:xs) = last xs
 
 
        last []     = error 
        last (x:[]) = x
        last (x:xs) = last xs
 
-We'd like to avoid the redundant pattern match, transforming to
+       
+Note [Scrutinee with cast]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider this:
+    f = \ t -> case (v `cast` co) of
+                V a b -> a : f t
 
 
-       last [] = error
-       last (x:[]) = x
-       last (x:(y:ys)) = last' y ys
-               where
-                 last' y []     = y
-                 last' _ (y:ys) = last' y ys
+Exactly the same optimisation (unrolling one call to f) will work here, 
+despite the cast.  See mk_alt_env in the Case branch of libCase.
 
 
-       (is this necessarily an improvement)
 
 
+Note [Only functions!]
+~~~~~~~~~~~~~~~~~~~~~~
+Consider the following code
 
 
-Similarly drop:
+       f = g (case v of V a b -> a : t f)
 
 
-       drop n [] = []
-       drop 0 xs = xs
-       drop n (x:xs) = drop (n-1) xs
+where g is expensive. If we aren't careful, liberate case will turn this into
 
 
-Would like to pass n along unboxed.
-       
+       f = g (case v of
+               V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
+                                in f)
+             )
+
+Yikes! We evaluate g twice. This leads to a O(2^n) explosion
+if g calls back to the same code recursively.
+
+Solution: make sure that we only do the liberate-case thing on *functions*
 
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
 
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
-
 Main worry: duplicating code excessively.  At the moment we duplicate
 the entire binding group once at each recursive call.  But there may
 be a group of recursive calls which share a common set of evaluated
 Main worry: duplicating code excessively.  At the moment we duplicate
 the entire binding group once at each recursive call.  But there may
 be a group of recursive calls which share a common set of evaluated
@@ -96,7 +96,6 @@ big.
 
 Data types
 ~~~~~~~~~~
 
 Data types
 ~~~~~~~~~~
-
 The ``level'' of a binder tells how many
 recursive defns lexically enclose the binding
 A recursive defn "encloses" its RHS, not its
 The ``level'' of a binder tells how many
 recursive defns lexically enclose the binding
 A recursive defn "encloses" its RHS, not its
@@ -110,67 +109,32 @@ scope.  For example:
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
-\begin{code}
-type LibCaseLevel = Int
 
 
-topLevel :: LibCaseLevel
-topLevel = 0
-\end{code}
-
-\begin{code}
-data LibCaseEnv
-  = LibCaseEnv
-       Int                     -- Bomb-out size for deciding if
-                               -- potential liberatees are too big.
-                               -- (passed in from cmd-line args)
-
-       LibCaseLevel            -- Current level
-
-       (IdEnv LibCaseLevel)    -- Binds all non-top-level in-scope Ids
-                               -- (top-level and imported things have
-                               -- a level of zero)
+%************************************************************************
+%*                                                                     *
+        Top-level code
+%*                                                                     *
+%************************************************************************
 
 
-       (IdEnv CoreBind)        -- Binds *only* recursively defined
-                               -- Ids, to their own binding group,
-                               -- and *only* in their own RHSs
-
-       [(Id,LibCaseLevel)]     -- Each of these Ids was scrutinised by an
-                               -- enclosing case expression, with the
-                               -- specified number of enclosing
-                               -- recursive bindings; furthermore,
-                               -- the Id is bound at a lower level
-                               -- than the case expression.  The
-                               -- order is insignificant; it's a bag
-                               -- really
-
-initEnv :: Int -> LibCaseEnv
-initEnv bomb_size = LibCaseEnv bomb_size 0 emptyVarEnv emptyVarEnv []
-
-bombOutSize (LibCaseEnv bomb_size _ _ _ _) = bomb_size
-\end{code}
-
-
-Programs
-~~~~~~~~
 \begin{code}
 \begin{code}
-liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
-liberateCase dflags binds
-  = do {
-       showPass dflags "Liberate case" ;
-       let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
-       endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
-                               {- no specific flag for dumping -} 
-    }
+liberateCase :: DynFlags -> [CoreBind] -> [CoreBind]
+liberateCase dflags binds = do_prog (initEnv dflags) binds
   where
   where
-    do_prog env [] = []
+    do_prog _   [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
                             where
                               (env', bind') = libCaseBind env bind
 \end{code}
 
     do_prog env (bind:binds) = bind' : do_prog env' binds
                             where
                               (env', bind') = libCaseBind env bind
 \end{code}
 
+
+%************************************************************************
+%*                                                                     *
+        Main payload
+%*                                                                     *
+%************************************************************************
+
 Bindings
 ~~~~~~~~
 Bindings
 ~~~~~~~~
-
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
@@ -180,22 +144,22 @@ libCaseBind env (NonRec binder rhs)
 libCaseBind env (Rec pairs)
   = (env_body, Rec pairs')
   where
 libCaseBind env (Rec pairs)
   = (env_body, Rec pairs')
   where
-    (binders, rhss) = unzip pairs
+    (binders, _rhss) = unzip pairs
 
     env_body = addBinders env binders
 
     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
 
 
     env_body = addBinders env binders
 
     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
 
-    env_rhs = if all rhs_small_enough rhss then extended_env else env
+    env_rhs = if all rhs_small_enough pairs then extended_env else env
 
        -- We extend the rec-env by binding each Id to its rhs, first
        -- processing the rhs with an *un-extended* environment, so
        -- that the same process doesn't occur for ever!
        --
 
        -- We extend the rec-env by binding each Id to its rhs, first
        -- processing the rhs with an *un-extended* environment, so
        -- that the same process doesn't occur for ever!
        --
-    extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
+    extended_env = addRecBinds env [ (localiseId binder, libCase env_body rhs)
                                   | (binder, rhs) <- pairs ]
 
                                   | (binder, rhs) <- pairs ]
 
-       -- Two subtle things: 
+       -- The call to localiseId is needed for two subtle reasons
        -- (a)  Reset the export flags on the binders so
        --      that we don't get name clashes on exported things if the 
        --      local binding floats out to top level.  This is most unlikely
        -- (a)  Reset the export flags on the binders so
        --      that we don't get name clashes on exported things if the 
        --      local binding floats out to top level.  This is most unlikely
@@ -205,10 +169,11 @@ libCaseBind env (Rec pairs)
        -- (b)  Make the name an Internal one.  External Names should never be
        --      nested; if it were floated to the top level, we'd get a name
        --      clash at code generation time.
        -- (b)  Make the name an Internal one.  External Names should never be
        --      nested; if it were floated to the top level, we'd get a name
        --      clash at code generation time.
-    adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
 
 
-    rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
-    lIBERATE_BOMB_SIZE   = bombOutSize env
+    rhs_small_enough (id,rhs)
+       =  idArity id > 0       -- Note [Only functions!]
+       && maybe True (\size -> couldBeSmallEnoughToInline size rhs)
+                      (bombOutSize env)
 \end{code}
 
 
 \end{code}
 
 
@@ -220,9 +185,9 @@ libCase :: LibCaseEnv
        -> CoreExpr
        -> CoreExpr
 
        -> CoreExpr
        -> CoreExpr
 
-libCase env (Var v)            = libCaseId env v
-libCase env (Lit lit)          = Lit lit
-libCase env (Type ty)          = Type ty
+libCase env (Var v)             = libCaseId env v
+libCase _   (Lit lit)           = Lit lit
+libCase _   (Type ty)           = Type ty
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Note note body)    = Note note (libCase env body)
 libCase env (Cast e co)         = Cast (libCase env e) co
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Note note body)    = Note note (libCase env body)
 libCase env (Cast e co)         = Cast (libCase env e) co
@@ -238,14 +203,17 @@ libCase env (Let bind body)
 libCase env (Case scrut bndr ty alts)
   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
   where
 libCase env (Case scrut bndr ty alts)
   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
   where
-    env_alts = addBinders env_with_scrut [bndr]
-    env_with_scrut = case scrut of
-                       Var scrut_var -> addScrutedVar env scrut_var
-                       other         -> env
+    env_alts = addBinders (mk_alt_env scrut) [bndr]
+    mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
+    mk_alt_env (Cast scrut _)  = mk_alt_env scrut      -- Note [Scrutinee with cast]
+    mk_alt_env _              = env
 
 
+libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
+                         -> (AltCon, [CoreBndr], CoreExpr)
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
+
 Ids
 ~~~
 \begin{code}
 Ids
 ~~~
 \begin{code}
@@ -261,22 +229,57 @@ libCaseId env v
   where
     rec_id_level = lookupLevel env v
     free_scruts  = freeScruts env rec_id_level
   where
     rec_id_level = lookupLevel env v
     free_scruts  = freeScruts env rec_id_level
+
+freeScruts :: LibCaseEnv
+          -> LibCaseLevel      -- Level of the recursive Id
+          -> [Id]              -- Ids that are scrutinised between the binding
+                               -- of the recursive Id and here
+freeScruts env rec_bind_lvl
+  = [v | (v,scrut_bind_lvl) <- lc_scruts env
+       , scrut_bind_lvl <= rec_bind_lvl]
+       -- Note [When to specialise]
 \end{code}
 
 \end{code}
 
+Note [When to specialise]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+  f = \x. letrec g = \y. case x of
+                          True  -> ... (f a) ...
+                          False -> ... (g b) ...
+
+We get the following levels
+         f  0
+         x  1
+         g  1
+         y  2  
 
 
+Then 'x' is being scrutinised at a deeper level than its binding, so
+it's added to lc_sruts:  [(x,1)]  
+
+We do *not* want to specialise the call to 'f', becuase 'x' is not free 
+in 'f'.  So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
+
+We *do* want to specialise the call to 'g', because 'x' is free in g.
+Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
+
+
+%************************************************************************
+%*                                                                     *
+       Utility functions
+%*                                                                     *
+%************************************************************************
 
 
-Utility functions
-~~~~~~~~~~~~~~~~~
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
-addBinders (LibCaseEnv bomb lvl lvl_env rec_env scruts) binders
-  = LibCaseEnv bomb lvl lvl_env' rec_env scruts
+addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
+  = env { lc_lvl_env = lvl_env' }
   where
     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
 
 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
   where
     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
 
 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
-addRecBinds (LibCaseEnv bomb lvl lvl_env rec_env scruts) pairs
-  = LibCaseEnv bomb lvl' lvl_env' rec_env' scruts
+addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
+                            lc_rec_env = rec_env}) pairs
+  = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
   where
     lvl'     = lvl + 1
     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
   where
     lvl'     = lvl + 1
     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
@@ -286,33 +289,82 @@ addScrutedVar :: LibCaseEnv
              -> Id             -- This Id is being scrutinised by a case expression
              -> LibCaseEnv
 
              -> Id             -- This Id is being scrutinised by a case expression
              -> LibCaseEnv
 
-addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
+addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
+                               lc_scruts = scruts }) scrut_var
   | bind_lvl < lvl
   | bind_lvl < lvl
-  = LibCaseEnv bomb lvl lvl_env rec_env scruts'
+  = env { lc_scruts = scruts' }
        -- Add to scruts iff the scrut_var is being scrutinised at
        -- a deeper level than its defn
 
   | otherwise = env
   where
        -- Add to scruts iff the scrut_var is being scrutinised at
        -- a deeper level than its defn
 
   | otherwise = env
   where
-    scruts'  = (scrut_var, lvl) : scruts
+    scruts'  = (scrut_var, bind_lvl) : scruts
     bind_lvl = case lookupVarEnv lvl_env scrut_var of
                 Just lvl -> lvl
                 Nothing  -> topLevel
 
 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
     bind_lvl = case lookupVarEnv lvl_env scrut_var of
                 Just lvl -> lvl
                 Nothing  -> topLevel
 
 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
-lookupRecId (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
-  = lookupVarEnv rec_env id
+lookupRecId env id = lookupVarEnv (lc_rec_env env) id
 
 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
 
 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
-lookupLevel (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
-  = case lookupVarEnv lvl_env id of
+lookupLevel env id
+  = case lookupVarEnv (lc_lvl_env env) id of
       Just lvl -> lvl
       Nothing  -> topLevel
       Just lvl -> lvl
       Nothing  -> topLevel
+\end{code}
 
 
-freeScruts :: LibCaseEnv
-          -> LibCaseLevel      -- Level of the recursive Id
-          -> [Id]              -- Ids that are scrutinised between the binding
-                               -- of the recursive Id and here
-freeScruts (LibCaseEnv bomb lvl lvl_env rec_env scruts) rec_bind_lvl
-  = [v | (v,scrut_lvl) <- scruts, scrut_lvl > rec_bind_lvl]
+%************************************************************************
+%*                                                                     *
+        The environment
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+type LibCaseLevel = Int
+
+topLevel :: LibCaseLevel
+topLevel = 0
+\end{code}
+
+\begin{code}
+data LibCaseEnv
+  = LibCaseEnv {
+       lc_size :: Maybe Int,   -- Bomb-out size for deciding if
+                               -- potential liberatees are too big.
+                               -- (passed in from cmd-line args)
+
+       lc_lvl :: LibCaseLevel, -- Current level
+               -- The level is incremented when (and only when) going
+               -- inside the RHS of a (sufficiently small) recursive
+               -- function.
+
+       lc_lvl_env :: IdEnv LibCaseLevel,  
+               -- Binds all non-top-level in-scope Ids (top-level and
+               -- imported things have a level of zero)
+
+       lc_rec_env :: IdEnv CoreBind, 
+               -- Binds *only* recursively defined ids, to their own
+               -- binding group, and *only* in their own RHSs
+
+       lc_scruts :: [(Id,LibCaseLevel)]
+               -- Each of these Ids was scrutinised by an enclosing
+               -- case expression, at a level deeper than its binding
+               -- level.  The LibCaseLevel recorded here is the *binding
+               -- level* of the scrutinised Id.
+               -- 
+               -- The order is insignificant; it's a bag really
+       }
+
+initEnv :: DynFlags -> LibCaseEnv
+initEnv dflags 
+  = LibCaseEnv { lc_size = liberateCaseThreshold dflags,
+                lc_lvl = 0,
+                lc_lvl_env = emptyVarEnv, 
+                lc_rec_env = emptyVarEnv,
+                lc_scruts = [] }
+
+bombOutSize :: LibCaseEnv -> Maybe Int
+bombOutSize = lc_size
 \end{code}
 \end{code}
+
+