Some refactoring and simplification in TcInteract.occurCheck
authorsimonpj@microsoft.com <unknown>
Thu, 7 Oct 2010 16:35:00 +0000 (16:35 +0000)
committersimonpj@microsoft.com <unknown>
Thu, 7 Oct 2010 16:35:00 +0000 (16:35 +0000)
compiler/cmm/CmmCPS.hs
compiler/main/CodeOutput.lhs
compiler/main/StaticFlags.hs
compiler/typecheck/TcErrors.lhs
compiler/typecheck/TcInteract.lhs
compiler/typecheck/TcSMonad.lhs
compiler/utils/Maybes.lhs

index 17c11ce..7bfdf84 100644 (file)
@@ -42,7 +42,7 @@ cmmCPS :: DynFlags -- ^ Dynamic flags: -dcmm-lint -ddump-cps-cmm
 cmmCPS dflags cmm_with_calls
   = do { when (dopt Opt_DoCmmLinting dflags) $
               do showPass dflags "CmmLint"
-                 case firstJust $ map cmmLint cmm_with_calls of
+                 case firstJusts $ map cmmLint cmm_with_calls of
                    Just err -> do printDump err
                                   ghcExit dflags 1
                    Nothing  -> return ()
index bc2dd1e..921bbde 100644 (file)
@@ -34,7 +34,7 @@ import Config
 import ErrUtils                ( dumpIfSet_dyn, showPass, ghcExit )
 import Outputable
 import Module
-import Maybes          ( firstJust )
+import Maybes          ( firstJusts )
 
 import Control.Exception
 import Control.Monad
@@ -69,7 +69,7 @@ codeOutput dflags this_mod location foreign_stubs pkg_deps flat_abstractC
     do { when (dopt Opt_DoCmmLinting dflags) $ do
                { showPass dflags "CmmLint"
                ; let lints = map cmmLint flat_abstractC
-               ; case firstJust lints of
+               ; case firstJusts lints of
                        Just err -> do { printDump err
                                       ; ghcExit dflags 1
                                       }
index bc2ae38..6e9e333 100644 (file)
@@ -84,7 +84,7 @@ module StaticFlags (
 import Config
 import FastString
 import Util
-import Maybes          ( firstJust )
+import Maybes          ( firstJusts )
 import Panic
 
 import Data.Maybe       ( listToMaybe )
@@ -138,7 +138,7 @@ lookUp     sw = sw `elem` packed_static_opts
 -- (lookup_str "foo") looks for the flag -foo=X or -fooX, 
 -- and returns the string X
 lookup_str sw 
-   = case firstJust (map (stripPrefix sw) staticFlags) of
+   = case firstJusts (map (stripPrefix sw) staticFlags) of
        Just ('=' : str) -> Just str
        Just str         -> Just str
        Nothing          -> Nothing     
index 9531a50..293b3a7 100644 (file)
@@ -721,8 +721,8 @@ wrapEqErrTcS fl ty1 ty2 thing_inside
        ; wrapErrTcS $ setCtFlavorLoc fl $ 
     do {   -- Apply the current substitition
            -- and zonk to get rid of flatten-skolems
-       ; ty_binds_bag <- readTcRef ty_binds_var
-       ; let subst = mkOpenTvSubst (mkVarEnv (bagToList ty_binds_bag))
+       ; ty_binds_map <- readTcRef ty_binds_var
+       ; let subst = mkOpenTvSubst (mapVarEnv snd ty_binds_map)
        ; env0 <- tcInitTidyEnv 
        ; (env1, ty1) <- zonkSubstTidy env0 subst ty1
        ; (env2, ty2) <- zonkSubstTidy env1 subst ty2
index d97002b..f0edcc9 100644 (file)
@@ -15,6 +15,7 @@ import Type
 import TypeRep 
 
 import Id 
+import VarEnv
 import Var
 
 import TcType
@@ -608,7 +609,7 @@ solveWithIdentity :: InertSet
 -- See [New Wanted Superclass Work] to see why solveWithIdentity 
 --     must work for Derived as well as Wanted
 solveWithIdentity inerts cv gw tv xi 
-  = do { tybnds <- getTcSTyBindsBag 
+  = do { tybnds <- getTcSTyBindsMap
        ; case occurCheck tybnds inerts tv xi of 
            Nothing              -> return Nothing 
            Just (xi_unflat,coi) -> solve_with xi_unflat coi }
@@ -640,7 +641,7 @@ solveWithIdentity inerts cv gw tv xi
                      -- See Note [Avoid double unifications] 
            ; return (Just cts) }
 
-occurCheck :: Bag (TcTyVar, TcType) -> InertSet
+occurCheck :: VarEnv (TcTyVar, TcType) -> InertSet
            -> TcTyVar -> TcType -> Maybe (TcType,CoercionI) 
 -- Traverse @ty@ to make sure that @tv@ does not appear under some flatten skolem. 
 -- If it appears under some flatten skolem look in that flatten skolem equivalence class 
@@ -651,8 +652,8 @@ occurCheck :: Bag (TcTyVar, TcType) -> InertSet
 --       coi :: ty' ~ ty 
 -- NB: The returned type ty' may not be flat!
 
-occurCheck ty_binds_bag inerts tv ty
-  = ok emptyVarSet ty 
+occurCheck ty_binds inerts the_tv the_ty
+  = ok emptyVarSet the_ty 
   where 
     -- If (fsk `elem` bad) then tv occurs in any rendering
     -- of the type under the expansion of fsk
@@ -677,32 +678,18 @@ occurCheck ty_binds_bag inerts tv ty
       = Just (ForAllTy tv1 ty1', mkForAllTyCoI tv1 coi) 
 
     -- Variable cases 
-    ok _bad this_ty@(TyVarTy tv') 
-      | not $ isTcTyVar tv' = Just (this_ty, IdCo this_ty) -- Bound variable
-      | tv == tv'           = Nothing                      -- Occurs check error
-  
-    ok bad (TyVarTy fsk) 
-      | FlatSkol zty <- tcTyVarDetails fsk 
-      = if fsk `elemVarSet` bad then 
-            -- its type has been checked 
-            go_down_eq_class bad $ getFskEqClass inerts fsk 
-        else 
-            -- its type is not yet checked
-            case ok bad zty of 
-              Nothing -> go_down_eq_class (bad `extendVarSet` fsk) $ 
-                         getFskEqClass inerts fsk 
-              Just (zty',ico) -> Just (zty',ico) 
+    ok bad this_ty@(TyVarTy tv) 
+      | tv == the_tv                                   = Nothing             -- Occurs check error
+      | not (isTcTyVar tv)                     = Just (this_ty, IdCo this_ty) -- Bound var
+      | FlatSkol zty <- tcTyVarDetails tv       = ok_fsk bad tv zty
+      | Just (_,ty) <- lookupVarEnv ty_binds tv = ok bad ty 
+      | otherwise                               = Just (this_ty, IdCo this_ty)
 
     -- Check if there exists a ty bind already, as a result of sneaky unification. 
-    ok bad this_ty@(TyVarTy tv0) 
-      = case Bag.foldlBag find_bind Nothing ty_binds_bag of 
-          Nothing -> Just (this_ty, IdCo this_ty)
-          Just ty0 -> ok bad ty0 
-      where find_bind Nothing (tvx,tyx) | tv0 == tvx = Just tyx
-            find_bind m _ = m 
     -- Fall through
     ok _bad _ty = Nothing 
 
+    -----------
     ok_pred bad (ClassP cn tys)
       | Just tys_cois <- allMaybes $ map (ok bad) tys 
       = let (tys', cois') = unzip tys_cois 
@@ -715,13 +702,25 @@ occurCheck ty_binds_bag inerts tv ty
       = Just (EqPred ty1' ty2', mkEqPredCoI coi1 coi2) 
     ok_pred _ _ = Nothing 
 
-    go_down_eq_class _bad_tvs [] = Nothing 
-    go_down_eq_class bad_tvs ((fsk1,co1):rest) 
-     | fsk1 `elemVarSet` bad_tvs = go_down_eq_class bad_tvs rest
-     | otherwise 
-     = case ok bad_tvs (TyVarTy fsk1) of 
-          Nothing -> go_down_eq_class (bad_tvs `extendVarSet` fsk1) rest 
-          Just (ty1,co1i') -> Just (ty1, mkTransCoI co1i' (ACo co1)) 
+    -----------
+    ok_fsk bad fsk zty
+      | fsk `elemVarSet` bad 
+            -- We are already trying to find a rendering of fsk, 
+           -- and to do that it seems we need a rendering, so fail
+      = Nothing
+      | otherwise 
+      = firstJusts (ok new_bad zty : map (go_under_fsk new_bad) fsk_equivs)
+      where
+        fsk_equivs = getFskEqClass inerts fsk 
+        new_bad    = bad `extendVarSetList` (fsk : map fst fsk_equivs)
+
+    -----------
+    go_under_fsk bad_tvs (fsk,co)
+      | FlatSkol zty <- tcTyVarDetails fsk
+      = case ok bad_tvs zty of
+           Nothing        -> Nothing
+           Just (ty,coi') -> Just (ty, mkTransCoI coi' (ACo co)) 
+      | otherwise = pprPanic "go_down_equiv" (ppr fsk)
 \end{code}
 
 
index f8b357a..a71548c 100644 (file)
@@ -31,7 +31,7 @@ module TcSMonad (
  
     getInstEnvs, getFamInstEnvs,                -- Getting the environments 
     getTopEnv, getGblEnv, getTcEvBinds, getUntouchablesTcS,
-    getTcEvBindsBag, getTcSContext, getTcSTyBinds, getTcSTyBindsBag,
+    getTcEvBindsBag, getTcSContext, getTcSTyBinds, getTcSTyBindsMap,
 
 
     newFlattenSkolemTy,                         -- Flatten skolems 
@@ -87,6 +87,7 @@ import TypeRep
 
 import Name
 import Var
+import VarEnv
 import Outputable
 import Bag
 import MonadUtils
@@ -336,7 +337,7 @@ data TcSEnv
       tcs_ev_binds :: EvBindsVar,
           -- Evidence bindings
 
-      tcs_ty_binds :: IORef (Bag (TcTyVar, TcType)),
+      tcs_ty_binds :: IORef (TyVarEnv (TcTyVar, TcType)),
           -- Global type bindings
 
       tcs_context :: SimplContext
@@ -415,7 +416,7 @@ runTcS :: SimplContext
        -> TcS a                       -- What to run
        -> TcM (a, Bag EvBind)
 runTcS context untouch tcs 
-  = do { ty_binds_var <- TcM.newTcRef emptyBag
+  = do { ty_binds_var <- TcM.newTcRef emptyVarEnv
        ; ev_binds_var@(EvBindsVar evb_ref _) <- TcM.newTcEvBinds
        ; let env = TcSEnv { tcs_ev_binds = ev_binds_var
                           , tcs_ty_binds = ty_binds_var
@@ -426,7 +427,7 @@ runTcS context untouch tcs
 
             -- Perform the type unifications required
        ; ty_binds <- TcM.readTcRef ty_binds_var
-       ; mapBagM_ do_unification ty_binds
+       ; mapM_ do_unification (varEnvElts ty_binds)
 
              -- And return
        ; ev_binds <- TcM.readTcRef evb_ref
@@ -454,7 +455,7 @@ tryTcS :: TcTyVarSet -> TcS a -> TcS a
 -- Like runTcS, but from within the TcS monad 
 -- Ignore all the evidence generated, and do not affect caller's evidence!
 tryTcS untch tcs 
-  = TcS (\env -> do { ty_binds_var <- TcM.newTcRef emptyBag
+  = TcS (\env -> do { ty_binds_var <- TcM.newTcRef emptyVarEnv
                     ; ev_binds_var <- TcM.newTcEvBinds
                     ; let env1 = env { tcs_ev_binds = ev_binds_var
                                      , tcs_ty_binds = ty_binds_var }
@@ -472,11 +473,11 @@ getTcSContext = TcS (return . tcs_context)
 getTcEvBinds :: TcS EvBindsVar
 getTcEvBinds = TcS (return . tcs_ev_binds) 
 
-getTcSTyBinds :: TcS (IORef (Bag (TcTyVar, TcType)))
+getTcSTyBinds :: TcS (IORef (TyVarEnv (TcTyVar, TcType)))
 getTcSTyBinds = TcS (return . tcs_ty_binds)
 
-getTcSTyBindsBag :: TcS (Bag (TcTyVar, TcType)) 
-getTcSTyBindsBag = getTcSTyBinds >>= wrapTcS . (TcM.readTcRef) 
+getTcSTyBindsMap :: TcS (TyVarEnv (TcTyVar, TcType)) 
+getTcSTyBindsMap = getTcSTyBinds >>= wrapTcS . (TcM.readTcRef) 
 
 
 getTcEvBindsBag :: TcS EvBindMap
@@ -499,7 +500,7 @@ setWantedTyBind tv ty
   = do { ref <- getTcSTyBinds
        ; wrapTcS $ 
          do { ty_binds <- TcM.readTcRef ref
-            ; TcM.writeTcRef ref (ty_binds `snocBag` (tv,ty)) } }
+            ; TcM.writeTcRef ref (extendVarEnv ty_binds tv (tv,ty)) } }
 
 setIPBind :: EvVar -> EvTerm -> TcS () 
 setIPBind = setEvBind 
index 1f443db..39e6185 100644 (file)
@@ -14,7 +14,7 @@ module Maybes (
         orElse,
         mapCatMaybes,
         allMaybes,
-        firstJust,
+        firstJust, firstJusts,
         expectJust,
         maybeToBool,
 
@@ -46,12 +46,14 @@ allMaybes (Just x  : ms) = case allMaybes ms of
                            Nothing -> Nothing
                            Just xs -> Just (x:xs)
 
+firstJust :: Maybe a -> Maybe a -> Maybe a
+firstJust (Just a) _ = Just a
+firstJust Nothing  b = b
+
 -- | Takes a list of @Maybes@ and returns the first @Just@ if there is one, or
 -- @Nothing@ otherwise.
-firstJust :: [Maybe a] -> Maybe a
-firstJust [] = Nothing
-firstJust (Just x  : _)  = Just x
-firstJust (Nothing : ms) = firstJust ms
+firstJusts :: [Maybe a] -> Maybe a
+firstJusts = foldr firstJust Nothing
 \end{code}
 
 \begin{code}
@@ -70,6 +72,7 @@ mapCatMaybes f (x:xs) = case f x of
 \end{code}
 
 \begin{code}
+
 orElse :: Maybe a -> a -> a
 (Just x) `orElse` _ = x
 Nothing  `orElse` y = y