Make various assertions work when !DEBUG
[ghc-hetmet.git] / compiler / typecheck / TcGadt.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 %************************************************************************
7 %*                                                                      *
8                 Type refinement for GADTs
9 %*                                                                      *
10 %************************************************************************
11
12 \begin{code}
13 {-# OPTIONS -w #-}
14 -- The above warning supression flag is a temporary kludge.
15 -- While working on this module you are encouraged to remove it and fix
16 -- any warnings in the module. See
17 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
18 -- for details
19
20 module TcGadt (
21         Refinement, emptyRefinement, isEmptyRefinement, 
22         gadtRefine, 
23         refineType, refinePred, refineResType,
24         tcUnifyTys, BindFlag(..)
25   ) where
26
27 #include "HsVersions.h"
28
29 import HsSyn
30 import Coercion
31 import Type
32
33 import TypeRep
34 import Var
35 import VarEnv
36 import VarSet
37 import ErrUtils
38 import Maybes
39 import Control.Monad
40 import Outputable
41 import TcType
42 import Unique
43 import UniqFM
44 \end{code}
45
46
47 %************************************************************************
48 %*                                                                      *
49                 What a refinement is
50 %*                                                                      *
51 %************************************************************************
52
53 \begin{code}
54 data Refinement = Reft InScopeSet InternalReft 
55
56 type InternalReft = TyVarEnv (Coercion, Type)
57 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
58 -- Not necessarily idemopotent
59
60 instance Outputable Refinement where
61   ppr (Reft in_scope env)
62     = ptext SLIT("Refinement") <+>
63         braces (ppr env)
64
65 emptyRefinement :: Refinement
66 emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
67
68 isEmptyRefinement :: Refinement -> Bool
69 isEmptyRefinement (Reft _ env) = isEmptyVarEnv env
70
71 refineType :: Refinement -> Type -> Maybe (Coercion, Type)
72 -- Apply the refinement to the type.
73 -- If (refineType r ty) = (co, ty')
74 -- Then co :: ty:=:ty'
75 -- Nothing => the refinement does nothing to this type
76 refineType (Reft in_scope env) ty
77   | not (isEmptyVarEnv env),            -- Common case
78     any (`elemVarEnv` env) (varSetElems (tyVarsOfType ty))
79   = Just (substTy co_subst ty, substTy tv_subst ty)
80   | otherwise
81   = Nothing     -- The type doesn't mention any refined type variables
82   where
83     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
84     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
85  
86 refinePred :: Refinement -> PredType -> Maybe (Coercion, PredType)
87 refinePred (Reft in_scope env) pred
88   | not (isEmptyVarEnv env),            -- Common case
89     any (`elemVarEnv` env) (varSetElems (tyVarsOfPred pred))
90   = Just (mkPredTy (substPred co_subst pred), substPred tv_subst pred)
91   | otherwise
92   = Nothing     -- The type doesn't mention any refined type variables
93   where
94     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
95     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
96  
97 refineResType :: Refinement -> Type -> (HsWrapper, Type)
98 -- Like refineType, but returns the 'sym' coercion
99 -- If (refineResType r ty) = (co, ty')
100 -- Then co :: ty':=:ty
101 -- It's convenient to return a HsWrapper here
102 refineResType reft ty
103   = case refineType reft ty of
104         Just (co, ty1) -> (WpCo (mkSymCoercion co), ty1)
105         Nothing        -> (idHsWrapper,             ty)
106 \end{code}
107
108
109 %************************************************************************
110 %*                                                                      *
111                 Generating a type refinement
112 %*                                                                      *
113 %************************************************************************
114
115 \begin{code}
116 gadtRefine :: Refinement
117            -> [TyVar]   -- Bind these by preference
118            -> [CoVar]
119            -> MaybeErr Message Refinement
120 \end{code}
121
122 (gadtRefine cvs) takes a list of coercion variables, and returns a
123 list of coercions, obtained by unifying the types equated by the
124 incoming coercions.  The returned coercions all have kinds of form
125 (a:=:ty), where a is a rigid type variable.
126
127 Example:
128   gadtRefine [c :: (a,Int):=:(Bool,b)]
129   = [ right (left c) :: a:=:Bool,       
130       sym (right c)  :: b:=:Int ]
131
132 That is, given evidence 'c' that (a,Int)=(Bool,b), it returns derived
133 evidence in easy-to-use form.  In particular, given any e::ty, we know 
134 that:
135         e `cast` ty[right (left c)/a, sym (right c)/b]
136         :: ty [Bool/a, Int/b]
137       
138 Two refinements:
139
140 - It can fail, if the coercion is unsatisfiable.
141
142 - It's biased, by being given a set of type variables to bind
143   when there is a choice. Example:
144         MkT :: forall a. a -> T [a]
145         f :: forall b. T [b] -> b
146         f x = let v = case x of { MkT y -> y }
147               in ...
148   Here we want to bind [a->b], not the other way round, because
149   in this example the return type is wobbly, and we want the
150   program to typecheck
151
152
153 -- E.g. (a, Bool, right (left c))
154 -- INVARIANT: in the triple (tv, ty, co), we have (co :: tv:=:ty)
155 -- The result is idempotent: the 
156
157 \begin{code}
158 gadtRefine (Reft in_scope env1) 
159            ex_tvs co_vars
160 -- Precondition: fvs( co_vars ) # env1
161 -- That is, the kinds of the co_vars are a
162 -- fixed point of the incoming refinement
163
164   = ASSERT2( not $ any (`elemVarEnv` env1) (varSetElems $ tyVarsOfTypes $ map tyVarKind co_vars),
165              ppr env1 $$ ppr co_vars $$ ppr (map tyVarKind co_vars) )
166     initUM (tryToBind tv_set) $
167     do  {       -- Run the unifier, starting with an empty env
168         ; env2 <- foldM do_one emptyInternalReft co_vars
169
170                 -- Find the fixed point of the resulting 
171                 -- non-idempotent substitution
172         ; let tmp_env = env1 `plusVarEnv` env2
173               out_env = fixTvCoEnv in_scope' tmp_env
174         ; WARN( not (null (badReftElts tmp_env)), ppr (badReftElts tmp_env) $$ ppr tmp_env )
175           WARN( not (null (badReftElts out_env)), ppr (badReftElts out_env) $$ ppr out_env )
176           return (Reft in_scope' out_env) }
177   where
178     tv_set = mkVarSet ex_tvs
179     in_scope' = foldr extend in_scope co_vars
180
181         -- For each co_var, add it *and* the tyvars it mentions, to in_scope
182     extend co_var in_scope
183         = extendInScopeSetSet in_scope $
184           extendVarSet (tyVarsOfType (tyVarKind co_var)) co_var
185         
186     do_one reft co_var = unify reft (TyVarTy co_var) ty1 ty2
187         where
188            (ty1,ty2) = splitCoercionKind (tyVarKind co_var)
189 \end{code}
190
191 %************************************************************************
192 %*                                                                      *
193                 Unification
194 %*                                                                      *
195 %************************************************************************
196
197 \begin{code}
198 tcUnifyTys :: (TyVar -> BindFlag)
199            -> [Type] -> [Type]
200            -> Maybe TvSubst     -- A regular one-shot substitution
201 -- The two types may have common type variables, and indeed do so in the
202 -- second call to tcUnifyTys in FunDeps.checkClsFD
203 --
204 -- We implement tcUnifyTys using the evidence-generating 'unify' function
205 -- in this module, even though we don't need to generate any evidence.
206 -- This is simply to avoid replicating all all the code for unify
207 tcUnifyTys bind_fn tys1 tys2
208   = maybeErrToMaybe $ initUM bind_fn $
209     do { reft <- unifyList emptyInternalReft cos tys1 tys2
210
211         -- Find the fixed point of the resulting non-idempotent substitution
212         ; let in_scope = mkInScopeSet (tvs1 `unionVarSet` tvs2)
213               tv_env   = fixTvSubstEnv in_scope (mapVarEnv snd reft)
214
215         ; return (mkTvSubst in_scope tv_env) }
216   where
217     tvs1 = tyVarsOfTypes tys1
218     tvs2 = tyVarsOfTypes tys2
219     cos  = zipWith mkUnsafeCoercion tys1 tys2
220
221
222 ----------------------------
223 fixTvCoEnv :: InScopeSet -> InternalReft -> InternalReft
224         -- Find the fixed point of a Refinement
225         -- (assuming it has no loops!)
226 fixTvCoEnv in_scope env
227   = fixpt
228   where
229     fixpt         = mapVarEnv step env
230
231     step (co, ty) = case refineType (Reft in_scope fixpt) ty of
232                         Nothing         -> (co,                     ty)
233                         Just (co', ty') -> (mkTransCoercion co co', ty')
234       -- Apply fixpt one step:
235       -- Use refineType to get a substituted type, ty', and a coercion, co_fn,
236       -- which justifies the substitution.  If the coercion is not the identity
237       -- then use transitivity with the original coercion
238
239 -----------------------------
240 fixTvSubstEnv :: InScopeSet -> TvSubstEnv -> TvSubstEnv
241 fixTvSubstEnv in_scope env
242   = fixpt 
243   where
244     fixpt = mapVarEnv (substTy (mkTvSubst in_scope fixpt)) env
245
246 ----------------------------
247 tryToBind :: TyVarSet -> TyVar -> BindFlag
248 tryToBind tv_set tv | tv `elemVarSet` tv_set = BindMe
249                     | otherwise              = AvoidMe
250
251 \end{code}
252
253
254 %************************************************************************
255 %*                                                                      *
256                 The workhorse
257 %*                                                                      *
258 %************************************************************************
259
260 \begin{code}
261 badReftElts :: InternalReft -> [(Unique, (Coercion,Type))]
262 -- Return the BAD elements of the refinement
263 -- Should be empty; used in asserions only
264 badReftElts env
265   = filter (not . ok) (ufmToList env)
266   where
267     ok :: (Unique, (Coercion, Type)) -> Bool
268     ok (u, (co, ty)) | Just tv <- tcGetTyVar_maybe ty1
269                      = varUnique tv == u && ty `tcEqType` ty2 
270                      | otherwise = False
271         where
272           (ty1,ty2) = coercionKind co
273
274 emptyInternalReft :: InternalReft
275 emptyInternalReft = emptyVarEnv
276
277 unify :: InternalReft           -- An existing substitution to extend
278       -> Coercion       -- Witness of their equality 
279       -> Type -> Type   -- Types to be unified, and witness of their equality
280       -> UM InternalReft                -- Just the extended substitution, 
281                                 -- Nothing if unification failed
282 -- We do not require the incoming substitution to be idempotent,
283 -- nor guarantee that the outgoing one is.  That's fixed up by
284 -- the wrappers.
285
286 -- PRE-CONDITION: in the call (unify r co ty1 ty2), we know that
287 --                      co :: (ty1:=:ty2)
288
289 -- Respects newtypes, PredTypes
290
291 unify subst co ty1 ty2 = -- pprTrace "unify" (ppr subst <+> pprParendType ty1 <+> pprParendType ty2) $
292                          unify_ subst co ty1 ty2
293
294 -- in unify_, any NewTcApps/Preds should be taken at face value
295 unify_ subst co (TyVarTy tv1) ty2  = uVar False subst co tv1 ty2
296 unify_ subst co ty1 (TyVarTy tv2)  = uVar True  subst co tv2 ty1
297
298 unify_ subst co ty1 ty2 | Just ty1' <- tcView ty1 = unify subst co ty1' ty2
299 unify_ subst co ty1 ty2 | Just ty2' <- tcView ty2 = unify subst co ty1 ty2'
300
301 unify_ subst co (PredTy p1) (PredTy p2) = unify_pred subst co p1 p2
302
303 unify_ subst co t1@(TyConApp tyc1 tys1) t2@(TyConApp tyc2 tys2) 
304   | tyc1 == tyc2 = unify_tys subst co tys1 tys2
305
306 unify_ subst co (FunTy ty1a ty1b) (FunTy ty2a ty2b) 
307   = do  { let [co1,co2] = decomposeCo 2 co
308         ; subst' <- unify subst co1 ty1a ty2a
309         ; unify subst' co2 ty1b ty2b }
310
311         -- Applications need a bit of care!
312         -- They can match FunTy and TyConApp, so use splitAppTy_maybe
313         -- NB: we've already dealt with type variables and Notes,
314         -- so if one type is an App the other one jolly well better be too
315 unify_ subst co (AppTy ty1a ty1b) ty2
316   | Just (ty2a, ty2b) <- repSplitAppTy_maybe ty2
317   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
318         ; unify subst' (mkRightCoercion co) ty1b ty2b }
319
320 unify_ subst co ty1 (AppTy ty2a ty2b)
321   | Just (ty1a, ty1b) <- repSplitAppTy_maybe ty1
322   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
323         ; unify subst' (mkRightCoercion co) ty1b ty2b }
324
325 unify_ subst co ty1 ty2 = failWith (misMatch ty1 ty2)
326         -- ForAlls??
327
328
329 ------------------------------
330 unify_pred subst co (ClassP c1 tys1) (ClassP c2 tys2)
331   | c1 == c2 = unify_tys subst co tys1 tys2
332 unify_pred subst co (IParam n1 t1) (IParam n2 t2)
333   | n1 == n2 = unify subst co t1 t2
334 unify_pred subst co p1 p2 = failWith (misMatch (PredTy p1) (PredTy p2))
335  
336 ------------------------------
337 unify_tys :: InternalReft -> Coercion -> [Type] -> [Type] -> UM InternalReft
338 unify_tys subst co xs ys
339   = unifyList subst (decomposeCo (length xs) co) xs ys
340
341 unifyList :: InternalReft -> [Coercion] -> [Type] -> [Type] -> UM InternalReft
342 unifyList subst orig_cos orig_xs orig_ys
343   = go subst orig_cos orig_xs orig_ys
344   where
345     go subst _        []     []     = return subst
346     go subst (co:cos) (x:xs) (y:ys) = do { subst' <- unify subst co x y
347                                          ; go subst' cos xs ys }
348     go subst _ _ _ = failWith (lengthMisMatch orig_xs orig_ys)
349
350 ---------------------------------
351 uVar :: Bool            -- Swapped
352      -> InternalReft    -- An existing substitution to extend
353      -> Coercion
354      -> TyVar           -- Type variable to be unified
355      -> Type            -- with this type
356      -> UM InternalReft
357
358 -- PRE-CONDITION: in the call (uVar swap r co tv1 ty), we know that
359 --      if swap=False   co :: (tv1:=:ty)
360 --      if swap=True    co :: (ty:=:tv1)
361
362 uVar swap subst co tv1 ty
363  = -- Check to see whether tv1 is refined by the substitution
364    case (lookupVarEnv subst tv1) of
365
366      -- Yes, call back into unify'
367      Just (co',ty')     -- co' :: (tv1:=:ty')
368         | swap          -- co :: (ty:=:tv1)
369         -> unify subst (mkTransCoercion co co') ty ty' 
370         | otherwise     -- co :: (tv1:=:ty)
371         -> unify subst (mkTransCoercion (mkSymCoercion co') co) ty' ty
372
373      -- No, continue
374      Nothing -> uUnrefined swap subst co
375                            tv1 ty ty
376
377
378 uUnrefined :: Bool                -- Whether the input is swapped
379            -> InternalReft        -- An existing substitution to extend
380            -> Coercion
381            -> TyVar               -- Type variable to be unified
382            -> Type                -- with this type
383            -> Type                -- (de-noted version)
384            -> UM InternalReft
385
386 -- We know that tv1 isn't refined
387 -- PRE-CONDITION: in the call (uUnrefined False r co tv1 ty2 ty2'), we know that
388 --      co :: tv1:=:ty2
389 -- and if the first argument is True instead, we know
390 --      co :: ty2:=:tv1
391
392 uUnrefined swap subst co tv1 ty2 ty2'
393   | Just ty2'' <- tcView ty2'
394   = uUnrefined swap subst co tv1 ty2 ty2''      -- Unwrap synonyms
395                 -- This is essential, in case we have
396                 --      type Foo a = a
397                 -- and then unify a :=: Foo a
398
399 uUnrefined swap subst co tv1 ty2 (TyVarTy tv2)
400   | tv1 == tv2          -- Same type variable
401   = return subst
402
403     -- Check to see whether tv2 is refined
404   | Just (co',ty') <- lookupVarEnv subst tv2    -- co' :: tv2:=:ty'
405   = uUnrefined False subst (mkTransCoercion (doSwap swap co) co') tv1 ty' ty'
406
407   -- So both are unrefined; next, see if the kinds force the direction
408   | eqKind k1 k2        -- Can update either; so check the bind-flags
409   = do  { b1 <- tvBindFlag tv1
410         ; b2 <- tvBindFlag tv2
411         ; case (b1,b2) of
412             (BindMe, _)          -> bind swap tv1 ty2
413
414             (AvoidMe, BindMe)    -> bind (not swap) tv2 ty1
415             (AvoidMe, _)         -> bind swap tv1 ty2
416
417             (WildCard, WildCard) -> return subst
418             (WildCard, Skolem)   -> return subst
419             (WildCard, _)        -> bind (not swap) tv2 ty1
420
421             (Skolem, WildCard)   -> return subst
422             (Skolem, Skolem)     -> failWith (misMatch ty1 ty2)
423             (Skolem, _)          -> bind (not swap) tv2 ty1
424         }
425
426   | k1 `isSubKind` k2 = bindTv (not swap) subst co tv2 ty1  -- Must update tv2
427   | k2 `isSubKind` k1 = bindTv swap subst co tv1 ty2        -- Must update tv1
428
429   | otherwise = failWith (kindMisMatch tv1 ty2)
430   where
431     ty1 = TyVarTy tv1
432     k1 = tyVarKind tv1
433     k2 = tyVarKind tv2
434     bind swap tv ty = extendReft swap subst tv co ty
435
436 uUnrefined swap subst co tv1 ty2 ty2'   -- ty2 is not a type variable
437   | tv1 `elemVarSet` substTvSet subst (tyVarsOfType ty2')
438   = failWith (occursCheck tv1 ty2)      -- Occurs check
439   | not (k2 `isSubKind` k1)
440   = failWith (kindMisMatch tv1 ty2)     -- Kind check
441   | otherwise
442   = bindTv swap subst co tv1 ty2                -- Bind tyvar to the synonym if poss
443   where
444     k1 = tyVarKind tv1
445     k2 = typeKind ty2'
446
447 substTvSet :: InternalReft -> TyVarSet -> TyVarSet
448 -- Apply the non-idempotent substitution to a set of type variables,
449 -- remembering that the substitution isn't necessarily idempotent
450 substTvSet subst tvs
451   = foldVarSet (unionVarSet . get) emptyVarSet tvs
452   where
453     get tv = case lookupVarEnv subst tv of
454                 Nothing     -> unitVarSet tv
455                 Just (_,ty) -> substTvSet subst (tyVarsOfType ty)
456
457 bindTv swap subst co tv ty      -- ty is not a type variable
458   = do  { b <- tvBindFlag tv
459         ; case b of
460             Skolem   -> failWith (misMatch (TyVarTy tv) ty)
461             WildCard -> return subst
462             other    -> extendReft swap subst tv co ty
463         }
464
465 doSwap :: Bool -> Coercion -> Coercion
466 doSwap swap co = if swap then mkSymCoercion co else co
467
468 extendReft :: Bool 
469            -> InternalReft 
470            -> TyVar 
471            -> Coercion 
472            -> Type 
473            -> UM InternalReft
474 extendReft swap subst tv  co ty
475   = ASSERT2( (coercionKindPredTy co1 `tcEqType` mkCoKind (mkTyVarTy tv) ty), 
476           (text "Refinement invariant failure: co = " <+> ppr co1  <+> ppr (coercionKindPredTy co1) $$ text "subst = " <+> ppr tv <+> ppr (mkCoKind (mkTyVarTy tv) ty)) )
477     return (extendVarEnv subst tv (co1, ty))
478   where
479     co1 = doSwap swap co
480
481 \end{code}
482
483 %************************************************************************
484 %*                                                                      *
485                 Unification monad
486 %*                                                                      *
487 %************************************************************************
488
489 \begin{code}
490 data BindFlag 
491   = BindMe      -- A regular type variable
492   | AvoidMe     -- Like BindMe but, given the choice, avoid binding it
493
494   | Skolem      -- This type variable is a skolem constant
495                 -- Don't bind it; it only matches itself
496
497   | WildCard    -- This type variable matches anything,
498                 -- and does not affect the substitution
499
500 newtype UM a = UM { unUM :: (TyVar -> BindFlag)
501                          -> MaybeErr Message a }
502
503 instance Monad UM where
504   return a = UM (\tvs -> Succeeded a)
505   fail s   = UM (\tvs -> Failed (text s))
506   m >>= k  = UM (\tvs -> case unUM m tvs of
507                            Failed err -> Failed err
508                            Succeeded v  -> unUM (k v) tvs)
509
510 initUM :: (TyVar -> BindFlag) -> UM a -> MaybeErr Message a
511 initUM badtvs um = unUM um badtvs
512
513 tvBindFlag :: TyVar -> UM BindFlag
514 tvBindFlag tv = UM (\tv_fn -> Succeeded (tv_fn tv))
515
516 failWith :: Message -> UM a
517 failWith msg = UM (\tv_fn -> Failed msg)
518
519 maybeErrToMaybe :: MaybeErr fail succ -> Maybe succ
520 maybeErrToMaybe (Succeeded a) = Just a
521 maybeErrToMaybe (Failed m)    = Nothing
522 \end{code}
523
524
525 %************************************************************************
526 %*                                                                      *
527                 Error reporting
528         We go to a lot more trouble to tidy the types
529         in TcUnify.  Maybe we'll end up having to do that
530         here too, but I'll leave it for now.
531 %*                                                                      *
532 %************************************************************************
533
534 \begin{code}
535 misMatch t1 t2
536   = ptext SLIT("Can't match types") <+> quotes (ppr t1) <+> 
537     ptext SLIT("and") <+> quotes (ppr t2)
538
539 lengthMisMatch tys1 tys2
540   = sep [ptext SLIT("Can't match unequal length lists"), 
541          nest 2 (ppr tys1), nest 2 (ppr tys2) ]
542
543 kindMisMatch tv1 t2
544   = vcat [ptext SLIT("Can't match kinds") <+> quotes (ppr (tyVarKind tv1)) <+> 
545             ptext SLIT("and") <+> quotes (ppr (typeKind t2)),
546           ptext SLIT("when matching") <+> quotes (ppr tv1) <+> 
547                 ptext SLIT("with") <+> quotes (ppr t2)]
548
549 occursCheck tv ty
550   = hang (ptext SLIT("Can't construct the infinite type"))
551        2 (ppr tv <+> equals <+> ppr ty)
552 \end{code}