[project @ 2000-07-07 12:13:43 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcImprove.lhs
1 \begin{code}
2 module TcImprove ( tcImprove ) where
3
4 #include "HsVersions.h"
5
6 import Name             ( Name )
7 import Class            ( Class, FunDep, className, classExtraBigSig )
8 import Unify            ( unifyTyListsX, matchTys )
9 import Subst            ( mkSubst, substTy )
10 import TcEnv            ( tcGetInstEnv, classInstEnv )
11 import TcMonad
12 import TcType           ( TcType, TcTyVar, TcTyVarSet, zonkTcType, zonkTcTypes )
13 import TcUnify          ( unifyTauTyLists )
14 import Inst             ( LIE, Inst, LookupInstResult(..),
15                           lookupInst, getFunDepsOfLIE, getIPsOfLIE,
16                           zonkLIE, zonkFunDeps {- for debugging -} )
17 import VarSet           ( VarSet, emptyVarSet, unionVarSet )
18 import VarEnv           ( emptyVarEnv )
19 import FunDeps          ( instantiateFdClassTys )
20 import Outputable
21 import List             ( elemIndex, nub )
22 \end{code}
23
24 \begin{code}
25 tcImprove :: LIE -> TcM s ()
26 -- Do unifications based on functional dependencies in the LIE
27 tcImprove lie 
28   = tcGetInstEnv        `thenNF_Tc` \ inst_env ->
29     let
30         nfdss, clas_nfdss, inst_nfdss, ip_nfdss :: [(TcTyVarSet, Name, [FunDep TcType])]
31         nfdss = ip_nfdss ++ clas_nfdss ++ inst_nfdss
32
33         cfdss :: [(Class, [FunDep TcType])]
34         cfdss      = getFunDepsOfLIE lie
35         clas_nfdss = [(emptyVarSet, className c, fds) | (c,fds) <- cfdss]
36
37         classes    = nub (map fst cfdss)
38         inst_nfdss = [ (free, className c, instantiateFdClassTys c ts)
39                      | c <- classes,
40                        (free, ts, i) <- classInstEnv inst_env c
41                      ]
42
43         ip_nfdss = [(emptyVarSet, n, [([], [ty])]) | (n,ty) <- getIPsOfLIE lie]
44
45         {- Example: we have
46                 class C a b c  |  a->b where ...
47                 instance C Int Bool c 
48         
49            Given the LIE        FD C (Int->t)
50            we get       clas_nfdss = [({}, C, [Int->t,     t->Int])
51                         inst_nfdss = [({c}, C, [Int->Bool, Bool->Int])]
52         
53            Another way would be to flatten a bit
54            we get       clas_nfdss = [({}, C, Int->t), ({}, C, t->Int)]
55                         inst_nfdss = [({c}, C, Int->Bool), ({c}, C, Bool->Int)]
56         
57            iterImprove then matches up the C and Int, and unifies t <-> Bool
58         -}      
59
60     in
61     iterImprove nfdss
62
63
64 iterImprove :: [(VarSet, Name, [FunDep TcType])] -> TcM s ()
65 iterImprove [] = returnTc ()
66 iterImprove cfdss
67   = selfImprove pairImprove cfdss       `thenTc` \ change2 ->
68     if {- change1 || -} change2 then
69         iterImprove cfdss
70     else
71         returnTc ()
72
73 -- ZZ this will do a lot of redundant checking wrt instances
74 -- it would do to make this operate over two lists, the first
75 -- with only clas_nfds and ip_nfds, and the second with everything
76 -- control would otherwise mimic the current loop, so that the
77 -- caller could control whether the redundant inst improvements
78 -- were avoided
79 -- you could then also use this to check for consistency of new instances
80
81 -- selfImprove is really just doing a cartesian product of all the fds
82 selfImprove f [] = returnTc False
83 selfImprove f (nfds : nfdss)
84   = mapTc (f nfds) nfdss        `thenTc` \ changes ->
85     selfImprove f nfdss         `thenTc` \ rest_changed ->
86     returnTc (or changes || rest_changed)
87
88 pairImprove (free1, n1, fds1) (free2, n2, fds2)
89   = if n1 == n2 then
90         checkFds (free1 `unionVarSet` free2) fds1 fds2
91     else
92         returnTc False
93
94 checkFds free [] [] = returnTc False
95 checkFds free (fd1 : fd1s) (fd2 : fd2s) =
96     checkFd free fd1 fd2        `thenTc` \ change ->
97     checkFds free fd1s fd2s     `thenTc` \ changes ->
98     returnTc (change || changes)
99 --checkFds _ _ = returnTc False
100
101 checkFd free (t_x, t_y) (s_x, s_y)
102   -- we need to zonk each time because unification
103   -- may happen at any time
104   = zonkUnifyTys free t_x s_x `thenTc` \ msubst ->
105     case msubst of
106       Just subst ->
107         let t_y' = map (substTy (mkSubst emptyVarEnv subst)) t_y
108             s_y' = map (substTy (mkSubst emptyVarEnv subst)) s_y
109         in
110             zonkEqTys t_y' s_y' `thenTc` \ eq ->
111             if eq then
112                 -- they're the same, nothing changes...
113                 returnTc False
114             else
115                 -- ZZ what happens if two instance vars unify?
116                 unifyTauTyLists t_y' s_y' `thenTc_`
117                 -- if we get here, something must have unified
118                 returnTc True
119       Nothing ->
120         returnTc False
121
122 zonkEqTys ts1 ts2
123   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
124     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
125     returnTc (ts1' == ts2')
126
127 zonkMatchTys ts1 free ts2
128   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
129     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
130     -- pprTrace "zMT" (ppr (ts1', free, ts2')) $
131     case matchTys free ts2' ts1' of
132       Just (subst, []) -> -- pprTrace "zMT match!" empty $
133                           returnTc (Just subst)
134       Nothing -> returnTc Nothing
135
136 zonkUnifyTys free ts1 ts2
137   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
138     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
139     -- pprTrace "zMT" (ppr (ts1', free, ts2')) $
140     case unifyTyListsX free ts2' ts1' of
141       Just subst -> returnTc (Just subst)
142       Nothing    -> returnTc Nothing
143 \end{code}