4c1f70ecf153e8a44a09bcc4d17c9657527b3f35
[ghc-hetmet.git] / compiler / typecheck / TcGadt.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 %************************************************************************
7 %*                                                                      *
8                 Type refinement for GADTs
9 %*                                                                      *
10 %************************************************************************
11
12 \begin{code}
13 module TcGadt (
14         Refinement, emptyRefinement, gadtRefine, 
15         refineType, refineResType,
16         dataConCanMatch,
17         tcUnifyTys, BindFlag(..)
18   ) where
19
20 #include "HsVersions.h"
21
22 import HsSyn
23 import Coercion
24 import Type
25 import TypeRep
26 import DataCon
27 import Var
28 import VarEnv
29 import VarSet
30 import ErrUtils
31 import Maybes
32 import Control.Monad
33 import Outputable
34 import TcType
35
36 #ifdef DEBUG
37 import Unique
38 import UniqFM
39 #endif
40 \end{code}
41
42
43 %************************************************************************
44 %*                                                                      *
45                 What a refinement is
46 %*                                                                      *
47 %************************************************************************
48
49 \begin{code}
50 data Refinement = Reft InScopeSet InternalReft 
51
52 type InternalReft = TyVarEnv (Coercion, Type)
53 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
54 -- Not necessarily idemopotent
55
56 instance Outputable Refinement where
57   ppr (Reft in_scope env)
58     = ptext SLIT("Refinement") <+>
59         braces (ppr env)
60
61 emptyRefinement :: Refinement
62 emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
63
64
65 refineType :: Refinement -> Type -> Maybe (Coercion, Type)
66 -- Apply the refinement to the type.
67 -- If (refineType r ty) = (co, ty')
68 -- Then co :: ty:=:ty'
69 -- Nothing => the refinement does nothing to this type
70 refineType (Reft in_scope env) ty
71   | not (isEmptyVarEnv env),            -- Common case
72     any (`elemVarEnv` env) (varSetElems (tyVarsOfType ty))
73   = Just (substTy co_subst ty, substTy tv_subst ty)
74   | otherwise
75   = Nothing     -- The type doesn't mention any refined type variables
76   where
77     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
78     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
79  
80 refineResType :: Refinement -> Type -> (HsWrapper, Type)
81 -- Like refineType, but returns the 'sym' coercion
82 -- If (refineResType r ty) = (co, ty')
83 -- Then co :: ty':=:ty
84 -- It's convenient to return a HsWrapper here
85 refineResType reft ty
86   = case refineType reft ty of
87         Just (co, ty1) -> (WpCo (mkSymCoercion co), ty1)
88         Nothing        -> (idHsWrapper,             ty)
89 \end{code}
90
91
92 %************************************************************************
93 %*                                                                      *
94                 Generating a type refinement
95 %*                                                                      *
96 %************************************************************************
97
98 \begin{code}
99 gadtRefine :: Refinement
100            -> [TyVar]   -- Bind these by preference
101            -> [CoVar]
102            -> MaybeErr Message Refinement
103 \end{code}
104
105 (gadtRefine cvs) takes a list of coercion variables, and returns a
106 list of coercions, obtained by unifying the types equated by the
107 incoming coercions.  The returned coercions all have kinds of form
108 (a:=:ty), where a is a rigid type variable.
109
110 Example:
111   gadtRefine [c :: (a,Int):=:(Bool,b)]
112   = [ right (left c) :: a:=:Bool,       
113       sym (right c)  :: b:=:Int ]
114
115 That is, given evidence 'c' that (a,Int)=(Bool,b), it returns derived
116 evidence in easy-to-use form.  In particular, given any e::ty, we know 
117 that:
118         e `cast` ty[right (left c)/a, sym (right c)/b]
119         :: ty [Bool/a, Int/b]
120       
121 Two refinements:
122
123 - It can fail, if the coercion is unsatisfiable.
124
125 - It's biased, by being given a set of type variables to bind
126   when there is a choice. Example:
127         MkT :: forall a. a -> T [a]
128         f :: forall b. T [b] -> b
129         f x = let v = case x of { MkT y -> y }
130               in ...
131   Here we want to bind [a->b], not the other way round, because
132   in this example the return type is wobbly, and we want the
133   program to typecheck
134
135
136 -- E.g. (a, Bool, right (left c))
137 -- INVARIANT: in the triple (tv, ty, co), we have (co :: tv:=:ty)
138 -- The result is idempotent: the 
139
140 \begin{code}
141 gadtRefine (Reft in_scope env1) 
142            ex_tvs co_vars
143 -- Precondition: fvs( co_vars ) # env1
144 -- That is, the kinds of the co_vars are a
145 -- fixed point of the incoming refinement
146
147   = ASSERT2( not $ any (`elemVarEnv` env1) (varSetElems $ tyVarsOfTypes $ map tyVarKind co_vars),
148              ppr env1 $$ ppr co_vars $$ ppr (map tyVarKind co_vars) )
149     initUM (tryToBind tv_set) $
150     do  {       -- Run the unifier, starting with an empty env
151         ; env2 <- foldM do_one emptyInternalReft co_vars
152
153                 -- Find the fixed point of the resulting 
154                 -- non-idempotent substitution
155         ; let tmp_env = env1 `plusVarEnv` env2
156               out_env = fixTvCoEnv in_scope' tmp_env
157         ; WARN( not (null (badReftElts tmp_env)), ppr (badReftElts tmp_env) $$ ppr tmp_env )
158           WARN( not (null (badReftElts out_env)), ppr (badReftElts out_env) $$ ppr out_env )
159           return (Reft in_scope' out_env) }
160   where
161     tv_set = mkVarSet ex_tvs
162     in_scope' = foldr extend in_scope co_vars
163
164         -- For each co_var, add it *and* the tyvars it mentions, to in_scope
165     extend co_var in_scope
166         = extendInScopeSetSet in_scope $
167           extendVarSet (tyVarsOfType (tyVarKind co_var)) co_var
168         
169     do_one reft co_var = unify reft (TyVarTy co_var) ty1 ty2
170         where
171            (ty1,ty2) = splitCoercionKind (tyVarKind co_var)
172 \end{code} 
173
174 %************************************************************************
175 %*                                                                      *
176                 Unification
177 %*                                                                      *
178 %************************************************************************
179
180 \begin{code}
181 tcUnifyTys :: (TyVar -> BindFlag)
182            -> [Type] -> [Type]
183            -> Maybe TvSubst     -- A regular one-shot substitution
184 -- The two types may have common type variables, and indeed do so in the
185 -- second call to tcUnifyTys in FunDeps.checkClsFD
186 --
187 -- We implement tcUnifyTys using the evidence-generating 'unify' function
188 -- in this module, even though we don't need to generate any evidence.
189 -- This is simply to avoid replicating all all the code for unify
190 tcUnifyTys bind_fn tys1 tys2
191   = maybeErrToMaybe $ initUM bind_fn $
192     do { reft <- unifyList emptyInternalReft cos tys1 tys2
193
194         -- Find the fixed point of the resulting non-idempotent substitution
195         ; let in_scope = mkInScopeSet (tvs1 `unionVarSet` tvs2)
196               tv_env   = fixTvSubstEnv in_scope (mapVarEnv snd reft)
197
198         ; return (mkTvSubst in_scope tv_env) }
199   where
200     tvs1 = tyVarsOfTypes tys1
201     tvs2 = tyVarsOfTypes tys2
202     cos  = zipWith mkUnsafeCoercion tys1 tys2
203
204
205 ----------------------------
206 fixTvCoEnv :: InScopeSet -> InternalReft -> InternalReft
207         -- Find the fixed point of a Refinement
208         -- (assuming it has no loops!)
209 fixTvCoEnv in_scope env
210   = fixpt
211   where
212     fixpt         = mapVarEnv step env
213
214     step (co, ty) = case refineType (Reft in_scope fixpt) ty of
215                         Nothing         -> (co,                     ty)
216                         Just (co', ty') -> (mkTransCoercion co co', ty')
217       -- Apply fixpt one step:
218       -- Use refineType to get a substituted type, ty', and a coercion, co_fn,
219       -- which justifies the substitution.  If the coercion is not the identity
220       -- then use transitivity with the original coercion
221
222 -----------------------------
223 fixTvSubstEnv :: InScopeSet -> TvSubstEnv -> TvSubstEnv
224 fixTvSubstEnv in_scope env
225   = fixpt 
226   where
227     fixpt = mapVarEnv (substTy (mkTvSubst in_scope fixpt)) env
228
229 ----------------------------
230 dataConCanMatch :: [Type] -> DataCon -> Bool
231 -- Returns True iff the data con can match a scrutinee of type (T tys)
232 --                  where T is the type constructor for the data con
233 --
234 -- Instantiate the equations and try to unify them
235 dataConCanMatch tys con
236   | null eq_spec      = True    -- Common
237   | all isTyVarTy tys = True    -- Also common
238   | otherwise
239   = isJust (tcUnifyTys (\tv -> BindMe) 
240                        (map (substTyVar subst . fst) eq_spec)
241                        (map snd eq_spec))
242   where
243     dc_tvs  = dataConUnivTyVars con
244     eq_spec = dataConEqSpec con
245     subst   = zipTopTvSubst dc_tvs tys
246
247 ----------------------------
248 tryToBind :: TyVarSet -> TyVar -> BindFlag
249 tryToBind tv_set tv | tv `elemVarSet` tv_set = BindMe
250                     | otherwise              = AvoidMe
251
252
253 \end{code}
254
255
256 %************************************************************************
257 %*                                                                      *
258                 The workhorse
259 %*                                                                      *
260 %************************************************************************
261
262 \begin{code}
263 #ifdef DEBUG
264 badReftElts :: InternalReft -> [(Unique, (Coercion,Type))]
265 -- Return the BAD elements of the refinement
266 -- Should be empty; used in asserions only
267 badReftElts env
268   = filter (not . ok) (ufmToList env)
269   where
270     ok :: (Unique, (Coercion, Type)) -> Bool
271     ok (u, (co, ty)) | Just tv <- tcGetTyVar_maybe ty1
272                      = varUnique tv == u && ty `tcEqType` ty2 
273                      | otherwise = False
274         where
275           (ty1,ty2) = coercionKind co
276 #endif
277
278 emptyInternalReft :: InternalReft
279 emptyInternalReft = emptyVarEnv
280
281 unify :: InternalReft           -- An existing substitution to extend
282       -> Coercion       -- Witness of their equality 
283       -> Type -> Type   -- Types to be unified, and witness of their equality
284       -> UM InternalReft                -- Just the extended substitution, 
285                                 -- Nothing if unification failed
286 -- We do not require the incoming substitution to be idempotent,
287 -- nor guarantee that the outgoing one is.  That's fixed up by
288 -- the wrappers.
289
290 -- PRE-CONDITION: in the call (unify r co ty1 ty2), we know that
291 --                      co :: (ty1:=:ty2)
292
293 -- Respects newtypes, PredTypes
294
295 unify subst co ty1 ty2 = -- pprTrace "unify" (ppr subst <+> pprParendType ty1 <+> pprParendType ty2) $
296                          unify_ subst co ty1 ty2
297
298 -- in unify_, any NewTcApps/Preds should be taken at face value
299 unify_ subst co (TyVarTy tv1) ty2  = uVar False subst co tv1 ty2
300 unify_ subst co ty1 (TyVarTy tv2)  = uVar True  subst co tv2 ty1
301
302 unify_ subst co ty1 ty2 | Just ty1' <- tcView ty1 = unify subst co ty1' ty2
303 unify_ subst co ty1 ty2 | Just ty2' <- tcView ty2 = unify subst co ty1 ty2'
304
305 unify_ subst co (PredTy p1) (PredTy p2) = unify_pred subst co p1 p2
306
307 unify_ subst co t1@(TyConApp tyc1 tys1) t2@(TyConApp tyc2 tys2) 
308   | tyc1 == tyc2 = unify_tys subst co tys1 tys2
309
310 unify_ subst co (FunTy ty1a ty1b) (FunTy ty2a ty2b) 
311   = do  { let [co1,co2] = decomposeCo 2 co
312         ; subst' <- unify subst co1 ty1a ty2a
313         ; unify subst' co2 ty1b ty2b }
314
315         -- Applications need a bit of care!
316         -- They can match FunTy and TyConApp, so use splitAppTy_maybe
317         -- NB: we've already dealt with type variables and Notes,
318         -- so if one type is an App the other one jolly well better be too
319 unify_ subst co (AppTy ty1a ty1b) ty2
320   | Just (ty2a, ty2b) <- repSplitAppTy_maybe ty2
321   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
322         ; unify subst' (mkRightCoercion co) ty1b ty2b }
323
324 unify_ subst co ty1 (AppTy ty2a ty2b)
325   | Just (ty1a, ty1b) <- repSplitAppTy_maybe ty1
326   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
327         ; unify subst' (mkRightCoercion co) ty1b ty2b }
328
329 unify_ subst co ty1 ty2 = failWith (misMatch ty1 ty2)
330         -- ForAlls??
331
332
333 ------------------------------
334 unify_pred subst co (ClassP c1 tys1) (ClassP c2 tys2)
335   | c1 == c2 = unify_tys subst co tys1 tys2
336 unify_pred subst co (IParam n1 t1) (IParam n2 t2)
337   | n1 == n2 = unify subst co t1 t2
338 unify_pred subst co p1 p2 = failWith (misMatch (PredTy p1) (PredTy p2))
339  
340 ------------------------------
341 unify_tys :: InternalReft -> Coercion -> [Type] -> [Type] -> UM InternalReft
342 unify_tys subst co xs ys
343   = unifyList subst (decomposeCo (length xs) co) xs ys
344
345 unifyList :: InternalReft -> [Coercion] -> [Type] -> [Type] -> UM InternalReft
346 unifyList subst orig_cos orig_xs orig_ys
347   = go subst orig_cos orig_xs orig_ys
348   where
349     go subst _        []     []     = return subst
350     go subst (co:cos) (x:xs) (y:ys) = do { subst' <- unify subst co x y
351                                          ; go subst' cos xs ys }
352     go subst _ _ _ = failWith (lengthMisMatch orig_xs orig_ys)
353
354 ---------------------------------
355 uVar :: Bool            -- Swapped
356      -> InternalReft    -- An existing substitution to extend
357      -> Coercion
358      -> TyVar           -- Type variable to be unified
359      -> Type            -- with this type
360      -> UM InternalReft
361
362 -- PRE-CONDITION: in the call (uVar swap r co tv1 ty), we know that
363 --      if swap=False   co :: (tv1:=:ty)
364 --      if swap=True    co :: (ty:=:tv1)
365
366 uVar swap subst co tv1 ty
367  = -- Check to see whether tv1 is refined by the substitution
368    case (lookupVarEnv subst tv1) of
369
370      -- Yes, call back into unify'
371      Just (co',ty')     -- co' :: (tv1:=:ty')
372         | swap          -- co :: (ty:=:tv1)
373         -> unify subst (mkTransCoercion co co') ty ty' 
374         | otherwise     -- co :: (tv1:=:ty)
375         -> unify subst (mkTransCoercion (mkSymCoercion co') co) ty' ty
376
377      -- No, continue
378      Nothing -> uUnrefined swap subst co
379                            tv1 ty ty
380
381
382 uUnrefined :: Bool                -- Whether the input is swapped
383            -> InternalReft        -- An existing substitution to extend
384            -> Coercion
385            -> TyVar               -- Type variable to be unified
386            -> Type                -- with this type
387            -> Type                -- (de-noted version)
388            -> UM InternalReft
389
390 -- We know that tv1 isn't refined
391 -- PRE-CONDITION: in the call (uUnrefined False r co tv1 ty2 ty2'), we know that
392 --      co :: tv1:=:ty2
393 -- and if the first argument is True instead, we know
394 --      co :: ty2:=:tv1
395
396 uUnrefined swap subst co tv1 ty2 ty2'
397   | Just ty2'' <- tcView ty2'
398   = uUnrefined swap subst co tv1 ty2 ty2''      -- Unwrap synonyms
399                 -- This is essential, in case we have
400                 --      type Foo a = a
401                 -- and then unify a :=: Foo a
402
403 uUnrefined swap subst co tv1 ty2 (TyVarTy tv2)
404   | tv1 == tv2          -- Same type variable
405   = return subst
406
407     -- Check to see whether tv2 is refined
408   | Just (co',ty') <- lookupVarEnv subst tv2    -- co' :: tv2:=:ty'
409   = uUnrefined False subst (mkTransCoercion (doSwap swap co) co') tv1 ty' ty'
410
411   -- So both are unrefined; next, see if the kinds force the direction
412   | eqKind k1 k2        -- Can update either; so check the bind-flags
413   = do  { b1 <- tvBindFlag tv1
414         ; b2 <- tvBindFlag tv2
415         ; case (b1,b2) of
416             (BindMe, _)          -> bind swap tv1 ty2
417
418             (AvoidMe, BindMe)    -> bind (not swap) tv2 ty1
419             (AvoidMe, _)         -> bind swap tv1 ty2
420
421             (WildCard, WildCard) -> return subst
422             (WildCard, Skolem)   -> return subst
423             (WildCard, _)        -> bind (not swap) tv2 ty1
424
425             (Skolem, WildCard)   -> return subst
426             (Skolem, Skolem)     -> failWith (misMatch ty1 ty2)
427             (Skolem, _)          -> bind (not swap) tv2 ty1
428         }
429
430   | k1 `isSubKind` k2 = bindTv (not swap) subst co tv2 ty1  -- Must update tv2
431   | k2 `isSubKind` k1 = bindTv swap subst co tv1 ty2        -- Must update tv1
432
433   | otherwise = failWith (kindMisMatch tv1 ty2)
434   where
435     ty1 = TyVarTy tv1
436     k1 = tyVarKind tv1
437     k2 = tyVarKind tv2
438     bind swap tv ty = extendReft swap subst tv co ty
439
440 uUnrefined swap subst co tv1 ty2 ty2'   -- ty2 is not a type variable
441   | tv1 `elemVarSet` substTvSet subst (tyVarsOfType ty2')
442   = failWith (occursCheck tv1 ty2)      -- Occurs check
443   | not (k2 `isSubKind` k1)
444   = failWith (kindMisMatch tv1 ty2)     -- Kind check
445   | otherwise
446   = bindTv swap subst co tv1 ty2                -- Bind tyvar to the synonym if poss
447   where
448     k1 = tyVarKind tv1
449     k2 = typeKind ty2'
450
451 substTvSet :: InternalReft -> TyVarSet -> TyVarSet
452 -- Apply the non-idempotent substitution to a set of type variables,
453 -- remembering that the substitution isn't necessarily idempotent
454 substTvSet subst tvs
455   = foldVarSet (unionVarSet . get) emptyVarSet tvs
456   where
457     get tv = case lookupVarEnv subst tv of
458                 Nothing     -> unitVarSet tv
459                 Just (_,ty) -> substTvSet subst (tyVarsOfType ty)
460
461 bindTv swap subst co tv ty      -- ty is not a type variable
462   = do  { b <- tvBindFlag tv
463         ; case b of
464             Skolem   -> failWith (misMatch (TyVarTy tv) ty)
465             WildCard -> return subst
466             other    -> extendReft swap subst tv co ty
467         }
468
469 doSwap :: Bool -> Coercion -> Coercion
470 doSwap swap co = if swap then mkSymCoercion co else co
471
472 extendReft :: Bool 
473            -> InternalReft 
474            -> TyVar 
475            -> Coercion 
476            -> Type 
477            -> UM InternalReft
478 extendReft swap subst tv  co ty
479   = ASSERT2( (coercionKindPredTy co1 `tcEqType` mkCoKind (mkTyVarTy tv) ty), 
480           (text "Refinement invariant failure: co = " <+> ppr co1  <+> ppr (coercionKindPredTy co1) $$ text "subst = " <+> ppr tv <+> ppr (mkCoKind (mkTyVarTy tv) ty)) )
481     return (extendVarEnv subst tv (co1, ty))
482   where
483     co1 = doSwap swap co
484
485 \end{code}
486
487 %************************************************************************
488 %*                                                                      *
489                 Unification monad
490 %*                                                                      *
491 %************************************************************************
492
493 \begin{code}
494 data BindFlag 
495   = BindMe      -- A regular type variable
496   | AvoidMe     -- Like BindMe but, given the choice, avoid binding it
497
498   | Skolem      -- This type variable is a skolem constant
499                 -- Don't bind it; it only matches itself
500
501   | WildCard    -- This type variable matches anything,
502                 -- and does not affect the substitution
503
504 newtype UM a = UM { unUM :: (TyVar -> BindFlag)
505                          -> MaybeErr Message a }
506
507 instance Monad UM where
508   return a = UM (\tvs -> Succeeded a)
509   fail s   = UM (\tvs -> Failed (text s))
510   m >>= k  = UM (\tvs -> case unUM m tvs of
511                            Failed err -> Failed err
512                            Succeeded v  -> unUM (k v) tvs)
513
514 initUM :: (TyVar -> BindFlag) -> UM a -> MaybeErr Message a
515 initUM badtvs um = unUM um badtvs
516
517 tvBindFlag :: TyVar -> UM BindFlag
518 tvBindFlag tv = UM (\tv_fn -> Succeeded (tv_fn tv))
519
520 failWith :: Message -> UM a
521 failWith msg = UM (\tv_fn -> Failed msg)
522
523 maybeErrToMaybe :: MaybeErr fail succ -> Maybe succ
524 maybeErrToMaybe (Succeeded a) = Just a
525 maybeErrToMaybe (Failed m)    = Nothing
526 \end{code}
527
528
529 %************************************************************************
530 %*                                                                      *
531                 Error reporting
532         We go to a lot more trouble to tidy the types
533         in TcUnify.  Maybe we'll end up having to do that
534         here too, but I'll leave it for now.
535 %*                                                                      *
536 %************************************************************************
537
538 \begin{code}
539 misMatch t1 t2
540   = ptext SLIT("Can't match types") <+> quotes (ppr t1) <+> 
541     ptext SLIT("and") <+> quotes (ppr t2)
542
543 lengthMisMatch tys1 tys2
544   = sep [ptext SLIT("Can't match unequal length lists"), 
545          nest 2 (ppr tys1), nest 2 (ppr tys2) ]
546
547 kindMisMatch tv1 t2
548   = vcat [ptext SLIT("Can't match kinds") <+> quotes (ppr (tyVarKind tv1)) <+> 
549             ptext SLIT("and") <+> quotes (ppr (typeKind t2)),
550           ptext SLIT("when matching") <+> quotes (ppr tv1) <+> 
551                 ptext SLIT("with") <+> quotes (ppr t2)]
552
553 occursCheck tv ty
554   = hang (ptext SLIT("Can't construct the infinite type"))
555        2 (ppr tv <+> equals <+> ppr ty)
556 \end{code}