[project @ 2001-03-19 16:24:37 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / SimplUtils.lhs
index e8a6433..55e7fca 100644 (file)
@@ -5,9 +5,9 @@
 
 \begin{code}
 module SimplUtils (
-       simplBinder, simplBinders, simplIds,
-       transformRhs,
-       mkCase, findAlt, findDefault,
+       simplBinder, simplBinders, simplRecIds, simplLetId,
+       tryRhsTyLam, tryEtaExpansion,
+       mkCase,
 
        -- The continuation type
        SimplCont(..), DupFlag(..), contIsDupable, contResultType,
@@ -23,26 +23,30 @@ import CmdLineOpts  ( switchIsOn, SimplifierSwitch(..),
                          opt_UF_UpdateInPlace
                        )
 import CoreSyn
-import CoreUtils       ( exprIsTrivial, cheapEqExpr, exprType, exprIsCheap, exprEtaExpandArity, bindNonRec )
-import Subst           ( InScopeSet, mkSubst, substBndrs, substBndr, substIds, substExpr )
+import CoreUtils       ( exprIsTrivial, cheapEqExpr, exprType, exprIsCheap, 
+                         etaExpand, exprEtaExpandArity, bindNonRec, mkCoerce,
+                         findDefault
+                       )
+import Subst           ( InScopeSet, mkSubst, substExpr )
+import qualified Subst ( simplBndrs, simplBndr, simplLetId )
 import Id              ( idType, idName, 
                          idUnfolding, idStrictness,
-                         mkVanillaId, idInfo
+                         mkLocalId, idInfo
                        )
-import IdInfo          ( StrictnessInfo(..), ArityInfo, atLeastArity )
+import IdInfo          ( StrictnessInfo(..) )
 import Maybes          ( maybeToBool, catMaybes )
 import Name            ( setNameUnique )
 import Demand          ( isStrict )
 import SimplMonad
 import Type            ( Type, mkForAllTys, seqType, repType,
-                         splitTyConApp_maybe, mkTyVarTys, splitFunTys, 
+                         splitTyConApp_maybe, tyConAppArgs, mkTyVarTys,
                          isDictTy, isDataType, isUnLiftedType,
                          splitRepFunTys
                        )
 import TyCon           ( tyConDataConsIfAvailable )
 import DataCon         ( dataConRepArity )
 import VarEnv          ( SubstEnv )
-import Util            ( lengthExceeds )
+import Util            ( lengthExceeds, mapAccumL )
 import Outputable
 \end{code}
 
@@ -365,7 +369,10 @@ interestingCallContext some_args some_val_args cont
   where
     interesting (InlinePlease _)       = True
     interesting (Select _ _ _ _ _)     = some_args
-    interesting (ApplyTo _ _ _ _)      = some_args     -- Can happen if we have (coerce t (f x)) y
+    interesting (ApplyTo _ _ _ _)      = True  -- Can happen if we have (coerce t (f x)) y
+                                               -- Perhaps True is a bit over-keen, but I've
+                                               -- seen (coerce f) x, where f has an INLINE prag,
+                                               -- So we have to give some motivaiton for inlining it
     interesting (ArgOf _ _ _)         = some_val_args
     interesting (Stop ty upd_in_place) = some_val_args && upd_in_place
     interesting (CoerceIt _ cont)      = interesting cont
@@ -425,7 +432,7 @@ simplBinders :: [InBinder] -> ([OutBinder] -> SimplM a) -> SimplM a
 simplBinders bndrs thing_inside
   = getSubst           `thenSmpl` \ subst ->
     let
-       (subst', bndrs') = substBndrs subst bndrs
+       (subst', bndrs') = Subst.simplBndrs subst bndrs
     in
     seqBndrs bndrs'    `seq`
     setSubst subst' (thing_inside bndrs')
@@ -434,23 +441,29 @@ simplBinder :: InBinder -> (OutBinder -> SimplM a) -> SimplM a
 simplBinder bndr thing_inside
   = getSubst           `thenSmpl` \ subst ->
     let
-       (subst', bndr') = substBndr subst bndr
+       (subst', bndr') = Subst.simplBndr subst bndr
     in
     seqBndr bndr'      `seq`
     setSubst subst' (thing_inside bndr')
 
 
--- Same semantics as simplBinders, but a little less 
--- plumbing and hence a little more efficient.
--- Maybe not worth the candle?
-simplIds :: [InBinder] -> ([OutBinder] -> SimplM a) -> SimplM a
-simplIds ids thing_inside
+simplRecIds :: [InBinder] -> ([OutBinder] -> SimplM a) -> SimplM a
+simplRecIds ids thing_inside
   = getSubst           `thenSmpl` \ subst ->
     let
-       (subst', bndrs') = substIds subst ids
+       (subst', ids') = mapAccumL Subst.simplLetId subst ids
     in
-    seqBndrs bndrs'    `seq`
-    setSubst subst' (thing_inside bndrs')
+    seqBndrs ids'      `seq`
+    setSubst subst' (thing_inside ids')
+
+simplLetId :: InBinder -> (OutBinder -> SimplM a) -> SimplM a
+simplLetId id thing_inside
+  = getSubst           `thenSmpl` \ subst ->
+    let
+       (subst', id') = Subst.simplLetId subst id
+    in
+    seqBndr id'        `seq`
+    setSubst subst' (thing_inside id')
 
 seqBndrs [] = ()
 seqBndrs (b:bs) = seqBndr b `seq` seqBndrs bs
@@ -464,26 +477,6 @@ seqBndr b | isTyVar b = b `seq` ()
 
 %************************************************************************
 %*                                                                     *
-\subsection{Transform a RHS}
-%*                                                                     *
-%************************************************************************
-
-Try (a) eta expansion
-    (b) type-lambda swizzling
-
-\begin{code}
-transformRhs :: OutExpr 
-            -> (ArityInfo -> OutExpr -> SimplM (OutStuff a))
-            -> SimplM (OutStuff a)
-
-transformRhs rhs thing_inside 
-  = tryRhsTyLam rhs                    $ \ rhs1 ->
-    tryEtaExpansion rhs1 thing_inside
-\end{code}
-
-
-%************************************************************************
-%*                                                                     *
 \subsection{Local tyvar-lifting}
 %*                                                                     *
 %************************************************************************
@@ -553,30 +546,34 @@ as we would normally do.
 
 
 \begin{code}
-tryRhsTyLam rhs thing_inside           -- Only does something if there's a let
-  | null tyvars || not (worth_it body) -- inside a type lambda, and a WHNF inside that
-  = thing_inside rhs
+tryRhsTyLam :: OutExpr -> SimplM ([OutBind], OutExpr)
+
+tryRhsTyLam rhs                        -- Only does something if there's a let
+  | null tyvars || not (worth_it body) -- inside a type lambda, 
+  = returnSmpl ([], rhs)               -- and a WHNF inside that
+
   | otherwise
-  = go (\x -> x) body          $ \ body' ->
-    thing_inside (mkLams tyvars body')
+  = go (\x -> x) body          `thenSmpl` \ (binds, body') ->
+    returnSmpl (binds,  mkLams tyvars body')
 
   where
     (tyvars, body) = collectTyBinders rhs
 
-    worth_it (Let _ e)      = whnf_in_middle e
-    worth_it other                  = False
+    worth_it e@(Let _ _) = whnf_in_middle e
+    worth_it e          = False
+
+    whnf_in_middle (Let (NonRec x rhs) e) | isUnLiftedType (idType x) = False
     whnf_in_middle (Let _ e) = whnf_in_middle e
     whnf_in_middle e        = exprIsCheap e
 
-
-    go fn (Let bind@(NonRec var rhs) body) thing_inside
+    go fn (Let bind@(NonRec var rhs) body)
       | exprIsTrivial rhs
-      = go (fn . Let bind) body thing_inside
+      = go (fn . Let bind) body
 
-    go fn (Let bind@(NonRec var rhs) body) thing_inside
-      = mk_poly tyvars_here var                                                `thenSmpl` \ (var', rhs') ->
-       addAuxiliaryBind (NonRec var' (mkLams tyvars_here (fn rhs)))    $
-       go (fn . Let (mk_silly_bind var rhs')) body thing_inside
+    go fn (Let (NonRec var rhs) body)
+      = mk_poly tyvars_here var                                `thenSmpl` \ (var', rhs') ->
+       go (fn . Let (mk_silly_bind var rhs')) body     `thenSmpl` \ (binds, body') ->
+       returnSmpl (NonRec var' (mkLams tyvars_here (fn rhs)) : binds, body')
 
       where
        tyvars_here = tyvars
@@ -599,13 +596,14 @@ tryRhsTyLam rhs thing_inside              -- Only does something if there's a let
                -- abstracting wrt *all* the tyvars.  We'll see if that
                -- gives rise to problems.   SLPJ June 98
 
-    go fn (Let (Rec prs) body) thing_inside
+    go fn (Let (Rec prs) body)
        = mapAndUnzipSmpl (mk_poly tyvars_here) vars    `thenSmpl` \ (vars', rhss') ->
         let
-           gn body = fn (foldr Let body (zipWith mk_silly_bind vars rhss'))
+           gn body  = fn (foldr Let body (zipWith mk_silly_bind vars rhss'))
+           new_bind = Rec (vars' `zip` [mkLams tyvars_here (gn rhs) | rhs <- rhss])
         in
-        addAuxiliaryBind (Rec (vars' `zip` [mkLams tyvars_here (gn rhs) | rhs <- rhss]))       $
-        go gn body thing_inside
+        go gn body                             `thenSmpl` \ (binds, body') -> 
+        returnSmpl (new_bind : binds, body')
        where
         (vars,rhss) = unzip prs
         tyvars_here = tyvars
@@ -613,15 +611,14 @@ tryRhsTyLam rhs thing_inside              -- Only does something if there's a let
                --       var_tys     = map idType vars
                -- See notes with tyvars_here above
 
-
-    go fn body thing_inside = thing_inside (fn body)
+    go fn body = returnSmpl ([], fn body)
 
     mk_poly tyvars_here var
       = getUniqueSmpl          `thenSmpl` \ uniq ->
        let
            poly_name = setNameUnique (idName var) uniq         -- Keep same name
            poly_ty   = mkForAllTys tyvars_here (idType var)    -- But new type of course
-           poly_id   = mkVanillaId poly_name poly_ty 
+           poly_id   = mkLocalId poly_name poly_ty 
 
                -- In the olden days, it was crucial to copy the occInfo of the original var, 
                -- because we were looking at occurrence-analysed but as yet unsimplified code!
@@ -694,81 +691,39 @@ that would leave use with some lets sandwiched between lambdas; that's
 what the final test in the first equation is for.
 
 \begin{code}
-tryEtaExpansion :: OutExpr 
-               -> (ArityInfo -> OutExpr -> SimplM (OutStuff a))
-               -> SimplM (OutStuff a)
-tryEtaExpansion rhs thing_inside
-  |  not opt_SimplDoLambdaEtaExpansion
-  || null y_tys                                -- No useful expansion
-  || not (is_case1 || is_case2)                -- Neither case matches
-  = thing_inside final_arity rhs       -- So, no eta expansion, but
-                                       -- return a good arity
-
-  | is_case1
-  = make_y_bndrs                       $ \ y_bndrs ->
-    thing_inside final_arity
-                (mkLams x_bndrs $ mkLams y_bndrs $
-                 mkApps body (map Var y_bndrs))
-
-  | otherwise  -- Must be case 2
-  = mapAndUnzipSmpl bind_z_arg arg_infos               `thenSmpl` \ (maybe_z_binds, z_args) ->
-    addAuxiliaryBinds (catMaybes maybe_z_binds)                $
-    make_y_bndrs                                       $  \ y_bndrs ->
-    thing_inside final_arity
-                (mkLams y_bndrs $
-                 mkApps (mkApps fun z_args) (map Var y_bndrs))
-  where
-    all_trivial_args = all is_trivial arg_infos
-    is_case1        = all_trivial_args
-    is_case2        = null x_bndrs && not (any unlifted_non_trivial arg_infos)
-
-    (x_bndrs, body)  = collectBinders rhs      -- NB: x_bndrs can include type variables
-    x_arity         = valBndrCount x_bndrs
+tryEtaExpansion :: OutExpr -> OutType -> SimplM ([OutBind], OutExpr)
+tryEtaExpansion rhs rhs_ty
+  |  not opt_SimplDoLambdaEtaExpansion                 -- Not if switched off
+  || exprIsTrivial rhs                         -- Not if RHS is trivial
+  || final_arity == 0                          -- Not if arity is zero
+  = returnSmpl ([], rhs)
+
+  | n_val_args == 0 && not arity_is_manifest
+  =    -- Some lambdas but not enough: case 1
+    getUniqSupplySmpl                          `thenSmpl` \ us ->
+    returnSmpl ([], etaExpand final_arity us rhs rhs_ty)
+
+  | n_val_args > 0 && not (any cant_bind arg_infos)
+  =    -- Partial application: case 2
+    mapAndUnzipSmpl bind_z_arg arg_infos       `thenSmpl` \ (maybe_z_binds, z_args) ->
+    getUniqSupplySmpl                          `thenSmpl` \ us ->
+    returnSmpl (catMaybes maybe_z_binds, 
+               etaExpand final_arity us (mkApps fun z_args) rhs_ty)
 
-    (fun, args)             = collectArgs body
-    arg_infos        = [(arg, exprType arg, exprIsTrivial arg) | arg <- args]
-
-    is_trivial          (_, _,  triv) = triv
-    unlifted_non_trivial (_, ty, triv) = not triv && isUnLiftedType ty
-
-    fun_arity       = exprEtaExpandArity fun
-
-    final_arity | all_trivial_args = atLeastArity (x_arity + extra_args_wanted)
-               | otherwise        = atLeastArity x_arity
-       -- Arity can be more than the number of lambdas
-       -- because of coerces. E.g.  \x -> coerce t (\y -> e) 
-       -- will have arity at least 2
-       -- The worker/wrapper pass will bring the coerce out to the top
+  | otherwise
+  = returnSmpl ([], rhs)
+  where
+    (fun, args)                           = collectArgs rhs
+    n_val_args                    = valArgCount args
+    (fun_arity, arity_is_manifest) = exprEtaExpandArity fun
+    final_arity                           = 0 `max` (fun_arity - n_val_args)
+    arg_infos                     = [(arg, exprType arg, exprIsTrivial arg) | arg <- args]
+    cant_bind (_, ty, triv)       = not triv && isUnLiftedType ty
 
     bind_z_arg (arg, arg_ty, trivial_arg) 
        | trivial_arg = returnSmpl (Nothing, arg)
         | otherwise   = newId SLIT("z") arg_ty $ \ z ->
                        returnSmpl (Just (NonRec z arg), Var z)
-
-    make_y_bndrs thing_inside 
-       = ASSERT( not (exprIsTrivial rhs) )
-         newIds SLIT("y") y_tys                        $ \ y_bndrs ->
-         tick (EtaExpansion (head y_bndrs))            `thenSmpl_`
-         thing_inside y_bndrs
-
-    (potential_extra_arg_tys, _) = splitFunTys (exprType body)
-       
-    y_tys :: [InType]
-    y_tys  = take extra_args_wanted potential_extra_arg_tys
-       
-    extra_args_wanted :: Int   -- Number of extra args we want
-    extra_args_wanted = 0 `max` (fun_arity - valArgCount args)
-
-       -- We used to expand the arity to the previous arity fo the
-       -- function; but this is pretty dangerous.  Consdier
-       --      f = \xy -> e
-       -- so that f has arity 2.  Now float something into f's RHS:
-       --      f = let z = BIG in \xy -> e
-       -- The last thing we want to do now is to put some lambdas
-       -- outside, to get
-       --      f = \xy -> let z = BIG in e
-       --
-       -- (bndr_arity - no_of_xs)              `max`
 \end{code}
 
 
@@ -847,15 +802,28 @@ and similar friends.
 mkCase scrut case_bndr alts
   | all identity_alt alts
   = tick (CaseIdentity case_bndr)              `thenSmpl_`
-    returnSmpl scrut
+    returnSmpl (re_note scrut)
   where
-    identity_alt (DEFAULT, [], Var v)     = v == case_bndr
-    identity_alt (DataAlt con, args, rhs) = cheapEqExpr rhs
-                                                       (mkConApp con (map Type arg_tys ++ map varToCoreExpr args))
-    identity_alt other                   = False
-
-    arg_tys = case splitTyConApp_maybe (idType case_bndr) of
-               Just (tycon, arg_tys) -> arg_tys
+    identity_alt (con, args, rhs) = de_note rhs `cheapEqExpr` identity_rhs con args
+
+    identity_rhs (DataAlt con) args = mkConApp con (arg_tys ++ map varToCoreExpr args)
+    identity_rhs (LitAlt lit)  _    = Lit lit
+    identity_rhs DEFAULT       _    = Var case_bndr
+
+    arg_tys = map Type (tyConAppArgs (idType case_bndr))
+
+       -- We've seen this:
+       --      case coerce T e of x { _ -> coerce T' x }
+       -- And we definitely want to eliminate this case!
+       -- So we throw away notes from the RHS, and reconstruct
+       -- (at least an approximation) at the other end
+    de_note (Note _ e) = de_note e
+    de_note e         = e
+
+       -- re_note wraps a coerce if it might be necessary
+    re_note scrut = case head alts of
+                       (_,_,rhs1@(Note _ _)) -> mkCoerce (exprType rhs1) (idType case_bndr) scrut
+                       other                 -> scrut
 \end{code}
 
 The catch-all case
@@ -866,22 +834,3 @@ mkCase other_scrut case_bndr other_alts
 \end{code}
 
 
-\begin{code}
-findDefault :: [CoreAlt] -> ([CoreAlt], Maybe CoreExpr)
-findDefault []                         = ([], Nothing)
-findDefault ((DEFAULT,args,rhs) : alts) = ASSERT( null alts && null args ) 
-                                         ([], Just rhs)
-findDefault (alt : alts)               = case findDefault alts of 
-                                           (alts', deflt) -> (alt : alts', deflt)
-
-findAlt :: AltCon -> [CoreAlt] -> CoreAlt
-findAlt con alts
-  = go alts
-  where
-    go []          = pprPanic "Missing alternative" (ppr con $$ vcat (map ppr alts))
-    go (alt : alts) | matches alt = alt
-                   | otherwise   = go alts
-
-    matches (DEFAULT, _, _) = True
-    matches (con1, _, _)    = con == con1
-\end{code}