Change type of TcGadt.refineType, plus consequences
[ghc-hetmet.git] / compiler / typecheck / TcGadt.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 %************************************************************************
7 %*                                                                      *
8                 Type refinement for GADTs
9 %*                                                                      *
10 %************************************************************************
11
12 \begin{code}
13 module TcGadt (
14         Refinement, emptyRefinement, gadtRefine, 
15         refineType, refineResType,
16         dataConCanMatch,
17         tcUnifyTys, BindFlag(..)
18   ) where
19
20 #include "HsVersions.h"
21
22 import HsSyn
23 import Coercion
24 import TypeRep
25 import DataCon
26 import Var
27 import VarEnv
28 import VarSet
29 import ErrUtils
30 import Maybes
31 import Control.Monad
32 import Outputable
33
34 #ifdef DEBUG
35 import Unique
36 import UniqFM
37 import TcType
38 #endif
39 \end{code}
40
41
42 %************************************************************************
43 %*                                                                      *
44                 What a refinement is
45 %*                                                                      *
46 %************************************************************************
47
48 \begin{code}
49 data Refinement = Reft InScopeSet InternalReft 
50 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
51 -- Not necessarily idemopotent
52
53 instance Outputable Refinement where
54   ppr (Reft in_scope env)
55     = ptext SLIT("Refinement") <+>
56         braces (ppr env)
57
58 emptyRefinement :: Refinement
59 emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
60
61
62 refineType :: Refinement -> Type -> Maybe (Coercion, Type)
63 -- Apply the refinement to the type.
64 -- If (refineType r ty) = (co, ty')
65 -- Then co :: ty:=:ty'
66 -- Nothing => the refinement does nothing to this type
67 refineType (Reft in_scope env) ty
68   | not (isEmptyVarEnv env),            -- Common case
69     any (`elemVarEnv` env) (varSetElems (tyVarsOfType ty))
70   = Just (substTy co_subst ty, substTy tv_subst ty)
71   | otherwise
72   = Nothing     -- The type doesn't mention any refined type variables
73   where
74     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
75     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
76  
77 refineResType :: Refinement -> Type -> (HsWrapper, Type)
78 -- Like refineType, but returns the 'sym' coercion
79 -- If (refineResType r ty) = (co, ty')
80 -- Then co :: ty':=:ty
81 -- It's convenient to return a HsWrapper here
82 refineResType reft ty
83   = case refineType reft ty of
84         Just (co, ty1) -> (WpCo (mkSymCoercion co), ty1)
85         Nothing        -> (idHsWrapper,             ty)
86 \end{code}
87
88
89 %************************************************************************
90 %*                                                                      *
91                 Generating a type refinement
92 %*                                                                      *
93 %************************************************************************
94
95 \begin{code}
96 gadtRefine :: Refinement
97            -> [TyVar]   -- Bind these by preference
98            -> [CoVar]
99            -> MaybeErr Message Refinement
100 \end{code}
101
102 (gadtRefine cvs) takes a list of coercion variables, and returns a
103 list of coercions, obtained by unifying the types equated by the
104 incoming coercions.  The returned coercions all have kinds of form
105 (a:=:ty), where a is a rigid type variable.
106
107 Example:
108   gadtRefine [c :: (a,Int):=:(Bool,b)]
109   = [ right (left c) :: a:=:Bool,       
110       sym (right c)  :: b:=:Int ]
111
112 That is, given evidence 'c' that (a,Int)=(Bool,b), it returns derived
113 evidence in easy-to-use form.  In particular, given any e::ty, we know 
114 that:
115         e `cast` ty[right (left c)/a, sym (right c)/b]
116         :: ty [Bool/a, Int/b]
117       
118 Two refinements:
119
120 - It can fail, if the coercion is unsatisfiable.
121
122 - It's biased, by being given a set of type variables to bind
123   when there is a choice. Example:
124         MkT :: forall a. a -> T [a]
125         f :: forall b. T [b] -> b
126         f x = let v = case x of { MkT y -> y }
127               in ...
128   Here we want to bind [a->b], not the other way round, because
129   in this example the return type is wobbly, and we want the
130   program to typecheck
131
132
133 -- E.g. (a, Bool, right (left c))
134 -- INVARIANT: in the triple (tv, ty, co), we have (co :: tv:=:ty)
135 -- The result is idempotent: the 
136
137 \begin{code}
138 gadtRefine (Reft in_scope env1) 
139            ex_tvs co_vars
140 -- Precondition: fvs( co_vars ) # env1
141 -- That is, the kinds of the co_vars are a
142 -- fixed point of the  incoming refinement
143
144   = ASSERT2( not $ any (`elemVarEnv` env1) (varSetElems $ tyVarsOfTypes $ map tyVarKind co_vars),
145              ppr env1 $$ ppr co_vars $$ ppr (map tyVarKind co_vars) )
146     initUM (tryToBind tv_set) $
147     do  {       -- Run the unifier, starting with an empty env
148         ; env2 <- foldM do_one emptyInternalReft co_vars
149
150                 -- Find the fixed point of the resulting 
151                 -- non-idempotent substitution
152         ; let tmp_env = env1 `plusVarEnv` env2
153               out_env = fixTvCoEnv in_scope' tmp_env
154         ; WARN( not (null (badReftElts tmp_env)), ppr (badReftElts tmp_env) $$ ppr tmp_env )
155           WARN( not (null (badReftElts out_env)), ppr (badReftElts out_env) $$ ppr out_env )
156           return (Reft in_scope' out_env) }
157   where
158     tv_set = mkVarSet ex_tvs
159     in_scope' = foldr extend in_scope co_vars
160     extend co_var in_scope
161         = extendInScopeSetSet (extendInScopeSet in_scope co_var)
162                               (tyVarsOfType (tyVarKind co_var))
163         
164     do_one reft co_var = unify reft (TyVarTy co_var) ty1 ty2
165         where
166            (ty1,ty2) = splitCoercionKind (tyVarKind co_var)
167 \end{code} 
168
169 %************************************************************************
170 %*                                                                      *
171                 Unification
172 %*                                                                      *
173 %************************************************************************
174
175 \begin{code}
176 tcUnifyTys :: (TyVar -> BindFlag)
177            -> [Type] -> [Type]
178            -> Maybe TvSubst     -- A regular one-shot substitution
179 -- The two types may have common type variables, and indeed do so in the
180 -- second call to tcUnifyTys in FunDeps.checkClsFD
181 --
182 -- We implement tcUnifyTys using the evidence-generating 'unify' function
183 -- in this module, even though we don't need to generate any evidence.
184 -- This is simply to avoid replicating all all the code for unify
185 tcUnifyTys bind_fn tys1 tys2
186   = maybeErrToMaybe $ initUM bind_fn $
187     do { reft <- unifyList emptyInternalReft cos tys1 tys2
188
189         -- Find the fixed point of the resulting non-idempotent substitution
190         ; let in_scope = mkInScopeSet (tvs1 `unionVarSet` tvs2)
191               tv_env   = fixTvSubstEnv in_scope (mapVarEnv snd reft)
192
193         ; return (mkTvSubst in_scope tv_env) }
194   where
195     tvs1 = tyVarsOfTypes tys1
196     tvs2 = tyVarsOfTypes tys2
197     cos  = zipWith mkUnsafeCoercion tys1 tys2
198
199
200 ----------------------------
201 fixTvCoEnv :: InScopeSet -> InternalReft -> InternalReft
202         -- Find the fixed point of a Refinement
203         -- (assuming it has no loops!)
204 fixTvCoEnv in_scope env
205   = fixpt
206   where
207     fixpt         = mapVarEnv step env
208
209     step (co, ty) = case refineType (Reft in_scope fixpt) ty of
210                         Nothing         -> (co,                     ty)
211                         Just (co', ty') -> (mkTransCoercion co co', ty')
212       -- Apply fixpt one step:
213       -- Use refineType to get a substituted type, ty', and a coercion, co_fn,
214       -- which justifies the substitution.  If the coercion is not the identity
215       -- then use transitivity with the original coercion
216
217 -----------------------------
218 fixTvSubstEnv :: InScopeSet -> TvSubstEnv -> TvSubstEnv
219 fixTvSubstEnv in_scope env
220   = fixpt 
221   where
222     fixpt = mapVarEnv (substTy (mkTvSubst in_scope fixpt)) env
223
224 ----------------------------
225 dataConCanMatch :: DataCon -> [Type] -> Bool
226 -- Returns True iff the data con can match a scrutinee of type (T tys)
227 --                  where T is the type constructor for the data con
228 --
229 -- Instantiate the equations and try to unify them
230 dataConCanMatch con tys
231   = isJust (tcUnifyTys (\tv -> BindMe) 
232                        (map (substTyVar subst . fst) eq_spec)
233                        (map snd eq_spec))
234   where
235     dc_tvs  = dataConUnivTyVars con
236     eq_spec = dataConEqSpec con
237     subst   = zipTopTvSubst dc_tvs tys
238
239 ----------------------------
240 tryToBind :: TyVarSet -> TyVar -> BindFlag
241 tryToBind tv_set tv | tv `elemVarSet` tv_set = BindMe
242                     | otherwise              = AvoidMe
243
244
245 \end{code}
246
247
248 %************************************************************************
249 %*                                                                      *
250                 The workhorse
251 %*                                                                      *
252 %************************************************************************
253
254 \begin{code}
255 type InternalReft = TyVarEnv (Coercion, Type)
256
257 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
258 -- Not necessarily idemopotent
259
260 #ifdef DEBUG
261 badReftElts :: InternalReft -> [(Unique, (Coercion,Type))]
262 -- Return the BAD elements of the refinement
263 -- Should be empty; used in asserions only
264 badReftElts env
265   = filter (not . ok) (ufmToList env)
266   where
267     ok :: (Unique, (Coercion, Type)) -> Bool
268     ok (u, (co, ty)) | Just tv <- tcGetTyVar_maybe ty1
269                      = varUnique tv == u && ty `tcEqType` ty2 
270                      | otherwise = False
271         where
272           (ty1,ty2) = coercionKind co
273 #endif
274
275 emptyInternalReft :: InternalReft
276 emptyInternalReft = emptyVarEnv
277
278 unify :: InternalReft           -- An existing substitution to extend
279       -> Coercion       -- Witness of their equality 
280       -> Type -> Type   -- Types to be unified, and witness of their equality
281       -> UM InternalReft                -- Just the extended substitution, 
282                                 -- Nothing if unification failed
283 -- We do not require the incoming substitution to be idempotent,
284 -- nor guarantee that the outgoing one is.  That's fixed up by
285 -- the wrappers.
286
287 -- PRE-CONDITION: in the call (unify r co ty1 ty2), we know that
288 --                      co :: (ty1:=:ty2)
289
290 -- Respects newtypes, PredTypes
291
292 unify subst co ty1 ty2 = -- pprTrace "unify" (ppr subst <+> pprParendType ty1 <+> pprParendType ty2) $
293                          unify_ subst co ty1 ty2
294
295 -- in unify_, any NewTcApps/Preds should be taken at face value
296 unify_ subst co (TyVarTy tv1) ty2  = uVar False subst co tv1 ty2
297 unify_ subst co ty1 (TyVarTy tv2)  = uVar True  subst co tv2 ty1
298
299 unify_ subst co ty1 ty2 | Just ty1' <- tcView ty1 = unify subst co ty1' ty2
300 unify_ subst co ty1 ty2 | Just ty2' <- tcView ty2 = unify subst co ty1 ty2'
301
302 unify_ subst co (PredTy p1) (PredTy p2) = unify_pred subst co p1 p2
303
304 unify_ subst co t1@(TyConApp tyc1 tys1) t2@(TyConApp tyc2 tys2) 
305   | tyc1 == tyc2 = unify_tys subst co tys1 tys2
306
307 unify_ subst co (FunTy ty1a ty1b) (FunTy ty2a ty2b) 
308   = do  { let [co1,co2] = decomposeCo 2 co
309         ; subst' <- unify subst co1 ty1a ty2a
310         ; unify subst' co2 ty1b ty2b }
311
312         -- Applications need a bit of care!
313         -- They can match FunTy and TyConApp, so use splitAppTy_maybe
314         -- NB: we've already dealt with type variables and Notes,
315         -- so if one type is an App the other one jolly well better be too
316 unify_ subst co (AppTy ty1a ty1b) ty2
317   | Just (ty2a, ty2b) <- repSplitAppTy_maybe ty2
318   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
319         ; unify subst' (mkRightCoercion co) ty1b ty2b }
320
321 unify_ subst co ty1 (AppTy ty2a ty2b)
322   | Just (ty1a, ty1b) <- repSplitAppTy_maybe ty1
323   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
324         ; unify subst' (mkRightCoercion co) ty1b ty2b }
325
326 unify_ subst co ty1 ty2 = failWith (misMatch ty1 ty2)
327         -- ForAlls??
328
329
330 ------------------------------
331 unify_pred subst co (ClassP c1 tys1) (ClassP c2 tys2)
332   | c1 == c2 = unify_tys subst co tys1 tys2
333 unify_pred subst co (IParam n1 t1) (IParam n2 t2)
334   | n1 == n2 = unify subst co t1 t2
335 unify_pred subst co p1 p2 = failWith (misMatch (PredTy p1) (PredTy p2))
336  
337 ------------------------------
338 unify_tys :: InternalReft -> Coercion -> [Type] -> [Type] -> UM InternalReft
339 unify_tys subst co xs ys
340   = unifyList subst (decomposeCo (length xs) co) xs ys
341
342 unifyList :: InternalReft -> [Coercion] -> [Type] -> [Type] -> UM InternalReft
343 unifyList subst orig_cos orig_xs orig_ys
344   = go subst orig_cos orig_xs orig_ys
345   where
346     go subst _        []     []     = return subst
347     go subst (co:cos) (x:xs) (y:ys) = do { subst' <- unify subst co x y
348                                          ; go subst' cos xs ys }
349     go subst _ _ _ = failWith (lengthMisMatch orig_xs orig_ys)
350
351 ---------------------------------
352 uVar :: Bool            -- Swapped
353      -> InternalReft    -- An existing substitution to extend
354      -> Coercion
355      -> TyVar           -- Type variable to be unified
356      -> Type            -- with this type
357      -> UM InternalReft
358
359 -- PRE-CONDITION: in the call (uVar swap r co tv1 ty), we know that
360 --      if swap=False   co :: (tv1:=:ty)
361 --      if swap=True    co :: (ty:=:tv1)
362
363 uVar swap subst co tv1 ty
364  = -- Check to see whether tv1 is refined by the substitution
365    case (lookupVarEnv subst tv1) of
366
367      -- Yes, call back into unify'
368      Just (co',ty')     -- co' :: (tv1:=:ty')
369         | swap          -- co :: (ty:=:tv1)
370         -> unify subst (mkTransCoercion co co') ty ty' 
371         | otherwise     -- co :: (tv1:=:ty)
372         -> unify subst (mkTransCoercion (mkSymCoercion co') co) ty' ty
373
374      -- No, continue
375      Nothing -> uUnrefined swap subst co
376                            tv1 ty ty
377
378
379 uUnrefined :: Bool                -- Whether the input is swapped
380            -> InternalReft        -- An existing substitution to extend
381            -> Coercion
382            -> TyVar               -- Type variable to be unified
383            -> Type                -- with this type
384            -> Type                -- (de-noted version)
385            -> UM InternalReft
386
387 -- We know that tv1 isn't refined
388 -- PRE-CONDITION: in the call (uUnrefined False r co tv1 ty2 ty2'), we know that
389 --      co :: tv1:=:ty2
390 -- and if the first argument is True instead, we know
391 --      co :: ty2:=:tv1
392
393 uUnrefined swap subst co tv1 ty2 ty2'
394   | Just ty2'' <- tcView ty2'
395   = uUnrefined swap subst co tv1 ty2 ty2''      -- Unwrap synonyms
396                 -- This is essential, in case we have
397                 --      type Foo a = a
398                 -- and then unify a :=: Foo a
399
400 uUnrefined swap subst co tv1 ty2 (TyVarTy tv2)
401   | tv1 == tv2          -- Same type variable
402   = return subst
403
404     -- Check to see whether tv2 is refined
405   | Just (co',ty') <- lookupVarEnv subst tv2    -- co' :: tv2:=:ty'
406   = uUnrefined False subst (mkTransCoercion (doSwap swap co) co') tv1 ty' ty'
407
408   -- So both are unrefined; next, see if the kinds force the direction
409   | eqKind k1 k2        -- Can update either; so check the bind-flags
410   = do  { b1 <- tvBindFlag tv1
411         ; b2 <- tvBindFlag tv2
412         ; case (b1,b2) of
413             (BindMe, _)          -> bind swap tv1 ty2
414
415             (AvoidMe, BindMe)    -> bind (not swap) tv2 ty1
416             (AvoidMe, _)         -> bind swap tv1 ty2
417
418             (WildCard, WildCard) -> return subst
419             (WildCard, Skolem)   -> return subst
420             (WildCard, _)        -> bind (not swap) tv2 ty1
421
422             (Skolem, WildCard)   -> return subst
423             (Skolem, Skolem)     -> failWith (misMatch ty1 ty2)
424             (Skolem, _)          -> bind (not swap) tv2 ty1
425         }
426
427   | k1 `isSubKind` k2 = bindTv (not swap) subst co tv2 ty1  -- Must update tv2
428   | k2 `isSubKind` k1 = bindTv swap subst co tv1 ty2        -- Must update tv1
429
430   | otherwise = failWith (kindMisMatch tv1 ty2)
431   where
432     ty1 = TyVarTy tv1
433     k1 = tyVarKind tv1
434     k2 = tyVarKind tv2
435     bind swap tv ty = extendReft swap subst tv co ty
436
437 uUnrefined swap subst co tv1 ty2 ty2'   -- ty2 is not a type variable
438   | tv1 `elemVarSet` substTvSet subst (tyVarsOfType ty2')
439   = failWith (occursCheck tv1 ty2)      -- Occurs check
440   | not (k2 `isSubKind` k1)
441   = failWith (kindMisMatch tv1 ty2)     -- Kind check
442   | otherwise
443   = bindTv swap subst co tv1 ty2                -- Bind tyvar to the synonym if poss
444   where
445     k1 = tyVarKind tv1
446     k2 = typeKind ty2'
447
448 substTvSet :: InternalReft -> TyVarSet -> TyVarSet
449 -- Apply the non-idempotent substitution to a set of type variables,
450 -- remembering that the substitution isn't necessarily idempotent
451 substTvSet subst tvs
452   = foldVarSet (unionVarSet . get) emptyVarSet tvs
453   where
454     get tv = case lookupVarEnv subst tv of
455                 Nothing     -> unitVarSet tv
456                 Just (_,ty) -> substTvSet subst (tyVarsOfType ty)
457
458 bindTv swap subst co tv ty      -- ty is not a type variable
459   = do  { b <- tvBindFlag tv
460         ; case b of
461             Skolem   -> failWith (misMatch (TyVarTy tv) ty)
462             WildCard -> return subst
463             other    -> extendReft swap subst tv co ty
464         }
465
466 doSwap :: Bool -> Coercion -> Coercion
467 doSwap swap co = if swap then mkSymCoercion co else co
468
469 extendReft :: Bool 
470            -> InternalReft 
471            -> TyVar 
472            -> Coercion 
473            -> Type 
474            -> UM InternalReft
475 extendReft swap subst tv  co ty
476   = ASSERT2( (coercionKindPredTy co1 `tcEqType` mkCoKind (mkTyVarTy tv) ty), 
477           (text "Refinement invariant failure: co = " <+> ppr co1  <+> ppr (coercionKindPredTy co1) $$ text "subst = " <+> ppr tv <+> ppr (mkCoKind (mkTyVarTy tv) ty)) )
478     return (extendVarEnv subst tv (co1, ty))
479   where
480     co1 = doSwap swap co
481
482 \end{code}
483
484 %************************************************************************
485 %*                                                                      *
486                 Unification monad
487 %*                                                                      *
488 %************************************************************************
489
490 \begin{code}
491 data BindFlag 
492   = BindMe      -- A regular type variable
493   | AvoidMe     -- Like BindMe but, given the choice, avoid binding it
494
495   | Skolem      -- This type variable is a skolem constant
496                 -- Don't bind it; it only matches itself
497
498   | WildCard    -- This type variable matches anything,
499                 -- and does not affect the substitution
500
501 newtype UM a = UM { unUM :: (TyVar -> BindFlag)
502                          -> MaybeErr Message a }
503
504 instance Monad UM where
505   return a = UM (\tvs -> Succeeded a)
506   fail s   = UM (\tvs -> Failed (text s))
507   m >>= k  = UM (\tvs -> case unUM m tvs of
508                            Failed err -> Failed err
509                            Succeeded v  -> unUM (k v) tvs)
510
511 initUM :: (TyVar -> BindFlag) -> UM a -> MaybeErr Message a
512 initUM badtvs um = unUM um badtvs
513
514 tvBindFlag :: TyVar -> UM BindFlag
515 tvBindFlag tv = UM (\tv_fn -> Succeeded (tv_fn tv))
516
517 failWith :: Message -> UM a
518 failWith msg = UM (\tv_fn -> Failed msg)
519
520 maybeErrToMaybe :: MaybeErr fail succ -> Maybe succ
521 maybeErrToMaybe (Succeeded a) = Just a
522 maybeErrToMaybe (Failed m)    = Nothing
523 \end{code}
524
525
526 %************************************************************************
527 %*                                                                      *
528                 Error reporting
529         We go to a lot more trouble to tidy the types
530         in TcUnify.  Maybe we'll end up having to do that
531         here too, but I'll leave it for now.
532 %*                                                                      *
533 %************************************************************************
534
535 \begin{code}
536 misMatch t1 t2
537   = ptext SLIT("Can't match types") <+> quotes (ppr t1) <+> 
538     ptext SLIT("and") <+> quotes (ppr t2)
539
540 lengthMisMatch tys1 tys2
541   = sep [ptext SLIT("Can't match unequal length lists"), 
542          nest 2 (ppr tys1), nest 2 (ppr tys2) ]
543
544 kindMisMatch tv1 t2
545   = vcat [ptext SLIT("Can't match kinds") <+> quotes (ppr (tyVarKind tv1)) <+> 
546             ptext SLIT("and") <+> quotes (ppr (typeKind t2)),
547           ptext SLIT("when matching") <+> quotes (ppr tv1) <+> 
548                 ptext SLIT("with") <+> quotes (ppr t2)]
549
550 occursCheck tv ty
551   = hang (ptext SLIT("Can't construct the infinite type"))
552        2 (ppr tv <+> equals <+> ppr ty)
553 \end{code}