Make -fliberate-case work for GADTs
[ghc-hetmet.git] / ghc / compiler / specialise / Specialise.lhs
index 8bacb9e..0e66b0b 100644 (file)
@@ -8,32 +8,31 @@ module Specialise ( specProgram ) where
 
 #include "HsVersions.h"
 
-import CmdLineOpts     ( DynFlags, DynFlag(..) )
+import DynFlags        ( DynFlags, DynFlag(..) )
 import Id              ( Id, idName, idType, mkUserLocal ) 
 import TcType          ( Type, mkTyVarTy, tcSplitSigmaTy, 
                          tyVarsOfTypes, tyVarsOfTheta, isClassPred,
-                         mkForAllTys, tcCmpType
+                         tcCmpType, isUnLiftedType
                        )
-import Subst           ( Subst, mkSubst, substTy, mkSubst, extendSubstList, mkInScopeSet,
-                         simplBndr, simplBndrs, 
-                         substAndCloneId, substAndCloneIds, substAndCloneRecIds,
-                         lookupIdSubst, substInScope
+import CoreSubst       ( Subst, mkEmptySubst, extendTvSubstList, lookupIdSubst,
+                         substBndr, substBndrs, substTy, substInScope,
+                         cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
                        ) 
-import Var             ( zapSpecPragmaId )
 import VarSet
 import VarEnv
 import CoreSyn
-import CoreUtils       ( applyTypeToArgs )
-import CoreFVs         ( exprFreeVars, exprsFreeVars )
-import CoreTidy                ( pprTidyIdRules )
+import CoreUtils       ( applyTypeToArgs, mkPiTypes )
+import CoreFVs         ( exprFreeVars, exprsFreeVars, idRuleVars )
+import CoreTidy                ( tidyRules )
 import CoreLint                ( showPass, endPass )
-import Rules           ( addIdSpecialisations, lookupRule )
-
+import Rules           ( addIdSpecialisations, mkLocalRule, lookupRule, emptyRuleBase, rulesOfBinds )
+import PprCore         ( pprRules )
 import UniqSupply      ( UniqSupply,
                          UniqSM, initUs_, thenUs, returnUs, getUniqueUs, 
                          getUs, mapUs
                        )
 import Name            ( nameOccName, mkSpecOcc, getSrcLoc )
+import MkId            ( voidArgId, realWorldPrimId )
 import FiniteMap
 import Maybes          ( catMaybes, maybeToBool )
 import ErrUtils                ( dumpIfSet_dyn )
@@ -586,7 +585,7 @@ specProgram dflags us binds
        endPass dflags "Specialise" Opt_D_dump_spec binds'
 
        dumpIfSet_dyn dflags Opt_D_dump_rules "Top-level specialisations"
-                 (vcat (map pprTidyIdRules (concat (map bindersOf binds'))))
+                 (pprRules (tidyRules emptyTidyEnv (rulesOfBinds binds')))
 
        return binds'
   where
@@ -595,7 +594,7 @@ specProgram dflags us binds
        -- accidentally re-use a unique that's already in use
        -- Easiest thing is to do it all at once, as if all the top-level
        -- decls were mutually recursive
-    top_subst      = mkSubst (mkInScopeSet (mkVarSet (bindersOfBinds binds))) emptySubstEnv
+    top_subst      = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
 
     go []          = returnSM ([], emptyUDs)
     go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
@@ -611,9 +610,7 @@ specProgram dflags us binds
 
 \begin{code}
 specVar :: Subst -> Id -> CoreExpr
-specVar subst v = case lookupIdSubst subst v of
-                       DoneEx e   -> e
-                       DoneId v _ -> Var v
+specVar subst v = lookupIdSubst subst v
 
 specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
 -- We carry a substitution down:
@@ -654,16 +651,16 @@ specExpr subst e@(Lam _ _)
     returnSM (mkLams bndrs' body'', filtered_uds)
   where
     (bndrs, body) = collectBinders e
-    (subst', bndrs') = simplBndrs subst bndrs
+    (subst', bndrs') = substBndrs subst bndrs
        -- More efficient to collect a group of binders together all at once
        -- and we don't want to split a lambda group with dumped bindings
 
-specExpr subst (Case scrut case_bndr alts)
-  = specExpr subst scrut                       `thenSM` \ (scrut', uds_scrut) ->
+specExpr subst (Case scrut case_bndr ty alts)
+  = specExpr subst scrut               `thenSM` \ (scrut', uds_scrut) ->
     mapAndCombineSM spec_alt alts      `thenSM` \ (alts', uds_alts) ->
-    returnSM (Case scrut' case_bndr' alts', uds_scrut `plusUDs` uds_alts)
+    returnSM (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
   where
-    (subst_alt, case_bndr') = simplBndr subst case_bndr
+    (subst_alt, case_bndr') = substBndr subst case_bndr
        -- No need to clone case binder; it can't float like a let(rec)
 
     spec_alt (con, args, rhs)
@@ -673,7 +670,7 @@ specExpr subst (Case scrut case_bndr alts)
          in
          returnSM ((con, args', rhs''), uds')
        where
-         (subst_rhs, args') = simplBndrs subst_alt args
+         (subst_rhs, args') = substBndrs subst_alt args
 
 ---------------- Finally, let is the interesting case --------------------
 specExpr subst (Let bind body)
@@ -803,7 +800,7 @@ specDefn subst calls (fn, rhs)
     let
        (spec_defns, spec_uds, spec_rules) = unzip3 stuff
 
-       fn' = addIdSpecialisations zapped_fn spec_rules
+       fn' = addIdSpecialisations fn spec_rules
     in
     returnSM ((fn',rhs'), 
              spec_defns, 
@@ -811,23 +808,18 @@ specDefn subst calls (fn, rhs)
 
   | otherwise  -- No calls or RHS doesn't fit our preconceptions
   = specExpr subst rhs                 `thenSM` \ (rhs', rhs_uds) ->
-    returnSM ((zapped_fn, rhs'), [], rhs_uds)
+    returnSM ((fn, rhs'), [], rhs_uds)
   
   where
-    zapped_fn           = zapSpecPragmaId fn
-       -- If the fn is a SpecPragmaId, make it discardable
-       -- It's role as a holder for a call instance is o'er
-       -- But it might be alive for some other reason by now.
-
     fn_type           = idType fn
     (tyvars, theta, _) = tcSplitSigmaTy fn_type
     n_tyvars          = length tyvars
     n_dicts           = length theta
 
+    (rhs_tyvars, rhs_ids, rhs_body) 
+       = collectTyAndValBinders (dropInline rhs)
        -- It's important that we "see past" any INLINE pragma
        -- else we'll fail to specialise an INLINE thing
-    (inline_me, rhs')              = dropInline rhs
-    (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs'
 
     rhs_dicts = take n_dicts rhs_ids
     rhs_bndrs = rhs_tyvars ++ rhs_dicts
@@ -871,25 +863,30 @@ specDefn subst calls (fn, rhs)
                       where
                         mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
                         mk_ty_arg rhs_tyvar (Just ty) = Type ty
-          rhs_subst  = extendSubstList subst spec_tyvars [DoneTy ty | Just ty <- call_ts]
+          rhs_subst  = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts])
        in
        cloneBinders rhs_subst rhs_dicts                `thenSM` \ (rhs_subst', rhs_dicts') ->
        let
           inst_args = ty_args ++ map Var rhs_dicts'
 
                -- Figure out the type of the specialised function
-          spec_id_ty = mkForAllTys poly_tyvars (applyTypeToArgs rhs fn_type inst_args)
+          body_ty = applyTypeToArgs rhs fn_type inst_args
+          (lam_args, app_args)                 -- Add a dummy argument if body_ty is unlifted
+               | isUnLiftedType body_ty        -- C.f. WwLib.mkWorkerArgs
+               = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
+               | otherwise = (poly_tyvars, poly_tyvars)
+          spec_id_ty = mkPiTypes lam_args body_ty
        in
        newIdSM fn spec_id_ty                           `thenSM` \ spec_f ->
-       specExpr rhs_subst' (mkLams poly_tyvars body)   `thenSM` \ (spec_rhs, rhs_uds) ->       
+       specExpr rhs_subst' (mkLams lam_args body)      `thenSM` \ (spec_rhs, rhs_uds) ->       
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
-           spec_env_rule = Rule (mkFastString ("SPEC " ++ showSDoc (ppr fn)))
-                               AlwaysActive
+           spec_env_rule = mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr fn)))
+                               AlwaysActive (idName fn)
                                (poly_tyvars ++ rhs_dicts')
                                inst_args 
-                               (mkTyApps (Var spec_f) (map mkTyVarTy poly_tyvars))
+                               (mkVarApps (Var spec_f) app_args)
 
                -- Add the { d1' = dx1; d2' = dx2 } usage stuff
           final_uds = foldr addDictBind rhs_uds (my_zipEqual "spec_call" rhs_dicts' call_ds)
@@ -910,9 +907,9 @@ specDefn subst calls (fn, rhs)
         | not (equalLength xs ys) = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
         | otherwise               = zipEqual doc xs ys
 
-dropInline :: CoreExpr -> (Bool, CoreExpr) 
-dropInline (Note InlineMe rhs) = (True, rhs)
-dropInline rhs                = (False, rhs)
+dropInline :: CoreExpr -> CoreExpr
+dropInline (Note InlineMe rhs) = rhs
+dropInline rhs                = rhs
 \end{code}
 
 %************************************************************************
@@ -1004,10 +1001,10 @@ mkCallUDs subst f args
   || not (all isClassPred theta)       
        -- Only specialise if all overloading is on class params. 
        -- In ptic, with implicit params, the type args
-       -- *don't* say what the value of the implicit param is!
+       --  *don't* say what the value of the implicit param is!
   || not (spec_tys `lengthIs` n_tyvars)
   || not ( dicts   `lengthIs` n_dicts)
-  || maybeToBool (lookupRule (\act -> True) (substInScope subst) f args)
+  || maybeToBool (lookupRule (\act -> True) (substInScope subst) emptyRuleBase f args)
        -- There's already a rule covering this call.  A typical case
        -- is where there's an explicit user-provided rule.  Then
        -- we don't want to create a specialised version 
@@ -1047,11 +1044,16 @@ zapCalls ids uds = uds {calls = delListFromFM (calls uds) ids}
 
 mkDB bind = (bind, bind_fvs bind)
 
-bind_fvs (NonRec bndr rhs) = exprFreeVars rhs
+bind_fvs (NonRec bndr rhs) = pair_fvs (bndr,rhs)
 bind_fvs (Rec prs)        = foldl delVarSet rhs_fvs bndrs
                           where
                             bndrs = map fst prs
-                            rhs_fvs = unionVarSets [exprFreeVars rhs | (bndr,rhs) <- prs]
+                            rhs_fvs = unionVarSets (map pair_fvs prs)
+
+pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idRuleVars bndr
+       -- Don't forget variables mentioned in the
+       -- rules of the bndr.  C.f. OccAnal.addRuleUsage
+
 
 addDictBind (dict,rhs) uds = uds { dict_binds = mkDB (NonRec dict rhs) `consBag` dict_binds uds }
 
@@ -1138,20 +1140,20 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 cloneBindSM subst (NonRec bndr rhs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndr') = substAndCloneId subst us bndr
+       (subst', bndr') = cloneIdBndr subst us bndr
     in
     returnUs (subst, subst', NonRec bndr' rhs)
 
 cloneBindSM subst (Rec pairs)
   = getUs      `thenUs` \ us ->
     let
-       (subst', bndrs') = substAndCloneRecIds subst us (map fst pairs)
+       (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
     in
     returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
 
 cloneBinders subst bndrs
   = getUs      `thenUs` \ us ->
-    returnUs (substAndCloneIds subst us bndrs)
+    returnUs (cloneIdBndrs subst us bndrs)
 
 newIdSM old_id new_ty
   = getUniqSM          `thenSM` \ uniq ->