74f38b997d7480d14220b910d0512efcc618a46e
[ghc-hetmet.git] / ghc / compiler / typecheck / TcImprove.lhs
1 \begin{code}
2 module TcImprove ( tcImprove ) where
3
4 #include "HsVersions.h"
5
6 import Name             ( Name )
7 import Class            ( Class, FunDep, className, classInstEnv, classExtraBigSig )
8 import Unify            ( unifyTyListsX, matchTys )
9 import Subst            ( mkSubst, substTy )
10 import TcMonad
11 import TcType           ( TcType, TcTyVar, TcTyVarSet, zonkTcType, zonkTcTypes )
12 import TcUnify          ( unifyTauTyLists )
13 import Inst             ( LIE, Inst, LookupInstResult(..),
14                           lookupInst, getFunDepsOfLIE, getIPsOfLIE,
15                           zonkLIE, zonkFunDeps {- for debugging -} )
16 import InstEnv          ( InstEnv )             -- Reqd for 4.02; InstEnv is a synonym, and
17                                                 -- 4.02 doesn't "see" it soon enough
18 import VarSet           ( VarSet, emptyVarSet, unionVarSet )
19 import VarEnv           ( emptyVarEnv )
20 import FunDeps          ( instantiateFdClassTys )
21 import Outputable
22 import List             ( elemIndex, nub )
23 \end{code}
24
25 \begin{code}
26 tcImprove :: LIE -> TcM s ()
27 -- Do unifications based on functional dependencies in the LIE
28 tcImprove lie 
29   | null nfdss = returnTc ()
30   | otherwise  = iterImprove nfdss
31   where
32         nfdss, clas_nfdss, inst_nfdss, ip_nfdss :: [(TcTyVarSet, Name, [FunDep TcType])]
33         nfdss = ip_nfdss ++ clas_nfdss ++ inst_nfdss
34
35         cfdss :: [(Class, [FunDep TcType])]
36         cfdss = getFunDepsOfLIE lie
37         clas_nfdss = map (\(c, fds) -> (emptyVarSet, className c, fds)) cfdss
38
39         classes = nub (map fst cfdss)
40         inst_nfdss = concatMap getInstNfdssOf classes
41
42         ips = getIPsOfLIE lie
43         ip_nfdss = map (\(n, ty) -> (emptyVarSet, n, [([], [ty])])) ips
44
45 {- Example: we have
46         class C a b c  |  a->b where ...
47         instance C Int Bool c 
48
49    Given the LIE        FD C (Int->t)
50    we get       clas_nfdss = [({}, C, [Int->t,     t->Int])
51                 inst_nfdss = [({c}, C, [Int->Bool, Bool->Int])]
52
53    Another way would be to flatten a bit
54    we get       clas_nfdss = [({}, C, Int->t), ({}, C, t->Int)]
55                 inst_nfdss = [({c}, C, Int->Bool), ({c}, C, Bool->Int)]
56
57    iterImprove then matches up the C and Int, and unifies t <-> Bool
58 -}
59
60 getInstNfdssOf :: Class -> [(TcTyVarSet, Name, [FunDep TcType])]
61 getInstNfdssOf clas 
62   = [ (free, nm, instantiateFdClassTys clas ts)
63     | (free, ts, i) <- classInstEnv clas
64     ]
65   where
66         nm = className clas
67
68 iterImprove :: [(VarSet, Name, [FunDep TcType])] -> TcM s ()
69 iterImprove [] = returnTc ()
70 iterImprove cfdss
71   = selfImprove pairImprove cfdss       `thenTc` \ change2 ->
72     if {- change1 || -} change2 then
73         iterImprove cfdss
74     else
75         returnTc ()
76
77 -- ZZ this will do a lot of redundant checking wrt instances
78 -- it would do to make this operate over two lists, the first
79 -- with only clas_nfds and ip_nfds, and the second with everything
80 -- control would otherwise mimic the current loop, so that the
81 -- caller could control whether the redundant inst improvements
82 -- were avoided
83 -- you could then also use this to check for consistency of new instances
84
85 -- selfImprove is really just doing a cartesian product of all the fds
86 selfImprove f [] = returnTc False
87 selfImprove f (nfds : nfdss)
88   = mapTc (f nfds) nfdss        `thenTc` \ changes ->
89     selfImprove f nfdss         `thenTc` \ rest_changed ->
90     returnTc (or changes || rest_changed)
91
92 pairImprove (free1, n1, fds1) (free2, n2, fds2)
93   = if n1 == n2 then
94         checkFds (free1 `unionVarSet` free2) fds1 fds2
95     else
96         returnTc False
97
98 checkFds free [] [] = returnTc False
99 checkFds free (fd1 : fd1s) (fd2 : fd2s) =
100     checkFd free fd1 fd2        `thenTc` \ change ->
101     checkFds free fd1s fd2s     `thenTc` \ changes ->
102     returnTc (change || changes)
103 --checkFds _ _ = returnTc False
104
105 checkFd free (t_x, t_y) (s_x, s_y)
106   -- we need to zonk each time because unification
107   -- may happen at any time
108   = zonkUnifyTys free t_x s_x `thenTc` \ msubst ->
109     case msubst of
110       Just subst ->
111         let t_y' = map (substTy (mkSubst emptyVarEnv subst)) t_y
112             s_y' = map (substTy (mkSubst emptyVarEnv subst)) s_y
113         in
114             zonkEqTys t_y' s_y' `thenTc` \ eq ->
115             if eq then
116                 -- they're the same, nothing changes...
117                 returnTc False
118             else
119                 -- ZZ what happens if two instance vars unify?
120                 unifyTauTyLists t_y' s_y' `thenTc_`
121                 -- if we get here, something must have unified
122                 returnTc True
123       Nothing ->
124         returnTc False
125
126 zonkEqTys ts1 ts2
127   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
128     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
129     returnTc (ts1' == ts2')
130
131 zonkMatchTys ts1 free ts2
132   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
133     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
134     -- pprTrace "zMT" (ppr (ts1', free, ts2')) $
135     case matchTys free ts2' ts1' of
136       Just (subst, []) -> -- pprTrace "zMT match!" empty $
137                           returnTc (Just subst)
138       Nothing -> returnTc Nothing
139
140 zonkUnifyTys free ts1 ts2
141   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
142     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
143     -- pprTrace "zMT" (ppr (ts1', free, ts2')) $
144     case unifyTyListsX free ts2' ts1' of
145       Just subst -> returnTc (Just subst)
146       Nothing    -> returnTc Nothing
147 \end{code}