For a non-recursive let, make sure we extend the value environment
[ghc-hetmet.git] / compiler / specialise / Specialise.lhs
index f18c8f9..f6f85a1 100644 (file)
@@ -10,25 +10,29 @@ module Specialise ( specProgram ) where
 
 import Id
 import TcType
+import CoreMonad
 import CoreSubst 
-import CoreUnfold      ( mkUnfolding, mkInlineRule )
+import CoreUnfold
 import VarSet
 import VarEnv
 import CoreSyn
 import Rules
 import CoreUtils       ( exprIsTrivial, applyTypeToArgs, mkPiTypes )
 import CoreFVs         ( exprFreeVars, exprsFreeVars, idFreeVars )
-import UniqSupply      ( UniqSupply, UniqSM, initUs_, MonadUnique(..) )
+import UniqSupply      ( UniqSM, initUs_, MonadUnique(..) )
 import Name
 import MkId            ( voidArgId, realWorldPrimId )
-import FiniteMap
 import Maybes          ( catMaybes, isJust )
-import BasicTypes      ( isNeverActive, inlinePragmaActivation )
+import BasicTypes      
+import HscTypes
 import Bag
 import Util
 import Outputable
 import FastString
 
+import Data.Map (Map)
+import qualified Data.Map as Map
+import qualified FiniteMap as Map
 \end{code}
 
 %************************************************************************
@@ -556,24 +560,98 @@ Hence, the invariant is this:
 %************************************************************************
 
 \begin{code}
-specProgram :: UniqSupply -> [CoreBind] -> [CoreBind]
-specProgram us binds = initSM us $
-                       do { (binds', uds') <- go binds
-                         ; return (wrapDictBinds (ud_binds uds') binds') }
+specProgram :: ModGuts -> CoreM ModGuts
+specProgram guts 
+  = do { hpt_rules <- getRuleBase
+       ; let local_rules = mg_rules guts
+             rule_base = extendRuleBaseList hpt_rules (mg_rules guts)
+
+            -- Specialise the bindings of this module
+       ; (binds', uds) <- runSpecM (go (mg_binds guts))
+
+            -- Specialise imported functions 
+       ; (new_rules, spec_binds) <- specImports emptyVarSet rule_base uds
+
+       ; return (guts { mg_binds = spec_binds ++ binds'
+                      , mg_rules = local_rules ++ new_rules }) }
   where
        -- We need to start with a Subst that knows all the things
        -- that are in scope, so that the substitution engine doesn't
        -- accidentally re-use a unique that's already in use
        -- Easiest thing is to do it all at once, as if all the top-level
        -- decls were mutually recursive
-    top_subst       = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
+    top_subst = mkEmptySubst $ mkInScopeSet $ mkVarSet $ 
+                bindersOfBinds $ mg_binds guts
 
     go []           = return ([], emptyUDs)
     go (bind:binds) = do (binds', uds) <- go binds
                          (bind', uds') <- specBind top_subst bind uds
                          return (bind' ++ binds', uds')
+
+specImports :: VarSet          -- Don't specialise these ones
+                               -- See Note [Avoiding recursive specialisation]
+            -> RuleBase                -- Rules from this module and the home package
+                               -- (but not external packages, which can change)
+            -> UsageDetails    -- Calls for imported things, and floating bindings
+            -> CoreM ( [CoreRule]   -- New rules
+                     , [CoreBind] ) -- Specialised bindings and floating bindings
+specImports done rb uds
+  = do { let import_calls = varEnvElts (ud_calls uds)
+       ; (rules, spec_binds) <- go rb import_calls
+       ; return (rules, wrapDictBinds (ud_binds uds) spec_binds) }
+  where
+    go _ [] = return ([], [])
+    go rb (CIS fn calls_for_fn : other_calls)
+      = do { (rules1, spec_binds1) <- specImport done rb fn (Map.toList calls_for_fn)
+           ; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls
+           ; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
+
+specImport :: VarSet               -- Don't specialise these
+                                   -- See Note [Avoiding recursive specialisation]
+           -> RuleBase             -- Rules from this module
+           -> Id -> [CallInfo]     -- Imported function and calls for it
+           -> CoreM ( [CoreRule]    -- New rules
+                    , [CoreBind] )  -- Specialised bindings
+specImport done rb fn calls_for_fn
+  | not (fn `elemVarSet` done)
+  , isInlinablePragma (idInlinePragma fn)
+  , Just rhs <- maybeUnfoldingTemplate (realIdUnfolding fn)
+  = do {     -- Get rules from the external package state
+                    -- We keep doing this in case we "page-fault in" 
+            -- more rules as we go along
+       ; hsc_env <- getHscEnv
+       ; eps <- liftIO $ hscEPS hsc_env 
+       ; let full_rb = unionRuleBase rb (eps_rule_base eps)
+             rules_for_fn = getRules full_rb fn 
+
+       ; (rules1, spec_pairs, uds) <- runSpecM $
+              specCalls emptySubst rules_for_fn calls_for_fn fn rhs
+       ; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs]
+                    -- After the rules kick in we may get recursion, but 
+            -- we rely on a global GlomBinds to sort that out later
+       
+             -- Now specialise any cascaded calls
+       ; (rules2, spec_binds2) <- specImports (extendVarSet done fn) 
+                                              (extendRuleBaseList rb rules1)
+                                              uds
+
+       ; return (rules2 ++ rules1, spec_binds2 ++ spec_binds1) }
+
+  | otherwise
+  = WARN( True, ptext (sLit "specImport discard") <+> ppr fn <+> ppr calls_for_fn )
+    return ([], [])    
 \end{code}
 
+Avoiding recursive specialisation
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
+'f's RHS.  So we want to specialise g,h.  But we don't want to
+specialise f any more!  It's possible that f's RHS might have a
+recursive yet-more-specialised call, so we'd diverge in that case.
+And if the call is to the same type, one specialisation is enough.
+Avoiding this recursive specialisation loop is the reason for the 
+'done' VarSet passed to specImports and specImport.
+
 %************************************************************************
 %*                                                                     *
 \subsubsection{@specExpr@: the main function}
@@ -704,7 +782,7 @@ specCase subst scrut' case_bndr [(con, args, rhs)]
          loc  = getSrcSpan name
 
     add_unf sc_flt sc_rhs  -- Sole purpose: make sc_flt respond True to interestingDictId
-      = setIdUnfolding sc_flt (mkUnfolding False False sc_rhs)
+      = setIdUnfolding sc_flt (mkSimpleUnfolding sc_rhs)
 
     arg_set = mkVarSet args'
     is_flt_sc_arg var =  isId var
@@ -761,7 +839,7 @@ to substitute sc -> sc_flt in the RHS
 
 %************************************************************************
 %*                                                                     *
-\subsubsection{Dealing with a binding}
+                     Dealing with a binding
 %*                                                                     *
 %************************************************************************
 
@@ -861,6 +939,34 @@ specDefn :: Subst
                   UsageDetails)        -- Stuff to fling upwards from the specialised versions
 
 specDefn subst body_uds fn rhs
+  = do { let (body_uds_without_me, calls_for_me) = callsForMe fn body_uds
+             rules_for_me = idCoreRules fn
+       ; (rules, spec_defns, spec_uds) <- specCalls subst rules_for_me 
+                                                    calls_for_me fn rhs
+       ; return ( fn `addIdSpecialisations` rules
+                , spec_defns
+                , body_uds_without_me `plusUDs` spec_uds) }
+               -- It's important that the `plusUDs` is this way
+               -- round, because body_uds_without_me may bind
+               -- dictionaries that are used in calls_for_me passed
+               -- to specDefn.  So the dictionary bindings in
+               -- spec_uds may mention dictionaries bound in
+               -- body_uds_without_me
+
+---------------------------
+specCalls :: Subst
+          -> [CoreRule]                        -- Existing RULES for the fn
+         -> [CallInfo] 
+         -> Id -> CoreExpr
+         -> SpecM ([CoreRule],         -- New RULES for the fn
+                   [(Id,CoreExpr)],    -- Extra, specialised bindings
+                   UsageDetails)       -- New usage details from the specialised RHSs
+
+-- This function checks existing rules, and does not create
+-- duplicate ones. So the caller does not nneed to do this filtering.
+-- See 'already_covered'
+
+specCalls subst rules_for_me calls_for_me fn rhs
        -- The first case is the interesting one
   |  rhs_tyvars `lengthIs`     n_tyvars -- Rhs of fn's defn has right number of big lambdas
   && rhs_ids    `lengthAtLeast` n_dicts        -- and enough dict args
@@ -873,26 +979,16 @@ specDefn subst body_uds fn rhs
 --     See Note [Inline specialisation] for why we do not 
 --     switch off specialisation for inline functions
 
-  = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me) $
-    do {       -- Make a specialised version for each call in calls_for_me
-         stuff <- mapM spec_call calls_for_me
+  = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me $$ ppr rules_for_me) $
+    do { stuff <- mapM spec_call calls_for_me
        ; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff)
-             fn' = addIdSpecialisations fn spec_rules
-             final_uds = body_uds_without_me `plusUDs` plusUDList spec_uds 
-               -- It's important that the `plusUDs` is this way
-               -- round, because body_uds_without_me may bind
-               -- dictionaries that are used in calls_for_me passed
-               -- to specDefn.  So the dictionary bindings in
-               -- spec_uds may mention dictionaries bound in
-               -- body_uds_without_me
-
-       ; return (fn', spec_defns, final_uds) }
+       ; return (spec_rules, spec_defns, plusUDList spec_uds) }
 
   | otherwise  -- No calls or RHS doesn't fit our preconceptions
   = WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn )
          -- Note [Specialisation shape]
     -- pprTrace "specDefn: none" (ppr fn $$ ppr calls_for_me) $
-    return (fn, [], body_uds_without_me)
+    return ([], [], emptyUDs)
   
   where
     fn_type           = idType fn
@@ -901,21 +997,17 @@ specDefn subst body_uds fn rhs
     (tyvars, theta, _) = tcSplitSigmaTy fn_type
     n_tyvars          = length tyvars
     n_dicts           = length theta
-    inl_act            = inlinePragmaActivation (idInlinePragma fn)
+    inl_prag           = idInlinePragma fn
+    inl_act            = inlinePragmaActivation inl_prag
+    is_local           = isLocalId fn
 
        -- Figure out whether the function has an INLINE pragma
        -- See Note [Inline specialisations]
-    fn_has_inline_rule :: Maybe Bool   -- Derive sat-flag from existing thing
-    fn_has_inline_rule = case isInlineRule_maybe fn_unf of
-                           Just (_,sat) -> Just sat
-                          Nothing      -> Nothing
 
     spec_arity = unfoldingArity fn_unf - n_dicts  -- Arity of the *specialised* inline rule
 
     (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs
 
-    (body_uds_without_me, calls_for_me) = callsForMe fn body_uds
-
     rhs_dict_ids = take n_dicts rhs_ids
     body         = mkLams (drop n_dicts rhs_ids) rhs_body
                -- Glue back on the non-dict lambdas
@@ -924,7 +1016,7 @@ specDefn subst body_uds fn rhs
     already_covered args         -- Note [Specialisations already covered]
        = isJust (lookupRule (const True) realIdUnfolding 
                             (substInScope subst) 
-                                   fn args (idCoreRules fn))
+                                   fn args rules_for_me)
 
     mk_ty_args :: [Maybe Type] -> [CoreExpr]
     mk_ty_args call_ts = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts
@@ -988,8 +1080,8 @@ specDefn subst body_uds fn rhs
                -- The rule to put in the function's specialisation is:
                --      forall b, d1',d2'.  f t1 b t3 d1' d2' = f1 b  
                rule_name = mkFastString ("SPEC " ++ showSDoc (ppr fn <+> ppr spec_ty_args))
-               spec_env_rule = mkLocalRule
-                                 rule_name
+               spec_env_rule = mkRule True {- Auto generated -} is_local
+                                  rule_name
                                  inl_act       -- Note [Auto-specialisation and RULES]
                                  (idName fn)
                                  (poly_tyvars ++ inst_dict_ids)
@@ -999,25 +1091,23 @@ specDefn subst body_uds fn rhs
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
                final_uds = foldr consDictBind rhs_uds dx_binds
 
+               -- Add an InlineRule if the parent has one
+               -- See Note [Inline specialisations]
+               spec_unf
+                  = case inlinePragmaSpec inl_prag of
+                      Inline    -> mkInlineUnfolding (Just spec_arity) spec_rhs
+                      Inlinable -> mkInlinableUnfolding spec_rhs
+                      _         -> NoUnfolding
+
                -- Adding arity information just propagates it a bit faster
                --      See Note [Arity decrease] in Simplify
                -- Copy InlinePragma information from the parent Id.
                -- So if f has INLINE[1] so does spec_f
                spec_f_w_arity = spec_f `setIdArity`          max 0 (fn_arity - n_dicts)
-                                        `setInlineActivation` inl_act
+                                        `setInlinePragma` inl_prag
+                                        `setIdUnfolding`  spec_unf
 
-               -- Add an InlineRule if the parent has one
-               -- See Note [Inline specialisations]
-               final_spec_f 
-                  | Just sat <- fn_has_inline_rule
-                 = let 
-                       mb_spec_arity = if sat then Just spec_arity else Nothing
-                    in 
-                    spec_f_w_arity `setIdUnfolding` mkInlineRule spec_rhs mb_spec_arity
-                 | otherwise 
-                 = spec_f_w_arity
-
-          ; return (Just ((final_spec_f, spec_rhs), final_uds, spec_env_rule)) } }
+          ; return (Just ((spec_f_w_arity, spec_rhs), final_uds, spec_env_rule)) } }
       where
        my_zipEqual xs ys zs
         | debugIsOn && not (equalLength xs ys && equalLength ys zs)
@@ -1046,7 +1136,7 @@ bindAuxiliaryDicts subst triples = go subst [] triples
 
       | otherwise        = go subst_w_unf (NonRec dx_id dx : binds) pairs
       where
-        dx_id1 = dx_id `setIdUnfolding` mkUnfolding False False dx
+        dx_id1 = dx_id `setIdUnfolding` mkSimpleUnfolding dx
        subst_w_unf = extendIdSubst subst d (Var dx_id1)
                     -- Important!  We're going to substitute dx_id1 for d
             -- and we want it to look "interesting", else we won't gather *any*
@@ -1147,7 +1237,7 @@ group.  (In this case it'll unravel a short moment later.)
 
 
 Conclusion: we catch the nasty case using filter_dfuns in
-callsForMe To be honest I'm not 100% certain that this is 100%
+callsForMe. To be honest I'm not 100% certain that this is 100%
 right, but it works.  Sigh.
 
 
@@ -1321,18 +1411,22 @@ emptyUDs = MkUD { ud_binds = emptyBag, ud_calls = emptyVarEnv }
 type CallDetails  = IdEnv CallInfoSet
 newtype CallKey   = CallKey [Maybe Type]                       -- Nothing => unconstrained type argument
 
--- CallInfo uses a FiniteMap, thereby ensuring that
+-- CallInfo uses a Map, thereby ensuring that
 -- we record only one call instance for any key
 --
 -- The list of types and dictionaries is guaranteed to
 -- match the type of f
-type CallInfoSet = FiniteMap CallKey ([DictExpr], VarSet)
+data CallInfoSet = CIS Id (Map CallKey ([DictExpr], VarSet))
                        -- Range is dict args and the vars of the whole
                        -- call (including tyvars)
                        -- [*not* include the main id itself, of course]
 
 type CallInfo = (CallKey, ([DictExpr], VarSet))
 
+instance Outputable CallInfoSet where
+  ppr (CIS fn map) = hang (ptext (sLit "CIS") <+> ppr fn)
+                        2 (ppr map)
+
 instance Outputable CallKey where
   ppr (CallKey ts) = ppr ts
 
@@ -1350,22 +1444,23 @@ instance Ord CallKey where
                  cmp (Just t1) (Just t2) = tcCmpType t1 t2
 
 unionCalls :: CallDetails -> CallDetails -> CallDetails
-unionCalls c1 c2 = plusVarEnv_C plusFM c1 c2
+unionCalls c1 c2 = plusVarEnv_C unionCallInfoSet c1 c2
 
--- plusCalls :: UsageDetails -> CallDetails -> UsageDetails
--- plusCalls uds call_ds = uds { ud_calls = ud_calls uds `unionCalls` call_ds }
+unionCallInfoSet :: CallInfoSet -> CallInfoSet -> CallInfoSet
+unionCallInfoSet (CIS f calls1) (CIS _ calls2) = CIS f (calls1 `Map.union` calls2)
 
 callDetailsFVs :: CallDetails -> VarSet
 callDetailsFVs calls = foldVarEnv (unionVarSet . callInfoFVs) emptyVarSet calls
 
 callInfoFVs :: CallInfoSet -> VarSet
-callInfoFVs call_info = foldFM (\_ (_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info
+callInfoFVs (CIS _ call_info) = Map.foldRight (\(_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info
 
 ------------------------------------------------------------                   
 singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails
 singleCall id tys dicts 
   = MkUD {ud_binds = emptyBag, 
-         ud_calls = unitVarEnv id (unitFM (CallKey tys) (dicts, call_fvs)) }
+         ud_calls = unitVarEnv id $ CIS id $ 
+                     Map.singleton (CallKey tys) (dicts, call_fvs) }
   where
     call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
     tys_fvs  = tyVarsOfTypes (catMaybes tys)
@@ -1381,8 +1476,8 @@ singleCall id tys dicts
 
 mkCallUDs :: Id -> [CoreExpr] -> UsageDetails
 mkCallUDs f args 
-  | not (isLocalId f)  -- Imported from elsewhere
-  || null theta                -- Not overloaded
+  | not (want_calls_for f)  -- Imported from elsewhere
+  || null theta                    -- Not overloaded
   || not (all isClassPred theta)       
        -- Only specialise if all overloading is on class params. 
        -- In ptic, with implicit params, the type args
@@ -1409,6 +1504,8 @@ mkCallUDs f args
     mk_spec_ty tyvar ty 
        | tyvar `elemVarSet` constrained_tyvars = Just ty
        | otherwise                             = Nothing
+
+    want_calls_for f = isLocalId f || isInlinablePragma (idInlinePragma f)
 \end{code}
 
 Note [Interesting dictionary arguments]
@@ -1539,7 +1636,7 @@ callsForMe fn (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
     uds_without_me = MkUD { ud_binds = orig_dbs, ud_calls = delVarEnv orig_calls fn }
     calls_for_me = case lookupVarEnv orig_calls fn of
                        Nothing -> []
-                       Just cs -> filter_dfuns (fmToList cs)
+                       Just (CIS _ calls) -> filter_dfuns (Map.toList calls)
 
     dep_set = foldlBag go (unitVarSet fn) orig_dbs
     go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set
@@ -1576,7 +1673,8 @@ deleteCallsMentioning bs calls
   = mapVarEnv filter_calls calls
   where
     filter_calls :: CallInfoSet -> CallInfoSet
-    filter_calls = filterFM (\_ (_, fvs) -> not (fvs `intersectsVarSet` bs))
+    filter_calls (CIS f calls) = CIS f (Map.filter keep_call calls)
+    keep_call (_, fvs) = not (fvs `intersectsVarSet` bs)
 
 deleteCallsFor :: [Id] -> CallDetails -> CallDetails
 -- Remove calls *for* bs
@@ -1593,8 +1691,9 @@ deleteCallsFor bs calls = delVarEnvList calls bs
 \begin{code}
 type SpecM a = UniqSM a
 
-initSM :: UniqSupply -> SpecM a -> a
-initSM   = initUs_
+runSpecM:: SpecM a -> CoreM a
+runSpecM spec = do { us <- getUniqueSupplyM
+                   ; return (initUs_ us spec) }
 
 mapAndCombineSM :: (a -> SpecM (b, UsageDetails)) -> [a] -> SpecM ([b], UsageDetails)
 mapAndCombineSM _ []     = return ([], emptyUDs)