Make -fliberate-case work for GADTs
[ghc-hetmet.git] / ghc / compiler / specialise / SpecConstr.lhs
index 45f9469..74944da 100644 (file)
@@ -12,31 +12,32 @@ module SpecConstr(
 
 import CoreSyn
 import CoreLint                ( showPass, endPass )
-import CoreUtils       ( exprType, eqExpr, mkPiTypes )
+import CoreUtils       ( exprType, tcEqExpr, mkPiTypes )
 import CoreFVs                 ( exprsFreeVars )
+import CoreSubst       ( Subst, mkSubst, substExpr )
+import CoreTidy                ( tidyRules )
+import PprCore         ( pprRules )
 import WwLib           ( mkWorkerArgs )
-import DataCon         ( dataConRepArity )
-import Type            ( tyConAppArgs )
-import PprCore         ( pprCoreRules )
-import Id              ( Id, idName, idType, idSpecialisation,
-                         isDataConId_maybe,
+import DataCon         ( dataConRepArity, isVanillaDataCon )
+import Type            ( tyConAppArgs, tyVarsOfTypes )
+import Unify           ( coreRefineTys )
+import Id              ( Id, idName, idType, isDataConWorkId_maybe, 
                          mkUserLocal, mkSysLocal )
 import Var             ( Var )
 import VarEnv
 import VarSet
 import Name            ( nameOccName, nameSrcLoc )
-import Rules           ( addIdSpecialisations )
+import Rules           ( addIdSpecialisations, mkLocalRule, rulesOfBinds )
 import OccName         ( mkSpecOcc )
 import ErrUtils                ( dumpIfSet_dyn )
-import CmdLineOpts     ( DynFlags, DynFlag(..) )
+import DynFlags                ( DynFlags, DynFlag(..) )
 import BasicTypes      ( Activation(..) )
-import Outputable
-
 import Maybes          ( orElse )
-import Util            ( mapAccumL, lengthAtLeast )
+import Util            ( mapAccumL, lengthAtLeast, notNull )
 import List            ( nubBy, partition )
 import UniqSupply
 import Outputable
+import FastString
 \end{code}
 
 -----------------------------------------------------
@@ -181,7 +182,7 @@ specConstrProgram dflags us binds
        endPass dflags "SpecConstr" Opt_D_dump_spec binds'
 
        dumpIfSet_dyn dflags Opt_D_dump_rules "Top-level specialisations"
-                 (vcat (map dump_specs (concat (map bindersOf binds'))))
+                 (pprRules (tidyRules emptyTidyEnv (rulesOfBinds binds')))
 
        return binds'
   where
@@ -189,8 +190,6 @@ specConstrProgram dflags us binds
     go env (bind:binds) = scBind env bind      `thenUs` \ (env', _, bind') ->
                          go env' binds         `thenUs` \ binds' ->
                          returnUs (bind' : binds')
-
-dump_specs var = pprCoreRules var (idSpecialisation var)
 \end{code}
 
 
@@ -207,10 +206,17 @@ data ScEnv = SCE { scope :: VarEnv HowBound,
                   cons  :: ConstrEnv
             }
 
-type ConstrEnv = IdEnv (AltCon, [CoreArg])
+type ConstrEnv = IdEnv ConValue
+data ConValue  = CV AltCon [CoreArg]
        -- Variables known to be bound to a constructor
        -- in a particular case alternative
 
+refineConstrEnv :: Subst -> ConstrEnv -> ConstrEnv
+-- The substitution is a type substitution only
+refineConstrEnv subst env = mapVarEnv refine_con_value env
+  where
+    refine_con_value (CV con args) = CV con (map (substExpr subst) args)
+
 emptyScEnv = SCE { scope = emptyVarEnv, cons = emptyVarEnv }
 
 data HowBound = RecFun         -- These are the recursive functions for which 
@@ -242,24 +248,47 @@ extendCaseBndrs :: ScEnv -> Id -> CoreExpr -> AltCon -> [Var] -> ScEnv
 extendCaseBndrs env case_bndr scrut DEFAULT alt_bndrs
   = extendBndrs env (case_bndr : alt_bndrs)
 
-extendCaseBndrs env case_bndr scrut con alt_bndrs
-  = case scrut of
+extendCaseBndrs env case_bndr scrut con@(LitAlt lit) alt_bndrs
+  = ASSERT( null alt_bndrs ) extendAlt env case_bndr scrut (CV con []) []
+
+extendCaseBndrs env case_bndr scrut con@(DataAlt data_con) alt_bndrs
+  | isVanillaDataCon data_con
+  = extendAlt env case_bndr scrut (CV con vanilla_args) alt_bndrs
+    
+  | otherwise  -- GADT
+  = extendAlt env1 case_bndr scrut (CV con gadt_args) alt_bndrs
+  where
+    vanilla_args = map Type (tyConAppArgs (idType case_bndr)) ++
+                  map varToCoreExpr alt_bndrs
+
+    gadt_args = map (substExpr subst . varToCoreExpr) alt_bndrs
+
+    (alt_tvs, _) = span isTyVar alt_bndrs
+    Just (tv_subst, is_local) = coreRefineTys data_con alt_tvs (idType case_bndr)
+    subst = mkSubst in_scope tv_subst emptyVarEnv      -- No Id substitition
+    in_scope = mkInScopeSet (tyVarsOfTypes (varEnvElts tv_subst))
+
+    env1 | is_local  = env
+        | otherwise = env { cons = refineConstrEnv subst (cons env) }
+
+
+
+extendAlt :: ScEnv -> Id -> CoreExpr -> ConValue -> [Var] -> ScEnv
+extendAlt env case_bndr scrut val alt_bndrs
+  = let 
+       env1 = SCE { scope = extendVarEnvList (scope env) [(b,Other) | b <- case_bndr : alt_bndrs],
+                   cons  = extendVarEnv     (cons  env) case_bndr val }
+    in
+    case scrut of
        Var v ->   -- Bind the scrutinee in the ConstrEnv if it's a variable
                   -- Also forget if the scrutinee is a RecArg, because we're
                   -- now in the branch of a case, and we don't want to
                   -- record a non-scrutinee use of v if we have
                   --   case v of { (a,b) -> ...(f v)... }
                 SCE { scope = extendVarEnv (scope env1) v Other,
-                      cons  = extendVarEnv (cons env1)  v (con,args) }
+                      cons  = extendVarEnv (cons env1)  v val }
        other -> env1
 
-  where
-    env1 = SCE { scope = extendVarEnvList (scope env) [(b,Other) | b <- case_bndr : alt_bndrs],
-                cons  = extendVarEnv     (cons  env) case_bndr (con,args) }
-
-    args = map Type (tyConAppArgs (idType case_bndr)) ++
-          map varToCoreExpr alt_bndrs
-
     -- When we encounter a recursive function binding
     -- f = \x y -> ...
     -- we want to extend the scope env with bindings 
@@ -336,11 +365,11 @@ scExpr env (Note n e) = scExpr env e      `thenUs` \ (usg,e') ->
 scExpr env (Lam b e)  = scExpr (extendBndr env b) e    `thenUs` \ (usg,e') ->
                        returnUs (usg, Lam b e')
 
-scExpr env (Case scrut b alts) 
+scExpr env (Case scrut b ty alts) 
   = sc_scrut scrut             `thenUs` \ (scrut_usg, scrut') ->
     mapAndUnzipUs sc_alt alts  `thenUs` \ (alts_usgs, alts') ->
     returnUs (combineUsages alts_usgs `combineUsage` scrut_usg,
-             Case scrut' b alts')
+             Case scrut' b ty alts')
   where
     sc_scrut e@(Var v) = returnUs (varUsage env v CaseScrut, e)
     sc_scrut e        = scExpr env e
@@ -376,7 +405,7 @@ scExpr env e@(App _ _)
 ----------------------
 scBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, ScUsage, CoreBind)
 scBind env (Rec [(fn,rhs)])
-  | not (null val_bndrs)
+  | notNull val_bndrs
   = scExpr env_fn_body body            `thenUs` \ (usg, body') ->
     let
        SCU { calls = calls, occs = occs } = usg
@@ -443,7 +472,7 @@ specialise env fn bndrs body (SCU {calls=calls, occs=occs})
                  (nubBy same_call good_calls `zip` [1..])
   where
     n_bndrs  = length bndrs
-    same_call as1 as2 = and (zipWith eqExpr as1 as2)
+    same_call as1 as2 = and (zipWith tcEqExpr as1 as2)
 
 ---------------------
 good_arg :: ConstrEnv -> IdEnv ArgOcc -> (CoreBndr, CoreArg) -> Bool
@@ -510,11 +539,11 @@ spec_one env fn rhs (pats, rule_number)
                -- Usual w/w hack to avoid generating 
                -- a spec_rhs of unlifted type and no args
        
-       rule_name = _PK_ ("SC:" ++ showSDoc (ppr fn <> int rule_number))
+       rule_name = mkFastString ("SC:" ++ showSDoc (ppr fn <> int rule_number))
        spec_rhs  = mkLams spec_lam_args spec_body
        spec_id   = mkUserLocal spec_occ spec_uniq (mkPiTypes spec_lam_args body_ty) fn_loc
-       rule      = Rule rule_name specConstrActivation
-                        bndrs pats (mkVarApps (Var spec_id) spec_call_args)
+       rule_rhs  = mkVarApps (Var spec_id) spec_call_args
+       rule      = mkLocalRule rule_name specConstrActivation fn_name bndrs pats rule_rhs
     in
     returnUs (rule, (spec_id, spec_rhs))
 
@@ -546,12 +575,12 @@ they are constructor applications.
     -- placeholder variables.  For example:
     --    C a (D (f x) (g y))  ==>  C p1 (D p2 p3)
 
-argToPat   :: ConstrEnv -> UniqSupply -> CoreArg   -> (UniqSupply, CoreExpr)
+argToPat   :: ConstrEnv -> UniqSupply -> CoreArg -> (UniqSupply, CoreExpr)
 argToPat env us (Type ty) 
   = (us, Type ty)
 
 argToPat env us arg
-  | Just (dc,args) <- is_con_app_maybe env arg
+  | Just (CV dc args) <- is_con_app_maybe env arg
   = let
        (us',args') = argsToPats env us args
     in
@@ -561,7 +590,7 @@ argToPat env us (Var v)     -- Don't uniqify existing vars,
   = (us, Var v)                -- so that we can spot when we pass them twice
 
 argToPat env us arg
-  = (us1, Var (mkSysLocal SLIT("sc") (uniqFromSupply us2) (exprType arg)))
+  = (us1, Var (mkSysLocal FSLIT("sc") (uniqFromSupply us2) (exprType arg)))
   where
     (us1,us2) = splitUniqSupply us
 
@@ -571,7 +600,7 @@ argsToPats env us args = mapAccumL (argToPat env) us args
 
 
 \begin{code}
-is_con_app_maybe :: ConstrEnv -> CoreExpr -> Maybe (AltCon, [CoreExpr])
+is_con_app_maybe :: ConstrEnv -> CoreExpr -> Maybe ConValue
 is_con_app_maybe env (Var v)
   = lookupVarEnv env v
        -- You might think we could look in the idUnfolding here
@@ -579,14 +608,14 @@ is_con_app_maybe env (Var v)
        -- case we are in, which is the whole point
 
 is_con_app_maybe env (Lit lit)
-  = Just (LitAlt lit, [])
+  = Just (CV (LitAlt lit) [])
 
 is_con_app_maybe env expr
   = case collectArgs expr of
-       (Var fun, args) | Just con <- isDataConId_maybe fun,
+       (Var fun, args) | Just con <- isDataConWorkId_maybe fun,
                          args `lengthAtLeast` dataConRepArity con
                -- Might be > because the arity excludes type args
-                       -> Just (DataAlt con,args)
+                       -> Just (CV (DataAlt con) args)
 
        other -> Nothing