Trim redundant import
[ghc-hetmet.git] / compiler / specialise / Specialise.lhs
index a5cffb1..4d8efdd 100644 (file)
@@ -14,25 +14,27 @@ module Specialise ( specProgram ) where
 
 #include "HsVersions.h"
 
-import DynFlags        ( DynFlags, DynFlag(..) )
 import Id              ( Id, idName, idType, mkUserLocal, idCoreRules,
-                         idInlinePragma, setInlinePragma ) 
+                         idInlinePragma, setInlinePragma, setIdUnfolding,
+                         isLocalId ) 
 import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
                          tyVarsOfTypes, tyVarsOfTheta, isClassPred,
                          tcCmpType, isUnLiftedType
                        )
 import CoreSubst       ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
                          substBndr, substBndrs, substTy, substInScope,
-                         cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
+                         cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs,
+                         extendIdSubst
                        ) 
+import CoreUnfold      ( mkUnfolding )
 import SimplUtils      ( interestingArg )
+import Var             ( DictId )
 import VarSet
 import VarEnv
 import CoreSyn
 import Rules
-import CoreUtils       ( applyTypeToArgs, mkPiTypes )
+import CoreUtils       ( exprIsTrivial, applyTypeToArgs, mkPiTypes )
 import CoreFVs         ( exprFreeVars, exprsFreeVars, idFreeVars )
-import CoreLint                ( showPass, endPass )
 import UniqSupply      ( UniqSupply,
                          UniqSM, initUs_,
                          MonadUnique(..)
@@ -41,7 +43,6 @@ import Name
 import MkId            ( voidArgId, realWorldPrimId )
 import FiniteMap
 import Maybes          ( catMaybes, isJust )
-import ErrUtils                ( dumpIfSet_dyn )
 import Bag
 import Util
 import Outputable
@@ -574,20 +575,9 @@ Hence, the invariant is this:
 %************************************************************************
 
 \begin{code}
-specProgram :: DynFlags -> UniqSupply -> [CoreBind] -> IO [CoreBind]
-specProgram dflags us binds = do
-   
-       showPass dflags "Specialise"
-
-       let binds' = initSM us (do (binds', uds') <- go binds
-                                  return (dumpAllDictBinds uds' binds'))
-
-       endPass dflags "Specialise" Opt_D_dump_spec binds'
-
-       dumpIfSet_dyn dflags Opt_D_dump_rules "Top-level specialisations"
-                     (pprRulesForUser (rulesOfBinds binds'))
-
-       return binds'
+specProgram :: UniqSupply -> [CoreBind] -> [CoreBind]
+specProgram us binds = initSM us (do (binds', uds') <- go binds
+                                    return (dumpAllDictBinds uds' binds'))
   where
        -- We need to start with a Subst that knows all the things
        -- that are in scope, so that the substitution engine doesn't
@@ -614,7 +604,7 @@ specVar subst v = lookupIdSubst subst v
 
 specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
 -- We carry a substitution down:
---     a) we must clone any binding that might flaot outwards,
+--     a) we must clone any binding that might float outwards,
 --        to avoid name clashes
 --     b) we carry a type substitution to use when analysing
 --        the RHS of specialised bindings (no type-let!)
@@ -772,8 +762,9 @@ specBindItself rhs_subst (Rec pairs) call_info
        ; if null spec_defns1 then   -- Common case: no specialisation
                    return (Rec (bndrs `zip` rhss'), rhs_uds)
         else do                     -- Specialisation occurred; do it again
-       { (bndrs2, spec_defns2, spec_uds2) <- specDefns rhs_subst 
-                                                              (calls spec_uds1) (bndrs1 `zip` rhss)
+       { (bndrs2, spec_defns2, spec_uds2) <- 
+                         -- pprTrace "specB" (ppr bndrs $$ ppr rhs_uds) $
+                         specDefns rhs_subst (calls spec_uds1) (bndrs1 `zip` rhss)
 
        ; let all_defns = spec_defns1 ++ spec_defns2 ++ zip bndrs2 rhss'
              
@@ -845,8 +836,8 @@ specDefn subst calls fn rhs
     (inline_rhs, rhs_inside) = dropInline rhs
     (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs_inside
 
-    rhs_dicts = take n_dicts rhs_ids
-    body      = mkLams (drop n_dicts rhs_ids) rhs_body
+    rhs_dict_ids = take n_dicts rhs_ids
+    body         = mkLams (drop n_dicts rhs_ids) rhs_body
                -- Glue back on the non-dict lambdas
 
     calls_for_me = case lookupFM calls fn of
@@ -877,7 +868,7 @@ specDefn subst calls fn rhs
         -- Supppose the call is for f [Just t1, Nothing, Just t3] [dx1, dx2]
 
        -- Construct the new binding
-       --      f1 = SUBST[a->t1,c->t3, d1->d1', d2->d2'] (/\ b d -> rhs)
+       --      f1 = SUBST[a->t1,c->t3, d1->d1', d2->d2'] (/\ b -> rhs)
        -- PLUS the usage-details
        --      { d1' = dx1; d2' = dx2 }
        -- where d1', d2' are cloned versions of d1,d2, with the type substitution
@@ -896,8 +887,12 @@ specDefn subst calls fn rhs
                ty_args       = mk_ty_args call_ts
                rhs_subst     = extendTvSubstList subst spec_tv_binds
 
-          ; (rhs_subst', rhs_dicts') <- cloneBinders rhs_subst rhs_dicts
-          ; let inst_args = ty_args ++ map Var rhs_dicts'
+          ; (rhs_subst1, inst_dict_ids) <- cloneDictBndrs rhs_subst rhs_dict_ids
+                         -- Clone rhs_dicts, including instantiating their types
+
+          ; let (rhs_subst2, dx_binds) = bindAuxiliaryDicts rhs_subst1 $
+                                         (my_zipEqual rhs_dict_ids inst_dict_ids call_ds)
+                inst_args = ty_args ++ map Var inst_dict_ids
 
           ; if already_covered inst_args then
                return Nothing
@@ -910,8 +905,8 @@ specDefn subst calls fn rhs
                   | otherwise = (poly_tyvars, poly_tyvars)
                 spec_id_ty = mkPiTypes lam_args body_ty
        
-           ; spec_f <- newIdSM fn spec_id_ty
-           ; (spec_rhs, rhs_uds) <- specExpr rhs_subst' (mkLams lam_args body)
+           ; spec_f <- newSpecIdSM fn spec_id_ty
+           ; (spec_rhs, rhs_uds) <- specExpr rhs_subst2 (mkLams lam_args body)
           ; let
                -- The rule to put in the function's specialisation is:
                --      forall b, d1',d2'.  f t1 b t3 d1' d2' = f1 b  
@@ -920,27 +915,52 @@ specDefn subst calls fn rhs
                                  rule_name
                                  inline_prag   -- Note [Auto-specialisation and RULES]
                                  (idName fn)
-                                 (poly_tyvars ++ rhs_dicts')
+                                 (poly_tyvars ++ inst_dict_ids)
                                  inst_args 
                                  (mkVarApps (Var spec_f) app_args)
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
-               final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
+               final_uds = foldr addDictBind rhs_uds dx_binds
 
                spec_pr | inline_rhs = (spec_f `setInlinePragma` inline_prag, Note InlineMe spec_rhs)
                        | otherwise  = (spec_f,                               spec_rhs)
 
           ; return (Just (spec_pr, final_uds, spec_env_rule)) } }
       where
-       my_zipEqual doc xs ys 
-        | debugIsOn && not (equalLength xs ys)
-             = pprPanic "my_zipEqual" (vcat 
-                                               [ ppr xs, ppr ys
-                                               , ppr fn <+> ppr call_ts
-                                               , ppr (idType fn), ppr theta
-                                               , ppr n_dicts, ppr rhs_dicts 
-                                               , ppr rhs])
-        | otherwise               = zipEqual doc xs ys
+       my_zipEqual xs ys zs
+        | debugIsOn && not (equalLength xs ys && equalLength ys zs)
+             = pprPanic "my_zipEqual" (vcat [ ppr xs, ppr ys
+                                           , ppr fn <+> ppr call_ts
+                                           , ppr (idType fn), ppr theta
+                                           , ppr n_dicts, ppr rhs_dict_ids 
+                                           , ppr rhs])
+        | otherwise = zip3 xs ys zs
+
+bindAuxiliaryDicts
+       :: Subst
+       -> [(DictId,DictId,CoreExpr)]   -- (orig_dict, inst_dict, dx)
+       -> (Subst,                      -- Substitute for all orig_dicts
+           [(DictId, CoreExpr)])       -- Auxiliary bindings
+-- Bind any dictionary arguments to fresh names, to preserve sharing
+-- Substitution already substitutes orig_dict -> inst_dict
+bindAuxiliaryDicts subst triples = go subst [] triples
+  where
+    go subst binds []    = (subst, binds)
+    go subst binds ((d, dx_id, dx) : pairs)
+      | exprIsTrivial dx = go (extendIdSubst subst d dx) binds pairs
+             -- No auxiliary binding necessary
+      | otherwise        = go subst_w_unf ((dx_id,dx) : binds) pairs
+      where
+        dx_id1 = dx_id `setIdUnfolding` mkUnfolding False dx
+       subst_w_unf = extendIdSubst subst d (Var dx_id1)
+                    -- Important!  We're going to substitute dx_id1 for d
+            -- and we want it to look "interesting", else we won't gather *any*
+            -- consequential calls. E.g.
+            --     f d = ...g d....
+            -- If we specialise f for a call (f (dfun dNumInt)), we'll get 
+            -- a consequent call (g d') with an auxiliary definition
+            --     d' = df dNumInt
+            -- We want that consequent call to look interesting
 \end{code}
 
 Note [Specialising a recursive group]
@@ -1030,7 +1050,7 @@ then its body must look like
 Reason: when specialising the body for a call (f ty dexp), we want to
 substitute dexp for d, and pick up specialised calls in the body of f.
 
-This doesn't always work.  One example I came across was htis:
+This doesn't always work.  One example I came across was this:
        newtype Gen a = MkGen{ unGen :: Int -> a }
 
        choose :: Eq a => a -> Gen a
@@ -1166,7 +1186,8 @@ singleCall id tys dicts
 
 mkCallUDs :: Id -> [CoreExpr] -> UsageDetails
 mkCallUDs f args 
-  | null theta
+  | not (isLocalId f)  -- Imported from elsewhere
+  || null theta                -- Not overloaded
   || not (all isClassPred theta)       
        -- Only specialise if all overloading is on class params. 
        -- In ptic, with implicit params, the type args
@@ -1175,10 +1196,12 @@ mkCallUDs f args
   || not ( dicts   `lengthIs` n_dicts)
   || not (any interestingArg dicts)    -- Note [Interesting dictionary arguments]
   -- See also Note [Specialisations already covered]
-  = emptyUDs   -- Not overloaded, or no specialisation wanted
+  = -- pprTrace "mkCallUDs: discarding" (vcat [ppr f, ppr args, ppr n_tyvars, ppr n_dicts, ppr (map interestingArg dicts)]) 
+    emptyUDs   -- Not overloaded, or no specialisation wanted
 
   | otherwise
-  = singleCall f spec_tys dicts
+  = -- pprTrace "mkCallUDs: keeping" (vcat [ppr f, ppr args, ppr n_tyvars, ppr n_dicts, ppr (map interestingArg dicts)]) 
+    singleCall f spec_tys dicts
   where
     (tyvars, theta, _) = tcSplitSigmaTy (idType f)
     constrained_tyvars = tyVarsOfTheta theta 
@@ -1327,19 +1350,20 @@ cloneBindSM subst (Rec pairs) = do
     let (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
     return (subst', subst', Rec (bndrs' `zip` map snd pairs))
 
-cloneBinders :: Subst -> [CoreBndr] -> SpecM (Subst, [CoreBndr])
-cloneBinders subst bndrs = do
-    us <- getUniqueSupplyM
-    return (cloneIdBndrs subst us bndrs)
-
-newIdSM :: Id -> Type -> SpecM Id
-newIdSM old_id new_ty = do
-    uniq <- getUniqueM
-    let
-        -- Give the new Id a similar occurrence name to the old one
-        name   = idName old_id
-        new_id = mkUserLocal (mkSpecOcc (nameOccName name)) uniq new_ty (getSrcSpan name)
-    return new_id
+cloneDictBndrs :: Subst -> [CoreBndr] -> SpecM (Subst, [CoreBndr])
+cloneDictBndrs subst bndrs 
+  = do { us <- getUniqueSupplyM
+       ; return (cloneIdBndrs subst us bndrs) }
+
+newSpecIdSM :: Id -> Type -> SpecM Id
+    -- Give the new Id a similar occurrence name to the old one
+newSpecIdSM old_id new_ty
+  = do { uniq <- getUniqueM
+       ; let 
+           name    = idName old_id
+           new_occ = mkSpecOcc (nameOccName name)
+           new_id  = mkUserLocal new_occ uniq new_ty (getSrcSpan name)
+        ; return new_id }
 \end{code}