#include "HsVersions.h"
-import InstEnv ( InstEnv ) -- Reqd for 4.02; InstEnv is a synonym, and
- -- 4.02 doesn't "see" it soon enough
-
-import Type ( tyVarsOfTypes )
-import Class ( classInstEnv, classExtraBigSig )
-import Unify ( matchTys )
-import Subst ( mkSubst, substTy )
+import Name ( Name )
+import Class ( Class, FunDep, className )
+import Unify ( unifyTyListsX )
+import Subst ( mkSubst, emptyInScopeSet, substTy )
+import TcEnv ( tcGetInstEnv, classInstEnv )
import TcMonad
-import TcType ( zonkTcType, zonkTcTypes )
+import TcType ( TcType, TcTyVarSet, zonkTcType )
import TcUnify ( unifyTauTyLists )
-import Inst ( Inst, LookupInstResult(..),
- lookupInst, isDict, getDictClassTys, getFunDepsOfLIE,
- zonkLIE {- for debugging -} )
-import VarSet ( emptyVarSet )
-import VarEnv ( emptyVarEnv )
+import Inst ( LIE, getFunDepsOfLIE, getIPsOfLIE )
+import VarSet ( VarSet, emptyVarSet, unionVarSet )
import FunDeps ( instantiateFdClassTys )
-import Bag ( bagToList )
-import Outputable
-import List ( elemIndex )
+import List ( nub )
\end{code}
-Improvement goes here.
-
\begin{code}
-tcImprove lie = iterImprove (getFunDepsOfLIE lie)
+tcImprove :: LIE -> TcM s ()
+-- Do unifications based on functional dependencies in the LIE
+tcImprove lie
+ = tcGetInstEnv `thenNF_Tc` \ inst_env ->
+ let
+ nfdss, clas_nfdss, inst_nfdss, ip_nfdss :: [(TcTyVarSet, Name, [FunDep TcType])]
+ nfdss = ip_nfdss ++ clas_nfdss ++ inst_nfdss
+
+ cfdss :: [(Class, [FunDep TcType])]
+ cfdss = getFunDepsOfLIE lie
+ clas_nfdss = [(emptyVarSet, className c, fds) | (c,fds) <- cfdss]
+
+ classes = nub (map fst cfdss)
+ inst_nfdss = [ (free, className c, instantiateFdClassTys c ts)
+ | c <- classes,
+ (free, ts, i) <- classInstEnv inst_env c
+ ]
+
+ ip_nfdss = [(emptyVarSet, n, [([], [ty])]) | (n,ty) <- getIPsOfLIE lie]
+
+ {- Example: we have
+ class C a b c | a->b where ...
+ instance C Int Bool c
+
+ Given the LIE FD C (Int->t)
+ we get clas_nfdss = [({}, C, [Int->t, t->Int])
+ inst_nfdss = [({c}, C, [Int->Bool, Bool->Int])]
+
+ Another way would be to flatten a bit
+ we get clas_nfdss = [({}, C, Int->t), ({}, C, t->Int)]
+ inst_nfdss = [({c}, C, Int->Bool), ({c}, C, Bool->Int)]
+
+ iterImprove then matches up the C and Int, and unifies t <-> Bool
+ -}
+
+ in
+ iterImprove nfdss
+
+iterImprove :: [(VarSet, Name, [FunDep TcType])] -> TcM s ()
+iterImprove [] = returnTc ()
iterImprove cfdss
- = instImprove cfdss `thenTc` \ change1 ->
- selfImprove pairImprove cfdss `thenTc` \ change2 ->
- if change1 || change2 then
+ = selfImprove pairImprove cfdss `thenTc` \ change2 ->
+ if {- change1 || -} change2 then
iterImprove cfdss
else
returnTc ()
-instImprove (cfds@(clas, fds) : cfdss)
- = instImprove1 cfds ins
- where ins = classInstEnv clas
-instImprove [] = returnTc False
-
-instImprove1 cfds@(clas, fds1) ((free, ts, _) : ins)
- = checkFds fds1 free fds2 `thenTc` \ changed ->
- instImprove1 cfds ins `thenTc` \ rest_changed ->
- returnTc (changed || rest_changed)
- where fds2 = instantiateFdClassTys clas ts
-instImprove1 _ _ = returnTc False
+-- ZZ this will do a lot of redundant checking wrt instances
+-- it would do to make this operate over two lists, the first
+-- with only clas_nfds and ip_nfds, and the second with everything
+-- control would otherwise mimic the current loop, so that the
+-- caller could control whether the redundant inst improvements
+-- were avoided
+-- you could then also use this to check for consistency of new instances
+-- selfImprove is really just doing a cartesian product of all the fds
selfImprove f [] = returnTc False
-selfImprove f (cfds : cfdss)
- = mapTc (f cfds) cfdss `thenTc` \ changes ->
- orTc changes `thenTc` \ changed ->
- selfImprove f cfdss `thenTc` \ rest_changed ->
- returnTc (changed || rest_changed)
-
-pairImprove (clas1, fds1) (clas2, fds2)
- = if clas1 == clas2 then
- checkFds fds1 emptyVarSet fds2
+selfImprove f (nfds : nfdss)
+ = mapTc (f nfds) nfdss `thenTc` \ changes ->
+ selfImprove f nfdss `thenTc` \ rest_changed ->
+ returnTc (or changes || rest_changed)
+
+pairImprove (free1, n1, fds1) (free2, n2, fds2)
+ = if n1 == n2 then
+ checkFds (free1 `unionVarSet` free2) fds1 fds2
else
returnTc False
-checkFds [] free [] = returnTc False
-checkFds (fd1 : fd1s) free (fd2 : fd2s) =
- checkFd fd1 free fd2 `thenTc` \ change ->
- checkFds fd1s free fd2s `thenTc` \ changes ->
+checkFds free [] [] = returnTc False
+checkFds free (fd1 : fd1s) (fd2 : fd2s) =
+ checkFd free fd1 fd2 `thenTc` \ change ->
+ checkFds free fd1s fd2s `thenTc` \ changes ->
returnTc (change || changes)
--checkFds _ _ = returnTc False
-checkFd (t_x, t_y) free (s_x, s_y)
+checkFd free (t_x, t_y) (s_x, s_y)
-- we need to zonk each time because unification
-- may happen at any time
- = zonkMatchTys t_x free s_x `thenTc` \ msubst ->
+ = zonkUnifyTys free t_x s_x `thenTc` \ msubst ->
case msubst of
Just subst ->
- let s_y' = map (substTy (mkSubst emptyVarEnv subst)) s_y in
- zonkMatchTys t_y free s_y `thenTc` \ msubst2 ->
- case msubst2 of
- Just _ ->
- -- they're the same, nothing changes
- returnTc False
- Nothing ->
- unifyTauTyLists t_y s_y' `thenTc_`
- -- if we get here, something must have unified
- returnTc True
+ let full_subst = mkSubst emptyInScopeSet subst
+ t_y' = map (substTy full_subst) t_y
+ s_y' = map (substTy full_subst) s_y
+ in
+ zonkEqTys t_y' s_y' `thenTc` \ eq ->
+ if eq then
+ -- they're the same, nothing changes...
+ returnTc False
+ else
+ -- ZZ what happens if two instance vars unify?
+ unifyTauTyLists t_y' s_y' `thenTc_`
+ -- if we get here, something must have unified
+ returnTc True
Nothing ->
returnTc False
-zonkMatchTys ts1 free ts2
+zonkEqTys ts1 ts2
= mapTc zonkTcType ts1 `thenTc` \ ts1' ->
mapTc zonkTcType ts2 `thenTc` \ ts2' ->
- --returnTc (ts1' == ts2')
- case matchTys free ts2' ts1' of
- Just (subst, []) -> returnTc (Just subst)
- Nothing -> returnTc Nothing
-
-{-
-instImprove clas fds =
- pprTrace "class inst env" (ppr (clas, classInstEnv clas)) $
- zonkFunDeps fds `thenTc` \ fds' ->
- pprTrace "lIEFDs" (ppr (clas, fds')) $
- case lookupInstEnvFDs clas fds' of
- Nothing -> returnTc ()
- Just (t_y, s_y) ->
- pprTrace "lIEFDs result" (ppr (t_y, s_y)) $
- unifyTauTyLists t_y s_y
-
-lookupInstEnvFDs clas fds
- = find env
- where
- env = classInstEnv clas
- (ctvs, fds, _, _, _, _) = classExtraBigSig clas
- find [] = Nothing
- find ((tpl_tyvars, tpl, val) : rest)
- = let tplx = concatMap (\us -> thingy tpl us ctvs) (map fst fds)
- tply = concatMap (\vs -> thingy tpl vs ctvs) (map snd fds)
- in
- case matchTys tpl_tyvars tplx tysx of
- Nothing -> find rest
- Just (tenv, leftovers) ->
- let subst = mkSubst (tyVarsOfTypes tys) tenv
- in
- -- this is the list of things that
- -- need to be unified
- Just (map (substTy subst) tply, tysy)
- tysx = concatMap (\us -> thingy tys us ctvs) (map fst fds)
- tysy = concatMap (\vs -> thingy tys vs ctvs) (map snd fds)
- thingy f us ctvs
- = map (f !!) is
- where is = map (\u -> let Just i = elemIndex u ctvs in i) us
--}
-
-{-
- = let (clas, tys) = getDictClassTys dict
- in
- -- first, do instance-based improvement
- instImprove clas tys `thenTc_`
- -- OK, now do pairwise stuff
- mapTc (f clas tys) dicts `thenTc` \ changes ->
- foldrTc (\a b -> returnTc (a || b)) False changes `thenTc` \ changed ->
- allDictPairs f dicts `thenTc` \ rest_changed ->
- returnTc (changed || rest_changed)
--}
+ returnTc (ts1' == ts2')
-\end{code}
-
-Utilities:
-
-A monadic version of the standard Prelude `or' function.
-\begin{code}
-orTc bs = foldrTc (\a b -> returnTc (a || b)) False bs
+zonkUnifyTys free ts1 ts2
+ = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
+ mapTc zonkTcType ts2 `thenTc` \ ts2' ->
+ -- pprTrace "zMT" (ppr (ts1', free, ts2')) $
+ case unifyTyListsX free ts2' ts1' of
+ Just subst -> returnTc (Just subst)
+ Nothing -> returnTc Nothing
\end{code}