Fix a bug in the handling of implication constraints (Trac #1430)
authorsimonpj@microsoft.com <unknown>
Tue, 19 Jun 2007 16:26:13 +0000 (16:26 +0000)
committersimonpj@microsoft.com <unknown>
Tue, 19 Jun 2007 16:26:13 +0000 (16:26 +0000)
Trac #1430 showed up quite a nasty bug in the handling of implication
constraints when we are *inferring* the type of a function.
See Note [Inference and implication constraints]:

  We can't (or at least don't) abstract over implications.  But we might
  have an implication constraint (perhaps arising from a nested pattern
  match) like
   C a => D a
  when we are now trying to quantify over 'a'.  Our best approximation
  is to make (D a) part of the inferred context, so we can use that to
  discharge the implication. Hence getImplicWanteds.

My solution is not marvellous, but it's better than before.  I transferred
function getDefaultableDicts from Inst to TcSimplify (since it's only
called there).  Many of the remaining 50 new lines are comments.  But
there is undoubtedly more code than before (sigh).

Test is tc228.

compiler/typecheck/Inst.lhs
compiler/typecheck/TcSimplify.lhs

index 5c6d8fe..962e4e0 100644 (file)
@@ -31,7 +31,7 @@ module Inst (
 
        isDict, isClassDict, isMethod, isImplicInst,
        isIPDict, isInheritableInst, isMethodOrLit,
-       isTyVarDict, isMethodFor, getDefaultableDicts,
+       isTyVarDict, isMethodFor, 
 
        zonkInst, zonkInsts,
        instToId, instToVar, instName,
@@ -54,7 +54,6 @@ import FunDeps
 import TcMType
 import TcType
 import Type
-import Class
 import Unify
 import Module
 import Coercion
@@ -211,26 +210,6 @@ isMethodOrLit (LitInst {}) = True
 isMethodOrLit other        = False
 \end{code}
 
-\begin{code}
-getDefaultableDicts :: [Inst] -> ([(Inst, Class, TcTyVar)], TcTyVarSet)
--- Look for free dicts of the form (C tv), even inside implications
--- *and* the set of tyvars mentioned by all *other* constaints
--- This disgustingly ad-hoc function is solely to support defaulting
-getDefaultableDicts insts
-  = (concat ps, unionVarSets tvs)
-  where
-    (ps, tvs) = mapAndUnzip get insts
-    get d@(Dict {tci_pred = ClassP cls [ty]})
-       | Just tv <- tcGetTyVar_maybe ty = ([(d,cls,tv)], emptyVarSet)
-       | otherwise                      = ([], tyVarsOfType ty)
-    get (ImplicInst {tci_tyvars = tvs, tci_wanted = wanteds})
-       = ([ up | up@(_,_,tv) <- ups, not (tv `elemVarSet` tv_set)],
-          ftvs `minusVarSet` tv_set)
-       where
-          tv_set = mkVarSet tvs
-          (ups, ftvs) = getDefaultableDicts wanteds
-    get inst = ([], tyVarsOfInst inst)
-\end{code}
 
 %************************************************************************
 %*                                                                     *
@@ -303,7 +282,7 @@ instCallDicts loc (pred : preds)
        ; return (dict:dicts, co_fn <.> WpApp (instToId dict)) }
 
 -------------
-cloneDict :: Inst -> TcM Inst  -- Only used for linear implicit params
+cloneDict :: Inst -> TcM Inst
 cloneDict dict@(Dict nm ty loc) = do { uniq <- newUnique
                                     ; return (dict {tci_name = setNameUnique nm uniq}) }
 cloneDict other = pprPanic "cloneDict" (ppr other)
index 373a174..adf0f78 100644 (file)
@@ -681,32 +681,77 @@ tcSimplifyInfer doc tau_tvs wanted
        ; traceTc (text "infer" <+> (ppr preds $$ ppr (grow preds tau_tvs') $$ ppr gbl_tvs $$ 
                   ppr (oclose preds gbl_tvs) $$ ppr free1 $$ ppr bound))
        ; let try_me inst = ReduceMe AddSCs
-       ; (irreds, binds) <- checkLoop (mkRedEnv doc try_me []) bound
-       ; qtvs' <- zonkQuantifiedTyVars (varSetElems qtvs)
+             red_env     = mkRedEnv doc try_me []
+       ; (irreds1, binds1) <- checkLoop red_env bound
+
+               -- Note [Inference and implication constraints]
+               -- By putting extra_dicts first, we make them available
+               -- to solve the implication constraints
+       ; let extra_dicts = getImplicWanteds qtvs irreds1
+       ; (irreds2, binds2) <- if null extra_dicts 
+                              then return (irreds1, emptyBag)
+                              else do { extra_dicts' <- mapM cloneDict extra_dicts
+                                      ; checkLoop red_env (extra_dicts' ++ irreds1) }
+
+               -- By now improvment may have taken place, and we must *not*
+               -- quantify over any variable free in the environment
+               -- tc137 (function h inside g) is an example
+       ; gbl_tvs <- tcGetGlobalTyVars
+       ; qtvs1 <- zonkTcTyVarsAndFV (varSetElems qtvs)
+       ; qtvs2 <- zonkQuantifiedTyVars (varSetElems (qtvs1 `minusVarSet` gbl_tvs))
 
                -- Do not quantify over constraints that *now* do not
                -- mention quantified type variables, because they are
-               -- simply ambiguous.  Example:
+               -- simply ambiguous (or might be bound further out).  Example:
                --      f :: Eq b => a -> (a, b)
                --      g x = fst (f x)
                -- From the RHS of g we get the MethodInst f77 :: alpha -> (alpha, beta)
-               -- We decide to quantify over 'alpha' alone, bur free1 does not include f77
+               -- We decide to quantify over 'alpha' alone, but free1 does not include f77
                -- because f77 mentions 'alpha'.  Then reducing leaves only the (ambiguous)
                -- constraint (Eq beta), which we dump back into the free set
                -- See test tcfail181
-       ; let (free2, irreds2) = partition (isFreeWhenInferring (mkVarSet qtvs')) irreds
-       ; extendLIEs free2
+       ; let (free3, irreds3) = partition (isFreeWhenInferring (mkVarSet qtvs2)) irreds2
+       ; extendLIEs free3
        
-               -- We can't abstract over implications
-       ; let (dicts, implics) = partition isDict irreds2
+               -- We can't abstract over any remaining unsolved 
+               -- implications so instead just float them outwards. Ugh.
+       ; let (q_dicts, implics) = partition isDict irreds3
        ; loc <- getInstLoc (ImplicOrigin doc)
-       ; implic_bind <- bindIrreds loc qtvs' dicts implics
+       ; implic_bind <- bindIrreds loc qtvs2 q_dicts implics
 
-       ; return (qtvs', dicts, binds `unionBags` implic_bind) }
+       ; return (qtvs2, q_dicts, binds1 `unionBags` binds2 `unionBags` implic_bind) }
        -- NB: when we are done, we might have some bindings, but
        -- the final qtvs might be empty.  See Note [NO TYVARS] below.
+
+getImplicWanteds :: TcTyVarSet -> [Inst] -> [Inst]
+-- See Note [Inference and implication constraints]
+-- Find the wanted constraints in implication constraints that mention the 
+-- quantified type variables, and are not bound by forall's in the constraint itself
+-- Returns only Dicts
+getImplicWanteds qtvs implics
+  = concatMap get implics
+  where
+    get d@(Dict {}) | tyVarsOfInst d `intersectsVarSet` qtvs = [d]
+                   | otherwise                              = []
+    get (ImplicInst {tci_tyvars = tvs, tci_wanted = wanteds})
+       = [ d | let tv_set = mkVarSet tvs
+             , d <- getImplicWanteds qtvs wanteds 
+             , not (tyVarsOfInst d `intersectsVarSet` tv_set)]
 \end{code}
 
+Note [Inference and implication constraints]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
+We can't (or at least don't) abstract over implications.  But we might
+have an implication constraint (perhaps arising from a nested pattern
+match) like
+       C a => D a
+when we are now trying to quantify over 'a'.  Our best approximation
+is to make (D a) part of the inferred context, so we can use that to
+discharge the implication. Hence getImplicWanteds.
+
+See Trac #1430 and test tc228.
+
+
 \begin{code}
 -----------------------------------------------------------
 -- tcSimplifyInferCheck is used when we know the constraints we are to simplify
@@ -968,7 +1013,7 @@ Inside the pattern match, which binds (a:*, x:a), we know that
 Hence we have a dictionary for Show [a] available; and indeed we 
 need it.  We are going to build an implication contraint
        forall a. (b~[a]) => Show [a]
-Later, we will solve this constraint using the knowledge (Show b)
+Later, we will solve this constraint using the knowledg e(Show b)
        
 But we MUST NOT reduce (Show [a]) to (Show a), else the whole
 thing becomes insoluble.  So we simplify gently (get rid of literals
@@ -1351,7 +1396,7 @@ tcSimplifyRuleLhs wanteds
                                 -- to fromInteger; this looks fragile to me
             ; lookup_result <- lookupSimpleInst w'
             ; case lookup_result of
-                GenInst ws' rhs -> go dicts (addBind binds w rhs) (ws' ++ ws)
+                GenInst ws' rhs -> go dicts (addBind binds (instToId w) rhs) (ws' ++ ws)
                 NoInstance      -> pprPanic "tcSimplifyRuleLhs" (ppr w)
          }
 \end{code}
@@ -1592,6 +1637,7 @@ reduceContext env wanteds
             text "----",
             text "avails" <+> pprAvails avails,
             text "improved =" <+> ppr improved,
+            text "irreds = " <+> ppr irreds,
             text "----------------------"
             ]))
 
@@ -1860,6 +1906,9 @@ reduceImplication env orig_avails reft tvs extra_givens wanteds inst_loc
                -- Extract the binding
        ; (binds, irreds) <- extractResults avails wanteds
  
+       ; traceTc (text "reduceImplication result" <+> vcat
+                       [ ppr irreds, ppr binds])
+
                -- We always discard the extra avails we've generated;
                -- but we remember if we have done any (global) improvement
        ; let ret_avails = updateImprovement orig_avails avails
@@ -1868,16 +1917,12 @@ reduceImplication env orig_avails reft tvs extra_givens wanteds inst_loc
                return (ret_avails, NoInstance)
          else do
        { (implic_insts, bind) <- makeImplicationBind inst_loc tvs reft extra_givens irreds
-                       -- This binding is useless if the recursive simplification
-                       -- made no progress; but currently we don't try to optimise that
-                       -- case.  After all, we only try hard to reduce at top level, or
-                       -- when inferring types.
 
        ; let   dict_ids = map instToId extra_givens
                co  = mkWpTyLams tvs <.> mkWpLams dict_ids <.> WpLet (binds `unionBags` bind)
                rhs = mkHsWrap co payload
                loc = instLocSpan inst_loc
-               payload | isSingleton wanteds = HsVar (instToId (head wanteds))
+               payload | [wanted] <- wanteds = HsVar (instToId wanted)
                        | otherwise = ExplicitTuple (map (L loc . HsVar . instToId) wanteds) Boxed
 
                -- If there are any irreds, we back off and return NoInstance
@@ -1940,7 +1985,7 @@ type ImprovementDone = Bool       -- True <=> some unification has happened
 
 type AvailEnv = FiniteMap Inst AvailHow
 data AvailHow
-  = IsIrred            -- Used for irreducible dictionaries,
+  = IsIrred TcId       -- Used for irreducible dictionaries,
                        -- which are going to be lambda bound
 
   | Given TcId                 -- Used for dictionaries for which we have a binding
@@ -1963,7 +2008,7 @@ instance Outputable AvailHow where
 
 -------------------------
 pprAvail :: AvailHow -> SDoc
-pprAvail IsIrred       = text "Irred"
+pprAvail (IsIrred x)   = text "Irred" <+> ppr x
 pprAvail (Given x)     = text "Given" <+> ppr x
 pprAvail (Rhs rhs bs)   = text "Rhs" <+> ppr rhs <+> braces (ppr bs)
 
@@ -2026,25 +2071,30 @@ extractResults (Avails _ avails) wanteds
          Nothing    -> pprTrace "Urk: extractResults" (ppr w) $
                        go avails binds irreds ws
 
-         Just IsIrred -> go (add_given avails w) binds (w:irreds) ws
-
          Just (Given id) 
-               | id == instToId w
-               -> go avails binds irreds ws 
+               | id == w_id -> go avails binds irreds ws 
+               | otherwise  -> go avails (addBind binds w_id (nlHsVar id)) irreds ws
                -- The sought Id can be one of the givens, via a superclass chain
                -- and then we definitely don't want to generate an x=x binding!
 
-               | otherwise
-               -> go avails (addBind binds w (nlHsVar id)) irreds ws
+         Just (IsIrred id) 
+               | id == w_id -> go (add_given avails w) binds           (w:irreds) ws
+               | otherwise  -> go avails (addBind binds w_id (nlHsVar id)) irreds ws
+               -- The add_given handles the case where we want (Ord a, Eq a), and we
+               -- don't want to emit *two* Irreds for Ord a, one via the superclass chain
+               -- This showed up in a dupliated Ord constraint in the error message for 
+               --      test tcfail043
 
          Just (Rhs rhs ws') -> go (add_given avails w) new_binds irreds (ws' ++ ws)
-                            where
-                               new_binds = addBind binds w rhs
+                            where      
+                               new_binds = addBind binds w_id rhs
+      where
+       w_id = instToId w       
 
     add_given avails w = extendAvailEnv avails w (Given (instToId w))
+       -- Don't add the same binding twice
 
-addBind binds inst rhs = binds `unionBags` unitBag (L (instSpan inst) 
-                                                     (VarBind (instToId inst) rhs))
+addBind binds id rhs = binds `unionBags` unitBag (L (getSrcSpan id) (VarBind id rhs))
 \end{code}
 
 
@@ -2116,7 +2166,7 @@ than with the Avails handling stuff in TcSimplify
 \begin{code}
 addIrred :: WantSCs -> Avails -> Inst -> TcM Avails
 addIrred want_scs avails irred = ASSERT2( not (irred `elemAvails` avails), ppr irred $$ ppr avails )
-                                addAvailAndSCs want_scs avails irred IsIrred
+                                addAvailAndSCs want_scs avails irred (IsIrred (instToId irred))
 
 addAvailAndSCs :: WantSCs -> Avails -> Inst -> AvailHow -> TcM Avails
 addAvailAndSCs want_scs avails inst avail
@@ -2226,9 +2276,10 @@ tc_simplify_top doc interactive wanteds
                -- NB: irreds are already zonked
        ; dflags <- getDOpts
        ; disambiguate interactive dflags irreds1       -- Does unification
-       ; (irreds2, binds2) <- topCheckLoop doc irreds1
+                                                       -- hence try again
 
-               -- Deal with implicit parameter
+               -- Deal with implicit parameters
+       ; (irreds2, binds2) <- topCheckLoop doc irreds1
        ; let (bad_ips, non_ips) = partition isIPDict irreds2
              (ambigs, others)   = partition isTyVarDict non_ips
 
@@ -2322,6 +2373,7 @@ disambiguate interactive dflags insts
    is_std_class cls = isStandardClass cls || (ovl_strings && (cls `hasKey` isStringClassKey))
        -- Similarly is_std_class
 
+-----------------------
 disambigGroup :: [Type]                        -- The default types
              -> [(Inst,Class,TcTyVar)] -- All standard classes of form (C a)
              -> TcM () -- Just does unification, to fix the default types
@@ -2347,6 +2399,7 @@ disambigGroup default_tys dicts
           ; unifyType default_ty (mkTyVarTy tyvar) }
 
 
+-----------------------
 getDefaultTys :: Bool -> Bool -> TcM [Type]
 getDefaultTys extended_deflts ovl_strings
   = do { mb_defaults <- getDeclaredDefaultTys
@@ -2368,6 +2421,26 @@ getDefaultTys extended_deflts ovl_strings
   where
     opt_deflt True  ty = [ty]
     opt_deflt False ty = []
+
+-----------------------
+getDefaultableDicts :: [Inst] -> ([(Inst, Class, TcTyVar)], TcTyVarSet)
+-- Look for free dicts of the form (C tv), even inside implications
+-- *and* the set of tyvars mentioned by all *other* constaints
+-- This disgustingly ad-hoc function is solely to support defaulting
+getDefaultableDicts insts
+  = (concat ps, unionVarSets tvs)
+  where
+    (ps, tvs) = mapAndUnzip get insts
+    get d@(Dict {tci_pred = ClassP cls [ty]})
+       | Just tv <- tcGetTyVar_maybe ty = ([(d,cls,tv)], emptyVarSet)
+       | otherwise                      = ([], tyVarsOfType ty)
+    get (ImplicInst {tci_tyvars = tvs, tci_wanted = wanteds})
+       = ([ up | up@(_,_,tv) <- ups, not (tv `elemVarSet` tv_set)],
+          ftvs `minusVarSet` tv_set)
+       where
+          tv_set = mkVarSet tvs
+          (ups, ftvs) = getDefaultableDicts wanteds
+    get inst = ([], tyVarsOfInst inst)
 \end{code}
 
 Note [Default unitTy]