Use implication constraints to improve type inference
[ghc-hetmet.git] / compiler / typecheck / TcGadt.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 %************************************************************************
7 %*                                                                      *
8                 Type refinement for GADTs
9 %*                                                                      *
10 %************************************************************************
11
12 \begin{code}
13 module TcGadt (
14         Refinement, emptyRefinement, isEmptyRefinement, 
15         gadtRefine, 
16         refineType, refinePred, refineResType,
17         dataConCanMatch,
18         tcUnifyTys, BindFlag(..)
19   ) where
20
21 #include "HsVersions.h"
22
23 import HsSyn
24 import Coercion
25 import Type
26
27 import TypeRep
28 import DataCon
29 import Var
30 import VarEnv
31 import VarSet
32 import ErrUtils
33 import Maybes
34 import Control.Monad
35 import Outputable
36 import TcType
37
38 #ifdef DEBUG
39 import Unique
40 import UniqFM
41 #endif
42 \end{code}
43
44
45 %************************************************************************
46 %*                                                                      *
47                 What a refinement is
48 %*                                                                      *
49 %************************************************************************
50
51 \begin{code}
52 data Refinement = Reft InScopeSet InternalReft 
53
54 type InternalReft = TyVarEnv (Coercion, Type)
55 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
56 -- Not necessarily idemopotent
57
58 instance Outputable Refinement where
59   ppr (Reft in_scope env)
60     = ptext SLIT("Refinement") <+>
61         braces (ppr env)
62
63 emptyRefinement :: Refinement
64 emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
65
66 isEmptyRefinement :: Refinement -> Bool
67 isEmptyRefinement (Reft _ env) = isEmptyVarEnv env
68
69 refineType :: Refinement -> Type -> Maybe (Coercion, Type)
70 -- Apply the refinement to the type.
71 -- If (refineType r ty) = (co, ty')
72 -- Then co :: ty:=:ty'
73 -- Nothing => the refinement does nothing to this type
74 refineType (Reft in_scope env) ty
75   | not (isEmptyVarEnv env),            -- Common case
76     any (`elemVarEnv` env) (varSetElems (tyVarsOfType ty))
77   = Just (substTy co_subst ty, substTy tv_subst ty)
78   | otherwise
79   = Nothing     -- The type doesn't mention any refined type variables
80   where
81     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
82     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
83  
84 refinePred :: Refinement -> PredType -> Maybe (Coercion, PredType)
85 refinePred (Reft in_scope env) pred
86   | not (isEmptyVarEnv env),            -- Common case
87     any (`elemVarEnv` env) (varSetElems (tyVarsOfPred pred))
88   = Just (mkPredTy (substPred co_subst pred), substPred tv_subst pred)
89   | otherwise
90   = Nothing     -- The type doesn't mention any refined type variables
91   where
92     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
93     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
94  
95 refineResType :: Refinement -> Type -> (HsWrapper, Type)
96 -- Like refineType, but returns the 'sym' coercion
97 -- If (refineResType r ty) = (co, ty')
98 -- Then co :: ty':=:ty
99 -- It's convenient to return a HsWrapper here
100 refineResType reft ty
101   = case refineType reft ty of
102         Just (co, ty1) -> (WpCo (mkSymCoercion co), ty1)
103         Nothing        -> (idHsWrapper,             ty)
104 \end{code}
105
106
107 %************************************************************************
108 %*                                                                      *
109                 Generating a type refinement
110 %*                                                                      *
111 %************************************************************************
112
113 \begin{code}
114 gadtRefine :: Refinement
115            -> [TyVar]   -- Bind these by preference
116            -> [CoVar]
117            -> MaybeErr Message Refinement
118 \end{code}
119
120 (gadtRefine cvs) takes a list of coercion variables, and returns a
121 list of coercions, obtained by unifying the types equated by the
122 incoming coercions.  The returned coercions all have kinds of form
123 (a:=:ty), where a is a rigid type variable.
124
125 Example:
126   gadtRefine [c :: (a,Int):=:(Bool,b)]
127   = [ right (left c) :: a:=:Bool,       
128       sym (right c)  :: b:=:Int ]
129
130 That is, given evidence 'c' that (a,Int)=(Bool,b), it returns derived
131 evidence in easy-to-use form.  In particular, given any e::ty, we know 
132 that:
133         e `cast` ty[right (left c)/a, sym (right c)/b]
134         :: ty [Bool/a, Int/b]
135       
136 Two refinements:
137
138 - It can fail, if the coercion is unsatisfiable.
139
140 - It's biased, by being given a set of type variables to bind
141   when there is a choice. Example:
142         MkT :: forall a. a -> T [a]
143         f :: forall b. T [b] -> b
144         f x = let v = case x of { MkT y -> y }
145               in ...
146   Here we want to bind [a->b], not the other way round, because
147   in this example the return type is wobbly, and we want the
148   program to typecheck
149
150
151 -- E.g. (a, Bool, right (left c))
152 -- INVARIANT: in the triple (tv, ty, co), we have (co :: tv:=:ty)
153 -- The result is idempotent: the 
154
155 \begin{code}
156 gadtRefine (Reft in_scope env1) 
157            ex_tvs co_vars
158 -- Precondition: fvs( co_vars ) # env1
159 -- That is, the kinds of the co_vars are a
160 -- fixed point of the incoming refinement
161
162   = ASSERT2( not $ any (`elemVarEnv` env1) (varSetElems $ tyVarsOfTypes $ map tyVarKind co_vars),
163              ppr env1 $$ ppr co_vars $$ ppr (map tyVarKind co_vars) )
164     initUM (tryToBind tv_set) $
165     do  {       -- Run the unifier, starting with an empty env
166         ; env2 <- foldM do_one emptyInternalReft co_vars
167
168                 -- Find the fixed point of the resulting 
169                 -- non-idempotent substitution
170         ; let tmp_env = env1 `plusVarEnv` env2
171               out_env = fixTvCoEnv in_scope' tmp_env
172         ; WARN( not (null (badReftElts tmp_env)), ppr (badReftElts tmp_env) $$ ppr tmp_env )
173           WARN( not (null (badReftElts out_env)), ppr (badReftElts out_env) $$ ppr out_env )
174           return (Reft in_scope' out_env) }
175   where
176     tv_set = mkVarSet ex_tvs
177     in_scope' = foldr extend in_scope co_vars
178
179         -- For each co_var, add it *and* the tyvars it mentions, to in_scope
180     extend co_var in_scope
181         = extendInScopeSetSet in_scope $
182           extendVarSet (tyVarsOfType (tyVarKind co_var)) co_var
183         
184     do_one reft co_var = unify reft (TyVarTy co_var) ty1 ty2
185         where
186            (ty1,ty2) = splitCoercionKind (tyVarKind co_var)
187 \end{code} 
188
189 %************************************************************************
190 %*                                                                      *
191                 Unification
192 %*                                                                      *
193 %************************************************************************
194
195 \begin{code}
196 tcUnifyTys :: (TyVar -> BindFlag)
197            -> [Type] -> [Type]
198            -> Maybe TvSubst     -- A regular one-shot substitution
199 -- The two types may have common type variables, and indeed do so in the
200 -- second call to tcUnifyTys in FunDeps.checkClsFD
201 --
202 -- We implement tcUnifyTys using the evidence-generating 'unify' function
203 -- in this module, even though we don't need to generate any evidence.
204 -- This is simply to avoid replicating all all the code for unify
205 tcUnifyTys bind_fn tys1 tys2
206   = maybeErrToMaybe $ initUM bind_fn $
207     do { reft <- unifyList emptyInternalReft cos tys1 tys2
208
209         -- Find the fixed point of the resulting non-idempotent substitution
210         ; let in_scope = mkInScopeSet (tvs1 `unionVarSet` tvs2)
211               tv_env   = fixTvSubstEnv in_scope (mapVarEnv snd reft)
212
213         ; return (mkTvSubst in_scope tv_env) }
214   where
215     tvs1 = tyVarsOfTypes tys1
216     tvs2 = tyVarsOfTypes tys2
217     cos  = zipWith mkUnsafeCoercion tys1 tys2
218
219
220 ----------------------------
221 fixTvCoEnv :: InScopeSet -> InternalReft -> InternalReft
222         -- Find the fixed point of a Refinement
223         -- (assuming it has no loops!)
224 fixTvCoEnv in_scope env
225   = fixpt
226   where
227     fixpt         = mapVarEnv step env
228
229     step (co, ty) = case refineType (Reft in_scope fixpt) ty of
230                         Nothing         -> (co,                     ty)
231                         Just (co', ty') -> (mkTransCoercion co co', ty')
232       -- Apply fixpt one step:
233       -- Use refineType to get a substituted type, ty', and a coercion, co_fn,
234       -- which justifies the substitution.  If the coercion is not the identity
235       -- then use transitivity with the original coercion
236
237 -----------------------------
238 fixTvSubstEnv :: InScopeSet -> TvSubstEnv -> TvSubstEnv
239 fixTvSubstEnv in_scope env
240   = fixpt 
241   where
242     fixpt = mapVarEnv (substTy (mkTvSubst in_scope fixpt)) env
243
244 ----------------------------
245 dataConCanMatch :: [Type] -> DataCon -> Bool
246 -- Returns True iff the data con can match a scrutinee of type (T tys)
247 --                  where T is the type constructor for the data con
248 --
249 -- Instantiate the equations and try to unify them
250 dataConCanMatch tys con
251   | null eq_spec      = True    -- Common
252   | all isTyVarTy tys = True    -- Also common
253   | otherwise
254   = isJust (tcUnifyTys (\tv -> BindMe) 
255                        (map (substTyVar subst . fst) eq_spec)
256                        (map snd eq_spec))
257   where
258     dc_tvs  = dataConUnivTyVars con
259     eq_spec = dataConEqSpec con
260     subst   = zipTopTvSubst dc_tvs tys
261
262 ----------------------------
263 tryToBind :: TyVarSet -> TyVar -> BindFlag
264 tryToBind tv_set tv | tv `elemVarSet` tv_set = BindMe
265                     | otherwise              = AvoidMe
266
267
268 \end{code}
269
270
271 %************************************************************************
272 %*                                                                      *
273                 The workhorse
274 %*                                                                      *
275 %************************************************************************
276
277 \begin{code}
278 #ifdef DEBUG
279 badReftElts :: InternalReft -> [(Unique, (Coercion,Type))]
280 -- Return the BAD elements of the refinement
281 -- Should be empty; used in asserions only
282 badReftElts env
283   = filter (not . ok) (ufmToList env)
284   where
285     ok :: (Unique, (Coercion, Type)) -> Bool
286     ok (u, (co, ty)) | Just tv <- tcGetTyVar_maybe ty1
287                      = varUnique tv == u && ty `tcEqType` ty2 
288                      | otherwise = False
289         where
290           (ty1,ty2) = coercionKind co
291 #endif
292
293 emptyInternalReft :: InternalReft
294 emptyInternalReft = emptyVarEnv
295
296 unify :: InternalReft           -- An existing substitution to extend
297       -> Coercion       -- Witness of their equality 
298       -> Type -> Type   -- Types to be unified, and witness of their equality
299       -> UM InternalReft                -- Just the extended substitution, 
300                                 -- Nothing if unification failed
301 -- We do not require the incoming substitution to be idempotent,
302 -- nor guarantee that the outgoing one is.  That's fixed up by
303 -- the wrappers.
304
305 -- PRE-CONDITION: in the call (unify r co ty1 ty2), we know that
306 --                      co :: (ty1:=:ty2)
307
308 -- Respects newtypes, PredTypes
309
310 unify subst co ty1 ty2 = -- pprTrace "unify" (ppr subst <+> pprParendType ty1 <+> pprParendType ty2) $
311                          unify_ subst co ty1 ty2
312
313 -- in unify_, any NewTcApps/Preds should be taken at face value
314 unify_ subst co (TyVarTy tv1) ty2  = uVar False subst co tv1 ty2
315 unify_ subst co ty1 (TyVarTy tv2)  = uVar True  subst co tv2 ty1
316
317 unify_ subst co ty1 ty2 | Just ty1' <- tcView ty1 = unify subst co ty1' ty2
318 unify_ subst co ty1 ty2 | Just ty2' <- tcView ty2 = unify subst co ty1 ty2'
319
320 unify_ subst co (PredTy p1) (PredTy p2) = unify_pred subst co p1 p2
321
322 unify_ subst co t1@(TyConApp tyc1 tys1) t2@(TyConApp tyc2 tys2) 
323   | tyc1 == tyc2 = unify_tys subst co tys1 tys2
324
325 unify_ subst co (FunTy ty1a ty1b) (FunTy ty2a ty2b) 
326   = do  { let [co1,co2] = decomposeCo 2 co
327         ; subst' <- unify subst co1 ty1a ty2a
328         ; unify subst' co2 ty1b ty2b }
329
330         -- Applications need a bit of care!
331         -- They can match FunTy and TyConApp, so use splitAppTy_maybe
332         -- NB: we've already dealt with type variables and Notes,
333         -- so if one type is an App the other one jolly well better be too
334 unify_ subst co (AppTy ty1a ty1b) ty2
335   | Just (ty2a, ty2b) <- repSplitAppTy_maybe ty2
336   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
337         ; unify subst' (mkRightCoercion co) ty1b ty2b }
338
339 unify_ subst co ty1 (AppTy ty2a ty2b)
340   | Just (ty1a, ty1b) <- repSplitAppTy_maybe ty1
341   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
342         ; unify subst' (mkRightCoercion co) ty1b ty2b }
343
344 unify_ subst co ty1 ty2 = failWith (misMatch ty1 ty2)
345         -- ForAlls??
346
347
348 ------------------------------
349 unify_pred subst co (ClassP c1 tys1) (ClassP c2 tys2)
350   | c1 == c2 = unify_tys subst co tys1 tys2
351 unify_pred subst co (IParam n1 t1) (IParam n2 t2)
352   | n1 == n2 = unify subst co t1 t2
353 unify_pred subst co p1 p2 = failWith (misMatch (PredTy p1) (PredTy p2))
354  
355 ------------------------------
356 unify_tys :: InternalReft -> Coercion -> [Type] -> [Type] -> UM InternalReft
357 unify_tys subst co xs ys
358   = unifyList subst (decomposeCo (length xs) co) xs ys
359
360 unifyList :: InternalReft -> [Coercion] -> [Type] -> [Type] -> UM InternalReft
361 unifyList subst orig_cos orig_xs orig_ys
362   = go subst orig_cos orig_xs orig_ys
363   where
364     go subst _        []     []     = return subst
365     go subst (co:cos) (x:xs) (y:ys) = do { subst' <- unify subst co x y
366                                          ; go subst' cos xs ys }
367     go subst _ _ _ = failWith (lengthMisMatch orig_xs orig_ys)
368
369 ---------------------------------
370 uVar :: Bool            -- Swapped
371      -> InternalReft    -- An existing substitution to extend
372      -> Coercion
373      -> TyVar           -- Type variable to be unified
374      -> Type            -- with this type
375      -> UM InternalReft
376
377 -- PRE-CONDITION: in the call (uVar swap r co tv1 ty), we know that
378 --      if swap=False   co :: (tv1:=:ty)
379 --      if swap=True    co :: (ty:=:tv1)
380
381 uVar swap subst co tv1 ty
382  = -- Check to see whether tv1 is refined by the substitution
383    case (lookupVarEnv subst tv1) of
384
385      -- Yes, call back into unify'
386      Just (co',ty')     -- co' :: (tv1:=:ty')
387         | swap          -- co :: (ty:=:tv1)
388         -> unify subst (mkTransCoercion co co') ty ty' 
389         | otherwise     -- co :: (tv1:=:ty)
390         -> unify subst (mkTransCoercion (mkSymCoercion co') co) ty' ty
391
392      -- No, continue
393      Nothing -> uUnrefined swap subst co
394                            tv1 ty ty
395
396
397 uUnrefined :: Bool                -- Whether the input is swapped
398            -> InternalReft        -- An existing substitution to extend
399            -> Coercion
400            -> TyVar               -- Type variable to be unified
401            -> Type                -- with this type
402            -> Type                -- (de-noted version)
403            -> UM InternalReft
404
405 -- We know that tv1 isn't refined
406 -- PRE-CONDITION: in the call (uUnrefined False r co tv1 ty2 ty2'), we know that
407 --      co :: tv1:=:ty2
408 -- and if the first argument is True instead, we know
409 --      co :: ty2:=:tv1
410
411 uUnrefined swap subst co tv1 ty2 ty2'
412   | Just ty2'' <- tcView ty2'
413   = uUnrefined swap subst co tv1 ty2 ty2''      -- Unwrap synonyms
414                 -- This is essential, in case we have
415                 --      type Foo a = a
416                 -- and then unify a :=: Foo a
417
418 uUnrefined swap subst co tv1 ty2 (TyVarTy tv2)
419   | tv1 == tv2          -- Same type variable
420   = return subst
421
422     -- Check to see whether tv2 is refined
423   | Just (co',ty') <- lookupVarEnv subst tv2    -- co' :: tv2:=:ty'
424   = uUnrefined False subst (mkTransCoercion (doSwap swap co) co') tv1 ty' ty'
425
426   -- So both are unrefined; next, see if the kinds force the direction
427   | eqKind k1 k2        -- Can update either; so check the bind-flags
428   = do  { b1 <- tvBindFlag tv1
429         ; b2 <- tvBindFlag tv2
430         ; case (b1,b2) of
431             (BindMe, _)          -> bind swap tv1 ty2
432
433             (AvoidMe, BindMe)    -> bind (not swap) tv2 ty1
434             (AvoidMe, _)         -> bind swap tv1 ty2
435
436             (WildCard, WildCard) -> return subst
437             (WildCard, Skolem)   -> return subst
438             (WildCard, _)        -> bind (not swap) tv2 ty1
439
440             (Skolem, WildCard)   -> return subst
441             (Skolem, Skolem)     -> failWith (misMatch ty1 ty2)
442             (Skolem, _)          -> bind (not swap) tv2 ty1
443         }
444
445   | k1 `isSubKind` k2 = bindTv (not swap) subst co tv2 ty1  -- Must update tv2
446   | k2 `isSubKind` k1 = bindTv swap subst co tv1 ty2        -- Must update tv1
447
448   | otherwise = failWith (kindMisMatch tv1 ty2)
449   where
450     ty1 = TyVarTy tv1
451     k1 = tyVarKind tv1
452     k2 = tyVarKind tv2
453     bind swap tv ty = extendReft swap subst tv co ty
454
455 uUnrefined swap subst co tv1 ty2 ty2'   -- ty2 is not a type variable
456   | tv1 `elemVarSet` substTvSet subst (tyVarsOfType ty2')
457   = failWith (occursCheck tv1 ty2)      -- Occurs check
458   | not (k2 `isSubKind` k1)
459   = failWith (kindMisMatch tv1 ty2)     -- Kind check
460   | otherwise
461   = bindTv swap subst co tv1 ty2                -- Bind tyvar to the synonym if poss
462   where
463     k1 = tyVarKind tv1
464     k2 = typeKind ty2'
465
466 substTvSet :: InternalReft -> TyVarSet -> TyVarSet
467 -- Apply the non-idempotent substitution to a set of type variables,
468 -- remembering that the substitution isn't necessarily idempotent
469 substTvSet subst tvs
470   = foldVarSet (unionVarSet . get) emptyVarSet tvs
471   where
472     get tv = case lookupVarEnv subst tv of
473                 Nothing     -> unitVarSet tv
474                 Just (_,ty) -> substTvSet subst (tyVarsOfType ty)
475
476 bindTv swap subst co tv ty      -- ty is not a type variable
477   = do  { b <- tvBindFlag tv
478         ; case b of
479             Skolem   -> failWith (misMatch (TyVarTy tv) ty)
480             WildCard -> return subst
481             other    -> extendReft swap subst tv co ty
482         }
483
484 doSwap :: Bool -> Coercion -> Coercion
485 doSwap swap co = if swap then mkSymCoercion co else co
486
487 extendReft :: Bool 
488            -> InternalReft 
489            -> TyVar 
490            -> Coercion 
491            -> Type 
492            -> UM InternalReft
493 extendReft swap subst tv  co ty
494   = ASSERT2( (coercionKindPredTy co1 `tcEqType` mkCoKind (mkTyVarTy tv) ty), 
495           (text "Refinement invariant failure: co = " <+> ppr co1  <+> ppr (coercionKindPredTy co1) $$ text "subst = " <+> ppr tv <+> ppr (mkCoKind (mkTyVarTy tv) ty)) )
496     return (extendVarEnv subst tv (co1, ty))
497   where
498     co1 = doSwap swap co
499
500 \end{code}
501
502 %************************************************************************
503 %*                                                                      *
504                 Unification monad
505 %*                                                                      *
506 %************************************************************************
507
508 \begin{code}
509 data BindFlag 
510   = BindMe      -- A regular type variable
511   | AvoidMe     -- Like BindMe but, given the choice, avoid binding it
512
513   | Skolem      -- This type variable is a skolem constant
514                 -- Don't bind it; it only matches itself
515
516   | WildCard    -- This type variable matches anything,
517                 -- and does not affect the substitution
518
519 newtype UM a = UM { unUM :: (TyVar -> BindFlag)
520                          -> MaybeErr Message a }
521
522 instance Monad UM where
523   return a = UM (\tvs -> Succeeded a)
524   fail s   = UM (\tvs -> Failed (text s))
525   m >>= k  = UM (\tvs -> case unUM m tvs of
526                            Failed err -> Failed err
527                            Succeeded v  -> unUM (k v) tvs)
528
529 initUM :: (TyVar -> BindFlag) -> UM a -> MaybeErr Message a
530 initUM badtvs um = unUM um badtvs
531
532 tvBindFlag :: TyVar -> UM BindFlag
533 tvBindFlag tv = UM (\tv_fn -> Succeeded (tv_fn tv))
534
535 failWith :: Message -> UM a
536 failWith msg = UM (\tv_fn -> Failed msg)
537
538 maybeErrToMaybe :: MaybeErr fail succ -> Maybe succ
539 maybeErrToMaybe (Succeeded a) = Just a
540 maybeErrToMaybe (Failed m)    = Nothing
541 \end{code}
542
543
544 %************************************************************************
545 %*                                                                      *
546                 Error reporting
547         We go to a lot more trouble to tidy the types
548         in TcUnify.  Maybe we'll end up having to do that
549         here too, but I'll leave it for now.
550 %*                                                                      *
551 %************************************************************************
552
553 \begin{code}
554 misMatch t1 t2
555   = ptext SLIT("Can't match types") <+> quotes (ppr t1) <+> 
556     ptext SLIT("and") <+> quotes (ppr t2)
557
558 lengthMisMatch tys1 tys2
559   = sep [ptext SLIT("Can't match unequal length lists"), 
560          nest 2 (ppr tys1), nest 2 (ppr tys2) ]
561
562 kindMisMatch tv1 t2
563   = vcat [ptext SLIT("Can't match kinds") <+> quotes (ppr (tyVarKind tv1)) <+> 
564             ptext SLIT("and") <+> quotes (ppr (typeKind t2)),
565           ptext SLIT("when matching") <+> quotes (ppr tv1) <+> 
566                 ptext SLIT("with") <+> quotes (ppr t2)]
567
568 occursCheck tv ty
569   = hang (ptext SLIT("Can't construct the infinite type"))
570        2 (ppr tv <+> equals <+> ppr ty)
571 \end{code}