2 module TcImprove ( tcImprove ) where
4 #include "HsVersions.h"
6 import Type ( tyVarsOfTypes )
7 import Class ( classInstEnv, classExtraBigSig )
8 import Unify ( matchTys )
9 import Subst ( mkSubst, substTy )
11 import TcType ( zonkTcType, zonkTcTypes )
12 import TcUnify ( unifyTauTyLists )
13 import Inst ( Inst, LookupInstResult(..),
14 lookupInst, isDict, getDictClassTys, getFunDepsOfLIE,
15 zonkLIE {- for debugging -} )
16 import VarSet ( emptyVarSet )
17 import VarEnv ( emptyVarEnv )
18 import FunDeps ( instantiateFdClassTys )
19 import Bag ( bagToList )
21 import List ( elemIndex )
24 Improvement goes here.
27 tcImprove lie = iterImprove (getFunDepsOfLIE lie)
30 = instImprove cfdss `thenTc` \ change1 ->
31 selfImprove pairImprove cfdss `thenTc` \ change2 ->
32 if change1 || change2 then
37 instImprove (cfds@(clas, fds) : cfdss)
38 = instImprove1 cfds ins
39 where ins = classInstEnv clas
40 instImprove [] = returnTc False
42 instImprove1 cfds@(clas, fds1) ((free, ts, _) : ins)
43 = checkFds fds1 free fds2 `thenTc` \ changed ->
44 instImprove1 cfds ins `thenTc` \ rest_changed ->
45 returnTc (changed || rest_changed)
46 where fds2 = instantiateFdClassTys clas ts
47 instImprove1 _ _ = returnTc False
49 selfImprove f [] = returnTc False
50 selfImprove f (cfds : cfdss)
51 = mapTc (f cfds) cfdss `thenTc` \ changes ->
52 orTc changes `thenTc` \ changed ->
53 selfImprove f cfdss `thenTc` \ rest_changed ->
54 returnTc (changed || rest_changed)
56 pairImprove (clas1, fds1) (clas2, fds2)
57 = if clas1 == clas2 then
58 checkFds fds1 emptyVarSet fds2
62 checkFds [] free [] = returnTc False
63 checkFds (fd1 : fd1s) free (fd2 : fd2s) =
64 checkFd fd1 free fd2 `thenTc` \ change ->
65 checkFds fd1s free fd2s `thenTc` \ changes ->
66 returnTc (change || changes)
67 --checkFds _ _ = returnTc False
69 checkFd (t_x, t_y) free (s_x, s_y)
70 -- we need to zonk each time because unification
71 -- may happen at any time
72 = zonkMatchTys t_x free s_x `thenTc` \ msubst ->
75 let s_y' = map (substTy (mkSubst emptyVarEnv subst)) s_y in
76 zonkMatchTys t_y free s_y `thenTc` \ msubst2 ->
79 -- they're the same, nothing changes
82 unifyTauTyLists t_y s_y' `thenTc_`
83 -- if we get here, something must have unified
88 zonkMatchTys ts1 free ts2
89 = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
90 mapTc zonkTcType ts2 `thenTc` \ ts2' ->
91 --returnTc (ts1' == ts2')
92 case matchTys free ts2' ts1' of
93 Just (subst, []) -> returnTc (Just subst)
94 Nothing -> returnTc Nothing
97 instImprove clas fds =
98 pprTrace "class inst env" (ppr (clas, classInstEnv clas)) $
99 zonkFunDeps fds `thenTc` \ fds' ->
100 pprTrace "lIEFDs" (ppr (clas, fds')) $
101 case lookupInstEnvFDs clas fds' of
102 Nothing -> returnTc ()
104 pprTrace "lIEFDs result" (ppr (t_y, s_y)) $
105 unifyTauTyLists t_y s_y
107 lookupInstEnvFDs clas fds
110 env = classInstEnv clas
111 (ctvs, fds, _, _, _, _) = classExtraBigSig clas
113 find ((tpl_tyvars, tpl, val) : rest)
114 = let tplx = concatMap (\us -> thingy tpl us ctvs) (map fst fds)
115 tply = concatMap (\vs -> thingy tpl vs ctvs) (map snd fds)
117 case matchTys tpl_tyvars tplx tysx of
119 Just (tenv, leftovers) ->
120 let subst = mkSubst (tyVarsOfTypes tys) tenv
122 -- this is the list of things that
123 -- need to be unified
124 Just (map (substTy subst) tply, tysy)
125 tysx = concatMap (\us -> thingy tys us ctvs) (map fst fds)
126 tysy = concatMap (\vs -> thingy tys vs ctvs) (map snd fds)
129 where is = map (\u -> let Just i = elemIndex u ctvs in i) us
133 = let (clas, tys) = getDictClassTys dict
135 -- first, do instance-based improvement
136 instImprove clas tys `thenTc_`
137 -- OK, now do pairwise stuff
138 mapTc (f clas tys) dicts `thenTc` \ changes ->
139 foldrTc (\a b -> returnTc (a || b)) False changes `thenTc` \ changed ->
140 allDictPairs f dicts `thenTc` \ rest_changed ->
141 returnTc (changed || rest_changed)
148 A monadic version of the standard Prelude `or' function.
150 orTc bs = foldrTc (\a b -> returnTc (a || b)) False bs