Improve the treatment of 'seq' (Trac #2273)
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
index 0df9b37..ab79239 100644 (file)
@@ -53,30 +53,13 @@ Example
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
 
-Other examples we'd like to catch with this kind of transformation
+Note that this deals with *free variables*.  SpecConstr deals with
+*arguments* that are of known form.  E.g.
 
        last []     = error 
        last (x:[]) = x
        last (x:xs) = last xs
 
 
        last []     = error 
        last (x:[]) = x
        last (x:xs) = last xs
 
-We'd like to avoid the redundant pattern match, transforming to
-
-       last [] = error
-       last (x:[]) = x
-       last (x:(y:ys)) = last' y ys
-               where
-                 last' y []     = y
-                 last' _ (y:ys) = last' y ys
-
-       (is this necessarily an improvement)
-
-Similarly drop:
-
-       drop n [] = []
-       drop 0 xs = xs
-       drop n (x:xs) = drop (n-1) xs
-
-Would like to pass n along unboxed.
        
 Note [Scrutinee with cast]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
        
 Note [Scrutinee with cast]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -88,9 +71,26 @@ Exactly the same optimisation (unrolling one call to f) will work here,
 despite the cast.  See mk_alt_env in the Case branch of libCase.
 
 
 despite the cast.  See mk_alt_env in the Case branch of libCase.
 
 
+Note [Only functions!]
+~~~~~~~~~~~~~~~~~~~~~~
+Consider the following code
+
+       f = g (case v of V a b -> a : t f)
+
+where g is expensive. If we aren't careful, liberate case will turn this into
+
+       f = g (case v of
+               V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
+                                in f)
+             )
+
+Yikes! We evaluate g twice. This leads to a O(2^n) explosion
+if g calls back to the same code recursively.
+
+Solution: make sure that we only do the liberate-case thing on *functions*
+
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
-
 Main worry: duplicating code excessively.  At the moment we duplicate
 the entire binding group once at each recursive call.  But there may
 be a group of recursive calls which share a common set of evaluated
 Main worry: duplicating code excessively.  At the moment we duplicate
 the entire binding group once at each recursive call.  But there may
 be a group of recursive calls which share a common set of evaluated
@@ -135,7 +135,7 @@ liberateCase hsc_env _ _ guts
                        {- no specific flag for dumping -} 
        ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
   where
                        {- no specific flag for dumping -} 
        ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
   where
-    do_prog env [] = []
+    do_prog _   [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
                             where
                               (env', bind') = libCaseBind env bind
     do_prog env (bind:binds) = bind' : do_prog env' binds
                             where
                               (env', bind') = libCaseBind env bind
@@ -159,13 +159,13 @@ libCaseBind env (NonRec binder rhs)
 libCaseBind env (Rec pairs)
   = (env_body, Rec pairs')
   where
 libCaseBind env (Rec pairs)
   = (env_body, Rec pairs')
   where
-    (binders, rhss) = unzip pairs
+    (binders, _rhss) = unzip pairs
 
     env_body = addBinders env binders
 
     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
 
 
     env_body = addBinders env binders
 
     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
 
-    env_rhs = if all rhs_small_enough rhss then extended_env else env
+    env_rhs = if all rhs_small_enough pairs then extended_env else env
 
        -- We extend the rec-env by binding each Id to its rhs, first
        -- processing the rhs with an *un-extended* environment, so
 
        -- We extend the rec-env by binding each Id to its rhs, first
        -- processing the rhs with an *un-extended* environment, so
@@ -186,8 +186,10 @@ libCaseBind env (Rec pairs)
        --      clash at code generation time.
     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
 
        --      clash at code generation time.
     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
 
-    rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
-    lIBERATE_BOMB_SIZE   = bombOutSize env
+    rhs_small_enough (id,rhs)
+       =  idArity id > 0       -- Note [Only functions!]
+       && maybe True (\size -> couldBeSmallEnoughToInline size rhs)
+                      (bombOutSize env)
 \end{code}
 
 
 \end{code}
 
 
@@ -199,9 +201,9 @@ libCase :: LibCaseEnv
        -> CoreExpr
        -> CoreExpr
 
        -> CoreExpr
        -> CoreExpr
 
-libCase env (Var v)            = libCaseId env v
-libCase env (Lit lit)          = Lit lit
-libCase env (Type ty)          = Type ty
+libCase env (Var v)             = libCaseId env v
+libCase _   (Lit lit)           = Lit lit
+libCase _   (Type ty)           = Type ty
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Note note body)    = Note note (libCase env body)
 libCase env (Cast e co)         = Cast (libCase env e) co
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Note note body)    = Note note (libCase env body)
 libCase env (Cast e co)         = Cast (libCase env e) co
@@ -220,8 +222,10 @@ libCase env (Case scrut bndr ty alts)
     env_alts = addBinders (mk_alt_env scrut) [bndr]
     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
     mk_alt_env (Cast scrut _)  = mk_alt_env scrut      -- Note [Scrutinee with cast]
     env_alts = addBinders (mk_alt_env scrut) [bndr]
     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
     mk_alt_env (Cast scrut _)  = mk_alt_env scrut      -- Note [Scrutinee with cast]
-    mk_alt_env otehr          = env
+    mk_alt_env _              = env
 
 
+libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
+                         -> (AltCon, [CoreBndr], CoreExpr)
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
@@ -241,8 +245,39 @@ libCaseId env v
   where
     rec_id_level = lookupLevel env v
     free_scruts  = freeScruts env rec_id_level
   where
     rec_id_level = lookupLevel env v
     free_scruts  = freeScruts env rec_id_level
+
+freeScruts :: LibCaseEnv
+          -> LibCaseLevel      -- Level of the recursive Id
+          -> [Id]              -- Ids that are scrutinised between the binding
+                               -- of the recursive Id and here
+freeScruts env rec_bind_lvl
+  = [v | (v,scrut_bind_lvl) <- lc_scruts env
+       , scrut_bind_lvl <= rec_bind_lvl]
+       -- Note [When to specialise]
 \end{code}
 
 \end{code}
 
+Note [When to specialise]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+  f = \x. letrec g = \y. case x of
+                          True  -> ... (f a) ...
+                          False -> ... (g b) ...
+
+We get the following levels
+         f  0
+         x  1
+         g  1
+         y  2  
+
+Then 'x' is being scrutinised at a deeper level than its binding, so
+it's added to lc_sruts:  [(x,1)]  
+
+We do *not* want to specialise the call to 'f', becuase 'x' is not free 
+in 'f'.  So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
+
+We *do* want to specialise the call to 'g', because 'x' is free in g.
+Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
+
 
 %************************************************************************
 %*                                                                     *
 
 %************************************************************************
 %*                                                                     *
@@ -279,7 +314,7 @@ addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env,
 
   | otherwise = env
   where
 
   | otherwise = env
   where
-    scruts'  = (scrut_var, lvl) : scruts
+    scruts'  = (scrut_var, bind_lvl) : scruts
     bind_lvl = case lookupVarEnv lvl_env scrut_var of
                 Just lvl -> lvl
                 Nothing  -> topLevel
     bind_lvl = case lookupVarEnv lvl_env scrut_var of
                 Just lvl -> lvl
                 Nothing  -> topLevel
@@ -292,13 +327,6 @@ lookupLevel env id
   = case lookupVarEnv (lc_lvl_env env) id of
       Just lvl -> lvl
       Nothing  -> topLevel
   = case lookupVarEnv (lc_lvl_env env) id of
       Just lvl -> lvl
       Nothing  -> topLevel
-
-freeScruts :: LibCaseEnv
-          -> LibCaseLevel      -- Level of the recursive Id
-          -> [Id]              -- Ids that are scrutinised between the binding
-                               -- of the recursive Id and here
-freeScruts env rec_bind_lvl
-  = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
 \end{code}
 
 %************************************************************************
 \end{code}
 
 %************************************************************************
@@ -317,40 +345,41 @@ topLevel = 0
 \begin{code}
 data LibCaseEnv
   = LibCaseEnv {
 \begin{code}
 data LibCaseEnv
   = LibCaseEnv {
-       lc_size :: Int,         -- Bomb-out size for deciding if
+       lc_size :: Maybe Int,   -- Bomb-out size for deciding if
                                -- potential liberatees are too big.
                                -- (passed in from cmd-line args)
 
        lc_lvl :: LibCaseLevel, -- Current level
                                -- potential liberatees are too big.
                                -- (passed in from cmd-line args)
 
        lc_lvl :: LibCaseLevel, -- Current level
+               -- The level is incremented when (and only when) going
+               -- inside the RHS of a (sufficiently small) recursive
+               -- function.
 
        lc_lvl_env :: IdEnv LibCaseLevel,  
 
        lc_lvl_env :: IdEnv LibCaseLevel,  
-                       -- Binds all non-top-level in-scope Ids
-                       -- (top-level and imported things have
-                       -- a level of zero)
+               -- Binds all non-top-level in-scope Ids (top-level and
+               -- imported things have a level of zero)
 
        lc_rec_env :: IdEnv CoreBind, 
 
        lc_rec_env :: IdEnv CoreBind, 
-                       -- Binds *only* recursively defined ids, 
-                       -- to their own binding group,
-                       -- and *only* in their own RHSs
+               -- Binds *only* recursively defined ids, to their own
+               -- binding group, and *only* in their own RHSs
 
        lc_scruts :: [(Id,LibCaseLevel)]
 
        lc_scruts :: [(Id,LibCaseLevel)]
-                       -- Each of these Ids was scrutinised by an
-                       -- enclosing case expression, with the
-                       -- specified number of enclosing
-                       -- recursive bindings; furthermore,
-                       -- the Id is bound at a lower level
-                       -- than the case expression.  The order is
-                       -- insignificant; it's a bag really
+               -- Each of these Ids was scrutinised by an enclosing
+               -- case expression, at a level deeper than its binding
+               -- level.  The LibCaseLevel recorded here is the *binding
+               -- level* of the scrutinised Id.
+               -- 
+               -- The order is insignificant; it's a bag really
        }
 
 initEnv :: DynFlags -> LibCaseEnv
 initEnv dflags 
        }
 
 initEnv :: DynFlags -> LibCaseEnv
 initEnv dflags 
-  = LibCaseEnv { lc_size = specThreshold dflags,
+  = LibCaseEnv { lc_size = liberateCaseThreshold dflags,
                 lc_lvl = 0,
                 lc_lvl_env = emptyVarEnv, 
                 lc_rec_env = emptyVarEnv,
                 lc_scruts = [] }
 
                 lc_lvl = 0,
                 lc_lvl_env = emptyVarEnv, 
                 lc_rec_env = emptyVarEnv,
                 lc_scruts = [] }
 
+bombOutSize :: LibCaseEnv -> Maybe Int
 bombOutSize = lc_size
 \end{code}
 
 bombOutSize = lc_size
 \end{code}