Major overhaul of the Simplifier
[ghc-hetmet.git] / compiler / specialise / SpecConstr.lhs
index afd53de..abf5360 100644 (file)
@@ -14,14 +14,13 @@ import CoreSyn
 import CoreLint                ( showPass, endPass )
 import CoreUtils       ( exprType, mkPiTypes )
 import CoreFVs                 ( exprsFreeVars )
 import CoreLint                ( showPass, endPass )
 import CoreUtils       ( exprType, mkPiTypes )
 import CoreFVs                 ( exprsFreeVars )
-import CoreSubst       ( Subst, mkSubst, substExpr )
 import CoreTidy                ( tidyRules )
 import PprCore         ( pprRules )
 import WwLib           ( mkWorkerArgs )
 import CoreTidy                ( tidyRules )
 import PprCore         ( pprRules )
 import WwLib           ( mkWorkerArgs )
-import DataCon         ( dataConRepArity, isVanillaDataCon, dataConTyVars )
-import Type            ( Type, tyConAppArgs, tyVarsOfTypes )
+import DataCon         ( dataConRepArity, dataConUnivTyVars )
+import Type            ( Type, tyConAppArgs )
+import Coercion                ( coercionKind )
 import Rules           ( matchN )
 import Rules           ( matchN )
-import Unify           ( coreRefineTys )
 import Id              ( Id, idName, idType, isDataConWorkId_maybe, 
                          mkUserLocal, mkSysLocal, idUnfolding, isLocalId )
 import Var             ( Var )
 import Id              ( Id, idName, idType, isDataConWorkId_maybe, 
                          mkUserLocal, mkSysLocal, idUnfolding, isLocalId )
 import Var             ( Var )
@@ -300,6 +299,24 @@ may avoid allocating it altogether.  Just like for constructors.
 
 Looks cool, but probably rare...but it might be easy to implement.
 
 
 Looks cool, but probably rare...but it might be easy to implement.
 
+
+Note [SpecConstr for casts]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider 
+    data family T a :: *
+    data instance T Int = T Int
+
+    foo n = ...
+       where
+         go (T 0) = 0
+         go (T n) = go (T (n-1))
+
+The recursive call ends up looking like 
+       go (T (I# ...) `cast` g)
+So we want to spot the construtor application inside the cast.
+That's why we have the Cast case in argToPat
+
+
 -----------------------------------------------------
                Stuff not yet handled
 -----------------------------------------------------
 -----------------------------------------------------
                Stuff not yet handled
 -----------------------------------------------------
@@ -429,12 +446,6 @@ data ConValue  = CV AltCon [CoreArg]
 instance Outputable ConValue where
    ppr (CV con args) = ppr con <+> interpp'SP args
 
 instance Outputable ConValue where
    ppr (CV con args) = ppr con <+> interpp'SP args
 
-refineConstrEnv :: Subst -> ConstrEnv -> ConstrEnv
--- The substitution is a type substitution only
-refineConstrEnv subst env = mapVarEnv refine_con_value env
-  where
-    refine_con_value (CV con args) = CV con (map (substExpr subst) args)
-
 emptyScEnv = SCE { scope = emptyVarEnv, cons = emptyVarEnv }
 
 data HowBound = RecFun -- These are the recursive functions for which 
 emptyScEnv = SCE { scope = emptyVarEnv, cons = emptyVarEnv }
 
 data HowBound = RecFun -- These are the recursive functions for which 
@@ -474,37 +485,25 @@ extendCaseBndrs env case_bndr scrut con alt_bndrs
                                [(b,how_bound) | b <- case_bndr:alt_bndrs] }
 
        -- Record RecArg for the components iff the scrutinee is RecArg
                                [(b,how_bound) | b <- case_bndr:alt_bndrs] }
 
        -- Record RecArg for the components iff the scrutinee is RecArg
+       -- I think the only reason for this is to keep the usage envt small
+       -- so is it worth it at all?
        --      [This comment looks plain wrong to me, so I'm ignoring it
        --           "Also forget if the scrutinee is a RecArg, because we're
        --           now in the branch of a case, and we don't want to
        --           record a non-scrutinee use of v if we have
        --              case v of { (a,b) -> ...(f v)... }" ]
        --      [This comment looks plain wrong to me, so I'm ignoring it
        --           "Also forget if the scrutinee is a RecArg, because we're
        --           now in the branch of a case, and we don't want to
        --           record a non-scrutinee use of v if we have
        --              case v of { (a,b) -> ...(f v)... }" ]
-    how_bound = case scrut of
-                 Var v -> lookupVarEnv cur_scope v `orElse` Other
-                 other -> Other
-
-    extend_data_con data_con
-       | isVanillaDataCon data_con = extendCons env1 scrut case_bndr (CV con vanilla_args)
-       | otherwise                 = extendCons env2 scrut case_bndr (CV con gadt_args)
-               -- Note env2 for GADTs
+    how_bound = get_how scrut
        where
        where
-    
-           vanilla_args = map Type (tyConAppArgs (idType case_bndr)) ++
-                          map varToCoreExpr alt_bndrs
-
-           gadt_args = map (substExpr subst . varToCoreExpr) alt_bndrs
-               -- This call generates some bogus warnings from substExpr,
-               -- because it's inconvenient to put all the Ids in scope
-               -- Will be fixed when we move to FC
-
-           (alt_tvs, _) = span isTyVar alt_bndrs
-           Just (tv_subst, is_local) = coreRefineTys data_con alt_tvs (idType case_bndr)
-           subst = mkSubst in_scope tv_subst emptyVarEnv       -- No Id substitition
-           in_scope = mkInScopeSet (tyVarsOfTypes (varEnvElts tv_subst))
-       
-           env2 | is_local  = env1
-                | otherwise = env1 { cons = refineConstrEnv subst (cons env) }
+           get_how (Var v)    = lookupVarEnv cur_scope v `orElse` Other
+           get_how (Cast e _) = get_how e
+           get_how (Note _ e) = get_how e
+           get_how other      = Other
 
 
+    extend_data_con data_con = 
+      extendCons env1 scrut case_bndr (CV con vanilla_args)
+       where
+           vanilla_args = map Type (tyConAppArgs (idType case_bndr)) ++
+                          varsToCoreExprs alt_bndrs
 
 extendCons :: ScEnv -> CoreExpr -> Id -> ConValue -> ScEnv
 extendCons env scrut case_bndr val
 
 extendCons :: ScEnv -> CoreExpr -> Id -> ConValue -> ScEnv
 extendCons env scrut case_bndr val
@@ -572,9 +571,10 @@ data ArgOcc = NoOcc        -- Doesn't occur at all; or a type argument
 
 {-     Note  [ScrutOcc]
 
 
 {-     Note  [ScrutOcc]
 
-An occurrence of ScrutOcc indicates that the thing is *only* taken apart or applied.
+An occurrence of ScrutOcc indicates that the thing, or a `cast` version of the thing,
+is *only* taken apart or applied.
 
 
-  Functions, litersl: ScrutOcc emptyUFM
+  Functions, literal: ScrutOcc emptyUFM
   Data constructors:  ScrutOcc subs,
 
 where (subs :: UniqFM [ArgOcc]) gives usage of the *pattern-bound* components,
   Data constructors:  ScrutOcc subs,
 
 where (subs :: UniqFM [ArgOcc]) gives usage of the *pattern-bound* components,
@@ -588,7 +588,7 @@ A pattern binds b, x::a, y::b, z::b->a, but not 'a'!
 -}
 
 instance Outputable ArgOcc where
 -}
 
 instance Outputable ArgOcc where
-  ppr (ScrutOcc xs) = ptext SLIT("scrut-occ") <> parens (ppr xs)
+  ppr (ScrutOcc xs) = ptext SLIT("scrut-occ") <> ppr xs
   ppr UnkOcc       = ptext SLIT("unk-occ")
   ppr BothOcc      = ptext SLIT("both-occ")
   ppr NoOcc                = ptext SLIT("no-occ")
   ppr UnkOcc       = ptext SLIT("unk-occ")
   ppr BothOcc      = ptext SLIT("both-occ")
   ppr NoOcc                = ptext SLIT("no-occ")
@@ -608,10 +608,7 @@ conArgOccs :: ArgOcc -> AltCon -> [ArgOcc]
 
 conArgOccs (ScrutOcc fm) (DataAlt dc) 
   | Just pat_arg_occs <- lookupUFM fm dc
 
 conArgOccs (ScrutOcc fm) (DataAlt dc) 
   | Just pat_arg_occs <- lookupUFM fm dc
-  = tyvar_unks ++ pat_arg_occs
-  where
-    tyvar_unks | isVanillaDataCon dc = [UnkOcc | tv <- dataConTyVars dc]
-              | otherwise           = []
+  = [UnkOcc | tv <- dataConUnivTyVars dc] ++ pat_arg_occs
 
 conArgOccs other con = repeat UnkOcc
 \end{code}
 
 conArgOccs other con = repeat UnkOcc
 \end{code}
@@ -636,6 +633,8 @@ scExpr env e@(Lit l)  = returnUs (nullUsage, e)
 scExpr env e@(Var v)  = returnUs (varUsage env v UnkOcc, e)
 scExpr env (Note n e) = scExpr env e   `thenUs` \ (usg,e') ->
                        returnUs (usg, Note n e')
 scExpr env e@(Var v)  = returnUs (varUsage env v UnkOcc, e)
 scExpr env (Note n e) = scExpr env e   `thenUs` \ (usg,e') ->
                        returnUs (usg, Note n e')
+scExpr env (Cast e co)= scExpr env e   `thenUs` \ (usg,e') ->
+                        returnUs (usg, Cast e' co)
 scExpr env (Lam b e)  = scExpr (extendBndr env b) e    `thenUs` \ (usg,e') ->
                        returnUs (usg, Lam b e')
 
 scExpr env (Lam b e)  = scExpr (extendBndr env b) e    `thenUs` \ (usg,e') ->
                        returnUs (usg, Lam b e')
 
@@ -689,9 +688,12 @@ scExpr env e@(App _ _)
 ----------------------
 scScrut :: ScEnv -> CoreExpr -> ArgOcc -> UniqSM (ScUsage, CoreExpr)
 -- Used for the scrutinee of a case, 
 ----------------------
 scScrut :: ScEnv -> CoreExpr -> ArgOcc -> UniqSM (ScUsage, CoreExpr)
 -- Used for the scrutinee of a case, 
--- or the function of an application
-scScrut env e@(Var v) occ = returnUs (varUsage env v occ, e)
-scScrut env e        occ = scExpr env e
+-- or the function of an application.
+-- Remember to look through casts
+scScrut env e@(Var v)   occ = returnUs (varUsage env v occ, e)
+scScrut env (Cast e co) occ = do { (usg, e') <- scScrut env e occ
+                                ; returnUs (usg, Cast e' co) }
+scScrut env e          occ = scExpr env e
 
 
 ----------------------
 
 
 ----------------------
@@ -752,7 +754,8 @@ specialise :: ScEnv
 specialise env fn bndrs body body_usg
   = do { let (_, bndr_occs) = lookupOccs body_usg bndrs
 
 specialise env fn bndrs body body_usg
   = do { let (_, bndr_occs) = lookupOccs body_usg bndrs
 
-       ; mb_calls <- mapM (callToPats (scope env) bndr_occs)
+       ; mb_calls <- -- pprTrace "specialise" (ppr fn <+> ppr bndrs <+> ppr bndr_occs) $
+                     mapM (callToPats (scope env) bndr_occs)
                           (lookupVarEnv (calls body_usg) fn `orElse` [])
 
        ; let good_calls :: [([Var], [CoreArg])]
                           (lookupVarEnv (calls body_usg) fn `orElse` [])
 
        ; let good_calls :: [([Var], [CoreArg])]
@@ -761,9 +764,8 @@ specialise env fn bndrs body body_usg
                         [ exprsFreeVars pats `delVarSetList` vs 
                         | (vs,pats) <- good_calls ]
              uniq_calls = nubBy (same_call in_scope) good_calls
                         [ exprsFreeVars pats `delVarSetList` vs 
                         | (vs,pats) <- good_calls ]
              uniq_calls = nubBy (same_call in_scope) good_calls
-    in
-    mapAndUnzipUs (spec_one env fn (mkLams bndrs body)) 
-                 (uniq_calls `zip` [1..]) }
+       ; mapAndUnzipUs (spec_one env fn (mkLams bndrs body)) 
+                       (uniq_calls `zip` [1..]) }
   where
        -- Two calls are the same if they match both ways
     same_call in_scope (vs1,as1)(vs2,as2)
   where
        -- Two calls are the same if they match both ways
     same_call in_scope (vs1,as1)(vs2,as2)
@@ -785,7 +787,8 @@ callToPats in_scope bndr_occs (con_env, args)
                -- Quantify over variables that are not in sccpe
                -- See Note [Shadowing] at the top
                
                -- Quantify over variables that are not in sccpe
                -- See Note [Shadowing] at the top
                
-       ; if or good_pats 
+       ; -- pprTrace "callToPats"  (ppr args $$ ppr prs $$ ppr bndr_occs) $
+         if or good_pats 
          then return (Just (qvars, pats))
          else return Nothing }
 
          then return (Just (qvars, pats))
          else return Nothing }
 
@@ -902,6 +905,20 @@ argToPat in_scope con_env (Var v) arg_occ
     then return (True, Var v)
     else wildCardPat (idType v)
 
     then return (True, Var v)
     else wildCardPat (idType v)
 
+argToPat in_scope con_env (Let _ arg) arg_occ
+  = argToPat in_scope con_env arg arg_occ
+       -- Look through let expressions
+       -- e.g.         f (let v = rhs in \y -> ...v...)
+       -- Here we can specialise for f (\y -> ...)
+       -- because the rule-matcher will look through the let.
+
+argToPat in_scope con_env (Cast arg co) arg_occ
+  = do { (interesting, arg') <- argToPat in_scope con_env arg arg_occ
+       ; if interesting then 
+               return (interesting, Cast arg' co)
+         else 
+               wildCardPat (snd (coercionKind co)) }
+
 argToPat in_scope con_env arg arg_occ
   | is_value_lam arg
   = return (True, arg)
 argToPat in_scope con_env arg arg_occ
   | is_value_lam arg
   = return (True, arg)
@@ -980,4 +997,5 @@ is_con_app_maybe env expr
 mk_con_app :: AltCon -> [CoreArg] -> CoreExpr
 mk_con_app (LitAlt lit)  []   = Lit lit
 mk_con_app (DataAlt con) args = mkConApp con args
 mk_con_app :: AltCon -> [CoreArg] -> CoreExpr
 mk_con_app (LitAlt lit)  []   = Lit lit
 mk_con_app (DataAlt con) args = mkConApp con args
+mk_con_app other args = panic "SpecConstr.mk_con_app"
 \end{code}
 \end{code}