Type families: fixed all non-termination in the testsuite
authorManuel M T Chakravarty <chak@cse.unsw.edu.au>
Sun, 14 Sep 2008 12:06:38 +0000 (12:06 +0000)
committerManuel M T Chakravarty <chak@cse.unsw.edu.au>
Sun, 14 Sep 2008 12:06:38 +0000 (12:06 +0000)
compiler/typecheck/TcSimplify.lhs
compiler/typecheck/TcTyFuns.lhs

index b074437..3c7df83 100644 (file)
@@ -1098,7 +1098,7 @@ checkLoop env wanteds
        
                ; (improved, binds, irreds) <- reduceContext env' wanteds'
 
-               ; if not improved then
+               ; if null irreds || not improved then
                    return (irreds, binds)
                  else do
        
@@ -1783,6 +1783,8 @@ reduceContext env wanteds0
            wanteds', 
            normalise_binds,
            eq_improved)     <- tcReduceEqs givens wanteds
+       ; traceTc $ text "reduceContext: tcReduceEqs" <+> vcat
+                     [ppr givens', ppr wanteds', ppr normalise_binds]
 
           -- Build the Avail mapping from "given_dicts"
        ; (init_state, _) <- getLIE $ do 
@@ -1797,7 +1799,7 @@ reduceContext env wanteds0
        ; (dict_binds, 
            bound_dicts, 
            dict_irreds)       <- extractResults avails wanted_dicts
-       ; traceTc $ text "reduceContext extractresults" <+> vcat
+       ; traceTc $ text "reduceContext: extractResults" <+> vcat
                      [ppr avails, ppr wanted_dicts, ppr dict_binds]
 
          -- Solve the wanted *implications*.  In doing so, we can provide
@@ -1812,34 +1814,21 @@ reduceContext env wanteds0
              implic_irreds = concat implic_irreds_s
 
           -- Collect all irreducible instances, and determine whether we should
-          -- go round again.  We do so in either of three cases:
+          -- go round again.  We do so in either of two cases:
           -- (1) If dictionary reduction or equality solving led to
           --     improvement (i.e., instantiated type variables).
-          -- (2) If we managed to normalise any dicts, there is merit in going
-          --     around gain, because reduceList may be able to get further.
-          -- (3) If we uncovered extra equalities.  We will try to solve them
+          -- (2) If we uncovered extra equalities.  We will try to solve them
           --     in the next iteration.
+
        ; let all_irreds       = dict_irreds ++ implic_irreds ++ extra_eqs
-              improvedFlexible = availsImproved avails ||
-                                 eq_improved
-              improvedDicts    = not $ isEmptyBag normalise_binds
+             avails_improved  = availsImproved avails
+              improvedFlexible = avails_improved || eq_improved
               extraEqs         = (not . null) extra_eqs
-              improved         = improvedFlexible || improvedDicts || extraEqs
-
-{- Old story
-         -- Figure out whether we should go round again.  We do so in either
-          -- two cases:
-          -- (1) If any of the mutable tyvars in givens or irreds has been
-          --     filled in by improvement, there is merit in going around 
-          --     again, because we may make further progress.
-          -- (2) If we managed to normalise any dicts, there is merit in going
-          --     around gain, because reduceList may be able to get further.
-
-       ; improvedMetaTy <- anyM isFilledMetaTyVar $ varSetElems $
-                           tyVarsOfInsts (givens ++ all_irreds)
-        ; let improvedDicts = not $ isEmptyBag normalise_binds
-              improved      = improvedMetaTy || improvedDicts
- -}
+              improved         = improvedFlexible || extraEqs
+              --
+              improvedHint  = (if avails_improved then " [AVAILS]" else "") ++
+                              (if eq_improved then " [EQ]" else "") ++
+                              (if extraEqs then " [EXTRA EQS]" else "")
 
        ; traceTc (text "reduceContext end" <+> (vcat [
             text "----------------------",
@@ -1848,7 +1837,7 @@ reduceContext env wanteds0
             text "wanted" <+> ppr wanteds0,
             text "----",
             text "avails" <+> pprAvails avails,
-            text "improved =" <+> ppr improved,
+            text "improved =" <+> ppr improved <+> text improvedHint,
             text "(all) irreds = " <+> ppr all_irreds,
             text "dict-binds = " <+> ppr dict_binds,
             text "implic-binds = " <+> ppr implic_binds,
@@ -1873,33 +1862,44 @@ tcImproveOne avails inst
                -- Avails has all the superclasses etc (good)
                -- It also has all the intermediates of the deduction (good)
                -- It does not have duplicates (good)
-               -- NB that (?x::t1) and (?x::t2) will be held separately in avails
-               --    so that improve will see them separate
+               -- NB that (?x::t1) and (?x::t2) will be held separately in 
+                --    avails so that improve will see them separate
        ; traceTc (text "improveOne" <+> ppr inst)
        ; unifyEqns eqns }
 
-unifyEqns :: [(Equation,(PredType,SDoc),(PredType,SDoc))] 
+unifyEqns :: [(Equation, (PredType, SDoc), (PredType, SDoc))] 
          -> TcM ImprovementDone
 unifyEqns [] = return False
 unifyEqns eqns
   = do { traceTc (ptext (sLit "Improve:") <+> vcat (map pprEquationDoc eqns))
-        ; mapM_ unify eqns
-       ; return True }
+        ; improved <- mapM unify eqns
+       ; return $ or improved
+        }
   where
     unify ((qtvs, pairs), what1, what2)
-         = addErrCtxtM (mkEqnMsg what1 what2) $ do
-           (_, _, tenv) <- tcInstTyVars (varSetElems qtvs)
-           mapM_ (unif_pr tenv) pairs
-    unif_pr tenv (ty1,ty2) =  unifyType (substTy tenv ty1) (substTy tenv ty2)
+         = addErrCtxtM (mkEqnMsg what1 what2) $ 
+             do { let freeTyVars = unionVarSets (map tvs_pr pairs) 
+                                   `minusVarSet` qtvs
+                ; (_, _, tenv) <- tcInstTyVars (varSetElems qtvs)
+                ; mapM_ (unif_pr tenv) pairs
+                ; anyM isFilledMetaTyVar $ varSetElems freeTyVars
+                }
+
+    unif_pr tenv (ty1, ty2) = unifyType (substTy tenv ty1) (substTy tenv ty2)
+
+    tvs_pr (ty1, ty2) = tyVarsOfType ty1 `unionVarSet` tyVarsOfType ty2
 
 pprEquationDoc :: (Equation, (PredType, SDoc), (PredType, SDoc)) -> SDoc
-pprEquationDoc (eqn, (p1, _), (p2, _)) = vcat [pprEquation eqn, nest 2 (ppr p1), nest 2 (ppr p2)]
+pprEquationDoc (eqn, (p1, _), (p2, _)) 
+  = vcat [pprEquation eqn, nest 2 (ppr p1), nest 2 (ppr p2)]
 
 mkEqnMsg :: (TcPredType, SDoc) -> (TcPredType, SDoc) -> TidyEnv
          -> TcM (TidyEnv, SDoc)
 mkEqnMsg (pred1,from1) (pred2,from2) tidy_env
-  = do { pred1' <- zonkTcPredType pred1; pred2' <- zonkTcPredType pred2
-       ; let { pred1'' = tidyPred tidy_env pred1'; pred2'' = tidyPred tidy_env pred2' }
+  = do { pred1' <- zonkTcPredType pred1
+        ; pred2' <- zonkTcPredType pred2
+       ; let { pred1'' = tidyPred tidy_env pred1'
+              ; pred2'' = tidyPred tidy_env pred2' }
        ; let msg = vcat [ptext (sLit "When using functional dependencies to combine"),
                          nest 2 (sep [ppr pred1'' <> comma, nest 2 from1]), 
                          nest 2 (sep [ppr pred2'' <> comma, nest 2 from2])]
index 981845a..0b026e1 100644 (file)
@@ -30,6 +30,7 @@ import Type
 import TypeRep         ( Type(..) )
 import TyCon
 import HsSyn
+import Id
 import VarEnv
 import VarSet
 import Var
@@ -319,6 +320,11 @@ propagateEqs eqCfg@(EqConfig {eqs = todoEqs})
 -- set of instances are the locals (without equalities) and the second set are
 -- all residual wanteds, including equalities. 
 --
+-- Remove all identity dictinary bindings (i.e., those whose source and target
+-- dictionary are the same).  This is important for termination, as
+-- TcSimplify.reduceContext takes the presence of dictionary bindings as an
+-- indicator that there was some improvement.
+--
 finaliseEqsAndDicts :: EqConfig 
                     -> TcM ([Inst], [Inst], TcDictBinds, Bool)
 finaliseEqsAndDicts (EqConfig { eqs     = eqs
@@ -328,11 +334,18 @@ finaliseEqsAndDicts (EqConfig { eqs     = eqs
                               })
   = do { (eqs', subst_binds, locals', wanteds') <- substitute eqs locals wanteds
        ; (eqs'', improved) <- instantiateAndExtract eqs'
-       ; return (locals', 
-                 eqs'' ++ wanteds', 
-                 subst_binds `unionBags` binds, 
-                 improved)
+       ; final_binds <- filterM nonTrivialDictBind $
+                          bagToList (subst_binds `unionBags` binds)
+       ; return (locals', eqs'' ++ wanteds', listToBag final_binds, improved)
        }
+  where
+    nonTrivialDictBind (L _ (VarBind { var_id = ide1
+                                     , var_rhs = L _ (HsWrap _ (HsVar ide2))}))
+      = do { ty1 <- zonkTcType (idType ide1)
+           ; ty2 <- zonkTcType (idType ide2)
+           ; return $ not (ty1 `tcEqType` ty2)
+           }
+    nonTrivialDictBind _ = return True
 \end{code}
 
 
@@ -415,6 +428,10 @@ In a corresponding manner, normDict normalises class dictionaries by
 extracting any synonym family applications and generation appropriate normal
 equalities. 
 
+Whenever we encounter a loopy equality (of the form a ~ T .. (F ...a...) ...),
+we drop that equality and raise an error if it is a wanted or a warning if it
+is a local.
+
 \begin{code}
 normEqInst :: Inst -> TcM ([RewriteInst], TyVarSet)
 -- Normalise one equality.
@@ -449,8 +466,18 @@ normEqInst inst
                  eqTys     = (ty1', ty2')
            ; (co', ty12_eqs') <- adjustCoercions co rewriteCo eqTys ty12_eqs
            ; eqs <- checkOrientation ty1' ty2' co' inst
-           ; return $ (eqs ++ ty12_eqs',
-                       ty1_skolems `unionVarSet` ty2_skolems)
+           ; if isLoopyEquality eqs ty12_eqs' 
+             then do { if isWantedCo (tci_co inst)
+                       then
+                          addErrCtxt (ptext (sLit "Rejecting loopy equality")) $
+                            eqInstMisMatch inst
+                       else
+                         warnDroppingLoopyEquality ty1 ty2
+                     ; return ([], emptyVarSet)         -- drop the equality
+                     }
+             else
+               return (eqs ++ ty12_eqs',
+                      ty1_skolems `unionVarSet` ty2_skolems)
            }
 
     mkRewriteFam con args ty2 co
@@ -474,6 +501,18 @@ normEqInst inst
                        unionVarSets (ty2_skolems:args_skolemss))
            }
 
+    -- If the original equality has the form a ~ T .. (F ...a...) ..., we will
+    -- have a variable equality with 'a' on the lhs as the first equality.
+    -- Then, check whether 'a' occurs in the lhs of any family equality
+    -- generated by flattening.
+    isLoopyEquality (RewriteVar {rwi_var = tv}:_) eqs
+      = any inRewriteFam eqs
+      where
+        inRewriteFam (RewriteFam {rwi_args = args}) 
+          = tv `elemVarSet` tyVarsOfTypes args
+        inRewriteFam _ = False
+    isLoopyEquality _ _ = False
+
 normDict :: Bool -> Inst -> TcM (Inst, [RewriteInst], TcDictBinds, TyVarSet)
 -- Normalise one dictionary or IP constraint.
 normDict isWanted inst@(Dict {tci_pred = ClassP clas args})
@@ -482,9 +521,14 @@ normDict isWanted inst@(Dict {tci_pred = ClassP clas args})
        ; let rewriteCo = PredTy $ ClassP clas cargs
              eqs       = concat args_eqss
              pred'     = ClassP clas args'
-       ; (inst', bind, eqs') <- mkDictBind inst isWanted rewriteCo pred' eqs
+       ; if null eqs
+         then  -- don't generate a binding if there is nothing to flatten
+           return (inst, [], emptyBag, emptyVarSet)
+         else do {
+       ; (inst', bind) <- mkDictBind inst isWanted rewriteCo pred'
+       ; eqs' <- if isWanted then return eqs else mapM wantedToLocal eqs
        ; return (inst', eqs', bind, unionVarSets args_skolemss)
-       }
+       }}
 normDict isWanted inst
   = return (inst, [], emptyBag, emptyVarSet)
 -- !!!TODO: Still need to normalise IP constraints.
@@ -660,13 +704,9 @@ mkDictBind :: Inst                 -- original instance
            -> Bool                 -- is this a wanted contraint?
            -> Coercion             -- coercion witnessing the rewrite
            -> PredType             -- coerced predicate
-           -> [RewriteInst]        -- equalities from flattening
            -> TcM (Inst,           -- new inst
-                   TcDictBinds,    -- binding for coerced dictionary
-                   [RewriteInst])  -- final equalities from flattening
-mkDictBind dict _isWanted _rewriteCo _pred []
-  = return (dict, emptyBag, [])    -- don't generate binding for an id coercion
-mkDictBind dict isWanted rewriteCo pred eqs
+                   TcDictBinds)    -- binding for coerced dictionary
+mkDictBind dict isWanted rewriteCo pred
   = do { dict' <- newDictBndr loc pred
          -- relate the old inst to the new one
          -- target_dict = source_dict `cast` st_co
@@ -683,8 +723,7 @@ mkDictBind dict isWanted rewriteCo pred eqs
              cast_expr = HsWrap (WpCast st_co) expr
              rhs       = L (instLocSpan loc) cast_expr
              binds     = instToDictBind target_dict rhs
-       ; eqs' <- if isWanted then return eqs else mapM wantedToLocal eqs
-       ; return (dict', binds, eqs')
+       ; return (dict', binds)
        }
   where
     loc = tci_loc dict
@@ -955,7 +994,9 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
       = return (res, binds, locals, wanteds)
     subst (eq@(RewriteVar {rwi_var = tv, rwi_right = ty, rwi_co = co}):eqs) 
           res binds locals wanteds
-      = do { let coSubst = zipOpenTvSubst [tv] [eqInstCoType co]
+      = do { traceTc $ ptext (sLit "TcTyFuns.substitute:") <+> ppr tv <+>
+                       ptext (sLit "->") <+> ppr ty
+           ; let coSubst = zipOpenTvSubst [tv] [eqInstCoType co]
                  tySubst = zipOpenTvSubst [tv] [ty]
            ; eqs'               <- mapM (substEq eq coSubst tySubst) eqs
            ; res'               <- mapM (substEq eq coSubst tySubst) res
@@ -1004,7 +1045,7 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
       = do { let co1Subst = mkSymCoercion $ 
                               PredTy (substPred coSubst (tci_pred dict))
                  pred'    = substPred tySubst (tci_pred dict)
-           ; (dict', binds, _) <- mkDictBind dict isWanted co1Subst pred' []
+           ; (dict', binds) <- mkDictBind dict isWanted co1Subst pred'
            ; return (binds, dict')
            }
 
@@ -1206,3 +1247,16 @@ misMatchMsg env0 (ty_act, ty_exp)
 
     ppr_extra env _ty = (env, empty)           -- Normal case
 \end{code}
+
+Warn of loopy local equalities that were dropped.
+
+\begin{code}
+warnDroppingLoopyEquality :: TcType -> TcType -> TcM ()
+warnDroppingLoopyEquality ty1 ty2 
+  = do { env0 <- tcInitTidyEnv
+       ; ty1 <- zonkTcType ty1
+       ; ty2 <- zonkTcType ty2
+       ; addWarnTc $ hang (ptext (sLit "Dropping loopy given equality"))
+                      2 (ppr ty1 <+> text "~" <+> ppr ty2)
+       }
+\end{code}