Make mkDFunUnfolding more robust
[ghc-hetmet.git] / compiler / coreSyn / CoreUnfold.lhs
index 519fb74..06a2d72 100644 (file)
@@ -26,7 +26,7 @@ module CoreUnfold (
 
        interestingArg, ArgSummary(..),
 
-       couldBeSmallEnoughToInline, 
+       couldBeSmallEnoughToInline, inlineBoringOk,
        certainlyWillInline, smallEnoughToInline,
 
        callSiteInline, CallCtxt(..), 
@@ -41,8 +41,8 @@ import StaticFlags
 import DynFlags
 import CoreSyn
 import PprCore         ()      -- Instances
-import TcType          ( tcSplitSigmaTy, tcSplitDFunHead )
-import OccurAnal
+import TcType           ( tcSplitDFunTy )
+import OccurAnal        ( occurAnalyseExpr )
 import CoreSubst hiding( substTy )
 import CoreFVs         ( exprFreeVars )
 import CoreArity       ( manifestArity, exprBotStrictness_maybe )
@@ -54,8 +54,7 @@ import Literal
 import PrimOp
 import IdInfo
 import BasicTypes      ( Arity )
-import TcType          ( tcSplitDFunTy )
-import Type 
+import Type
 import Coercion
 import PrelNames
 import VarEnv           ( mkInScopeSet )
@@ -95,11 +94,8 @@ mkDFunUnfolding :: Type -> [DFunArg CoreExpr] -> Unfolding
 mkDFunUnfolding dfun_ty ops 
   = DFunUnfolding dfun_nargs data_con ops
   where
-    (tvs, theta, head_ty) = tcSplitSigmaTy dfun_ty
-         -- NB: tcSplitSigmaTy: do not look through a newtype
-         --     when the dictionary type is a newtype
-    (cls, _)   = tcSplitDFunHead head_ty
-    dfun_nargs = length tvs + length theta
+    (tvs, n_theta, cls, _) = tcSplitDFunTy dfun_ty
+    dfun_nargs = length tvs + n_theta
     data_con   = classDataCon cls
 
 mkWwInlineRule :: Id -> CoreExpr -> Arity -> Unfolding
@@ -126,12 +122,7 @@ mkInlineUnfolding mb_arity expr
                           Nothing -> (unSaturatedOk, manifestArity expr')
                           Just ar -> (needSaturated, ar)
               
-    boring_ok = case calcUnfoldingGuidance True    -- Treat as cheap
-                                          False   -- But not bottoming
-                                           (arity+1) expr' of
-                 (_, UnfWhen _ boring_ok) -> boring_ok
-                 _other                   -> boringCxtNotOk
-     -- See Note [INLINE for small functions]
+    boring_ok = inlineBoringOk expr'
 
 mkInlinableUnfolding :: CoreExpr -> Unfolding
 mkInlinableUnfolding expr
@@ -162,6 +153,10 @@ mkUnfolding :: UnfoldingSource -> Bool -> Bool -> CoreExpr -> Unfolding
 -- Calculates unfolding guidance
 -- Occurrence-analyses the expression before capturing it
 mkUnfolding src top_lvl is_bottoming expr
+  | top_lvl && is_bottoming
+  , not (exprIsTrivial expr)
+  = NoUnfolding    -- See Note [Do not inline top-level bottoming functions]
+  | otherwise
   = CoreUnfolding { uf_tmpl      = occurAnalyseExpr expr,
                    uf_src        = src,
                    uf_arity      = arity,
@@ -173,7 +168,7 @@ mkUnfolding src top_lvl is_bottoming expr
                    uf_guidance   = guidance }
   where
     is_cheap = exprIsCheap expr
-    (arity, guidance) = calcUnfoldingGuidance is_cheap (top_lvl && is_bottoming) 
+    (arity, guidance) = calcUnfoldingGuidance is_cheap
                                               opt_UF_CreationThreshold expr
        -- Sometimes during simplification, there's a large let-bound thing     
        -- which has been substituted, and so is now dead; so 'expr' contains
@@ -193,15 +188,35 @@ mkUnfolding src top_lvl is_bottoming expr
 %************************************************************************
 
 \begin{code}
+inlineBoringOk :: CoreExpr -> Bool
+-- See Note [INLINE for small functions]
+-- True => the result of inlining the expression is 
+--         no bigger than the expression itself
+--     eg      (\x y -> f y x)
+-- This is a quick and dirty version. It doesn't attempt
+-- to deal with  (\x y z -> x (y z))
+-- The really important one is (x `cast` c)
+inlineBoringOk e
+  = go 0 e
+  where
+    go :: Int -> CoreExpr -> Bool
+    go credit (Lam x e) | isId x           = go (credit+1) e
+                        | otherwise        = go credit e
+    go credit (App f (Type {}))            = go credit f
+    go credit (App f a) | credit > 0  
+                        , exprIsTrivial a  = go (credit-1) f
+    go credit (Note _ e)                  = go credit e     
+    go credit (Cast e _)                  = go credit e
+    go _      (Var {})                            = boringCxtOk
+    go _      _                                   = boringCxtNotOk
+
 calcUnfoldingGuidance
        :: Bool         -- True <=> the rhs is cheap, or we want to treat it
                        --          as cheap (INLINE things)     
-        -> Bool                -- True <=> this is a top-level unfolding for a
-                       --          diverging function; don't inline this
         -> Int         -- Bomb out if size gets bigger than this
        -> CoreExpr     -- Expression to look at
        -> (Arity, UnfoldingGuidance)
-calcUnfoldingGuidance expr_is_cheap top_bot bOMB_OUT_SIZE expr
+calcUnfoldingGuidance expr_is_cheap bOMB_OUT_SIZE expr
   = case collectBinders expr of { (bndrs, body) ->
     let
         val_bndrs   = filter isId bndrs
@@ -214,9 +229,6 @@ calcUnfoldingGuidance expr_is_cheap top_bot bOMB_OUT_SIZE expr
                | uncondInline n_val_bndrs (iBox size)
                 , expr_is_cheap
                -> UnfWhen unSaturatedOk boringCxtOk   -- Note [INLINE for small functions]
-               | top_bot  -- See Note [Do not inline top-level bottoming functions]
-               -> UnfNever
-
                | otherwise
                -> UnfIfGoodArgs { ug_args  = map (discount cased_bndrs) val_bndrs
                                 , ug_size  = iBox size
@@ -1269,7 +1281,7 @@ exprIsConApp_maybe id_unf expr
         , let sat = length args == dfun_nargs    -- See Note [DFun arity check]
           in if sat then True else 
              pprTrace "Unsaturated dfun" (ppr fun <+> int dfun_nargs $$ ppr args) False   
-        , let (dfun_tvs, _cls, dfun_res_tys) = tcSplitDFunTy (idType fun)
+        , let (dfun_tvs, _n_theta, _cls, dfun_res_tys) = tcSplitDFunTy (idType fun)
               subst    = zipOpenTvSubst dfun_tvs (stripTypeArgs (takeList dfun_tvs args))
               mk_arg (DFunConstArg e) = e
               mk_arg (DFunLamArg i)   = args !! i