[project @ 2002-04-05 23:24:25 by sof]
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index 6fbc5b9..16d3748 100644 (file)
@@ -9,41 +9,39 @@ module Specialise ( specProgram ) where
 #include "HsVersions.h"
 
 import CmdLineOpts     ( DynFlags, DynFlag(..) )
-import Id              ( Id, idName, idType, mkUserLocal,
-                         idSpecialisation, modifyIdInfo
-                       )
-import IdInfo          ( zapSpecPragInfo )
-import VarSet
-import VarEnv
-
-import Type            ( Type, mkTyVarTy, splitSigmaTy, 
+import Id              ( Id, idName, idType, mkUserLocal, idSpecialisation, isDataConWrapId )
+import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
                          tyVarsOfTypes, tyVarsOfTheta, 
-                         mkForAllTys 
+                         mkForAllTys, tcCmpType
                        )
-import Subst           ( Subst, mkSubst, substTy, mkSubst, substBndrs, extendSubstList, mkInScopeSet,
-                         substId, substAndCloneId, substAndCloneIds, lookupIdSubst, substInScope
+import Subst           ( Subst, mkSubst, substTy, mkSubst, extendSubstList, mkInScopeSet,
+                         simplBndr, simplBndrs, 
+                         substAndCloneId, substAndCloneIds, substAndCloneRecIds,
+                         lookupIdSubst, substInScope
                        ) 
+import Var             ( zapSpecPragmaId )
 import VarSet
 import VarEnv
 import CoreSyn
 import CoreUtils       ( applyTypeToArgs )
-import CoreUnfold      ( certainlyWillInline )
 import CoreFVs         ( exprFreeVars, exprsFreeVars )
+import CoreTidy                ( pprTidyIdRules )
 import CoreLint                ( showPass, endPass )
-import PprCore         ( pprCoreRules )
 import Rules           ( addIdSpecialisations, lookupRule )
 
 import UniqSupply      ( UniqSupply,
-                         UniqSM, initUs_, thenUs, thenUs_, returnUs, getUniqueUs, 
-                         withUs, mapUs
+                         UniqSM, initUs_, thenUs, returnUs, getUniqueUs, 
+                         getUs, mapUs
                        )
 import Name            ( nameOccName, mkSpecOcc, getSrcLoc )
 import FiniteMap
 import Maybes          ( catMaybes, maybeToBool )
 import ErrUtils                ( dumpIfSet_dyn )
+import BasicTypes      ( Activation( AlwaysActive ) )
 import Bag
 import List            ( partition )
-import Util            ( zipEqual, zipWithEqual )
+import Util            ( zipEqual, zipWithEqual, cmpList, lengthIs,
+                         equalLength, lengthAtLeast, notNull )
 import Outputable
 
 
@@ -588,7 +586,7 @@ specProgram dflags us binds
        endPass dflags "Specialise" Opt_D_dump_spec binds'
 
        dumpIfSet_dyn dflags Opt_D_dump_rules "Top-level specialisations"
-                 (vcat (map dump_specs (concat (map bindersOf binds'))))
+                 (vcat (map pprTidyIdRules (concat (map bindersOf binds'))))
 
        return binds'
   where
@@ -603,8 +601,6 @@ specProgram dflags us binds
     go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
                      specBind top_subst bind uds       `thenSM` \ (bind', uds') ->
                      returnSM (bind' ++ binds', uds')
-
-dump_specs var = pprCoreRules var (idSpecialisation var)
 \end{code}
 
 %************************************************************************
@@ -658,7 +654,7 @@ specExpr subst e@(Lam _ _)
     returnSM (mkLams bndrs' body'', filtered_uds)
   where
     (bndrs, body) = collectBinders e
-    (subst', bndrs') = substBndrs subst bndrs
+    (subst', bndrs') = simplBndrs subst bndrs
        -- More efficient to collect a group of binders together all at once
        -- and we don't want to split a lambda group with dumped bindings
 
@@ -667,7 +663,7 @@ specExpr subst (Case scrut case_bndr alts)
     mapAndCombineSM spec_alt alts      `thenSM` \ (alts', uds_alts) ->
     returnSM (Case scrut' case_bndr' alts', uds_scrut `plusUDs` uds_alts)
   where
-    (subst_alt, case_bndr') = substId subst case_bndr
+    (subst_alt, case_bndr') = simplBndr subst case_bndr
        -- No need to clone case binder; it can't float like a let(rec)
 
     spec_alt (con, args, rhs)
@@ -677,7 +673,7 @@ specExpr subst (Case scrut case_bndr alts)
          in
          returnSM ((con, args', rhs''), uds')
        where
-         (subst_rhs, args') = substBndrs subst_alt args
+         (subst_rhs, args') = simplBndrs subst_alt args
 
 ---------------- Finally, let is the interesting case --------------------
 specExpr subst (Let bind body)
@@ -788,10 +784,17 @@ specDefn :: Subst                 -- Subst to use for RHS
 
 specDefn subst calls (fn, rhs)
        -- The first case is the interesting one
-  |  n_tyvars == length rhs_tyvars     -- Rhs of fn's defn has right number of big lambdas
-  && n_dicts  <= length rhs_bndrs      -- and enough dict args
-  && not (null calls_for_me)           -- And there are some calls to specialise
-  && not (certainlyWillInline fn)      -- And it's not small
+  |  rhs_tyvars `lengthIs` n_tyvars    -- Rhs of fn's defn has right number of big lambdas
+  && rhs_bndrs  `lengthAtLeast` n_dicts        -- and enough dict args
+  && notNull calls_for_me              -- And there are some calls to specialise
+  && not (isDataConWrapId fn)          -- And it's not a data con wrapper, which have
+                                       -- stupid overloading that simply discard the dictionary
+
+-- At one time I tried not specialising small functions
+-- but sometimes there are big functions marked INLINE
+-- that we'd like to specialise.  In particular, dictionary
+-- functions, which Marcin is keen to inline
+--  && not (certainlyWillInline fn)    -- And it's not small
                                        -- If it's small, it's better just to inline
                                        -- it than to construct lots of specialisations
   =   -- Specialise the body of the function
@@ -800,9 +803,9 @@ specDefn subst calls (fn, rhs)
       -- Make a specialised version for each call in calls_for_me
     mapSM spec_call calls_for_me               `thenSM` \ stuff ->
     let
-       (spec_defns, spec_uds, spec_env_stuff) = unzip3 stuff
+       (spec_defns, spec_uds, spec_rules) = unzip3 stuff
 
-       fn' = addIdSpecialisations zapped_fn spec_env_stuff
+       fn' = addIdSpecialisations zapped_fn spec_rules
     in
     returnSM ((fn',rhs'), 
              spec_defns, 
@@ -813,17 +816,21 @@ specDefn subst calls (fn, rhs)
     returnSM ((zapped_fn, rhs'), [], rhs_uds)
   
   where
-    zapped_fn           = modifyIdInfo zapSpecPragInfo fn
+    zapped_fn           = zapSpecPragmaId fn
        -- If the fn is a SpecPragmaId, make it discardable
        -- It's role as a holder for a call instance is o'er
        -- But it might be alive for some other reason by now.
 
     fn_type           = idType fn
-    (tyvars, theta, _) = splitSigmaTy fn_type
+    (tyvars, theta, _) = tcSplitSigmaTy fn_type
     n_tyvars          = length tyvars
     n_dicts           = length theta
 
-    (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs
+       -- It's important that we "see past" any INLINE pragma
+       -- else we'll fail to specialise an INLINE thing
+    (inline_me, rhs')              = dropInline rhs
+    (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs'
+
     rhs_dicts = take n_dicts rhs_ids
     rhs_bndrs = rhs_tyvars ++ rhs_dicts
     body      = mkLams (drop n_dicts rhs_ids) rhs_body
@@ -835,12 +842,12 @@ specDefn subst calls (fn, rhs)
 
     ----------------------------------------------------------
        -- Specialise to one particular call pattern
-    spec_call :: ([Maybe Type], ([DictExpr], VarSet))          -- Call instance
-              -> SpecM ((Id,CoreExpr),                         -- Specialised definition
-                       UsageDetails,                           -- Usage details from specialised body
-                       ([CoreBndr], [CoreExpr], CoreExpr))     -- Info for the Id's SpecEnv
-    spec_call (call_ts, (call_ds, call_fvs))
-      = ASSERT( length call_ts == n_tyvars && length call_ds == n_dicts )
+    spec_call :: (CallKey, ([DictExpr], VarSet))       -- Call instance
+              -> SpecM ((Id,CoreExpr),                 -- Specialised definition
+                       UsageDetails,                   -- Usage details from specialised body
+                       CoreRule)                       -- Info for the Id's SpecEnv
+    spec_call (CallKey call_ts, (call_ds, call_fvs))
+      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts )
                -- Calls are only recorded for properly-saturated applications
        
        -- Suppose f's defn is  f = /\ a b c d -> \ d1 d2 -> rhs        
@@ -880,21 +887,34 @@ specDefn subst calls (fn, rhs)
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
-           spec_env_rule = (poly_tyvars ++ rhs_dicts',
-                           inst_args, 
-                           mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
+           spec_env_rule = Rule (_PK_ ("SPEC " ++ showSDoc (ppr fn)))
+                               AlwaysActive
+                               (poly_tyvars ++ rhs_dicts')
+                               inst_args 
+                               (mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
           final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
+
+       -- NOTE: we don't add back in any INLINE pragma on the RHS, so even if
+       -- the original function said INLINE, the specialised copies won't.
+       -- The idea is that the point of inlining was precisely to specialise
+       -- the function at its call site, and that's not so important for the
+       -- specialised copies.   But it still smells like an ad hoc decision.
+
        in
-        returnSM ((spec_f, spec_rhs),
+        returnSM ((spec_f, spec_rhs),  
                  final_uds,
                  spec_env_rule)
 
       where
        my_zipEqual doc xs ys 
-        | length xs /= length ys = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
-        | otherwise              = zipEqual doc xs ys
+        | not (equalLength xs ys) = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
+        | otherwise               = zipEqual doc xs ys
+
+dropInline :: CoreExpr -> (Bool, CoreExpr) 
+dropInline (Note InlineMe rhs) = (True, rhs)
+dropInline rhs                = (False, rhs)
 \end{code}
 
 %************************************************************************
@@ -924,12 +944,13 @@ type DictExpr = CoreExpr
 emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM }
 
 type ProtoUsageDetails = ([DictBind],
-                         [(Id, [Maybe Type], ([DictExpr], VarSet))]
+                         [(Id, CallKey, ([DictExpr], VarSet))]
                         )
 
 ------------------------------------------------------------                   
 type CallDetails  = FiniteMap Id CallInfo
-type CallInfo     = FiniteMap [Maybe Type]                     -- Nothing => unconstrained type argument
+newtype CallKey   = CallKey [Maybe Type]                       -- Nothing => unconstrained type argument
+type CallInfo     = FiniteMap CallKey
                              ([DictExpr], VarSet)              -- Dict args and the vars of the whole
                                                                -- call (including tyvars)
                                                                -- [*not* include the main id itself, of course]
@@ -937,12 +958,25 @@ type CallInfo     = FiniteMap [Maybe Type]                        -- Nothing => unconstrained type ar
        -- The list of types and dictionaries is guaranteed to
        -- match the type of f
 
+-- Type isn't an instance of Ord, so that we can control which
+-- instance we use.  That's tiresome here.  Oh well
+instance Eq CallKey where
+  k1 == k2 = case k1 `compare` k2 of { EQ -> True; other -> False }
+
+instance Ord CallKey where
+  compare (CallKey k1) (CallKey k2) = cmpList cmp k1 k2
+               where
+                 cmp Nothing Nothing     = EQ
+                 cmp Nothing (Just t2)   = LT
+                 cmp (Just t1) Nothing   = GT
+                 cmp (Just t1) (Just t2) = tcCmpType t1 t2
+
 unionCalls :: CallDetails -> CallDetails -> CallDetails
 unionCalls c1 c2 = plusFM_C plusFM c1 c2
 
 singleCall :: Id -> [Maybe Type] -> [DictExpr] -> CallDetails
 singleCall id tys dicts 
-  = unitFM id (unitFM tys (dicts, call_fvs))
+  = unitFM id (unitFM (CallKey tys) (dicts, call_fvs))
   where
     call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
     tys_fvs  = tyVarsOfTypes (catMaybes tys)
@@ -964,14 +998,14 @@ listToCallDetails calls
 
 callDetailsToList calls = [ (id,tys,dicts)
                          | (id,fm) <- fmToList calls,
-                           (tys,dicts) <- fmToList fm
+                           (tys, dicts) <- fmToList fm
                          ]
 
 mkCallUDs subst f args 
   | null theta
-  || length spec_tys /= n_tyvars
-  || length dicts    /= n_dicts
-  || maybeToBool (lookupRule (substInScope subst) f args)
+  || not (spec_tys `lengthIs` n_tyvars)
+  || not ( dicts   `lengthIs` n_dicts)
+  || maybeToBool (lookupRule (\act -> True) (substInScope subst) f args)
        -- There's already a rule covering this call.  A typical case
        -- is where there's an explicit user-provided rule.  Then
        -- we don't want to create a specialised version 
@@ -983,7 +1017,7 @@ mkCallUDs subst f args
          calls      = singleCall f spec_tys dicts
     }
   where
-    (tyvars, theta, _) = splitSigmaTy (idType f)
+    (tyvars, theta, _) = tcSplitSigmaTy (idType f)
     constrained_tyvars = tyVarsOfTheta theta 
     n_tyvars          = length tyvars
     n_dicts           = length theta
@@ -1084,12 +1118,6 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs,
 %************************************************************************
 
 \begin{code}
-lookupId:: IdEnv Id -> Id -> Id
-lookupId env id = case lookupVarEnv env id of
-                       Nothing  -> id
-                       Just id' -> id'
-
-----------------------------------------
 type SpecM a = UniqSM a
 
 thenSM    = thenUs
@@ -1107,25 +1135,22 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 -- Clone the binders of the bind; return new bind with the cloned binders
 -- Return the substitution to use for RHSs, and the one to use for the body
 cloneBindSM subst (NonRec bndr rhs)
-  = withUs     $ \ us ->
+  = getUs      `thenUs` \ us ->
     let
-       (subst', us', bndr') = substAndCloneId subst us bndr
+       (subst', bndr') = substAndCloneId subst us bndr
     in
-    ((subst, subst', NonRec bndr' rhs), us')
+    returnUs (subst, subst', NonRec bndr' rhs)
 
 cloneBindSM subst (Rec pairs)
-  = withUs     $ \ us ->
+  = getUs      `thenUs` \ us ->
     let
-       (subst', us', bndrs') = substAndCloneIds subst us (map fst pairs)
+       (subst', bndrs') = substAndCloneRecIds subst us (map fst pairs)
     in
-    ((subst', subst', Rec (bndrs' `zip` map snd pairs)), us')
+    returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
 
 cloneBinders subst bndrs
-  = withUs     $ \ us -> 
-    let
-       (subst', us', bndrs') = substAndCloneIds subst us bndrs
-    in
-    ((subst', bndrs'), us')
+  = getUs      `thenUs` \ us ->
+    returnUs (substAndCloneIds subst us bndrs)
 
 newIdSM old_id new_ty
   = getUniqSM          `thenSM` \ uniq ->