Re-do the arity calculation mechanism again (fix Trac #3959)
authorsimonpj@microsoft.com <unknown>
Fri, 13 Aug 2010 16:11:51 +0000 (16:11 +0000)
committersimonpj@microsoft.com <unknown>
Fri, 13 Aug 2010 16:11:51 +0000 (16:11 +0000)
After rumination, yet again, on the subject of arity calculation,
I have redone what an ArityType is (it's purely internal to the
CoreArity module), and documented it better.  The result should
fix #3959, and I hope the related #3961, #3983.

There is lots of new documentation: in particular
 * Note [ArityType]
   describes the new datatype for arity info

 * Note [State hack and bottoming functions]
   says how bottoming functions are dealt with, particularly
   covering catch# and Trac #3959

I also found I had to be careful not to eta-expand single-method
class constructors; see Note [Newtype classes and eta expansion].
I think this part could be done better, but it works ok.

compiler/coreSyn/CoreArity.lhs
compiler/types/Type.lhs

index d5849cb..e63d121 100644 (file)
@@ -23,14 +23,13 @@ import Var
 import VarEnv
 import Id
 import Type
-import TyCon   ( isRecursiveTyCon )
+import TyCon   ( isRecursiveTyCon, isClassTyCon )
 import TcType  ( isDictLikeTy )
 import Coercion
 import BasicTypes
 import Unique
 import Outputable
 import DynFlags
-import StaticFlags     ( opt_NoStateHack )
 import FastString
 \end{code}
 
@@ -67,10 +66,19 @@ should have arity 3, regardless of f's arity.
 Note [exprArity invariant]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
 exprArity has the following invariant:
-       (exprArity e) = n, then manifestArity (etaExpand e n) = n
 
-That is, if exprArity says "the arity is n" then etaExpand really can get
-"n" manifest lambdas to the top.
+  * If typeArity (exprType e) = n,
+    then manifestArity (etaExpand e n) = n
+    That is, etaExpand can always expand as much as typeArity says
+    So the case analysis in etaExpand and in typeArity must match
+  * exprArity e <= typeArity (exprType e)      
+
+  * Hence if (exprArity e) = n, then manifestArity (etaExpand e n) = n
+
+    That is, if exprArity says "the arity is n" then etaExpand really 
+    can get "n" manifest lambdas to the top.
 
 Why is this important?  Because 
   - In TidyPgm we use exprArity to fix the *final arity* of 
@@ -101,14 +109,82 @@ exprArity e = go e
     go (Lam x e) | isId x         = go e + 1
                 | otherwise       = go e
     go (Note _ e)                  = go e
-    go (Cast e co)                 = go e `min` typeArity (snd (coercionKind co))
+    go (Cast e co)                 = go e `min` length (typeArity (snd (coercionKind co)))
                                                -- Note [exprArity invariant]
     go (App e (Type _))            = go e
     go (App f a) | exprIsTrivial a = (go f - 1) `max` 0
         -- See Note [exprArity for applications]
     go _                          = 0
+
+
+typeArity :: Type -> [OneShot]
+-- How many value arrows are visible in the type?
+-- We look through foralls, and newtypes
+-- See Note [exprArity invariant]
+typeArity ty 
+  | Just (_, ty')  <- splitForAllTy_maybe ty 
+  = typeArity ty'
+
+  | Just (arg,res) <- splitFunTy_maybe ty    
+  = isStateHackType arg : typeArity res
+
+  | Just (tc,tys) <- splitTyConApp_maybe ty 
+  , Just (ty', _) <- instNewTyCon_maybe tc tys
+  , not (isRecursiveTyCon tc)
+  , not (isClassTyCon tc)      -- Do not eta-expand through newtype classes
+                               -- See Note [Newtype classes and eta expansion]
+  = typeArity ty'
+       -- Important to look through non-recursive newtypes, so that, eg 
+       --      (f x)   where f has arity 2, f :: Int -> IO ()
+       -- Here we want to get arity 1 for the result!
+
+  | otherwise
+  = []
 \end{code}
 
+Note [Newtype classes and eta expansion]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We have to be careful when eta-expanding through newtypes.  In general
+it's a good idea, but annoyingly it interacts badly with the class-op 
+rule mechanism.  Consider
+   class C a where { op :: a -> a }
+   instance C b => C [b] where
+     op x = ...
+
+These translate to
+
+   co :: forall a. (a->a) ~ C a
+
+   $copList :: C b -> [b] -> [b]
+   $copList d x = ...
+
+   $dfList :: C b -> C [b]
+   {-# DFunUnfolding = [$copList] #-}
+   $dfList d = $copList d |> co@[b]
+
+Now suppose we have:
+
+   dCInt :: C Int    
+
+   blah :: [Int] -> [Int]
+   blah = op ($dfList dCInt)
+
+Now we want the built-in op/$dfList rule will fire to give
+   blah = $copList dCInt
+
+But with eta-expansion 'blah' might (and in Trac #3772, which is
+slightly more complicated, does) turn into
+
+   blah = op (\eta. ($dfList dCInt |> sym co) eta)
+
+and now it is *much* harder for the op/$dfList rule to fire, becuase
+exprIsConApp_maybe won't hold of the argument to op.  I considered
+trying to *make* it hold, but it's tricky and I gave up.
+
+The test simplCore/should_compile/T3722 is an excellent example.
+
+
 Note [exprArity for applications]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 When we come to an application we check that the arg is trivial.
@@ -133,24 +209,14 @@ When we come to an application we check that the arg is trivial.
 %************************************************************************
 
 \begin{code}
--- ^ The Arity returned is the number of value args the 
--- expression can be applied to without doing much work
-exprEtaExpandArity :: DynFlags -> CoreExpr -> Arity
--- exprEtaExpandArity is used when eta expanding
---     e  ==>  \xy -> e x y
-exprEtaExpandArity dflags e
-    = applyStateHack e (arityType dicts_cheap e)
-  where
-    dicts_cheap = dopt Opt_DictsCheap dflags
-
 exprBotStrictness_maybe :: CoreExpr -> Maybe (Arity, StrictSig)
 -- A cheap and cheerful function that identifies bottoming functions
 -- and gives them a suitable strictness signatures.  It's used during
 -- float-out
 exprBotStrictness_maybe e
-  = case arityType False e of
-       AT _ ATop -> Nothing
-       AT a ABot -> Just (a, mkStrictSig (mkTopDmdType (replicate a topDmd) BotRes))
+  = case getBotArity (arityType False e) of
+       Nothing -> Nothing
+       Just ar -> Just (ar, mkStrictSig (mkTopDmdType (replicate ar topDmd) BotRes))
 \end{code}     
 
 Note [Definition of arity]
@@ -234,40 +300,6 @@ we want to get:             coerce T (\x::[T] -> (coerce ([T]->Int) e) x)
   And since negate has arity 2, you might try to eta expand.  But you can't
   decopose Int to a function type.   Hence the final case in eta_expand.
   
-\begin{code}
-applyStateHack :: CoreExpr -> ArityType -> Arity
-applyStateHack e (AT orig_arity is_bot)
-  | opt_NoStateHack = orig_arity
-  | ABot <- is_bot  = orig_arity   -- Note [State hack and bottoming functions]
-  | otherwise       = go orig_ty orig_arity
-  where                        -- Note [The state-transformer hack]
-    orig_ty = exprType e
-    go :: Type -> Arity -> Arity
-    go ty arity                -- This case analysis should match that in eta_expand
-       | Just (_, ty') <- splitForAllTy_maybe ty   = go ty' arity
-       | Just (arg,res) <- splitFunTy_maybe ty
-       , arity > 0 || isStateHackType arg = 1 + go res (arity-1)
-
--- See Note [trimCast]
-       | Just (tc,tys) <- splitTyConApp_maybe ty 
-       , Just (ty', _) <- instNewTyCon_maybe tc tys
-       , not (isRecursiveTyCon tc)                 = go ty' arity
-               -- Important to look through non-recursive newtypes, so that, eg 
-               --      (f x)   where f has arity 2, f :: Int -> IO ()
-               -- Here we want to get arity 1 for the result!
--------
-
-{-
-       = if arity > 0 then 1 + go res (arity-1)
-         else if isStateHackType arg then
-               pprTrace "applystatehack" (vcat [ppr orig_arity, ppr orig_ty,
-                                               ppr ty, ppr res, ppr e]) $
-               1 + go res (arity-1)
-          else WARN( arity > 0, ppr arity ) 0
--}                                              
-       | otherwise = WARN( arity > 0, ppr arity <+> ppr ty) 0
-\end{code}
-
 Note [The state-transformer hack]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Suppose we have 
@@ -320,90 +352,169 @@ Extrude g1.g3
 And now we can repeat the whole loop.  Aargh!  The bug is in applying the
 state hack to a function which then swallows the argument.
 
+This arose in another guise in Trac #3959.  Here we had
+
+     catch# (throw exn >> return ())
+
+Note that (throw :: forall a e. Exn e => e -> a) is called with [a = IO ()].
+After inlining (>>) we get 
+
+     catch# (\_. throw {IO ()} exn)
+
+We must *not* eta-expand to 
+
+     catch# (\_ _. throw {...} exn)
+
+because 'catch#' expects to get a (# _,_ #) after applying its argument to
+a State#, not another function!  
+
+In short, we use the state hack to allow us to push let inside a lambda,
+but not to introduce a new lambda.
+
+
+Note [ArityType]
+~~~~~~~~~~~~~~~~
+ArityType is the result of a compositional analysis on expressions,
+from which we can decide the real arity of the expression (extracted
+with function getArity).
+
+Here is what the fields mean. If e has ArityType 
+     (AT as r), where n = length as, 
+then
+
+ * If r is ABot then (e x1..xn) definitely diverges
+   Partial applications may or may not diverge
+
+ * If r is ACheap then (e x1..x(n-1)) is cheap,
+   including any nested sub-expressions inside e
+   (say e is (f e1 e2) then e1,e2 are cheap too)
+
+ * e, (e x1), ... (e x1 ... x(n-1)) are definitely really 
+   functions, or bottom, not casts from a data type
+   So eta expansion is dynamically ok; 
+    see Note [State hack and bottoming functions], 
+    the part about catch#
+
+We regard ABot as stronger than ACheap; ie if ABot holds
+we don't bother about ACheap
+
+Suppose f = \xy. x+y
+Then  f             :: AT [False,False] ACheap
+      f v           :: AT [False]      ACheap
+      f <expensive> :: AT [False]      ATop
+Note the ArityRes flag tells whether the whole expression is cheap.
+Note also that having a non-empty 'as' doesn't mean it has that
+arity; see (f <expensive>) which does not have arity 1!
+
+The key function getArity extracts the arity (which in turn guides
+eta-expansion) from ArityType. 
+  * If the term is cheap or diverges we can certainly eta expand it
+      e.g.   (f x)   where x has arity 2
+  
+  * If its a function whose first arg is one-shot (probably via the
+    state hack) we can eta expand it
+      e.g.   (getChar <expensive>)  
 
 -------------------- Main arity code ----------------------------
 \begin{code}
--- If e has ArityType (AT n r), then the term 'e'
---  * Must be applied to at least n *value* args 
---     before doing any significant work
---  * It will not diverge before being applied to n
---     value arguments
---  * If 'r' is ABot, then it guarantees to diverge if 
---     applied to n arguments (or more)
-
-data ArityType = AT Arity ArityRes
-data ArityRes  = ATop                  -- Know nothing
-              | ABot                   -- Diverges
+-- See Note [ArityType]
+data ArityType = AT [OneShot] ArityRes
+     -- There is always an explicit lambda
+     -- to justify the [OneShot]
+
+type OneShot = Bool    -- False <=> Know nothing
+                       -- True  <=> Can definitely float inside this lambda
+                      -- The 'True' case can arise either because a binder
+                      -- is marked one-shot, or because it's a state lambda
+                      -- and we have the state hack on
+
+data ArityRes  = ATop | ACheap | ABot
 
 vanillaArityType :: ArityType
-vanillaArityType = AT 0 ATop   -- Totally uninformative
+vanillaArityType = AT [] ATop  -- Totally uninformative
 
-incArity :: ArityType -> ArityType
-incArity (AT a r) = AT (a+1) r
+-- ^ The Arity returned is the number of value args the [_$_]
+-- expression can be applied to without doing much work
+exprEtaExpandArity :: DynFlags -> CoreExpr -> Arity
+-- exprEtaExpandArity is used when eta expanding
+--     e  ==>  \xy -> e x y
+exprEtaExpandArity dflags e
+  = case (arityType dicts_cheap e) of
+      AT (a:as) res | want_eta a res -> 1 + length as
+      _                              -> 0
+  where
+    want_eta one_shot ATop   = one_shot
+    want_eta _        _      = True
 
-decArity :: ArityType -> ArityType
-decArity (AT 0 r) = AT 0     r
-decArity (AT a r) = AT (a-1) r
+    dicts_cheap = dopt Opt_DictsCheap dflags
 
-andArityType :: ArityType -> ArityType -> ArityType   -- Used for branches of a 'case'
-andArityType (AT a1 ATop) (AT a2 ATop) = AT (a1 `min` a2) ATop
-andArityType (AT _  ABot) (AT a2 ATop) = AT a2           ATop
-andArityType (AT a1 ATop) (AT _  ABot) = AT a1           ATop
-andArityType (AT a1 ABot) (AT a2 ABot) = AT (a1 `max` a2) ABot
+getBotArity :: ArityType -> Maybe Arity
+-- Arity of a divergent function
+getBotArity (AT as ABot) = Just (length as)
+getBotArity _            = Nothing
 
----------------------------
-trimCast :: Coercion -> ArityType -> ArityType
--- Trim the arity to be no more than allowed by the
--- arrows in ty2, where co :: ty1~ty2
-trimCast _ at = at
-
-{-        Omitting for now Note [trimCast]
-trimCast co at@(AT ar _)
-  | ar > co_arity = AT co_arity ATop
-  | otherwise     = at
+arityLam :: Id -> ArityType -> ArityType
+arityLam id (AT as r) = AT (isOneShotBndr id : as) r
+
+floatIn :: Bool -> ArityType -> ArityType
+-- We have something like (let x = E in b), 
+-- where b has the given arity type.  
+floatIn c (AT as r) = AT as (extendArityRes r c)
+
+arityApp :: ArityType -> CoreExpr -> ArityType
+-- Processing (fun arg) where at is the ArityType of fun,
+arityApp (AT [] r)     arg = AT [] (extendArityRes r (exprIsCheap arg))
+arityApp (AT (_:as) r) arg = AT as (extendArityRes r (exprIsCheap arg))
+
+extendArityRes :: ArityRes -> Bool -> ArityRes
+extendArityRes ABot   _    = ABot
+extendArityRes ACheap True = ACheap
+extendArityRes _      _    = ATop
+
+andArityType :: ArityType -> ArityType -> ArityType   -- Used for branches of a 'case'
+andArityType (AT as1 r1) (AT as2 r2) 
+  = AT (go_as as1 as2) (go_r r1 r2)
   where
-    (_, ty2) = coercionKind co
-    co_arity = typeArity ty2
--}
+    go_r ABot ABot     = ABot
+    go_r ABot ACheap   = ACheap
+    go_r ACheap ABot   = ACheap
+    go_r ACheap ACheap = ACheap
+    go_r _    _        = ATop
+
+    go_as (os1:as1) (os2:as2) = (os1 || os2) : go_as as1 as2
+    go_as []        as2       = as2 
+    go_as as1       []        = as1
 \end{code}
 
-Note [trimCast]
-~~~~~~~~~~~~~~~
-When you try putting trimCast back in, comment out the snippets
-flagged by the other references to Note [trimCast]
 
 \begin{code}
 ---------------------------
-trimArity :: Bool -> ArityType -> ArityType
--- We have something like (let x = E in b), where b has the given
--- arity type.  Then
---     * If E is cheap we can push it inside as far as we like
---     * If b eventually diverges, we allow ourselves to push inside
---       arbitrarily, even though that is not quite right
-trimArity _cheap (AT a ABot) = AT a ABot
-trimArity True   (AT a ATop) = AT a ATop
-trimArity False  (AT _ ATop) = AT 0 ATop       -- Bale out
-
----------------------------
 arityType :: Bool -> CoreExpr -> ArityType
 arityType _ (Var v)
   | Just strict_sig <- idStrictness_maybe v
   , (ds, res) <- splitStrictSig strict_sig
-  , isBotRes res
-  = AT (length ds) ABot        -- Function diverges
+  = mk_arity (length ds) res
   | otherwise
-  = AT (idArity v) ATop
+  = mk_arity (idArity v) TopRes
+
+  where
+    mk_arity id_arity res 
+      | isBotRes res = AT (take id_arity one_shots) ABot
+      | id_arity>0   = AT (take id_arity one_shots) ACheap
+      | otherwise    = AT []                        ATop
+
+    one_shots = typeArity (idType v)
 
        -- Lambdas; increase arity
 arityType dicts_cheap (Lam x e)
-  | isId x    = incArity (arityType dicts_cheap e)
+  | isId x    = arityLam x (arityType dicts_cheap e)
   | otherwise = arityType dicts_cheap e
 
        -- Applications; decrease arity
 arityType dicts_cheap (App fun (Type _))
    = arityType dicts_cheap fun
 arityType dicts_cheap (App fun arg )
-   = trimArity (exprIsCheap arg) (decArity (arityType dicts_cheap fun))
+   = arityApp (arityType dicts_cheap fun) arg 
 
        -- Case/Let; keep arity if either the expression is cheap
        -- or it's a 1-shot lambda
@@ -413,11 +524,11 @@ arityType dicts_cheap (App fun arg )
        --      f x y = case x of { (a,b) -> e }
        -- The difference is observable using 'seq'
 arityType dicts_cheap (Case scrut _ _ alts)
-  = trimArity (exprIsCheap scrut)
+  = floatIn (exprIsCheap scrut)
              (foldr1 andArityType [arityType dicts_cheap rhs | (_,_,rhs) <- alts])
 
 arityType dicts_cheap (Let b e) 
-  = trimArity (cheap_bind b) (arityType dicts_cheap e)
+  = floatIn (cheap_bind b) (arityType dicts_cheap e)
   where
     cheap_bind (NonRec b e) = is_cheap (b,e)
     cheap_bind (Rec prs)    = all is_cheap prs
@@ -443,9 +554,9 @@ arityType dicts_cheap (Let b e)
        -- See Note [Dictionary-like types] in TcType.lhs for why we use
        -- isDictLikeTy here rather than isDictTy
 
-arityType dicts_cheap (Note _ e)  = arityType dicts_cheap e
-arityType dicts_cheap (Cast e co) = trimCast co (arityType dicts_cheap e)
-arityType _           _           = vanillaArityType
+arityType dicts_cheap (Note _ e) = arityType dicts_cheap e
+arityType dicts_cheap (Cast e _) = arityType dicts_cheap e
+arityType _           _          = vanillaArityType
 \end{code}
   
   
@@ -589,7 +700,7 @@ mkEtaWW orig_n in_scope orig_ty
   where
     empty_subst = mkTvSubst in_scope emptyTvSubstEnv
 
-    go n subst ty eis
+    go n subst ty eis      -- See Note [exprArity invariant]
        | n == 0
        = (getTvInScope subst, reverse eis)
 
@@ -603,7 +714,6 @@ mkEtaWW orig_n in_scope orig_ty
            -- Avoid free vars of the original expression
        = go (n-1) subst' res_ty (EtaVar eta_id' : eis)
                                           
--- See Note [trimCast]
        | Just(ty',co) <- splitNewTypeRepCo_maybe ty
        =       -- Given this:
                        --      newtype T = MkT ([T] -> Int)
@@ -612,7 +722,6 @@ mkEtaWW orig_n in_scope orig_ty
                        -- We want to get
                        --      coerce T (\x::[T] -> (coerce ([T]->Int) e) x)
          go n subst ty' (EtaCo (Type.substTy subst co) : eis)
--------
 
        | otherwise                        -- We have an expression of arity > 0, 
        = WARN( True, ppr orig_n <+> ppr orig_ty )
index 579c5da..8817222 100644 (file)
@@ -30,7 +30,7 @@ module Type (
 
        mkFunTy, mkFunTys, splitFunTy, splitFunTy_maybe, 
        splitFunTys, splitFunTysN,
-       funResultTy, funArgTy, zipFunTys, typeArity,
+       funResultTy, funArgTy, zipFunTys, 
 
        mkTyConApp, mkTyConTy, 
        tyConAppTyCon, tyConAppArgs, 
@@ -141,7 +141,6 @@ import VarSet
 import Name
 import Class
 import TyCon
-import BasicTypes      ( Arity )
 
 -- others
 import StaticFlags
@@ -498,14 +497,6 @@ funArgTy :: Type -> Type
 funArgTy ty | Just ty' <- coreView ty = funArgTy ty'
 funArgTy (FunTy arg _res)  = arg
 funArgTy ty                = pprPanic "funArgTy" (ppr ty)
-
-typeArity :: Type -> Arity
--- How many value arrows are visible in the type?
--- We look through foralls, but not through newtypes, dictionaries etc
-typeArity ty | Just ty' <- coreView ty = typeArity ty'
-typeArity (FunTy _ ty)    = 1 + typeArity ty
-typeArity (ForAllTy _ ty) = typeArity ty
-typeArity _               = 0
 \end{code}
 
 ---------------------------------------------------------------------