From 635952097df211953c4bd0456b37eba64c485f60 Mon Sep 17 00:00:00 2001 From: "simonpj@microsoft.com" Date: Fri, 13 Aug 2010 16:11:51 +0000 Subject: [PATCH] Re-do the arity calculation mechanism again (fix Trac #3959) After rumination, yet again, on the subject of arity calculation, I have redone what an ArityType is (it's purely internal to the CoreArity module), and documented it better. The result should fix #3959, and I hope the related #3961, #3983. There is lots of new documentation: in particular * Note [ArityType] describes the new datatype for arity info * Note [State hack and bottoming functions] says how bottoming functions are dealt with, particularly covering catch# and Trac #3959 I also found I had to be careful not to eta-expand single-method class constructors; see Note [Newtype classes and eta expansion]. I think this part could be done better, but it works ok. --- compiler/coreSyn/CoreArity.lhs | 341 ++++++++++++++++++++++++++-------------- compiler/types/Type.lhs | 11 +- 2 files changed, 226 insertions(+), 126 deletions(-) diff --git a/compiler/coreSyn/CoreArity.lhs b/compiler/coreSyn/CoreArity.lhs index d5849cb..e63d121 100644 --- a/compiler/coreSyn/CoreArity.lhs +++ b/compiler/coreSyn/CoreArity.lhs @@ -23,14 +23,13 @@ import Var import VarEnv import Id import Type -import TyCon ( isRecursiveTyCon ) +import TyCon ( isRecursiveTyCon, isClassTyCon ) import TcType ( isDictLikeTy ) import Coercion import BasicTypes import Unique import Outputable import DynFlags -import StaticFlags ( opt_NoStateHack ) import FastString \end{code} @@ -67,10 +66,19 @@ should have arity 3, regardless of f's arity. Note [exprArity invariant] ~~~~~~~~~~~~~~~~~~~~~~~~~~ exprArity has the following invariant: - (exprArity e) = n, then manifestArity (etaExpand e n) = n -That is, if exprArity says "the arity is n" then etaExpand really can get -"n" manifest lambdas to the top. + * If typeArity (exprType e) = n, + then manifestArity (etaExpand e n) = n + + That is, etaExpand can always expand as much as typeArity says + So the case analysis in etaExpand and in typeArity must match + + * exprArity e <= typeArity (exprType e) + + * Hence if (exprArity e) = n, then manifestArity (etaExpand e n) = n + + That is, if exprArity says "the arity is n" then etaExpand really + can get "n" manifest lambdas to the top. Why is this important? Because - In TidyPgm we use exprArity to fix the *final arity* of @@ -101,14 +109,82 @@ exprArity e = go e go (Lam x e) | isId x = go e + 1 | otherwise = go e go (Note _ e) = go e - go (Cast e co) = go e `min` typeArity (snd (coercionKind co)) + go (Cast e co) = go e `min` length (typeArity (snd (coercionKind co))) -- Note [exprArity invariant] go (App e (Type _)) = go e go (App f a) | exprIsTrivial a = (go f - 1) `max` 0 -- See Note [exprArity for applications] go _ = 0 + + +typeArity :: Type -> [OneShot] +-- How many value arrows are visible in the type? +-- We look through foralls, and newtypes +-- See Note [exprArity invariant] +typeArity ty + | Just (_, ty') <- splitForAllTy_maybe ty + = typeArity ty' + + | Just (arg,res) <- splitFunTy_maybe ty + = isStateHackType arg : typeArity res + + | Just (tc,tys) <- splitTyConApp_maybe ty + , Just (ty', _) <- instNewTyCon_maybe tc tys + , not (isRecursiveTyCon tc) + , not (isClassTyCon tc) -- Do not eta-expand through newtype classes + -- See Note [Newtype classes and eta expansion] + = typeArity ty' + -- Important to look through non-recursive newtypes, so that, eg + -- (f x) where f has arity 2, f :: Int -> IO () + -- Here we want to get arity 1 for the result! + + | otherwise + = [] \end{code} +Note [Newtype classes and eta expansion] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +We have to be careful when eta-expanding through newtypes. In general +it's a good idea, but annoyingly it interacts badly with the class-op +rule mechanism. Consider + + class C a where { op :: a -> a } + instance C b => C [b] where + op x = ... + +These translate to + + co :: forall a. (a->a) ~ C a + + $copList :: C b -> [b] -> [b] + $copList d x = ... + + $dfList :: C b -> C [b] + {-# DFunUnfolding = [$copList] #-} + $dfList d = $copList d |> co@[b] + +Now suppose we have: + + dCInt :: C Int + + blah :: [Int] -> [Int] + blah = op ($dfList dCInt) + +Now we want the built-in op/$dfList rule will fire to give + blah = $copList dCInt + +But with eta-expansion 'blah' might (and in Trac #3772, which is +slightly more complicated, does) turn into + + blah = op (\eta. ($dfList dCInt |> sym co) eta) + +and now it is *much* harder for the op/$dfList rule to fire, becuase +exprIsConApp_maybe won't hold of the argument to op. I considered +trying to *make* it hold, but it's tricky and I gave up. + +The test simplCore/should_compile/T3722 is an excellent example. + + Note [exprArity for applications] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When we come to an application we check that the arg is trivial. @@ -133,24 +209,14 @@ When we come to an application we check that the arg is trivial. %************************************************************************ \begin{code} --- ^ The Arity returned is the number of value args the --- expression can be applied to without doing much work -exprEtaExpandArity :: DynFlags -> CoreExpr -> Arity --- exprEtaExpandArity is used when eta expanding --- e ==> \xy -> e x y -exprEtaExpandArity dflags e - = applyStateHack e (arityType dicts_cheap e) - where - dicts_cheap = dopt Opt_DictsCheap dflags - exprBotStrictness_maybe :: CoreExpr -> Maybe (Arity, StrictSig) -- A cheap and cheerful function that identifies bottoming functions -- and gives them a suitable strictness signatures. It's used during -- float-out exprBotStrictness_maybe e - = case arityType False e of - AT _ ATop -> Nothing - AT a ABot -> Just (a, mkStrictSig (mkTopDmdType (replicate a topDmd) BotRes)) + = case getBotArity (arityType False e) of + Nothing -> Nothing + Just ar -> Just (ar, mkStrictSig (mkTopDmdType (replicate ar topDmd) BotRes)) \end{code} Note [Definition of arity] @@ -234,40 +300,6 @@ we want to get: coerce T (\x::[T] -> (coerce ([T]->Int) e) x) And since negate has arity 2, you might try to eta expand. But you can't decopose Int to a function type. Hence the final case in eta_expand. -\begin{code} -applyStateHack :: CoreExpr -> ArityType -> Arity -applyStateHack e (AT orig_arity is_bot) - | opt_NoStateHack = orig_arity - | ABot <- is_bot = orig_arity -- Note [State hack and bottoming functions] - | otherwise = go orig_ty orig_arity - where -- Note [The state-transformer hack] - orig_ty = exprType e - go :: Type -> Arity -> Arity - go ty arity -- This case analysis should match that in eta_expand - | Just (_, ty') <- splitForAllTy_maybe ty = go ty' arity - | Just (arg,res) <- splitFunTy_maybe ty - , arity > 0 || isStateHackType arg = 1 + go res (arity-1) - --- See Note [trimCast] - | Just (tc,tys) <- splitTyConApp_maybe ty - , Just (ty', _) <- instNewTyCon_maybe tc tys - , not (isRecursiveTyCon tc) = go ty' arity - -- Important to look through non-recursive newtypes, so that, eg - -- (f x) where f has arity 2, f :: Int -> IO () - -- Here we want to get arity 1 for the result! -------- - -{- - = if arity > 0 then 1 + go res (arity-1) - else if isStateHackType arg then - pprTrace "applystatehack" (vcat [ppr orig_arity, ppr orig_ty, - ppr ty, ppr res, ppr e]) $ - 1 + go res (arity-1) - else WARN( arity > 0, ppr arity ) 0 --} - | otherwise = WARN( arity > 0, ppr arity <+> ppr ty) 0 -\end{code} - Note [The state-transformer hack] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Suppose we have @@ -320,90 +352,169 @@ Extrude g1.g3 And now we can repeat the whole loop. Aargh! The bug is in applying the state hack to a function which then swallows the argument. +This arose in another guise in Trac #3959. Here we had + + catch# (throw exn >> return ()) + +Note that (throw :: forall a e. Exn e => e -> a) is called with [a = IO ()]. +After inlining (>>) we get + + catch# (\_. throw {IO ()} exn) + +We must *not* eta-expand to + + catch# (\_ _. throw {...} exn) + +because 'catch#' expects to get a (# _,_ #) after applying its argument to +a State#, not another function! + +In short, we use the state hack to allow us to push let inside a lambda, +but not to introduce a new lambda. + + +Note [ArityType] +~~~~~~~~~~~~~~~~ +ArityType is the result of a compositional analysis on expressions, +from which we can decide the real arity of the expression (extracted +with function getArity). + +Here is what the fields mean. If e has ArityType + (AT as r), where n = length as, +then + + * If r is ABot then (e x1..xn) definitely diverges + Partial applications may or may not diverge + + * If r is ACheap then (e x1..x(n-1)) is cheap, + including any nested sub-expressions inside e + (say e is (f e1 e2) then e1,e2 are cheap too) + + * e, (e x1), ... (e x1 ... x(n-1)) are definitely really + functions, or bottom, not casts from a data type + So eta expansion is dynamically ok; + see Note [State hack and bottoming functions], + the part about catch# + +We regard ABot as stronger than ACheap; ie if ABot holds +we don't bother about ACheap + +Suppose f = \xy. x+y +Then f :: AT [False,False] ACheap + f v :: AT [False] ACheap + f :: AT [False] ATop +Note the ArityRes flag tells whether the whole expression is cheap. +Note also that having a non-empty 'as' doesn't mean it has that +arity; see (f ) which does not have arity 1! + +The key function getArity extracts the arity (which in turn guides +eta-expansion) from ArityType. + * If the term is cheap or diverges we can certainly eta expand it + e.g. (f x) where x has arity 2 + + * If its a function whose first arg is one-shot (probably via the + state hack) we can eta expand it + e.g. (getChar ) -------------------- Main arity code ---------------------------- \begin{code} --- If e has ArityType (AT n r), then the term 'e' --- * Must be applied to at least n *value* args --- before doing any significant work --- * It will not diverge before being applied to n --- value arguments --- * If 'r' is ABot, then it guarantees to diverge if --- applied to n arguments (or more) - -data ArityType = AT Arity ArityRes -data ArityRes = ATop -- Know nothing - | ABot -- Diverges +-- See Note [ArityType] +data ArityType = AT [OneShot] ArityRes + -- There is always an explicit lambda + -- to justify the [OneShot] + +type OneShot = Bool -- False <=> Know nothing + -- True <=> Can definitely float inside this lambda + -- The 'True' case can arise either because a binder + -- is marked one-shot, or because it's a state lambda + -- and we have the state hack on + +data ArityRes = ATop | ACheap | ABot vanillaArityType :: ArityType -vanillaArityType = AT 0 ATop -- Totally uninformative +vanillaArityType = AT [] ATop -- Totally uninformative -incArity :: ArityType -> ArityType -incArity (AT a r) = AT (a+1) r +-- ^ The Arity returned is the number of value args the [_$_] +-- expression can be applied to without doing much work +exprEtaExpandArity :: DynFlags -> CoreExpr -> Arity +-- exprEtaExpandArity is used when eta expanding +-- e ==> \xy -> e x y +exprEtaExpandArity dflags e + = case (arityType dicts_cheap e) of + AT (a:as) res | want_eta a res -> 1 + length as + _ -> 0 + where + want_eta one_shot ATop = one_shot + want_eta _ _ = True -decArity :: ArityType -> ArityType -decArity (AT 0 r) = AT 0 r -decArity (AT a r) = AT (a-1) r + dicts_cheap = dopt Opt_DictsCheap dflags -andArityType :: ArityType -> ArityType -> ArityType -- Used for branches of a 'case' -andArityType (AT a1 ATop) (AT a2 ATop) = AT (a1 `min` a2) ATop -andArityType (AT _ ABot) (AT a2 ATop) = AT a2 ATop -andArityType (AT a1 ATop) (AT _ ABot) = AT a1 ATop -andArityType (AT a1 ABot) (AT a2 ABot) = AT (a1 `max` a2) ABot +getBotArity :: ArityType -> Maybe Arity +-- Arity of a divergent function +getBotArity (AT as ABot) = Just (length as) +getBotArity _ = Nothing ---------------------------- -trimCast :: Coercion -> ArityType -> ArityType --- Trim the arity to be no more than allowed by the --- arrows in ty2, where co :: ty1~ty2 -trimCast _ at = at - -{- Omitting for now Note [trimCast] -trimCast co at@(AT ar _) - | ar > co_arity = AT co_arity ATop - | otherwise = at +arityLam :: Id -> ArityType -> ArityType +arityLam id (AT as r) = AT (isOneShotBndr id : as) r + +floatIn :: Bool -> ArityType -> ArityType +-- We have something like (let x = E in b), +-- where b has the given arity type. +floatIn c (AT as r) = AT as (extendArityRes r c) + +arityApp :: ArityType -> CoreExpr -> ArityType +-- Processing (fun arg) where at is the ArityType of fun, +arityApp (AT [] r) arg = AT [] (extendArityRes r (exprIsCheap arg)) +arityApp (AT (_:as) r) arg = AT as (extendArityRes r (exprIsCheap arg)) + +extendArityRes :: ArityRes -> Bool -> ArityRes +extendArityRes ABot _ = ABot +extendArityRes ACheap True = ACheap +extendArityRes _ _ = ATop + +andArityType :: ArityType -> ArityType -> ArityType -- Used for branches of a 'case' +andArityType (AT as1 r1) (AT as2 r2) + = AT (go_as as1 as2) (go_r r1 r2) where - (_, ty2) = coercionKind co - co_arity = typeArity ty2 --} + go_r ABot ABot = ABot + go_r ABot ACheap = ACheap + go_r ACheap ABot = ACheap + go_r ACheap ACheap = ACheap + go_r _ _ = ATop + + go_as (os1:as1) (os2:as2) = (os1 || os2) : go_as as1 as2 + go_as [] as2 = as2 + go_as as1 [] = as1 \end{code} -Note [trimCast] -~~~~~~~~~~~~~~~ -When you try putting trimCast back in, comment out the snippets -flagged by the other references to Note [trimCast] \begin{code} --------------------------- -trimArity :: Bool -> ArityType -> ArityType --- We have something like (let x = E in b), where b has the given --- arity type. Then --- * If E is cheap we can push it inside as far as we like --- * If b eventually diverges, we allow ourselves to push inside --- arbitrarily, even though that is not quite right -trimArity _cheap (AT a ABot) = AT a ABot -trimArity True (AT a ATop) = AT a ATop -trimArity False (AT _ ATop) = AT 0 ATop -- Bale out - ---------------------------- arityType :: Bool -> CoreExpr -> ArityType arityType _ (Var v) | Just strict_sig <- idStrictness_maybe v , (ds, res) <- splitStrictSig strict_sig - , isBotRes res - = AT (length ds) ABot -- Function diverges + = mk_arity (length ds) res | otherwise - = AT (idArity v) ATop + = mk_arity (idArity v) TopRes + + where + mk_arity id_arity res + | isBotRes res = AT (take id_arity one_shots) ABot + | id_arity>0 = AT (take id_arity one_shots) ACheap + | otherwise = AT [] ATop + + one_shots = typeArity (idType v) -- Lambdas; increase arity arityType dicts_cheap (Lam x e) - | isId x = incArity (arityType dicts_cheap e) + | isId x = arityLam x (arityType dicts_cheap e) | otherwise = arityType dicts_cheap e -- Applications; decrease arity arityType dicts_cheap (App fun (Type _)) = arityType dicts_cheap fun arityType dicts_cheap (App fun arg ) - = trimArity (exprIsCheap arg) (decArity (arityType dicts_cheap fun)) + = arityApp (arityType dicts_cheap fun) arg -- Case/Let; keep arity if either the expression is cheap -- or it's a 1-shot lambda @@ -413,11 +524,11 @@ arityType dicts_cheap (App fun arg ) -- f x y = case x of { (a,b) -> e } -- The difference is observable using 'seq' arityType dicts_cheap (Case scrut _ _ alts) - = trimArity (exprIsCheap scrut) + = floatIn (exprIsCheap scrut) (foldr1 andArityType [arityType dicts_cheap rhs | (_,_,rhs) <- alts]) arityType dicts_cheap (Let b e) - = trimArity (cheap_bind b) (arityType dicts_cheap e) + = floatIn (cheap_bind b) (arityType dicts_cheap e) where cheap_bind (NonRec b e) = is_cheap (b,e) cheap_bind (Rec prs) = all is_cheap prs @@ -443,9 +554,9 @@ arityType dicts_cheap (Let b e) -- See Note [Dictionary-like types] in TcType.lhs for why we use -- isDictLikeTy here rather than isDictTy -arityType dicts_cheap (Note _ e) = arityType dicts_cheap e -arityType dicts_cheap (Cast e co) = trimCast co (arityType dicts_cheap e) -arityType _ _ = vanillaArityType +arityType dicts_cheap (Note _ e) = arityType dicts_cheap e +arityType dicts_cheap (Cast e _) = arityType dicts_cheap e +arityType _ _ = vanillaArityType \end{code} @@ -589,7 +700,7 @@ mkEtaWW orig_n in_scope orig_ty where empty_subst = mkTvSubst in_scope emptyTvSubstEnv - go n subst ty eis + go n subst ty eis -- See Note [exprArity invariant] | n == 0 = (getTvInScope subst, reverse eis) @@ -603,7 +714,6 @@ mkEtaWW orig_n in_scope orig_ty -- Avoid free vars of the original expression = go (n-1) subst' res_ty (EtaVar eta_id' : eis) --- See Note [trimCast] | Just(ty',co) <- splitNewTypeRepCo_maybe ty = -- Given this: -- newtype T = MkT ([T] -> Int) @@ -612,7 +722,6 @@ mkEtaWW orig_n in_scope orig_ty -- We want to get -- coerce T (\x::[T] -> (coerce ([T]->Int) e) x) go n subst ty' (EtaCo (Type.substTy subst co) : eis) -------- | otherwise -- We have an expression of arity > 0, = WARN( True, ppr orig_n <+> ppr orig_ty ) diff --git a/compiler/types/Type.lhs b/compiler/types/Type.lhs index 579c5da..8817222 100644 --- a/compiler/types/Type.lhs +++ b/compiler/types/Type.lhs @@ -30,7 +30,7 @@ module Type ( mkFunTy, mkFunTys, splitFunTy, splitFunTy_maybe, splitFunTys, splitFunTysN, - funResultTy, funArgTy, zipFunTys, typeArity, + funResultTy, funArgTy, zipFunTys, mkTyConApp, mkTyConTy, tyConAppTyCon, tyConAppArgs, @@ -141,7 +141,6 @@ import VarSet import Name import Class import TyCon -import BasicTypes ( Arity ) -- others import StaticFlags @@ -498,14 +497,6 @@ funArgTy :: Type -> Type funArgTy ty | Just ty' <- coreView ty = funArgTy ty' funArgTy (FunTy arg _res) = arg funArgTy ty = pprPanic "funArgTy" (ppr ty) - -typeArity :: Type -> Arity --- How many value arrows are visible in the type? --- We look through foralls, but not through newtypes, dictionaries etc -typeArity ty | Just ty' <- coreView ty = typeArity ty' -typeArity (FunTy _ ty) = 1 + typeArity ty -typeArity (ForAllTy _ ty) = typeArity ty -typeArity _ = 0 \end{code} --------------------------------------------------------------------- -- 1.7.10.4