Type families: fixed all non-termination in the testsuite
[ghc-hetmet.git] / compiler / typecheck / TcTyFuns.lhs
index 981845a..0b026e1 100644 (file)
@@ -30,6 +30,7 @@ import Type
 import TypeRep         ( Type(..) )
 import TyCon
 import HsSyn
+import Id
 import VarEnv
 import VarSet
 import Var
@@ -319,6 +320,11 @@ propagateEqs eqCfg@(EqConfig {eqs = todoEqs})
 -- set of instances are the locals (without equalities) and the second set are
 -- all residual wanteds, including equalities. 
 --
+-- Remove all identity dictinary bindings (i.e., those whose source and target
+-- dictionary are the same).  This is important for termination, as
+-- TcSimplify.reduceContext takes the presence of dictionary bindings as an
+-- indicator that there was some improvement.
+--
 finaliseEqsAndDicts :: EqConfig 
                     -> TcM ([Inst], [Inst], TcDictBinds, Bool)
 finaliseEqsAndDicts (EqConfig { eqs     = eqs
@@ -328,11 +334,18 @@ finaliseEqsAndDicts (EqConfig { eqs     = eqs
                               })
   = do { (eqs', subst_binds, locals', wanteds') <- substitute eqs locals wanteds
        ; (eqs'', improved) <- instantiateAndExtract eqs'
-       ; return (locals', 
-                 eqs'' ++ wanteds', 
-                 subst_binds `unionBags` binds, 
-                 improved)
+       ; final_binds <- filterM nonTrivialDictBind $
+                          bagToList (subst_binds `unionBags` binds)
+       ; return (locals', eqs'' ++ wanteds', listToBag final_binds, improved)
        }
+  where
+    nonTrivialDictBind (L _ (VarBind { var_id = ide1
+                                     , var_rhs = L _ (HsWrap _ (HsVar ide2))}))
+      = do { ty1 <- zonkTcType (idType ide1)
+           ; ty2 <- zonkTcType (idType ide2)
+           ; return $ not (ty1 `tcEqType` ty2)
+           }
+    nonTrivialDictBind _ = return True
 \end{code}
 
 
@@ -415,6 +428,10 @@ In a corresponding manner, normDict normalises class dictionaries by
 extracting any synonym family applications and generation appropriate normal
 equalities. 
 
+Whenever we encounter a loopy equality (of the form a ~ T .. (F ...a...) ...),
+we drop that equality and raise an error if it is a wanted or a warning if it
+is a local.
+
 \begin{code}
 normEqInst :: Inst -> TcM ([RewriteInst], TyVarSet)
 -- Normalise one equality.
@@ -449,8 +466,18 @@ normEqInst inst
                  eqTys     = (ty1', ty2')
            ; (co', ty12_eqs') <- adjustCoercions co rewriteCo eqTys ty12_eqs
            ; eqs <- checkOrientation ty1' ty2' co' inst
-           ; return $ (eqs ++ ty12_eqs',
-                       ty1_skolems `unionVarSet` ty2_skolems)
+           ; if isLoopyEquality eqs ty12_eqs' 
+             then do { if isWantedCo (tci_co inst)
+                       then
+                          addErrCtxt (ptext (sLit "Rejecting loopy equality")) $
+                            eqInstMisMatch inst
+                       else
+                         warnDroppingLoopyEquality ty1 ty2
+                     ; return ([], emptyVarSet)         -- drop the equality
+                     }
+             else
+               return (eqs ++ ty12_eqs',
+                      ty1_skolems `unionVarSet` ty2_skolems)
            }
 
     mkRewriteFam con args ty2 co
@@ -474,6 +501,18 @@ normEqInst inst
                        unionVarSets (ty2_skolems:args_skolemss))
            }
 
+    -- If the original equality has the form a ~ T .. (F ...a...) ..., we will
+    -- have a variable equality with 'a' on the lhs as the first equality.
+    -- Then, check whether 'a' occurs in the lhs of any family equality
+    -- generated by flattening.
+    isLoopyEquality (RewriteVar {rwi_var = tv}:_) eqs
+      = any inRewriteFam eqs
+      where
+        inRewriteFam (RewriteFam {rwi_args = args}) 
+          = tv `elemVarSet` tyVarsOfTypes args
+        inRewriteFam _ = False
+    isLoopyEquality _ _ = False
+
 normDict :: Bool -> Inst -> TcM (Inst, [RewriteInst], TcDictBinds, TyVarSet)
 -- Normalise one dictionary or IP constraint.
 normDict isWanted inst@(Dict {tci_pred = ClassP clas args})
@@ -482,9 +521,14 @@ normDict isWanted inst@(Dict {tci_pred = ClassP clas args})
        ; let rewriteCo = PredTy $ ClassP clas cargs
              eqs       = concat args_eqss
              pred'     = ClassP clas args'
-       ; (inst', bind, eqs') <- mkDictBind inst isWanted rewriteCo pred' eqs
+       ; if null eqs
+         then  -- don't generate a binding if there is nothing to flatten
+           return (inst, [], emptyBag, emptyVarSet)
+         else do {
+       ; (inst', bind) <- mkDictBind inst isWanted rewriteCo pred'
+       ; eqs' <- if isWanted then return eqs else mapM wantedToLocal eqs
        ; return (inst', eqs', bind, unionVarSets args_skolemss)
-       }
+       }}
 normDict isWanted inst
   = return (inst, [], emptyBag, emptyVarSet)
 -- !!!TODO: Still need to normalise IP constraints.
@@ -660,13 +704,9 @@ mkDictBind :: Inst                 -- original instance
            -> Bool                 -- is this a wanted contraint?
            -> Coercion             -- coercion witnessing the rewrite
            -> PredType             -- coerced predicate
-           -> [RewriteInst]        -- equalities from flattening
            -> TcM (Inst,           -- new inst
-                   TcDictBinds,    -- binding for coerced dictionary
-                   [RewriteInst])  -- final equalities from flattening
-mkDictBind dict _isWanted _rewriteCo _pred []
-  = return (dict, emptyBag, [])    -- don't generate binding for an id coercion
-mkDictBind dict isWanted rewriteCo pred eqs
+                   TcDictBinds)    -- binding for coerced dictionary
+mkDictBind dict isWanted rewriteCo pred
   = do { dict' <- newDictBndr loc pred
          -- relate the old inst to the new one
          -- target_dict = source_dict `cast` st_co
@@ -683,8 +723,7 @@ mkDictBind dict isWanted rewriteCo pred eqs
              cast_expr = HsWrap (WpCast st_co) expr
              rhs       = L (instLocSpan loc) cast_expr
              binds     = instToDictBind target_dict rhs
-       ; eqs' <- if isWanted then return eqs else mapM wantedToLocal eqs
-       ; return (dict', binds, eqs')
+       ; return (dict', binds)
        }
   where
     loc = tci_loc dict
@@ -955,7 +994,9 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
       = return (res, binds, locals, wanteds)
     subst (eq@(RewriteVar {rwi_var = tv, rwi_right = ty, rwi_co = co}):eqs) 
           res binds locals wanteds
-      = do { let coSubst = zipOpenTvSubst [tv] [eqInstCoType co]
+      = do { traceTc $ ptext (sLit "TcTyFuns.substitute:") <+> ppr tv <+>
+                       ptext (sLit "->") <+> ppr ty
+           ; let coSubst = zipOpenTvSubst [tv] [eqInstCoType co]
                  tySubst = zipOpenTvSubst [tv] [ty]
            ; eqs'               <- mapM (substEq eq coSubst tySubst) eqs
            ; res'               <- mapM (substEq eq coSubst tySubst) res
@@ -1004,7 +1045,7 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
       = do { let co1Subst = mkSymCoercion $ 
                               PredTy (substPred coSubst (tci_pred dict))
                  pred'    = substPred tySubst (tci_pred dict)
-           ; (dict', binds, _) <- mkDictBind dict isWanted co1Subst pred' []
+           ; (dict', binds) <- mkDictBind dict isWanted co1Subst pred'
            ; return (binds, dict')
            }
 
@@ -1206,3 +1247,16 @@ misMatchMsg env0 (ty_act, ty_exp)
 
     ppr_extra env _ty = (env, empty)           -- Normal case
 \end{code}
+
+Warn of loopy local equalities that were dropped.
+
+\begin{code}
+warnDroppingLoopyEquality :: TcType -> TcType -> TcM ()
+warnDroppingLoopyEquality ty1 ty2 
+  = do { env0 <- tcInitTidyEnv
+       ; ty1 <- zonkTcType ty1
+       ; ty2 <- zonkTcType ty2
+       ; addWarnTc $ hang (ptext (sLit "Dropping loopy given equality"))
+                      2 (ppr ty1 <+> text "~" <+> ppr ty2)
+       }
+\end{code}