Minor refactoring
[ghc-hetmet.git] / compiler / typecheck / TcGadt.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 %************************************************************************
7 %*                                                                      *
8                 Type refinement for GADTs
9 %*                                                                      *
10 %************************************************************************
11
12 \begin{code}
13 module TcGadt (
14         Refinement, emptyRefinement, gadtRefine, 
15         refineType, refineResType,
16         dataConCanMatch,
17         tcUnifyTys, BindFlag(..)
18   ) where
19
20 #include "HsVersions.h"
21
22 import HsSyn
23 import Coercion
24 import TypeRep
25 import DataCon
26 import Var
27 import VarEnv
28 import VarSet
29 import ErrUtils
30 import Maybes
31 import Control.Monad
32 import Outputable
33 import TcType
34
35 #ifdef DEBUG
36 import Unique
37 import UniqFM
38 #endif
39 \end{code}
40
41
42 %************************************************************************
43 %*                                                                      *
44                 What a refinement is
45 %*                                                                      *
46 %************************************************************************
47
48 \begin{code}
49 data Refinement = Reft InScopeSet InternalReft 
50
51 type InternalReft = TyVarEnv (Coercion, Type)
52 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
53 -- Not necessarily idemopotent
54
55 instance Outputable Refinement where
56   ppr (Reft in_scope env)
57     = ptext SLIT("Refinement") <+>
58         braces (ppr env)
59
60 emptyRefinement :: Refinement
61 emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
62
63
64 refineType :: Refinement -> Type -> Maybe (Coercion, Type)
65 -- Apply the refinement to the type.
66 -- If (refineType r ty) = (co, ty')
67 -- Then co :: ty:=:ty'
68 -- Nothing => the refinement does nothing to this type
69 refineType (Reft in_scope env) ty
70   | not (isEmptyVarEnv env),            -- Common case
71     any (`elemVarEnv` env) (varSetElems (tyVarsOfType ty))
72   = Just (substTy co_subst ty, substTy tv_subst ty)
73   | otherwise
74   = Nothing     -- The type doesn't mention any refined type variables
75   where
76     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
77     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
78  
79 refineResType :: Refinement -> Type -> (HsWrapper, Type)
80 -- Like refineType, but returns the 'sym' coercion
81 -- If (refineResType r ty) = (co, ty')
82 -- Then co :: ty':=:ty
83 -- It's convenient to return a HsWrapper here
84 refineResType reft ty
85   = case refineType reft ty of
86         Just (co, ty1) -> (WpCo (mkSymCoercion co), ty1)
87         Nothing        -> (idHsWrapper,             ty)
88 \end{code}
89
90
91 %************************************************************************
92 %*                                                                      *
93                 Generating a type refinement
94 %*                                                                      *
95 %************************************************************************
96
97 \begin{code}
98 gadtRefine :: Refinement
99            -> [TyVar]   -- Bind these by preference
100            -> [CoVar]
101            -> MaybeErr Message Refinement
102 \end{code}
103
104 (gadtRefine cvs) takes a list of coercion variables, and returns a
105 list of coercions, obtained by unifying the types equated by the
106 incoming coercions.  The returned coercions all have kinds of form
107 (a:=:ty), where a is a rigid type variable.
108
109 Example:
110   gadtRefine [c :: (a,Int):=:(Bool,b)]
111   = [ right (left c) :: a:=:Bool,       
112       sym (right c)  :: b:=:Int ]
113
114 That is, given evidence 'c' that (a,Int)=(Bool,b), it returns derived
115 evidence in easy-to-use form.  In particular, given any e::ty, we know 
116 that:
117         e `cast` ty[right (left c)/a, sym (right c)/b]
118         :: ty [Bool/a, Int/b]
119       
120 Two refinements:
121
122 - It can fail, if the coercion is unsatisfiable.
123
124 - It's biased, by being given a set of type variables to bind
125   when there is a choice. Example:
126         MkT :: forall a. a -> T [a]
127         f :: forall b. T [b] -> b
128         f x = let v = case x of { MkT y -> y }
129               in ...
130   Here we want to bind [a->b], not the other way round, because
131   in this example the return type is wobbly, and we want the
132   program to typecheck
133
134
135 -- E.g. (a, Bool, right (left c))
136 -- INVARIANT: in the triple (tv, ty, co), we have (co :: tv:=:ty)
137 -- The result is idempotent: the 
138
139 \begin{code}
140 gadtRefine (Reft in_scope env1) 
141            ex_tvs co_vars
142 -- Precondition: fvs( co_vars ) # env1
143 -- That is, the kinds of the co_vars are a
144 -- fixed point of the incoming refinement
145
146   = ASSERT2( not $ any (`elemVarEnv` env1) (varSetElems $ tyVarsOfTypes $ map tyVarKind co_vars),
147              ppr env1 $$ ppr co_vars $$ ppr (map tyVarKind co_vars) )
148     initUM (tryToBind tv_set) $
149     do  {       -- Run the unifier, starting with an empty env
150         ; env2 <- foldM do_one emptyInternalReft co_vars
151
152                 -- Find the fixed point of the resulting 
153                 -- non-idempotent substitution
154         ; let tmp_env = env1 `plusVarEnv` env2
155               out_env = fixTvCoEnv in_scope' tmp_env
156         ; WARN( not (null (badReftElts tmp_env)), ppr (badReftElts tmp_env) $$ ppr tmp_env )
157           WARN( not (null (badReftElts out_env)), ppr (badReftElts out_env) $$ ppr out_env )
158           return (Reft in_scope' out_env) }
159   where
160     tv_set = mkVarSet ex_tvs
161     in_scope' = foldr extend in_scope co_vars
162
163         -- For each co_var, add it *and* the tyvars it mentions, to in_scope
164     extend co_var in_scope
165         = extendInScopeSetSet in_scope $
166           extendVarSet (tyVarsOfType (tyVarKind co_var)) co_var
167         
168     do_one reft co_var = unify reft (TyVarTy co_var) ty1 ty2
169         where
170            (ty1,ty2) = splitCoercionKind (tyVarKind co_var)
171 \end{code} 
172
173 %************************************************************************
174 %*                                                                      *
175                 Unification
176 %*                                                                      *
177 %************************************************************************
178
179 \begin{code}
180 tcUnifyTys :: (TyVar -> BindFlag)
181            -> [Type] -> [Type]
182            -> Maybe TvSubst     -- A regular one-shot substitution
183 -- The two types may have common type variables, and indeed do so in the
184 -- second call to tcUnifyTys in FunDeps.checkClsFD
185 --
186 -- We implement tcUnifyTys using the evidence-generating 'unify' function
187 -- in this module, even though we don't need to generate any evidence.
188 -- This is simply to avoid replicating all all the code for unify
189 tcUnifyTys bind_fn tys1 tys2
190   = maybeErrToMaybe $ initUM bind_fn $
191     do { reft <- unifyList emptyInternalReft cos tys1 tys2
192
193         -- Find the fixed point of the resulting non-idempotent substitution
194         ; let in_scope = mkInScopeSet (tvs1 `unionVarSet` tvs2)
195               tv_env   = fixTvSubstEnv in_scope (mapVarEnv snd reft)
196
197         ; return (mkTvSubst in_scope tv_env) }
198   where
199     tvs1 = tyVarsOfTypes tys1
200     tvs2 = tyVarsOfTypes tys2
201     cos  = zipWith mkUnsafeCoercion tys1 tys2
202
203
204 ----------------------------
205 fixTvCoEnv :: InScopeSet -> InternalReft -> InternalReft
206         -- Find the fixed point of a Refinement
207         -- (assuming it has no loops!)
208 fixTvCoEnv in_scope env
209   = fixpt
210   where
211     fixpt         = mapVarEnv step env
212
213     step (co, ty) = case refineType (Reft in_scope fixpt) ty of
214                         Nothing         -> (co,                     ty)
215                         Just (co', ty') -> (mkTransCoercion co co', ty')
216       -- Apply fixpt one step:
217       -- Use refineType to get a substituted type, ty', and a coercion, co_fn,
218       -- which justifies the substitution.  If the coercion is not the identity
219       -- then use transitivity with the original coercion
220
221 -----------------------------
222 fixTvSubstEnv :: InScopeSet -> TvSubstEnv -> TvSubstEnv
223 fixTvSubstEnv in_scope env
224   = fixpt 
225   where
226     fixpt = mapVarEnv (substTy (mkTvSubst in_scope fixpt)) env
227
228 ----------------------------
229 dataConCanMatch :: DataCon -> [Type] -> Bool
230 -- Returns True iff the data con can match a scrutinee of type (T tys)
231 --                  where T is the type constructor for the data con
232 --
233 -- Instantiate the equations and try to unify them
234 dataConCanMatch con tys
235   = isJust (tcUnifyTys (\tv -> BindMe) 
236                        (map (substTyVar subst . fst) eq_spec)
237                        (map snd eq_spec))
238   where
239     dc_tvs  = dataConUnivTyVars con
240     eq_spec = dataConEqSpec con
241     subst   = zipTopTvSubst dc_tvs tys
242
243 ----------------------------
244 tryToBind :: TyVarSet -> TyVar -> BindFlag
245 tryToBind tv_set tv | tv `elemVarSet` tv_set = BindMe
246                     | otherwise              = AvoidMe
247
248
249 \end{code}
250
251
252 %************************************************************************
253 %*                                                                      *
254                 The workhorse
255 %*                                                                      *
256 %************************************************************************
257
258 \begin{code}
259 #ifdef DEBUG
260 badReftElts :: InternalReft -> [(Unique, (Coercion,Type))]
261 -- Return the BAD elements of the refinement
262 -- Should be empty; used in asserions only
263 badReftElts env
264   = filter (not . ok) (ufmToList env)
265   where
266     ok :: (Unique, (Coercion, Type)) -> Bool
267     ok (u, (co, ty)) | Just tv <- tcGetTyVar_maybe ty1
268                      = varUnique tv == u && ty `tcEqType` ty2 
269                      | otherwise = False
270         where
271           (ty1,ty2) = coercionKind co
272 #endif
273
274 emptyInternalReft :: InternalReft
275 emptyInternalReft = emptyVarEnv
276
277 unify :: InternalReft           -- An existing substitution to extend
278       -> Coercion       -- Witness of their equality 
279       -> Type -> Type   -- Types to be unified, and witness of their equality
280       -> UM InternalReft                -- Just the extended substitution, 
281                                 -- Nothing if unification failed
282 -- We do not require the incoming substitution to be idempotent,
283 -- nor guarantee that the outgoing one is.  That's fixed up by
284 -- the wrappers.
285
286 -- PRE-CONDITION: in the call (unify r co ty1 ty2), we know that
287 --                      co :: (ty1:=:ty2)
288
289 -- Respects newtypes, PredTypes
290
291 unify subst co ty1 ty2 = -- pprTrace "unify" (ppr subst <+> pprParendType ty1 <+> pprParendType ty2) $
292                          unify_ subst co ty1 ty2
293
294 -- in unify_, any NewTcApps/Preds should be taken at face value
295 unify_ subst co (TyVarTy tv1) ty2  = uVar False subst co tv1 ty2
296 unify_ subst co ty1 (TyVarTy tv2)  = uVar True  subst co tv2 ty1
297
298 unify_ subst co ty1 ty2 | Just ty1' <- tcView ty1 = unify subst co ty1' ty2
299 unify_ subst co ty1 ty2 | Just ty2' <- tcView ty2 = unify subst co ty1 ty2'
300
301 unify_ subst co (PredTy p1) (PredTy p2) = unify_pred subst co p1 p2
302
303 unify_ subst co t1@(TyConApp tyc1 tys1) t2@(TyConApp tyc2 tys2) 
304   | tyc1 == tyc2 = unify_tys subst co tys1 tys2
305
306 unify_ subst co (FunTy ty1a ty1b) (FunTy ty2a ty2b) 
307   = do  { let [co1,co2] = decomposeCo 2 co
308         ; subst' <- unify subst co1 ty1a ty2a
309         ; unify subst' co2 ty1b ty2b }
310
311         -- Applications need a bit of care!
312         -- They can match FunTy and TyConApp, so use splitAppTy_maybe
313         -- NB: we've already dealt with type variables and Notes,
314         -- so if one type is an App the other one jolly well better be too
315 unify_ subst co (AppTy ty1a ty1b) ty2
316   | Just (ty2a, ty2b) <- repSplitAppTy_maybe ty2
317   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
318         ; unify subst' (mkRightCoercion co) ty1b ty2b }
319
320 unify_ subst co ty1 (AppTy ty2a ty2b)
321   | Just (ty1a, ty1b) <- repSplitAppTy_maybe ty1
322   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
323         ; unify subst' (mkRightCoercion co) ty1b ty2b }
324
325 unify_ subst co ty1 ty2 = failWith (misMatch ty1 ty2)
326         -- ForAlls??
327
328
329 ------------------------------
330 unify_pred subst co (ClassP c1 tys1) (ClassP c2 tys2)
331   | c1 == c2 = unify_tys subst co tys1 tys2
332 unify_pred subst co (IParam n1 t1) (IParam n2 t2)
333   | n1 == n2 = unify subst co t1 t2
334 unify_pred subst co p1 p2 = failWith (misMatch (PredTy p1) (PredTy p2))
335  
336 ------------------------------
337 unify_tys :: InternalReft -> Coercion -> [Type] -> [Type] -> UM InternalReft
338 unify_tys subst co xs ys
339   = unifyList subst (decomposeCo (length xs) co) xs ys
340
341 unifyList :: InternalReft -> [Coercion] -> [Type] -> [Type] -> UM InternalReft
342 unifyList subst orig_cos orig_xs orig_ys
343   = go subst orig_cos orig_xs orig_ys
344   where
345     go subst _        []     []     = return subst
346     go subst (co:cos) (x:xs) (y:ys) = do { subst' <- unify subst co x y
347                                          ; go subst' cos xs ys }
348     go subst _ _ _ = failWith (lengthMisMatch orig_xs orig_ys)
349
350 ---------------------------------
351 uVar :: Bool            -- Swapped
352      -> InternalReft    -- An existing substitution to extend
353      -> Coercion
354      -> TyVar           -- Type variable to be unified
355      -> Type            -- with this type
356      -> UM InternalReft
357
358 -- PRE-CONDITION: in the call (uVar swap r co tv1 ty), we know that
359 --      if swap=False   co :: (tv1:=:ty)
360 --      if swap=True    co :: (ty:=:tv1)
361
362 uVar swap subst co tv1 ty
363  = -- Check to see whether tv1 is refined by the substitution
364    case (lookupVarEnv subst tv1) of
365
366      -- Yes, call back into unify'
367      Just (co',ty')     -- co' :: (tv1:=:ty')
368         | swap          -- co :: (ty:=:tv1)
369         -> unify subst (mkTransCoercion co co') ty ty' 
370         | otherwise     -- co :: (tv1:=:ty)
371         -> unify subst (mkTransCoercion (mkSymCoercion co') co) ty' ty
372
373      -- No, continue
374      Nothing -> uUnrefined swap subst co
375                            tv1 ty ty
376
377
378 uUnrefined :: Bool                -- Whether the input is swapped
379            -> InternalReft        -- An existing substitution to extend
380            -> Coercion
381            -> TyVar               -- Type variable to be unified
382            -> Type                -- with this type
383            -> Type                -- (de-noted version)
384            -> UM InternalReft
385
386 -- We know that tv1 isn't refined
387 -- PRE-CONDITION: in the call (uUnrefined False r co tv1 ty2 ty2'), we know that
388 --      co :: tv1:=:ty2
389 -- and if the first argument is True instead, we know
390 --      co :: ty2:=:tv1
391
392 uUnrefined swap subst co tv1 ty2 ty2'
393   | Just ty2'' <- tcView ty2'
394   = uUnrefined swap subst co tv1 ty2 ty2''      -- Unwrap synonyms
395                 -- This is essential, in case we have
396                 --      type Foo a = a
397                 -- and then unify a :=: Foo a
398
399 uUnrefined swap subst co tv1 ty2 (TyVarTy tv2)
400   | tv1 == tv2          -- Same type variable
401   = return subst
402
403     -- Check to see whether tv2 is refined
404   | Just (co',ty') <- lookupVarEnv subst tv2    -- co' :: tv2:=:ty'
405   = uUnrefined False subst (mkTransCoercion (doSwap swap co) co') tv1 ty' ty'
406
407   -- So both are unrefined; next, see if the kinds force the direction
408   | eqKind k1 k2        -- Can update either; so check the bind-flags
409   = do  { b1 <- tvBindFlag tv1
410         ; b2 <- tvBindFlag tv2
411         ; case (b1,b2) of
412             (BindMe, _)          -> bind swap tv1 ty2
413
414             (AvoidMe, BindMe)    -> bind (not swap) tv2 ty1
415             (AvoidMe, _)         -> bind swap tv1 ty2
416
417             (WildCard, WildCard) -> return subst
418             (WildCard, Skolem)   -> return subst
419             (WildCard, _)        -> bind (not swap) tv2 ty1
420
421             (Skolem, WildCard)   -> return subst
422             (Skolem, Skolem)     -> failWith (misMatch ty1 ty2)
423             (Skolem, _)          -> bind (not swap) tv2 ty1
424         }
425
426   | k1 `isSubKind` k2 = bindTv (not swap) subst co tv2 ty1  -- Must update tv2
427   | k2 `isSubKind` k1 = bindTv swap subst co tv1 ty2        -- Must update tv1
428
429   | otherwise = failWith (kindMisMatch tv1 ty2)
430   where
431     ty1 = TyVarTy tv1
432     k1 = tyVarKind tv1
433     k2 = tyVarKind tv2
434     bind swap tv ty = extendReft swap subst tv co ty
435
436 uUnrefined swap subst co tv1 ty2 ty2'   -- ty2 is not a type variable
437   | tv1 `elemVarSet` substTvSet subst (tyVarsOfType ty2')
438   = failWith (occursCheck tv1 ty2)      -- Occurs check
439   | not (k2 `isSubKind` k1)
440   = failWith (kindMisMatch tv1 ty2)     -- Kind check
441   | otherwise
442   = bindTv swap subst co tv1 ty2                -- Bind tyvar to the synonym if poss
443   where
444     k1 = tyVarKind tv1
445     k2 = typeKind ty2'
446
447 substTvSet :: InternalReft -> TyVarSet -> TyVarSet
448 -- Apply the non-idempotent substitution to a set of type variables,
449 -- remembering that the substitution isn't necessarily idempotent
450 substTvSet subst tvs
451   = foldVarSet (unionVarSet . get) emptyVarSet tvs
452   where
453     get tv = case lookupVarEnv subst tv of
454                 Nothing     -> unitVarSet tv
455                 Just (_,ty) -> substTvSet subst (tyVarsOfType ty)
456
457 bindTv swap subst co tv ty      -- ty is not a type variable
458   = do  { b <- tvBindFlag tv
459         ; case b of
460             Skolem   -> failWith (misMatch (TyVarTy tv) ty)
461             WildCard -> return subst
462             other    -> extendReft swap subst tv co ty
463         }
464
465 doSwap :: Bool -> Coercion -> Coercion
466 doSwap swap co = if swap then mkSymCoercion co else co
467
468 extendReft :: Bool 
469            -> InternalReft 
470            -> TyVar 
471            -> Coercion 
472            -> Type 
473            -> UM InternalReft
474 extendReft swap subst tv  co ty
475   = ASSERT2( (coercionKindPredTy co1 `tcEqType` mkCoKind (mkTyVarTy tv) ty), 
476           (text "Refinement invariant failure: co = " <+> ppr co1  <+> ppr (coercionKindPredTy co1) $$ text "subst = " <+> ppr tv <+> ppr (mkCoKind (mkTyVarTy tv) ty)) )
477     return (extendVarEnv subst tv (co1, ty))
478   where
479     co1 = doSwap swap co
480
481 \end{code}
482
483 %************************************************************************
484 %*                                                                      *
485                 Unification monad
486 %*                                                                      *
487 %************************************************************************
488
489 \begin{code}
490 data BindFlag 
491   = BindMe      -- A regular type variable
492   | AvoidMe     -- Like BindMe but, given the choice, avoid binding it
493
494   | Skolem      -- This type variable is a skolem constant
495                 -- Don't bind it; it only matches itself
496
497   | WildCard    -- This type variable matches anything,
498                 -- and does not affect the substitution
499
500 newtype UM a = UM { unUM :: (TyVar -> BindFlag)
501                          -> MaybeErr Message a }
502
503 instance Monad UM where
504   return a = UM (\tvs -> Succeeded a)
505   fail s   = UM (\tvs -> Failed (text s))
506   m >>= k  = UM (\tvs -> case unUM m tvs of
507                            Failed err -> Failed err
508                            Succeeded v  -> unUM (k v) tvs)
509
510 initUM :: (TyVar -> BindFlag) -> UM a -> MaybeErr Message a
511 initUM badtvs um = unUM um badtvs
512
513 tvBindFlag :: TyVar -> UM BindFlag
514 tvBindFlag tv = UM (\tv_fn -> Succeeded (tv_fn tv))
515
516 failWith :: Message -> UM a
517 failWith msg = UM (\tv_fn -> Failed msg)
518
519 maybeErrToMaybe :: MaybeErr fail succ -> Maybe succ
520 maybeErrToMaybe (Succeeded a) = Just a
521 maybeErrToMaybe (Failed m)    = Nothing
522 \end{code}
523
524
525 %************************************************************************
526 %*                                                                      *
527                 Error reporting
528         We go to a lot more trouble to tidy the types
529         in TcUnify.  Maybe we'll end up having to do that
530         here too, but I'll leave it for now.
531 %*                                                                      *
532 %************************************************************************
533
534 \begin{code}
535 misMatch t1 t2
536   = ptext SLIT("Can't match types") <+> quotes (ppr t1) <+> 
537     ptext SLIT("and") <+> quotes (ppr t2)
538
539 lengthMisMatch tys1 tys2
540   = sep [ptext SLIT("Can't match unequal length lists"), 
541          nest 2 (ppr tys1), nest 2 (ppr tys2) ]
542
543 kindMisMatch tv1 t2
544   = vcat [ptext SLIT("Can't match kinds") <+> quotes (ppr (tyVarKind tv1)) <+> 
545             ptext SLIT("and") <+> quotes (ppr (typeKind t2)),
546           ptext SLIT("when matching") <+> quotes (ppr tv1) <+> 
547                 ptext SLIT("with") <+> quotes (ppr t2)]
548
549 occursCheck tv ty
550   = hang (ptext SLIT("Can't construct the infinite type"))
551        2 (ppr tv <+> equals <+> ppr ty)
552 \end{code}