Some refactoring and simplification in TcInteract.occurCheck
[ghc-hetmet.git] / compiler / typecheck / TcInteract.lhs
index d97002b..f0edcc9 100644 (file)
@@ -15,6 +15,7 @@ import Type
 import TypeRep 
 
 import Id 
+import VarEnv
 import Var
 
 import TcType
@@ -608,7 +609,7 @@ solveWithIdentity :: InertSet
 -- See [New Wanted Superclass Work] to see why solveWithIdentity 
 --     must work for Derived as well as Wanted
 solveWithIdentity inerts cv gw tv xi 
-  = do { tybnds <- getTcSTyBindsBag 
+  = do { tybnds <- getTcSTyBindsMap
        ; case occurCheck tybnds inerts tv xi of 
            Nothing              -> return Nothing 
            Just (xi_unflat,coi) -> solve_with xi_unflat coi }
@@ -640,7 +641,7 @@ solveWithIdentity inerts cv gw tv xi
                      -- See Note [Avoid double unifications] 
            ; return (Just cts) }
 
-occurCheck :: Bag (TcTyVar, TcType) -> InertSet
+occurCheck :: VarEnv (TcTyVar, TcType) -> InertSet
            -> TcTyVar -> TcType -> Maybe (TcType,CoercionI) 
 -- Traverse @ty@ to make sure that @tv@ does not appear under some flatten skolem. 
 -- If it appears under some flatten skolem look in that flatten skolem equivalence class 
@@ -651,8 +652,8 @@ occurCheck :: Bag (TcTyVar, TcType) -> InertSet
 --       coi :: ty' ~ ty 
 -- NB: The returned type ty' may not be flat!
 
-occurCheck ty_binds_bag inerts tv ty
-  = ok emptyVarSet ty 
+occurCheck ty_binds inerts the_tv the_ty
+  = ok emptyVarSet the_ty 
   where 
     -- If (fsk `elem` bad) then tv occurs in any rendering
     -- of the type under the expansion of fsk
@@ -677,32 +678,18 @@ occurCheck ty_binds_bag inerts tv ty
       = Just (ForAllTy tv1 ty1', mkForAllTyCoI tv1 coi) 
 
     -- Variable cases 
-    ok _bad this_ty@(TyVarTy tv') 
-      | not $ isTcTyVar tv' = Just (this_ty, IdCo this_ty) -- Bound variable
-      | tv == tv'           = Nothing                      -- Occurs check error
-  
-    ok bad (TyVarTy fsk) 
-      | FlatSkol zty <- tcTyVarDetails fsk 
-      = if fsk `elemVarSet` bad then 
-            -- its type has been checked 
-            go_down_eq_class bad $ getFskEqClass inerts fsk 
-        else 
-            -- its type is not yet checked
-            case ok bad zty of 
-              Nothing -> go_down_eq_class (bad `extendVarSet` fsk) $ 
-                         getFskEqClass inerts fsk 
-              Just (zty',ico) -> Just (zty',ico) 
+    ok bad this_ty@(TyVarTy tv) 
+      | tv == the_tv                                   = Nothing             -- Occurs check error
+      | not (isTcTyVar tv)                     = Just (this_ty, IdCo this_ty) -- Bound var
+      | FlatSkol zty <- tcTyVarDetails tv       = ok_fsk bad tv zty
+      | Just (_,ty) <- lookupVarEnv ty_binds tv = ok bad ty 
+      | otherwise                               = Just (this_ty, IdCo this_ty)
 
     -- Check if there exists a ty bind already, as a result of sneaky unification. 
-    ok bad this_ty@(TyVarTy tv0) 
-      = case Bag.foldlBag find_bind Nothing ty_binds_bag of 
-          Nothing -> Just (this_ty, IdCo this_ty)
-          Just ty0 -> ok bad ty0 
-      where find_bind Nothing (tvx,tyx) | tv0 == tvx = Just tyx
-            find_bind m _ = m 
     -- Fall through
     ok _bad _ty = Nothing 
 
+    -----------
     ok_pred bad (ClassP cn tys)
       | Just tys_cois <- allMaybes $ map (ok bad) tys 
       = let (tys', cois') = unzip tys_cois 
@@ -715,13 +702,25 @@ occurCheck ty_binds_bag inerts tv ty
       = Just (EqPred ty1' ty2', mkEqPredCoI coi1 coi2) 
     ok_pred _ _ = Nothing 
 
-    go_down_eq_class _bad_tvs [] = Nothing 
-    go_down_eq_class bad_tvs ((fsk1,co1):rest) 
-     | fsk1 `elemVarSet` bad_tvs = go_down_eq_class bad_tvs rest
-     | otherwise 
-     = case ok bad_tvs (TyVarTy fsk1) of 
-          Nothing -> go_down_eq_class (bad_tvs `extendVarSet` fsk1) rest 
-          Just (ty1,co1i') -> Just (ty1, mkTransCoI co1i' (ACo co1)) 
+    -----------
+    ok_fsk bad fsk zty
+      | fsk `elemVarSet` bad 
+            -- We are already trying to find a rendering of fsk, 
+           -- and to do that it seems we need a rendering, so fail
+      = Nothing
+      | otherwise 
+      = firstJusts (ok new_bad zty : map (go_under_fsk new_bad) fsk_equivs)
+      where
+        fsk_equivs = getFskEqClass inerts fsk 
+        new_bad    = bad `extendVarSetList` (fsk : map fst fsk_equivs)
+
+    -----------
+    go_under_fsk bad_tvs (fsk,co)
+      | FlatSkol zty <- tcTyVarDetails fsk
+      = case ok bad_tvs zty of
+           Nothing        -> Nothing
+           Just (ty,coi') -> Just (ty, mkTransCoI coi' (ACo co)) 
+      | otherwise = pprPanic "go_down_equiv" (ppr fsk)
 \end{code}