Improved specialisation of recursive groups
[ghc-hetmet.git] / compiler / specialise / Specialise.lhs
index 3564c27..a5cffb1 100644 (file)
@@ -15,7 +15,7 @@ module Specialise ( specProgram ) where
 #include "HsVersions.h"
 
 import DynFlags        ( DynFlags, DynFlag(..) )
-import Id              ( Id, idName, idType, mkUserLocal, 
+import Id              ( Id, idName, idType, mkUserLocal, idCoreRules,
                          idInlinePragma, setInlinePragma ) 
 import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
                          tyVarsOfTypes, tyVarsOfTheta, isClassPred,
@@ -25,6 +25,7 @@ import CoreSubst      ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
                          substBndr, substBndrs, substTy, substInScope,
                          cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
                        ) 
+import SimplUtils      ( interestingArg )
 import VarSet
 import VarEnv
 import CoreSyn
@@ -39,7 +40,7 @@ import UniqSupply     ( UniqSupply,
 import Name
 import MkId            ( voidArgId, realWorldPrimId )
 import FiniteMap
-import Maybes          ( catMaybes, maybeToBool )
+import Maybes          ( catMaybes, isJust )
 import ErrUtils                ( dumpIfSet_dyn )
 import Bag
 import Util
@@ -486,8 +487,6 @@ of this is permanently ruled out.
 Still, this is no great hardship, because we intend to eliminate
 overloading altogether anyway!
 
-
-
 A note about non-tyvar dictionaries
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Some Ids have types like
@@ -512,7 +511,7 @@ Should we specialise wrt this compound-type dictionary?  We used to say
 But it is simpler and more uniform to specialise wrt these dicts too;
 and in future GHC is likely to support full fledged type signatures 
 like
-       f ;: Eq [(a,b)] => ...
+       f :: Eq [(a,b)] => ...
 
 
 %************************************************************************
@@ -641,7 +640,7 @@ specExpr subst expr@(App {})
                                return (App fun' arg', uds_arg `plusUDs` uds_app)
 
     go (Var f)       args = case specVar subst f of
-                                Var f' -> return (Var f', mkCallUDs subst f' args)
+                                Var f' -> return (Var f', mkCallUDs f' args)
                                 e'     -> return (e', emptyUDs)        -- I don't expect this!
     go other        _    = specExpr subst other
 
@@ -748,39 +747,72 @@ finishSpecBind bind
     add (NonRec b r, b_fvs) (prs, fvs) = ((b,r)  : prs, b_fvs `unionVarSet` fvs)
     add (Rec b_prs,  b_fvs) (prs, fvs) = (b_prs ++ prs, b_fvs `unionVarSet` fvs)
 
+---------------------------
 specBindItself :: Subst -> CoreBind -> CallDetails -> SpecM (CoreBind, UsageDetails)
 
 -- specBindItself deals with the RHS, specialising it according
 -- to the calls found in the body (if any)
-specBindItself rhs_subst (NonRec bndr rhs) call_info = do
-    ((bndr',rhs'), spec_defns, spec_uds) <- specDefn rhs_subst call_info (bndr,rhs)
-    let
-        new_bind | null spec_defns = NonRec bndr' rhs'
-                 | otherwise       = Rec ((bndr',rhs'):spec_defns)
+specBindItself rhs_subst (NonRec fn rhs) call_info
+  = do { (rhs', rhs_uds) <- specExpr rhs_subst rhs          -- Do RHS of original fn
+       ; (fn', spec_defns, spec_uds) <- specDefn rhs_subst call_info fn rhs
+       ; if null spec_defns then
+                   return (NonRec fn rhs', rhs_uds)
+        else 
+           return (Rec ((fn',rhs') : spec_defns), rhs_uds `plusUDs` spec_uds) }
                -- bndr' mentions the spec_defns in its SpecEnv
                -- Not sure why we couln't just put the spec_defns first
-    return (new_bind, spec_uds)
-
-specBindItself rhs_subst (Rec pairs) call_info = do
-    stuff <- mapM (specDefn rhs_subst call_info) pairs
-    let
-       (pairs', spec_defns_s, spec_uds_s) = unzip3 stuff
-       spec_defns = concat spec_defns_s
-       spec_uds   = plusUDList spec_uds_s
-        new_bind   = Rec (spec_defns ++ pairs')
-    return (new_bind, spec_uds)
-
-
-specDefn :: Subst                      -- Subst to use for RHS
+                 
+specBindItself rhs_subst (Rec pairs) call_info
+       -- Note [Specialising a recursive group]
+  = do { let (bndrs,rhss) = unzip pairs
+       ; (rhss', rhs_uds) <- mapAndCombineSM (specExpr rhs_subst) rhss
+       ; let all_calls = call_info `unionCalls` calls rhs_uds
+       ; (bndrs1, spec_defns1, spec_uds1) <- specDefns rhs_subst all_calls pairs
+
+       ; if null spec_defns1 then   -- Common case: no specialisation
+                   return (Rec (bndrs `zip` rhss'), rhs_uds)
+        else do                     -- Specialisation occurred; do it again
+       { (bndrs2, spec_defns2, spec_uds2) <- specDefns rhs_subst 
+                                                              (calls spec_uds1) (bndrs1 `zip` rhss)
+
+       ; let all_defns = spec_defns1 ++ spec_defns2 ++ zip bndrs2 rhss'
+             
+       ; return (Rec all_defns, rhs_uds `plusUDs` spec_uds1 `plusUDs` spec_uds2) } }
+
+
+---------------------------
+specDefns :: Subst
         -> CallDetails                 -- Info on how it is used in its scope
-        -> (Id, CoreExpr)              -- The thing being bound and its un-processed RHS
-        -> SpecM ((Id, CoreExpr),      -- The thing and its processed RHS
-                                       --      the Id may now have specialisations attached
+        -> [(Id,CoreExpr)]             -- The things being bound and their un-processed RHS
+        -> SpecM ([Id],                -- Original Ids with RULES added
+                  [(Id,CoreExpr)],     -- Extra, specialised bindings
+                  UsageDetails)        -- Stuff to fling upwards from the specialised versions
+
+-- Specialise a list of bindings (the contents of a Rec), but flowing usages
+-- upwards binding by binding.  Example: { f = ...g ...; g = ...f .... }
+-- Then if the input CallDetails has a specialised call for 'g', whose specialisation
+-- in turn generates a specialised call for 'f', we catch that in this one sweep.
+-- But not vice versa (it's a fixpoint problem).
+
+specDefns _subst _call_info []
+  = return ([], [], emptyUDs)
+specDefns subst call_info ((bndr,rhs):pairs)
+  = do { (bndrs', spec_defns, spec_uds) <- specDefns subst call_info pairs
+       ; let all_calls = call_info `unionCalls` calls spec_uds
+       ; (bndr', spec_defns1, spec_uds1) <- specDefn subst all_calls bndr rhs
+       ; return (bndr' : bndrs',
+                        spec_defns1 ++ spec_defns, 
+                spec_uds1 `plusUDs` spec_uds) }
+
+---------------------------
+specDefn :: Subst
+        -> CallDetails                 -- Info on how it is used in its scope
+        -> Id -> CoreExpr              -- The thing being bound and its un-processed RHS
+        -> SpecM (Id,                  -- Original Id with added RULES
                   [(Id,CoreExpr)],     -- Extra, specialised bindings
-                  UsageDetails         -- Stuff to fling upwards from the RHS and its
-           )                           --      specialised versions
+                  UsageDetails)        -- Stuff to fling upwards from the specialised versions
 
-specDefn subst calls (fn, rhs)
+specDefn subst calls fn rhs
        -- The first case is the interesting one
   |  rhs_tyvars `lengthIs`     n_tyvars -- Rhs of fn's defn has right number of big lambdas
   && rhs_ids    `lengthAtLeast` n_dicts        -- and enough dict args
@@ -788,27 +820,18 @@ specDefn subst calls (fn, rhs)
 
 --   && not (certainlyWillInline (idUnfolding fn))     -- And it's not small
 --     See Note [Inline specialisation] for why we do not 
---     switch off specialisation for inline functions = do
-  = do
-     -- Specialise the body of the function
-    (rhs', rhs_uds) <- specExpr subst rhs
-
-      -- Make a specialised version for each call in calls_for_me
-    stuff <- mapM spec_call calls_for_me
-    let
-        (spec_defns, spec_uds, spec_rules) = unzip3 stuff
-
-        fn' = addIdSpecialisations fn spec_rules
+--     switch off specialisation for inline functions
 
-    return ((fn',rhs'),
-              spec_defns,
-              rhs_uds `plusUDs` plusUDList spec_uds)
+  = do {       -- Make a specialised version for each call in calls_for_me
+         stuff <- mapM spec_call calls_for_me
+       ; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff)
+             fn' = addIdSpecialisations fn spec_rules
+       ; return (fn', spec_defns, plusUDList spec_uds) }
 
   | otherwise  -- No calls or RHS doesn't fit our preconceptions
   = WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn )
          -- Note [Specialisation shape]
-    (do  { (rhs', rhs_uds) <- specExpr subst rhs
-       ; return ((fn, rhs'), [], rhs_uds) })
+    return (fn, [], emptyUDs)
   
   where
     fn_type           = idType fn
@@ -830,77 +853,84 @@ specDefn subst calls (fn, rhs)
                        Nothing -> []
                        Just cs -> fmToList cs
 
+    already_covered :: [CoreExpr] -> Bool
+    already_covered args         -- Note [Specialisations already covered]
+       = isJust (lookupRule (const True) (substInScope subst) 
+                                   fn args (idCoreRules fn))
+
+    mk_ty_args :: [Maybe Type] -> [CoreExpr]
+    mk_ty_args call_ts = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts
+              where
+                 mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
+                 mk_ty_arg _         (Just ty) = Type ty
+
     ----------------------------------------------------------
        -- Specialise to one particular call pattern
-    spec_call :: (CallKey, ([DictExpr], VarSet))       -- Call instance
-              -> SpecM ((Id,CoreExpr),                 -- Specialised definition
-                       UsageDetails,                   -- Usage details from specialised body
-                       CoreRule)                       -- Info for the Id's SpecEnv
+    spec_call :: (CallKey, ([DictExpr], VarSet))  -- Call instance
+              -> SpecM (Maybe ((Id,CoreExpr),    -- Specialised definition
+                              UsageDetails,      -- Usage details from specialised body
+                              CoreRule))         -- Info for the Id's SpecEnv
     spec_call (CallKey call_ts, (call_ds, _))
-      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts ) do
-               -- Calls are only recorded for properly-saturated applications
+      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts )
        
-       -- Suppose f's defn is  f = /\ a b c d -> \ d1 d2 -> rhs        
-        -- Supppose the call is for f [Just t1, Nothing, Just t3, Nothing] [dx1, dx2]
+       -- Suppose f's defn is  f = /\ a b c -> \ d1 d2 -> rhs  
+        -- Supppose the call is for f [Just t1, Nothing, Just t3] [dx1, dx2]
 
        -- Construct the new binding
        --      f1 = SUBST[a->t1,c->t3, d1->d1', d2->d2'] (/\ b d -> rhs)
        -- PLUS the usage-details
        --      { d1' = dx1; d2' = dx2 }
-       -- where d1', d2' are cloned versions of d1,d2, with the type substitution applied.
+       -- where d1', d2' are cloned versions of d1,d2, with the type substitution
+       -- applied.  These auxiliary bindings just avoid duplication of dx1, dx2
        --
        -- Note that the substitution is applied to the whole thing.
        -- This is convenient, but just slightly fragile.  Notably:
-       --      * There had better be no name clashes in a/b/c/d
-       --
-        let
-               -- poly_tyvars = [b,d] in the example above
+       --      * There had better be no name clashes in a/b/c
+        do { let
+               -- poly_tyvars = [b] in the example above
                -- spec_tyvars = [a,c] 
-               -- ty_args     = [t1,b,t3,d]
-          poly_tyvars = [tv | (tv, Nothing) <- rhs_tyvars `zip` call_ts]
-           spec_tyvars = [tv | (tv, Just _)  <- rhs_tyvars `zip` call_ts]
-          ty_args     = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts
-                      where
-                        mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
-                        mk_ty_arg _         (Just ty) = Type ty
-
-           spec_ty_args = [ty | Just ty <- call_ts]
-          rhs_subst  = extendTvSubstList subst (spec_tyvars `zip` spec_ty_args)
-
-       (rhs_subst', rhs_dicts') <- cloneBinders rhs_subst rhs_dicts
-       let
-          inst_args = ty_args ++ map Var rhs_dicts'
-
-               -- Figure out the type of the specialised function
-          body_ty = applyTypeToArgs rhs fn_type inst_args
-          (lam_args, app_args)                 -- Add a dummy argument if body_ty is unlifted
-               | isUnLiftedType body_ty        -- C.f. WwLib.mkWorkerArgs
-               = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
-               | otherwise = (poly_tyvars, poly_tyvars)
-          spec_id_ty = mkPiTypes lam_args body_ty
-
-        spec_f <- newIdSM fn spec_id_ty
-        (spec_rhs, rhs_uds) <- specExpr rhs_subst' (mkLams lam_args body)
-       let
+               -- ty_args     = [t1,b,t3]
+               poly_tyvars   = [tv | (tv, Nothing) <- rhs_tyvars `zip` call_ts]
+               spec_tv_binds = [(tv,ty) | (tv, Just ty) <- rhs_tyvars `zip` call_ts]
+               spec_ty_args  = map snd spec_tv_binds
+               ty_args       = mk_ty_args call_ts
+               rhs_subst     = extendTvSubstList subst spec_tv_binds
+
+          ; (rhs_subst', rhs_dicts') <- cloneBinders rhs_subst rhs_dicts
+          ; let inst_args = ty_args ++ map Var rhs_dicts'
+
+          ; if already_covered inst_args then
+               return Nothing
+            else do
+          {    -- Figure out the type of the specialised function
+            let body_ty = applyTypeToArgs rhs fn_type inst_args
+                (lam_args, app_args)           -- Add a dummy argument if body_ty is unlifted
+                  | isUnLiftedType body_ty     -- C.f. WwLib.mkWorkerArgs
+                  = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
+                  | otherwise = (poly_tyvars, poly_tyvars)
+                spec_id_ty = mkPiTypes lam_args body_ty
+       
+           ; spec_f <- newIdSM fn spec_id_ty
+           ; (spec_rhs, rhs_uds) <- specExpr rhs_subst' (mkLams lam_args body)
+          ; let
                -- The rule to put in the function's specialisation is:
-               --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
-          rule_name = mkFastString ("SPEC " ++ showSDoc (ppr fn <+> ppr spec_ty_args))
-           spec_env_rule = mkLocalRule 
-                               rule_name
-                               inline_prag     -- Note [Auto-specialisation and RULES]
-                               (idName fn)
-                               (poly_tyvars ++ rhs_dicts')
-                               inst_args 
-                               (mkVarApps (Var spec_f) app_args)
+               --      forall b, d1',d2'.  f t1 b t3 d1' d2' = f1 b  
+               rule_name = mkFastString ("SPEC " ++ showSDoc (ppr fn <+> ppr spec_ty_args))
+               spec_env_rule = mkLocalRule
+                                 rule_name
+                                 inline_prag   -- Note [Auto-specialisation and RULES]
+                                 (idName fn)
+                                 (poly_tyvars ++ rhs_dicts')
+                                 inst_args 
+                                 (mkVarApps (Var spec_f) app_args)
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
-          final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
-
-          spec_pr | inline_rhs = (spec_f `setInlinePragma` inline_prag, Note InlineMe spec_rhs)
-                  | otherwise  = (spec_f,                               spec_rhs)
+               final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
 
-        return (spec_pr, final_uds, spec_env_rule)
+               spec_pr | inline_rhs = (spec_f `setInlinePragma` inline_prag, Note InlineMe spec_rhs)
+                       | otherwise  = (spec_f,                               spec_rhs)
 
+          ; return (Just (spec_pr, final_uds, spec_env_rule)) } }
       where
        my_zipEqual doc xs ys 
         | debugIsOn && not (equalLength xs ys)
@@ -910,9 +940,58 @@ specDefn subst calls (fn, rhs)
                                                , ppr (idType fn), ppr theta
                                                , ppr n_dicts, ppr rhs_dicts 
                                                , ppr rhs])
-        | otherwise               = zipEqual doc xs ys
+        | otherwise               = zipEqual doc xs ys
 \end{code}
 
+Note [Specialising a recursive group]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+    let rec { f x = ...g x'...
+            ; g y = ...f y'.... }
+    in f 'a'
+Here we specialise 'f' at Char; but that is very likely to lead to 
+a specialisation of 'g' at Char.  We must do the latter, else the
+whole point of specialisation is lost.
+
+But we do not want to keep iterating to a fixpoint, because in the
+presence of polymorphic recursion we might generate an infinite number
+of specialisations.
+
+So we use the following heuristic:
+  * Arrange the rec block in dependency order, so far as possible
+    (the occurrence analyser already does this)
+
+  * Specialise it much like a sequence of lets
+
+  * Then go through the block a second time, feeding call-info from
+    the RHSs back in the bottom, as it were
+
+In effect, the ordering maxmimises the effectiveness of each sweep,
+and we do just two sweeps.   This should catch almost every case of 
+monomorphic recursion -- the exception could be a very knotted-up
+recursion with multiple cycles tied up together.
+
+This plan is implemented in the Rec case of specBindItself.
+Note [Specialisations already covered]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We obviously don't want to generate two specialisations for the same
+argument pattern.  There are two wrinkles
+
+1. We do the already-covered test in specDefn, not when we generate
+the CallInfo in mkCallUDs.  We used to test in the latter place, but
+we now iterate the specialiser somewhat, and the Id at the call site
+might therefore not have all the RULES that we can see in specDefn
+
+2. What about two specialisations where the second is an *instance*
+of the first?  If the more specific one shows up first, we'll generate
+specialisations for both.  If the *less* specific one shows up first,
+we *don't* currently generate a specialisation for the more specific
+one.  (See the call to lookupRule in already_covered.)  Reasons:
+  (a) lookupRule doesn't say which matches are exact (bad reason)
+  (b) if the earlier specialisation is user-provided, it's
+      far from clear that we should auto-specialise further
+
 Note [Auto-specialisation and RULES]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider:
@@ -1037,13 +1116,16 @@ emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM, ud_fvs = emptyVarSet }
 ------------------------------------------------------------                   
 type CallDetails  = FiniteMap Id CallInfo
 newtype CallKey   = CallKey [Maybe Type]                       -- Nothing => unconstrained type argument
-type CallInfo     = FiniteMap CallKey
-                             ([DictExpr], VarSet)              -- Dict args and the vars of the whole
-                                                               -- call (including tyvars)
-                                                               -- [*not* include the main id itself, of course]
-       -- The finite maps eliminate duplicates
-       -- The list of types and dictionaries is guaranteed to
-       -- match the type of f
+
+-- CallInfo uses a FiniteMap, thereby ensuring that
+-- we record only one call instance for any key
+--
+-- The list of types and dictionaries is guaranteed to
+-- match the type of f
+type CallInfo = FiniteMap CallKey ([DictExpr], VarSet)
+                       -- Range is dict args and the vars of the whole
+                       -- call (including tyvars)
+                       -- [*not* include the main id itself, of course]
 
 instance Outputable CallKey where
   ppr (CallKey ts) = ppr ts
@@ -1082,8 +1164,8 @@ singleCall id tys dicts
        --
        -- We don't include the 'id' itself.
 
-mkCallUDs :: Subst -> Id -> [CoreExpr] -> UsageDetails
-mkCallUDs subst f args 
+mkCallUDs :: Id -> [CoreExpr] -> UsageDetails
+mkCallUDs f args 
   | null theta
   || not (all isClassPred theta)       
        -- Only specialise if all overloading is on class params. 
@@ -1091,11 +1173,8 @@ mkCallUDs subst f args
        --  *don't* say what the value of the implicit param is!
   || not (spec_tys `lengthIs` n_tyvars)
   || not ( dicts   `lengthIs` n_dicts)
-  || maybeToBool (lookupRule (\_act -> True) (substInScope subst) emptyRuleBase f args)
-       -- There's already a rule covering this call.  A typical case
-       -- is where there's an explicit user-provided rule.  Then
-       -- we don't want to create a specialised version 
-       -- of the function that overlaps.
+  || not (any interestingArg dicts)    -- Note [Interesting dictionary arguments]
+  -- See also Note [Specialisations already covered]
   = emptyUDs   -- Not overloaded, or no specialisation wanted
 
   | otherwise
@@ -1112,8 +1191,21 @@ mkCallUDs subst f args
     mk_spec_ty tyvar ty 
        | tyvar `elemVarSet` constrained_tyvars = Just ty
        | otherwise                             = Nothing
+\end{code}
 
-------------------------------------------------------------                   
+Note [Interesting dictionary arguments]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider this
+        \a.\d:Eq a.  let f = ... in ...(f d)...
+There really is not much point in specialising f wrt the dictionary d,
+because the code for the specialised f is not improved at all, because
+d is lambda-bound.  We simply get junk specialisations.
+
+We re-use the function SimplUtils.interestingArg function to determine
+what sort of dictionary arguments have *some* information in them.
+
+
+\begin{code}
 plusUDs :: UsageDetails -> UsageDetails -> UsageDetails
 plusUDs (MkUD {dict_binds = db1, calls = calls1, ud_fvs = fvs1})
        (MkUD {dict_binds = db2, calls = calls2, ud_fvs = fvs2})