Use implication constraints to improve type inference
[ghc-hetmet.git] / compiler / types / FunDeps.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 2000
4 %
5
6 FunDeps - functional dependencies
7
8 It's better to read it as: "if we know these, then we're going to know these"
9
10 \begin{code}
11 module FunDeps (
12         Equation, pprEquation,
13         oclose, grow, improve, improveOne,
14         checkInstCoverage, checkFunDeps,
15         pprFundeps
16     ) where
17
18 #include "HsVersions.h"
19
20 import Name
21 import Var
22 import Class
23 import TcGadt
24 import Type
25 import Coercion
26 import TcType
27 import InstEnv
28 import VarSet
29 import VarEnv
30 import Outputable
31 import Util
32 import ListSetOps
33
34 import Data.List        ( tails )
35 import Data.Maybe       ( isJust )
36 \end{code}
37
38
39 %************************************************************************
40 %*                                                                      *
41 \subsection{Close type variables}
42 %*                                                                      *
43 %************************************************************************
44
45 (oclose preds tvs) closes the set of type variables tvs, 
46 wrt functional dependencies in preds.  The result is a superset
47 of the argument set.  For example, if we have
48         class C a b | a->b where ...
49 then
50         oclose [C (x,y) z, C (x,p) q] {x,y} = {x,y,z}
51 because if we know x and y then that fixes z.
52
53 Using oclose
54 ~~~~~~~~~~~~
55 oclose is used
56
57 a) When determining ambiguity.  The type
58         forall a,b. C a b => a
59 is not ambiguous (given the above class decl for C) because
60 a determines b.  
61
62 b) When generalising a type T.  Usually we take FV(T) \ FV(Env),
63 but in fact we need
64         FV(T) \ (FV(Env)+)
65 where the '+' is the oclosure operation.  Notice that we do not 
66 take FV(T)+.  This puzzled me for a bit.  Consider
67
68         f = E
69
70 and suppose e have that E :: C a b => a, and suppose that b is
71 free in the environment. Then we quantify over 'a' only, giving
72 the type forall a. C a b => a.  Since a->b but we don't have b->a,
73 we might have instance decls like
74         instance C Bool Int where ...
75         instance C Char Int where ...
76 so knowing that b=Int doesn't fix 'a'; so we quantify over it.
77
78                 ---------------
79                 A WORRY: ToDo!
80                 ---------------
81 If we have      class C a b => D a b where ....
82                 class D a b | a -> b where ...
83 and the preds are [C (x,y) z], then we want to see the fd in D,
84 even though it is not explicit in C, giving [({x,y},{z})]
85
86 Similarly for instance decls?  E.g. Suppose we have
87         instance C a b => Eq (T a b) where ...
88 and we infer a type t with constraints Eq (T a b) for a particular
89 expression, and suppose that 'a' is free in the environment.  
90 We could generalise to
91         forall b. Eq (T a b) => t
92 but if we reduced the constraint, to C a b, we'd see that 'a' determines
93 b, so that a better type might be
94         t (with free constraint C a b) 
95 Perhaps it doesn't matter, because we'll still force b to be a
96 particular type at the call sites.  Generalising over too many
97 variables (provided we don't shadow anything by quantifying over a
98 variable that is actually free in the envt) may postpone errors; it
99 won't hide them altogether.
100
101
102 \begin{code}
103 oclose :: [PredType] -> TyVarSet -> TyVarSet
104 oclose preds fixed_tvs
105   | null tv_fds = fixed_tvs     -- Fast escape hatch for common case
106   | otherwise   = loop fixed_tvs
107   where
108     loop fixed_tvs
109         | new_fixed_tvs `subVarSet` fixed_tvs = fixed_tvs
110         | otherwise                           = loop new_fixed_tvs
111         where
112           new_fixed_tvs = foldl extend fixed_tvs tv_fds
113
114     extend fixed_tvs (ls,rs) | ls `subVarSet` fixed_tvs = fixed_tvs `unionVarSet` rs
115                              | otherwise                = fixed_tvs
116
117     tv_fds  :: [(TyVarSet,TyVarSet)]
118         -- In our example, tv_fds will be [ ({x,y}, {z}), ({x,p},{q}) ]
119         -- Meaning "knowing x,y fixes z, knowing x,p fixes q"
120     tv_fds  = [ (tyVarsOfTypes xs, tyVarsOfTypes ys)
121               | ClassP cls tys <- preds,                -- Ignore implicit params
122                 let (cls_tvs, cls_fds) = classTvsFds cls,
123                 fd <- cls_fds,
124                 let (xs,ys) = instFD fd cls_tvs tys
125               ]
126 \end{code}
127
128 \begin{code}
129 grow :: [PredType] -> TyVarSet -> TyVarSet
130 -- See Note [Ambiguity] in TcSimplify
131 grow preds fixed_tvs 
132   | null preds = fixed_tvs
133   | otherwise  = loop fixed_tvs
134   where
135     loop fixed_tvs
136         | new_fixed_tvs `subVarSet` fixed_tvs = fixed_tvs
137         | otherwise                           = loop new_fixed_tvs
138         where
139           new_fixed_tvs = foldl extend fixed_tvs pred_sets
140
141     extend fixed_tvs pred_tvs 
142         | fixed_tvs `intersectsVarSet` pred_tvs = fixed_tvs `unionVarSet` pred_tvs
143         | otherwise                             = fixed_tvs
144
145     pred_sets = [tyVarsOfPred pred | pred <- preds]
146 \end{code}
147     
148 %************************************************************************
149 %*                                                                      *
150 \subsection{Generate equations from functional dependencies}
151 %*                                                                      *
152 %************************************************************************
153
154
155 \begin{code}
156 ----------
157 type Equation = (TyVarSet, [(Type, Type)])
158 -- These pairs of types should be equal, for some
159 -- substitution of the tyvars in the tyvar set
160 -- INVARIANT: corresponding types aren't already equal
161
162 -- It's important that we have a *list* of pairs of types.  Consider
163 --      class C a b c | a -> b c where ...
164 --      instance C Int x x where ...
165 -- Then, given the constraint (C Int Bool v) we should improve v to Bool,
166 -- via the equation ({x}, [(Bool,x), (v,x)])
167 -- This would not happen if the class had looked like
168 --      class C a b c | a -> b, a -> c
169
170 -- To "execute" the equation, make fresh type variable for each tyvar in the set,
171 -- instantiate the two types with these fresh variables, and then unify.
172 --
173 -- For example, ({a,b}, (a,Int,b), (Int,z,Bool))
174 -- We unify z with Int, but since a and b are quantified we do nothing to them
175 -- We usually act on an equation by instantiating the quantified type varaibles
176 -- to fresh type variables, and then calling the standard unifier.
177
178 pprEquation (qtvs, pairs) 
179   = vcat [ptext SLIT("forall") <+> braces (pprWithCommas ppr (varSetElems qtvs)),
180           nest 2 (vcat [ ppr t1 <+> ptext SLIT(":=:") <+> ppr t2 | (t1,t2) <- pairs])]
181
182 ----------
183 type Pred_Loc = (PredType, SDoc)        -- SDoc says where the Pred comes from
184
185 improve :: (Class -> [Instance])                -- Gives instances for given class
186         -> [Pred_Loc]                           -- Current constraints; 
187         -> [(Equation,Pred_Loc,Pred_Loc)]       -- Derived equalities that must also hold
188                                                 -- (NB the above INVARIANT for type Equation)
189                                                 -- The Pred_Locs explain which two predicates were
190                                                 -- combined (for error messages)
191 \end{code}
192
193 Given a bunch of predicates that must hold, such as
194
195         C Int t1, C Int t2, C Bool t3, ?x::t4, ?x::t5
196
197 improve figures out what extra equations must hold.
198 For example, if we have
199
200         class C a b | a->b where ...
201
202 then improve will return
203
204         [(t1,t2), (t4,t5)]
205
206 NOTA BENE:
207
208   * improve does not iterate.  It's possible that when we make
209     t1=t2, for example, that will in turn trigger a new equation.
210     This would happen if we also had
211         C t1 t7, C t2 t8
212     If t1=t2, we also get t7=t8.
213
214     improve does *not* do this extra step.  It relies on the caller
215     doing so.
216
217   * The equations unify types that are not already equal.  So there
218     is no effect iff the result of improve is empty
219
220
221
222 \begin{code}
223 improve inst_env preds
224   = [ eqn | group <- equivClassesByUniq (predTyUnique . fst) (filterEqPreds preds),
225             eqn   <- checkGroup inst_env group ]
226   where 
227     filterEqPreds = filter (not . isEqPred . fst)
228         -- Equality predicates don't have uniques
229         -- In any case, improvement *generates*, rather than
230         -- *consumes*, equality constraints
231
232 improveOne :: (Class -> [Instance])
233            -> Pred_Loc
234            -> [Pred_Loc]
235            -> [(Equation,Pred_Loc,Pred_Loc)]
236
237 -- Just do improvement triggered by a single, distinguised predicate
238
239 improveOne inst_env pred@(IParam ip ty, _) preds
240   = [ ((emptyVarSet, [(ty,ty2)]), pred, p2) 
241     | p2@(IParam ip2 ty2, _) <- preds
242     , ip==ip2
243     , not (ty `tcEqType` ty2)]
244
245 improveOne inst_env pred@(ClassP cls tys, _) preds
246   | tys `lengthAtLeast` 2
247   = instance_eqns ++ pairwise_eqns
248         -- NB: we put the instance equations first.   This biases the 
249         -- order so that we first improve individual constraints against the
250         -- instances (which are perhaps in a library and less likely to be
251         -- wrong; and THEN perform the pairwise checks.
252         -- The other way round, it's possible for the pairwise check to succeed
253         -- and cause a subsequent, misleading failure of one of the pair with an
254         -- instance declaration.  See tcfail143.hs for an example
255   where
256     (cls_tvs, cls_fds) = classTvsFds cls
257     instances          = inst_env cls
258     rough_tcs          = roughMatchTcs tys
259
260         -- NOTE that we iterate over the fds first; they are typically
261         -- empty, which aborts the rest of the loop.
262     pairwise_eqns :: [(Equation,Pred_Loc,Pred_Loc)]
263     pairwise_eqns       -- This group comes from pairwise comparison
264       = [ (eqn, pred, p2)
265         | fd <- cls_fds
266         , p2@(ClassP cls2 tys2, _) <- preds
267         , cls == cls2
268         , eqn <- checkClsFD emptyVarSet fd cls_tvs tys tys2
269         ]
270
271     instance_eqns :: [(Equation,Pred_Loc,Pred_Loc)]
272     instance_eqns       -- This group comes from comparing with instance decls
273       = [ (eqn, p_inst, pred)
274         | fd <- cls_fds         -- Iterate through the fundeps first, 
275                                 -- because there often are none!
276         , let rough_fd_tcs = trimRoughMatchTcs cls_tvs fd rough_tcs
277         , ispec@(Instance { is_tvs = qtvs, is_tys = tys_inst, 
278                             is_tcs = mb_tcs_inst }) <- instances
279         , not (instanceCantMatch mb_tcs_inst rough_tcs)
280         , eqn <- checkClsFD qtvs fd cls_tvs tys_inst tys
281         , let p_inst = (mkClassPred cls tys_inst, 
282                         ptext SLIT("arising from the instance declaration at")
283                         <+> ppr (getSrcLoc ispec))
284         ]
285
286 improveOne inst_env eq_pred preds
287   = []
288
289 ----------
290 checkGroup :: (Class -> [Instance])
291            -> [Pred_Loc]
292            -> [(Equation, Pred_Loc, Pred_Loc)]
293   -- The preds are all for the same class or implicit param
294
295 checkGroup inst_env (p1@(IParam _ ty, _) : ips)
296   =     -- For implicit parameters, all the types must match
297     [ ((emptyVarSet, [(ty,ty')]), p1, p2) 
298     | p2@(IParam _ ty', _) <- ips, not (ty `tcEqType` ty')]
299
300 checkGroup inst_env clss@((ClassP cls _, _) : _)
301   =     -- For classes life is more complicated  
302         -- Suppose the class is like
303         --      classs C as | (l1 -> r1), (l2 -> r2), ... where ...
304         -- Then FOR EACH PAIR (ClassP c tys1, ClassP c tys2) in the list clss
305         -- we check whether
306         --      U l1[tys1/as] = U l2[tys2/as]
307         --  (where U is a unifier)
308         -- 
309         -- If so, we return the pair
310         --      U r1[tys1/as] = U l2[tys2/as]
311         --
312         -- We need to do something very similar comparing each predicate
313         -- with relevant instance decls
314
315     instance_eqns ++ pairwise_eqns
316         -- NB: we put the instance equations first.   This biases the 
317         -- order so that we first improve individual constraints against the
318         -- instances (which are perhaps in a library and less likely to be
319         -- wrong; and THEN perform the pairwise checks.
320         -- The other way round, it's possible for the pairwise check to succeed
321         -- and cause a subsequent, misleading failure of one of the pair with an
322         -- instance declaration.  See tcfail143.hs for an exmample
323
324   where
325     (cls_tvs, cls_fds) = classTvsFds cls
326     instances          = inst_env cls
327
328         -- NOTE that we iterate over the fds first; they are typically
329         -- empty, which aborts the rest of the loop.
330     pairwise_eqns :: [(Equation,Pred_Loc,Pred_Loc)]
331     pairwise_eqns       -- This group comes from pairwise comparison
332       = [ (eqn, p1, p2)
333         | fd <- cls_fds,
334           p1@(ClassP _ tys1, _) : rest <- tails clss,
335           p2@(ClassP _ tys2, _) <- rest,
336           eqn <- checkClsFD emptyVarSet fd cls_tvs tys1 tys2
337         ]
338
339     instance_eqns :: [(Equation,Pred_Loc,Pred_Loc)]
340     instance_eqns       -- This group comes from comparing with instance decls
341       = [ (eqn, p1, p2)
342         | fd <- cls_fds,        -- Iterate through the fundeps first, 
343                                 -- because there often are none!
344           p2@(ClassP _ tys2, _) <- clss,
345           let rough_tcs2 = trimRoughMatchTcs cls_tvs fd (roughMatchTcs tys2),
346           ispec@(Instance { is_tvs = qtvs, is_tys = tys1, 
347                             is_tcs = mb_tcs1 }) <- instances,
348           not (instanceCantMatch mb_tcs1 rough_tcs2),
349           eqn <- checkClsFD qtvs fd cls_tvs tys1 tys2,
350           let p1 = (mkClassPred cls tys1, 
351                     ptext SLIT("arising from the instance declaration at") <+> 
352                         ppr (getSrcLoc ispec))
353         ]
354 ----------
355 checkClsFD :: TyVarSet                  -- Quantified type variables; see note below
356            -> FunDep TyVar -> [TyVar]   -- One functional dependency from the class
357            -> [Type] -> [Type]
358            -> [Equation]
359
360 checkClsFD qtvs fd clas_tvs tys1 tys2
361 -- 'qtvs' are the quantified type variables, the ones which an be instantiated 
362 -- to make the types match.  For example, given
363 --      class C a b | a->b where ...
364 --      instance C (Maybe x) (Tree x) where ..
365 --
366 -- and an Inst of form (C (Maybe t1) t2), 
367 -- then we will call checkClsFD with
368 --
369 --      qtvs = {x}, tys1 = [Maybe x,  Tree x]
370 --                  tys2 = [Maybe t1, t2]
371 --
372 -- We can instantiate x to t1, and then we want to force
373 --      (Tree x) [t1/x]  :=:   t2
374 --
375 -- This function is also used when matching two Insts (rather than an Inst
376 -- against an instance decl. In that case, qtvs is empty, and we are doing
377 -- an equality check
378 -- 
379 -- This function is also used by InstEnv.badFunDeps, which needs to *unify*
380 -- For the one-sided matching case, the qtvs are just from the template,
381 -- so we get matching
382 --
383   = ASSERT2( length tys1 == length tys2     && 
384              length tys1 == length clas_tvs 
385             , ppr tys1 <+> ppr tys2 )
386
387     case tcUnifyTys bind_fn ls1 ls2 of
388         Nothing  -> []
389         Just subst | isJust (tcUnifyTys bind_fn rs1' rs2') 
390                         -- Don't include any equations that already hold. 
391                         -- Reason: then we know if any actual improvement has happened,
392                         --         in which case we need to iterate the solver
393                         -- In making this check we must taking account of the fact that any 
394                         -- qtvs that aren't already instantiated can be instantiated to anything 
395                         -- at all
396                   -> []
397
398                   | otherwise   -- Aha!  A useful equation
399                   -> [ (qtvs', zip rs1' rs2')]
400                         -- We could avoid this substTy stuff by producing the eqn
401                         -- (qtvs, ls1++rs1, ls2++rs2)
402                         -- which will re-do the ls1/ls2 unification when the equation is
403                         -- executed.  What we're doing instead is recording the partial
404                         -- work of the ls1/ls2 unification leaving a smaller unification problem
405                   where
406                     rs1'  = substTys subst rs1 
407                     rs2'  = substTys subst rs2
408                     qtvs' = filterVarSet (`notElemTvSubst` subst) qtvs
409                         -- qtvs' are the quantified type variables
410                         -- that have not been substituted out
411                         --      
412                         -- Eg.  class C a b | a -> b
413                         --      instance C Int [y]
414                         -- Given constraint C Int z
415                         -- we generate the equation
416                         --      ({y}, [y], z)
417   where
418     bind_fn tv | tv `elemVarSet` qtvs = BindMe
419                | otherwise            = Skolem
420
421     (ls1, rs1) = instFD fd clas_tvs tys1
422     (ls2, rs2) = instFD fd clas_tvs tys2
423
424 instFD :: FunDep TyVar -> [TyVar] -> [Type] -> FunDep Type
425 instFD (ls,rs) tvs tys
426   = (map lookup ls, map lookup rs)
427   where
428     env       = zipVarEnv tvs tys
429     lookup tv = lookupVarEnv_NF env tv
430 \end{code}
431
432 \begin{code}
433 checkInstCoverage :: Class -> [Type] -> Bool
434 -- Check that the Coverage Condition is obeyed in an instance decl
435 -- For example, if we have 
436 --      class theta => C a b | a -> b
437 --      instance C t1 t2 
438 -- Then we require fv(t2) `subset` fv(t1)
439 -- See Note [Coverage Condition] below
440
441 checkInstCoverage clas inst_taus
442   = all fundep_ok fds
443   where
444     (tyvars, fds) = classTvsFds clas
445     fundep_ok fd  = tyVarsOfTypes rs `subVarSet` tyVarsOfTypes ls
446                  where
447                    (ls,rs) = instFD fd tyvars inst_taus
448 \end{code}
449
450 Note [Coverage condition]
451 ~~~~~~~~~~~~~~~~~~~~~~~~~
452 For the coverage condition, we used to require only that 
453         fv(t2) `subset` oclose(fv(t1), theta)
454
455 Example:
456         class Mul a b c | a b -> c where
457                 (.*.) :: a -> b -> c
458
459         instance Mul Int Int Int where (.*.) = (*)
460         instance Mul Int Float Float where x .*. y = fromIntegral x * y
461         instance Mul a b c => Mul a [b] [c] where x .*. v = map (x.*.) v
462
463 In the third instance, it's not the case that fv([c]) `subset` fv(a,[b]).
464 But it is the case that fv([c]) `subset` oclose( theta, fv(a,[b]) )
465
466 But it is a mistake to accept the instance because then this defn:
467         f = \ b x y -> if b then x .*. [y] else y
468 makes instance inference go into a loop, because it requires the constraint
469         Mul a [b] b
470
471
472 %************************************************************************
473 %*                                                                      *
474         Check that a new instance decl is OK wrt fundeps
475 %*                                                                      *
476 %************************************************************************
477
478 Here is the bad case:
479         class C a b | a->b where ...
480         instance C Int Bool where ...
481         instance C Int Char where ...
482
483 The point is that a->b, so Int in the first parameter must uniquely
484 determine the second.  In general, given the same class decl, and given
485
486         instance C s1 s2 where ...
487         instance C t1 t2 where ...
488
489 Then the criterion is: if U=unify(s1,t1) then U(s2) = U(t2).
490
491 Matters are a little more complicated if there are free variables in
492 the s2/t2.  
493
494         class D a b c | a -> b
495         instance D a b => D [(a,a)] [b] Int
496         instance D a b => D [a]     [b] Bool
497
498 The instance decls don't overlap, because the third parameter keeps
499 them separate.  But we want to make sure that given any constraint
500         D s1 s2 s3
501 if s1 matches 
502
503
504 \begin{code}
505 checkFunDeps :: (InstEnv, InstEnv) -> Instance
506              -> Maybe [Instance]        -- Nothing  <=> ok
507                                         -- Just dfs <=> conflict with dfs
508 -- Check wheher adding DFunId would break functional-dependency constraints
509 -- Used only for instance decls defined in the module being compiled
510 checkFunDeps inst_envs ispec
511   | null bad_fundeps = Nothing
512   | otherwise        = Just bad_fundeps
513   where
514     (ins_tvs, _, clas, ins_tys) = instanceHead ispec
515     ins_tv_set   = mkVarSet ins_tvs
516     cls_inst_env = classInstances inst_envs clas
517     bad_fundeps  = badFunDeps cls_inst_env clas ins_tv_set ins_tys
518
519 badFunDeps :: [Instance] -> Class
520            -> TyVarSet -> [Type]        -- Proposed new instance type
521            -> [Instance]
522 badFunDeps cls_insts clas ins_tv_set ins_tys 
523   = [ ispec | fd <- fds,        -- fds is often empty
524               let trimmed_tcs = trimRoughMatchTcs clas_tvs fd rough_tcs,
525               ispec@(Instance { is_tcs = mb_tcs, is_tvs = tvs, 
526                                 is_tys = tys }) <- cls_insts,
527                 -- Filter out ones that can't possibly match, 
528                 -- based on the head of the fundep
529               not (instanceCantMatch trimmed_tcs mb_tcs),       
530               notNull (checkClsFD (tvs `unionVarSet` ins_tv_set) 
531                                    fd clas_tvs tys ins_tys)
532     ]
533   where
534     (clas_tvs, fds) = classTvsFds clas
535     rough_tcs = roughMatchTcs ins_tys
536
537 trimRoughMatchTcs :: [TyVar] -> FunDep TyVar -> [Maybe Name] -> [Maybe Name]
538 -- Computing rough_tcs for a particular fundep
539 --      class C a b c | a c -> b where ... 
540 -- For each instance .... => C ta tb tc
541 -- we want to match only on the types ta, tb; so our
542 -- rough-match thing must similarly be filtered.  
543 -- Hence, we Nothing-ise the tb type right here
544 trimRoughMatchTcs clas_tvs (ltvs,_) mb_tcs
545   = zipWith select clas_tvs mb_tcs
546   where
547     select clas_tv mb_tc | clas_tv `elem` ltvs = mb_tc
548                          | otherwise           = Nothing
549 \end{code}
550
551
552