[project @ 2001-02-28 11:48:34 by simonpj]
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index 81799e5..bdef352 100644 (file)
@@ -8,45 +8,42 @@ module Specialise ( specProgram ) where
 
 #include "HsVersions.h"
 
-import CmdLineOpts     ( opt_D_verbose_core2core, opt_D_dump_spec, opt_D_dump_rules )
-import Id              ( Id, idName, idType, mkTemplateLocals, mkUserLocal,
-                         idSpecialisation, setIdNoDiscard, isExportedId,
-                         modifyIdInfo, idUnfolding
+import CmdLineOpts     ( DynFlags, DynFlag(..) )
+import Id              ( Id, idName, idType, mkUserLocal,
+                         idSpecialisation, modifyIdInfo
                        )
 import IdInfo          ( zapSpecPragInfo )
 import VarSet
 import VarEnv
 
-import Type            ( Type, mkTyVarTy, splitSigmaTy, splitFunTysN,
-                         tyVarsOfType, tyVarsOfTypes, tyVarsOfTheta, applyTys,
-                         mkForAllTys, boxedTypeKind
+import Type            ( Type, mkTyVarTy, splitSigmaTy, 
+                         tyVarsOfTypes, tyVarsOfTheta, 
+                         mkForAllTys 
                        )
-import PprType          ( {- instance Outputable Type -} )
-import Subst           ( Subst, mkSubst, substTy, emptySubst, substBndrs, extendSubstList,
-                         substId, substAndCloneId, substAndCloneIds, lookupIdSubst
+import Subst           ( Subst, mkSubst, substTy, mkSubst, substBndrs, extendSubstList, mkInScopeSet,
+                         substId, substAndCloneId, substAndCloneIds, lookupIdSubst, substInScope
                        ) 
-import Var             ( TyVar, mkSysTyVar, setVarUnique )
 import VarSet
 import VarEnv
 import CoreSyn
 import CoreUtils       ( applyTypeToArgs )
 import CoreUnfold      ( certainlyWillInline )
 import CoreFVs         ( exprFreeVars, exprsFreeVars )
-import CoreLint                ( beginPass, endPass )
+import CoreLint                ( showPass, endPass )
 import PprCore         ( pprCoreRules )
-import Rules           ( addIdSpecialisations )
+import Rules           ( addIdSpecialisations, lookupRule )
 
 import UniqSupply      ( UniqSupply,
-                         UniqSM, initUs_, thenUs, thenUs_, returnUs, getUniqueUs, 
-                         getUs, setUs, uniqFromSupply, splitUniqSupply, mapUs
+                         UniqSM, initUs_, thenUs, thenUs, returnUs, getUniqueUs, 
+                         withUs, mapUs
                        )
 import Name            ( nameOccName, mkSpecOcc, getSrcLoc )
 import FiniteMap
-import Maybes          ( MaybeErr(..), catMaybes )
-import ErrUtils                ( dumpIfSet )
+import Maybes          ( catMaybes, maybeToBool )
+import ErrUtils                ( dumpIfSet_dyn )
 import Bag
 import List            ( partition )
-import Util            ( zipEqual, zipWithEqual, mapAccumL )
+import Util            ( zipEqual, zipWithEqual )
 import Outputable
 
 
@@ -580,24 +577,31 @@ Hence, the invariant is this:
 %************************************************************************
 
 \begin{code}
-specProgram :: UniqSupply -> [CoreBind] -> IO [CoreBind]
-specProgram us binds
+specProgram :: DynFlags -> UniqSupply -> [CoreBind] -> IO [CoreBind]
+specProgram dflags us binds
   = do
-       beginPass "Specialise"
+       showPass dflags "Specialise"
 
        let binds' = initSM us (go binds        `thenSM` \ (binds', uds') ->
                                returnSM (dumpAllDictBinds uds' binds'))
 
-       endPass "Specialise" (opt_D_dump_spec || opt_D_verbose_core2core) binds'
+       endPass dflags "Specialise" Opt_D_dump_spec binds'
 
-       dumpIfSet opt_D_dump_rules "Top-level specialisations"
+       dumpIfSet_dyn dflags Opt_D_dump_rules "Top-level specialisations"
                  (vcat (map dump_specs (concat (map bindersOf binds'))))
 
        return binds'
   where
+       -- We need to start with a Subst that knows all the things
+       -- that are in scope, so that the substitution engine doesn't
+       -- accidentally re-use a unique that's already in use
+       -- Easiest thing is to do it all at once, as if all the top-level
+       -- decls were mutually recursive
+    top_subst      = mkSubst (mkInScopeSet (mkVarSet (bindersOfBinds binds))) emptySubstEnv
+
     go []          = returnSM ([], emptyUDs)
     go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
-                     specBind emptySubst bind uds      `thenSM` \ (bind', uds') ->
+                     specBind top_subst bind uds       `thenSM` \ (bind', uds') ->
                      returnSM (bind' ++ binds', uds')
 
 dump_specs var = pprCoreRules var (idSpecialisation var)
@@ -641,7 +645,7 @@ specExpr subst expr@(App fun arg)
                            returnSM (App fun' arg', uds_arg `plusUDs` uds_app)
 
     go (Var f)       args = case specVar subst f of
-                               Var f' -> returnSM (Var f', mkCallUDs f' args)
+                               Var f' -> returnSM (Var f', mkCallUDs subst f' args)
                                e'     -> returnSM (e', emptyUDs)       -- I don't expect this!
     go other        args = specExpr subst other
 
@@ -664,6 +668,7 @@ specExpr subst (Case scrut case_bndr alts)
     returnSM (Case scrut' case_bndr' alts', uds_scrut `plusUDs` uds_alts)
   where
     (subst_alt, case_bndr') = substId subst case_bndr
+       -- No need to clone case binder; it can't float like a let(rec)
 
     spec_alt (con, args, rhs)
        = specExpr subst_rhs rhs                `thenSM` \ (rhs', uds) ->
@@ -795,9 +800,9 @@ specDefn subst calls (fn, rhs)
       -- Make a specialised version for each call in calls_for_me
     mapSM spec_call calls_for_me               `thenSM` \ stuff ->
     let
-       (spec_defns, spec_uds, spec_env_stuff) = unzip3 stuff
+       (spec_defns, spec_uds, spec_rules) = unzip3 stuff
 
-       fn' = addIdSpecialisations zapped_fn spec_env_stuff
+       fn' = addIdSpecialisations zapped_fn spec_rules
     in
     returnSM ((fn',rhs'), 
              spec_defns, 
@@ -813,10 +818,10 @@ specDefn subst calls (fn, rhs)
        -- It's role as a holder for a call instance is o'er
        -- But it might be alive for some other reason by now.
 
-    fn_type             = idType fn
-    (tyvars, theta, tau) = splitSigmaTy fn_type
-    n_tyvars            = length tyvars
-    n_dicts             = length theta
+    fn_type           = idType fn
+    (tyvars, theta, _) = splitSigmaTy fn_type
+    n_tyvars          = length tyvars
+    n_dicts           = length theta
 
     (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs
     rhs_dicts = take n_dicts rhs_ids
@@ -830,10 +835,10 @@ specDefn subst calls (fn, rhs)
 
     ----------------------------------------------------------
        -- Specialise to one particular call pattern
-    spec_call :: ([Maybe Type], ([DictExpr], VarSet))          -- Call instance
-              -> SpecM ((Id,CoreExpr),                         -- Specialised definition
-                       UsageDetails,                           -- Usage details from specialised body
-                       ([CoreBndr], [CoreExpr], CoreExpr))     -- Info for the Id's SpecEnv
+    spec_call :: ([Maybe Type], ([DictExpr], VarSet))  -- Call instance
+              -> SpecM ((Id,CoreExpr),                 -- Specialised definition
+                       UsageDetails,                   -- Usage details from specialised body
+                       CoreRule)                       -- Info for the Id's SpecEnv
     spec_call (call_ts, (call_ds, call_fvs))
       = ASSERT( length call_ts == n_tyvars && length call_ds == n_dicts )
                -- Calls are only recorded for properly-saturated applications
@@ -875,9 +880,10 @@ specDefn subst calls (fn, rhs)
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
-           spec_env_rule = (poly_tyvars ++ rhs_dicts',
-                           inst_args, 
-                           mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
+           spec_env_rule = Rule (_PK_ ("SPEC " ++ showSDoc (ppr fn)))
+                               (poly_tyvars ++ rhs_dicts')
+                               inst_args 
+                               (mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
           final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
@@ -935,8 +941,8 @@ type CallInfo     = FiniteMap [Maybe Type]                  -- Nothing => unconstrained type ar
 unionCalls :: CallDetails -> CallDetails -> CallDetails
 unionCalls c1 c2 = plusFM_C plusFM c1 c2
 
-singleCall :: (Id, [Maybe Type], [DictExpr]) -> CallDetails
-singleCall (id, tys, dicts) 
+singleCall :: Id -> [Maybe Type] -> [DictExpr] -> CallDetails
+singleCall id tys dicts 
   = unitFM id (unitFM tys (dicts, call_fvs))
   where
     call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
@@ -962,21 +968,26 @@ callDetailsToList calls = [ (id,tys,dicts)
                            (tys,dicts) <- fmToList fm
                          ]
 
-mkCallUDs f args 
+mkCallUDs subst f args 
   | null theta
   || length spec_tys /= n_tyvars
   || length dicts    /= n_dicts
-  = emptyUDs   -- Not overloaded
+  || maybeToBool (lookupRule (substInScope subst) f args)
+       -- There's already a rule covering this call.  A typical case
+       -- is where there's an explicit user-provided rule.  Then
+       -- we don't want to create a specialised version 
+       -- of the function that overlaps.
+  = emptyUDs   -- Not overloaded, or no specialisation wanted
 
   | otherwise
   = MkUD {dict_binds = emptyBag, 
-         calls      = singleCall (f, spec_tys, dicts)
+         calls      = singleCall f spec_tys dicts
     }
   where
-    (tyvars, theta, tau) = splitSigmaTy (idType f)
-    constrained_tyvars   = tyVarsOfTheta theta 
-    n_tyvars            = length tyvars
-    n_dicts             = length theta
+    (tyvars, theta, _) = splitSigmaTy (idType f)
+    constrained_tyvars = tyVarsOfTheta theta 
+    n_tyvars          = length tyvars
+    n_dicts           = length theta
 
     spec_tys = [mk_spec_ty tv ty | (tv, Type ty) <- tyvars `zip` args]
     dicts    = [dict_expr | (_, dict_expr) <- theta `zip` (drop n_tyvars args)]
@@ -1083,11 +1094,8 @@ lookupId env id = case lookupVarEnv env id of
 type SpecM a = UniqSM a
 
 thenSM    = thenUs
-thenSM_    = thenUs_
 returnSM  = returnUs
 getUniqSM = getUniqueUs
-getUniqSupplySM = getUs
-setUniqSupplySM = setUs
 mapSM     = mapUs
 initSM   = initUs_
 
@@ -1100,29 +1108,25 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 -- Clone the binders of the bind; return new bind with the cloned binders
 -- Return the substitution to use for RHSs, and the one to use for the body
 cloneBindSM subst (NonRec bndr rhs)
-  = getUs      `thenUs` \ us ->
+  = withUs     $ \ us ->
     let
        (subst', us', bndr') = substAndCloneId subst us bndr
     in
-    setUs us'  `thenUs_`
-    returnUs (subst, subst', NonRec bndr' rhs)
+    ((subst, subst', NonRec bndr' rhs), us')
 
 cloneBindSM subst (Rec pairs)
-  = getUs      `thenUs` \ us ->
+  = withUs     $ \ us ->
     let
        (subst', us', bndrs') = substAndCloneIds subst us (map fst pairs)
     in
-    setUs us'  `thenUs_`
-    returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
+    ((subst', subst', Rec (bndrs' `zip` map snd pairs)), us')
 
 cloneBinders subst bndrs
-  = getUs      `thenUs` \ us ->
+  = withUs     $ \ us -> 
     let
        (subst', us', bndrs') = substAndCloneIds subst us bndrs
     in
-    setUs us'  `thenUs_`
-    returnUs (subst', bndrs')
-
+    ((subst', bndrs'), us')
 
 newIdSM old_id new_ty
   = getUniqSM          `thenSM` \ uniq ->
@@ -1130,17 +1134,8 @@ newIdSM old_id new_ty
        -- Give the new Id a similar occurrence name to the old one
        name   = idName old_id
        new_id = mkUserLocal (mkSpecOcc (nameOccName name)) uniq new_ty (getSrcLoc name)
-
-       -- If the old Id was exported, make the new one non-discardable,
-       -- else we will discard it since it doesn't seem to be called.
-       new_id' | isExportedId old_id = setIdNoDiscard new_id
-               | otherwise           = new_id
     in
-    returnSM new_id'
-
-newTyVarSM
-  = getUniqSM          `thenSM` \ uniq ->
-    returnSM (mkSysTyVar uniq boxedTypeKind)
+    returnSM new_id
 \end{code}