Make -fliberate-case work for GADTs
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index 0428772..0e66b0b 100644 (file)
@@ -8,41 +8,41 @@ module Specialise ( specProgram ) where
 
 #include "HsVersions.h"
 
-import CmdLineOpts     ( DynFlags, DynFlag(..) )
-import Id              ( Id, idName, idType, mkUserLocal, idSpecialisation, isDataConWrapId )
+import DynFlags        ( DynFlags, DynFlag(..) )
+import Id              ( Id, idName, idType, mkUserLocal ) 
 import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
-                         tyVarsOfTypes, tyVarsOfTheta, 
-                         mkForAllTys, tcCmpType
+                         tyVarsOfTypes, tyVarsOfTheta, isClassPred,
+                         tcCmpType, isUnLiftedType
                        )
-import Subst           ( Subst, mkSubst, substTy, mkSubst, extendSubstList, mkInScopeSet,
-                         simplBndr, simplBndrs, 
-                         substAndCloneId, substAndCloneIds, substAndCloneRecIds,
-                         lookupIdSubst, substInScope
+import CoreSubst       ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
+                         substBndr, substBndrs, substTy, substInScope,
+                         cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
                        ) 
-import Var             ( zapSpecPragmaId )
 import VarSet
 import VarEnv
 import CoreSyn
-import CoreUtils       ( applyTypeToArgs )
-import CoreFVs         ( exprFreeVars, exprsFreeVars )
+import CoreUtils       ( applyTypeToArgs, mkPiTypes )
+import CoreFVs         ( exprFreeVars, exprsFreeVars, idRuleVars )
+import CoreTidy                ( tidyRules )
 import CoreLint                ( showPass, endPass )
-import PprCore         ( pprCoreRules )
-import Rules           ( addIdSpecialisations, lookupRule )
-
+import Rules           ( addIdSpecialisations, mkLocalRule, lookupRule, emptyRuleBase, rulesOfBinds )
+import PprCore         ( pprRules )
 import UniqSupply      ( UniqSupply,
                          UniqSM, initUs_, thenUs, returnUs, getUniqueUs, 
                          getUs, mapUs
                        )
 import Name            ( nameOccName, mkSpecOcc, getSrcLoc )
+import MkId            ( voidArgId, realWorldPrimId )
 import FiniteMap
 import Maybes          ( catMaybes, maybeToBool )
 import ErrUtils                ( dumpIfSet_dyn )
 import BasicTypes      ( Activation( AlwaysActive ) )
 import Bag
 import List            ( partition )
-import Util            ( zipEqual, zipWithEqual, cmpList )
+import Util            ( zipEqual, zipWithEqual, cmpList, lengthIs,
+                         equalLength, lengthAtLeast, notNull )
 import Outputable
-
+import FastString
 
 infixr 9 `thenSM`
 \end{code}
@@ -585,7 +585,7 @@ specProgram dflags us binds
        endPass dflags "Specialise" Opt_D_dump_spec binds'
 
        dumpIfSet_dyn dflags Opt_D_dump_rules "Top-level specialisations"
-                 (vcat (map dump_specs (concat (map bindersOf binds'))))
+                 (pprRules (tidyRules emptyTidyEnv (rulesOfBinds binds')))
 
        return binds'
   where
@@ -594,14 +594,12 @@ specProgram dflags us binds
        -- accidentally re-use a unique that's already in use
        -- Easiest thing is to do it all at once, as if all the top-level
        -- decls were mutually recursive
-    top_subst      = mkSubst (mkInScopeSet (mkVarSet (bindersOfBinds binds))) emptySubstEnv
+    top_subst      = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
 
     go []          = returnSM ([], emptyUDs)
     go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
                      specBind top_subst bind uds       `thenSM` \ (bind', uds') ->
                      returnSM (bind' ++ binds', uds')
-
-dump_specs var = pprCoreRules var (idSpecialisation var)
 \end{code}
 
 %************************************************************************
@@ -612,9 +610,7 @@ dump_specs var = pprCoreRules var (idSpecialisation var)
 
 \begin{code}
 specVar :: Subst -> Id -> CoreExpr
-specVar subst v = case lookupIdSubst subst v of
-                       DoneEx e   -> e
-                       DoneId v _ -> Var v
+specVar subst v = lookupIdSubst subst v
 
 specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
 -- We carry a substitution down:
@@ -655,16 +651,16 @@ specExpr subst e@(Lam _ _)
     returnSM (mkLams bndrs' body'', filtered_uds)
   where
     (bndrs, body) = collectBinders e
-    (subst', bndrs') = simplBndrs subst bndrs
+    (subst', bndrs') = substBndrs subst bndrs
        -- More efficient to collect a group of binders together all at once
        -- and we don't want to split a lambda group with dumped bindings
 
-specExpr subst (Case scrut case_bndr alts)
-  = specExpr subst scrut                       `thenSM` \ (scrut', uds_scrut) ->
+specExpr subst (Case scrut case_bndr ty alts)
+  = specExpr subst scrut               `thenSM` \ (scrut', uds_scrut) ->
     mapAndCombineSM spec_alt alts      `thenSM` \ (alts', uds_alts) ->
-    returnSM (Case scrut' case_bndr' alts', uds_scrut `plusUDs` uds_alts)
+    returnSM (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
   where
-    (subst_alt, case_bndr') = simplBndr subst case_bndr
+    (subst_alt, case_bndr') = substBndr subst case_bndr
        -- No need to clone case binder; it can't float like a let(rec)
 
     spec_alt (con, args, rhs)
@@ -674,7 +670,7 @@ specExpr subst (Case scrut case_bndr alts)
          in
          returnSM ((con, args', rhs''), uds')
        where
-         (subst_rhs, args') = simplBndrs subst_alt args
+         (subst_rhs, args') = substBndrs subst_alt args
 
 ---------------- Finally, let is the interesting case --------------------
 specExpr subst (Let bind body)
@@ -785,11 +781,9 @@ specDefn :: Subst                  -- Subst to use for RHS
 
 specDefn subst calls (fn, rhs)
        -- The first case is the interesting one
-  |  n_tyvars == length rhs_tyvars     -- Rhs of fn's defn has right number of big lambdas
-  && n_dicts  <= length rhs_bndrs      -- and enough dict args
-  && not (null calls_for_me)           -- And there are some calls to specialise
-  && not (isDataConWrapId fn)          -- And it's not a data con wrapper, which have
-                                       -- stupid overloading that simply discard the dictionary
+  |  rhs_tyvars `lengthIs` n_tyvars    -- Rhs of fn's defn has right number of big lambdas
+  && rhs_bndrs  `lengthAtLeast` n_dicts        -- and enough dict args
+  && notNull calls_for_me              -- And there are some calls to specialise
 
 -- At one time I tried not specialising small functions
 -- but sometimes there are big functions marked INLINE
@@ -806,7 +800,7 @@ specDefn subst calls (fn, rhs)
     let
        (spec_defns, spec_uds, spec_rules) = unzip3 stuff
 
-       fn' = addIdSpecialisations zapped_fn spec_rules
+       fn' = addIdSpecialisations fn spec_rules
     in
     returnSM ((fn',rhs'), 
              spec_defns, 
@@ -814,23 +808,18 @@ specDefn subst calls (fn, rhs)
 
   | otherwise  -- No calls or RHS doesn't fit our preconceptions
   = specExpr subst rhs                 `thenSM` \ (rhs', rhs_uds) ->
-    returnSM ((zapped_fn, rhs'), [], rhs_uds)
+    returnSM ((fn, rhs'), [], rhs_uds)
   
   where
-    zapped_fn           = zapSpecPragmaId fn
-       -- If the fn is a SpecPragmaId, make it discardable
-       -- It's role as a holder for a call instance is o'er
-       -- But it might be alive for some other reason by now.
-
     fn_type           = idType fn
     (tyvars, theta, _) = tcSplitSigmaTy fn_type
     n_tyvars          = length tyvars
     n_dicts           = length theta
 
+    (rhs_tyvars, rhs_ids, rhs_body) 
+       = collectTyAndValBinders (dropInline rhs)
        -- It's important that we "see past" any INLINE pragma
        -- else we'll fail to specialise an INLINE thing
-    (inline_me, rhs')              = dropInline rhs
-    (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs'
 
     rhs_dicts = take n_dicts rhs_ids
     rhs_bndrs = rhs_tyvars ++ rhs_dicts
@@ -848,7 +837,7 @@ specDefn subst calls (fn, rhs)
                        UsageDetails,                   -- Usage details from specialised body
                        CoreRule)                       -- Info for the Id's SpecEnv
     spec_call (CallKey call_ts, (call_ds, call_fvs))
-      = ASSERT( length call_ts == n_tyvars && length call_ds == n_dicts )
+      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts )
                -- Calls are only recorded for properly-saturated applications
        
        -- Suppose f's defn is  f = /\ a b c d -> \ d1 d2 -> rhs        
@@ -874,25 +863,30 @@ specDefn subst calls (fn, rhs)
                       where
                         mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
                         mk_ty_arg rhs_tyvar (Just ty) = Type ty
-          rhs_subst  = extendSubstList subst spec_tyvars [DoneTy ty | Just ty <- call_ts]
+          rhs_subst  = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts])
        in
        cloneBinders rhs_subst rhs_dicts                `thenSM` \ (rhs_subst', rhs_dicts') ->
        let
           inst_args = ty_args ++ map Var rhs_dicts'
 
                -- Figure out the type of the specialised function
-          spec_id_ty = mkForAllTys poly_tyvars (applyTypeToArgs rhs fn_type inst_args)
+          body_ty = applyTypeToArgs rhs fn_type inst_args
+          (lam_args, app_args)                 -- Add a dummy argument if body_ty is unlifted
+               | isUnLiftedType body_ty        -- C.f. WwLib.mkWorkerArgs
+               = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
+               | otherwise = (poly_tyvars, poly_tyvars)
+          spec_id_ty = mkPiTypes lam_args body_ty
        in
        newIdSM fn spec_id_ty                           `thenSM` \ spec_f ->
-       specExpr rhs_subst' (mkLams poly_tyvars body)   `thenSM` \ (spec_rhs, rhs_uds) ->       
+       specExpr rhs_subst' (mkLams lam_args body)      `thenSM` \ (spec_rhs, rhs_uds) ->       
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
-           spec_env_rule = Rule (_PK_ ("SPEC " ++ showSDoc (ppr fn)))
-                               AlwaysActive
+           spec_env_rule = mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr fn)))
+                               AlwaysActive (idName fn)
                                (poly_tyvars ++ rhs_dicts')
                                inst_args 
-                               (mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
+                               (mkVarApps (Var spec_f) app_args)
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
           final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
@@ -910,12 +904,12 @@ specDefn subst calls (fn, rhs)
 
       where
        my_zipEqual doc xs ys 
-        | length xs /= length ys = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
-        | otherwise              = zipEqual doc xs ys
+        | not (equalLength xs ys) = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
+        | otherwise               = zipEqual doc xs ys
 
-dropInline :: CoreExpr -> (Bool, CoreExpr) 
-dropInline (Note InlineMe rhs) = (True, rhs)
-dropInline rhs                = (False, rhs)
+dropInline :: CoreExpr -> CoreExpr
+dropInline (Note InlineMe rhs) = rhs
+dropInline rhs                = rhs
 \end{code}
 
 %************************************************************************
@@ -1004,9 +998,13 @@ callDetailsToList calls = [ (id,tys,dicts)
 
 mkCallUDs subst f args 
   | null theta
-  || length spec_tys /= n_tyvars
-  || length dicts    /= n_dicts
-  || maybeToBool (lookupRule (\act -> True) (substInScope subst) f args)
+  || not (all isClassPred theta)       
+       -- Only specialise if all overloading is on class params. 
+       -- In ptic, with implicit params, the type args
+       --  *don't* say what the value of the implicit param is!
+  || not (spec_tys `lengthIs` n_tyvars)
+  || not ( dicts   `lengthIs` n_dicts)
+  || maybeToBool (lookupRule (\act -> True) (substInScope subst) emptyRuleBase f args)
        -- There's already a rule covering this call.  A typical case
        -- is where there's an explicit user-provided rule.  Then
        -- we don't want to create a specialised version 
@@ -1026,10 +1024,9 @@ mkCallUDs subst f args
     spec_tys = [mk_spec_ty tv ty | (tv, Type ty) <- tyvars `zip` args]
     dicts    = [dict_expr | (_, dict_expr) <- theta `zip` (drop n_tyvars args)]
     
-    mk_spec_ty tyvar ty | tyvar `elemVarSet` constrained_tyvars
-                       = Just ty
-                       | otherwise
-                       = Nothing
+    mk_spec_ty tyvar ty 
+       | tyvar `elemVarSet` constrained_tyvars = Just ty
+       | otherwise                             = Nothing
 
 ------------------------------------------------------------                   
 plusUDs :: UsageDetails -> UsageDetails -> UsageDetails
@@ -1047,11 +1044,16 @@ zapCalls ids uds = uds {calls = delListFromFM (calls uds) ids}
 
 mkDB bind = (bind, bind_fvs bind)
 
-bind_fvs (NonRec bndr rhs) = exprFreeVars rhs
+bind_fvs (NonRec bndr rhs) = pair_fvs (bndr,rhs)
 bind_fvs (Rec prs)        = foldl delVarSet rhs_fvs bndrs
                           where
                             bndrs = map fst prs
-                            rhs_fvs = unionVarSets [exprFreeVars rhs | (bndr,rhs) <- prs]
+                            rhs_fvs = unionVarSets (map pair_fvs prs)
+
+pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idRuleVars bndr
+       -- Don't forget variables mentioned in the
+       -- rules of the bndr.  C.f. OccAnal.addRuleUsage
+
 
 addDictBind (dict,rhs) uds = uds { dict_binds = mkDB (NonRec dict rhs) `consBag` dict_binds uds }
 
@@ -1105,7 +1107,7 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs,
     dump_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs)
        | dump_idset `intersectsVarSet` fvs     -- Dump it
        = (free_dbs, dump_dbs `snocBag` db,
-          dump_idset `unionVarSet` mkVarSet (bindersOf bind))
+          extendVarSetList dump_idset (bindersOf bind))
 
        | otherwise     -- Don't dump it
        = (free_dbs `snocBag` db, dump_dbs, dump_idset)
@@ -1138,20 +1140,20 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 cloneBindSM subst (NonRec bndr rhs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndr') = substAndCloneId subst us bndr
+       (subst', bndr') = cloneIdBndr subst us bndr
     in
     returnUs (subst, subst', NonRec bndr' rhs)
 
 cloneBindSM subst (Rec pairs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndrs') = substAndCloneRecIds subst us (map fst pairs)
+       (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
     in
     returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
 
 cloneBinders subst bndrs
   = getUs      `thenUs` \ us ->
-    returnUs (substAndCloneIds subst us bndrs)
+    returnUs (cloneIdBndrs subst us bndrs)
 
 newIdSM old_id new_ty
   = getUniqSM          `thenSM` \ uniq ->