0250a30bb737c0b6eeb39806f198ee9eb2118510
[ghc-hetmet.git] / ghc / compiler / typecheck / TcImprove.lhs
1 \begin{code}
2 module TcImprove ( tcImprove ) where
3
4 #include "HsVersions.h"
5
6 import Type             ( tyVarsOfTypes )
7 import Class            ( classInstEnv, classExtraBigSig )
8 import Unify            ( matchTys )
9 import Subst            ( mkSubst, substTy )
10 import TcMonad
11 import TcType           ( zonkTcType, zonkTcTypes )
12 import TcUnify          ( unifyTauTyLists )
13 import Inst             ( Inst, LookupInstResult(..),
14                           lookupInst, isDict, getDictClassTys, getFunDepsOfLIE,
15                           zonkLIE {- for debugging -} )
16 import VarSet           ( emptyVarSet )
17 import VarEnv           ( emptyVarEnv )
18 import FunDeps          ( instantiateFdClassTys )
19 import Bag              ( bagToList )
20 import Outputable
21 import List             ( elemIndex )
22 \end{code}
23
24 Improvement goes here.
25
26 \begin{code}
27 tcImprove lie = iterImprove (getFunDepsOfLIE lie)
28
29 iterImprove cfdss
30   = instImprove cfdss                   `thenTc` \ change1 ->
31     selfImprove pairImprove cfdss       `thenTc` \ change2 ->
32     if change1 || change2 then
33         iterImprove cfdss
34     else
35         returnTc ()
36
37 instImprove (cfds@(clas, fds) : cfdss)
38   = instImprove1 cfds ins
39   where ins = classInstEnv clas
40 instImprove [] = returnTc False
41
42 instImprove1 cfds@(clas, fds1) ((free, ts, _) : ins)
43   = checkFds fds1 free fds2     `thenTc` \ changed ->
44     instImprove1 cfds ins       `thenTc` \ rest_changed ->
45     returnTc (changed || rest_changed)
46   where fds2 = instantiateFdClassTys clas ts
47 instImprove1 _ _ = returnTc False
48
49 selfImprove f [] = returnTc False
50 selfImprove f (cfds : cfdss)
51   = mapTc (f cfds) cfdss        `thenTc` \ changes ->
52     orTc changes                `thenTc` \ changed ->
53     selfImprove f cfdss         `thenTc` \ rest_changed ->
54     returnTc (changed || rest_changed)
55
56 pairImprove (clas1, fds1) (clas2, fds2)
57   = if clas1 == clas2 then
58         checkFds fds1 emptyVarSet fds2
59     else
60         returnTc False
61
62 checkFds [] free [] = returnTc False
63 checkFds (fd1 : fd1s) free (fd2 : fd2s) =
64     checkFd fd1 free fd2        `thenTc` \ change ->
65     checkFds fd1s free fd2s     `thenTc` \ changes ->
66     returnTc (change || changes)
67 --checkFds _ _ = returnTc False
68
69 checkFd (t_x, t_y) free (s_x, s_y)
70   -- we need to zonk each time because unification
71   -- may happen at any time
72   = zonkMatchTys t_x free s_x `thenTc` \ msubst ->
73     case msubst of
74       Just subst ->
75         let s_y' = map (substTy (mkSubst emptyVarEnv subst)) s_y in
76             zonkMatchTys t_y free s_y `thenTc` \ msubst2 ->
77                 case msubst2 of
78                   Just _ ->
79                     -- they're the same, nothing changes
80                     returnTc False
81                   Nothing ->
82                     unifyTauTyLists t_y s_y' `thenTc_`
83                     -- if we get here, something must have unified
84                     returnTc True
85       Nothing ->
86         returnTc False
87
88 zonkMatchTys ts1 free ts2
89   = mapTc zonkTcType ts1 `thenTc` \ ts1' ->
90     mapTc zonkTcType ts2 `thenTc` \ ts2' ->
91     --returnTc (ts1' == ts2')
92     case matchTys free ts2' ts1' of
93       Just (subst, []) -> returnTc (Just subst)
94       Nothing -> returnTc Nothing
95
96 {-
97 instImprove clas fds =
98     pprTrace "class inst env" (ppr (clas, classInstEnv clas)) $
99     zonkFunDeps fds `thenTc` \ fds' ->
100     pprTrace "lIEFDs" (ppr (clas, fds')) $
101     case lookupInstEnvFDs clas fds' of
102       Nothing -> returnTc ()
103       Just (t_y, s_y) ->
104         pprTrace "lIEFDs result" (ppr (t_y, s_y)) $
105         unifyTauTyLists t_y s_y
106
107 lookupInstEnvFDs clas fds
108   = find env
109   where
110     env = classInstEnv clas
111     (ctvs, fds, _, _, _, _) = classExtraBigSig clas
112     find [] = Nothing
113     find ((tpl_tyvars, tpl, val) : rest)
114       = let tplx = concatMap (\us -> thingy tpl us ctvs) (map fst fds)
115             tply = concatMap (\vs -> thingy tpl vs ctvs) (map snd fds)
116         in
117             case matchTys tpl_tyvars tplx tysx of
118               Nothing -> find rest
119               Just (tenv, leftovers) ->
120                 let subst = mkSubst (tyVarsOfTypes tys) tenv
121                 in
122                     -- this is the list of things that
123                     -- need to be unified
124                     Just (map (substTy subst) tply, tysy)
125     tysx = concatMap (\us -> thingy tys us ctvs) (map fst fds)
126     tysy = concatMap (\vs -> thingy tys vs ctvs) (map snd fds)
127     thingy f us ctvs
128       = map (f !!) is
129         where is = map (\u -> let Just i = elemIndex u ctvs in i) us
130 -}
131
132 {-
133   = let (clas, tys) = getDictClassTys dict
134     in
135         -- first, do instance-based improvement
136         instImprove clas tys `thenTc_`
137         -- OK, now do pairwise stuff
138         mapTc (f clas tys) dicts `thenTc` \ changes ->
139         foldrTc (\a b -> returnTc (a || b)) False changes `thenTc` \ changed ->
140         allDictPairs f dicts `thenTc` \ rest_changed ->
141         returnTc (changed || rest_changed)
142 -}
143
144 \end{code}
145
146 Utilities:
147
148 A monadic version of the standard Prelude `or' function.
149 \begin{code}
150 orTc bs = foldrTc (\a b -> returnTc (a || b)) False bs
151 \end{code}