Remove GADT refinements, part 4
[ghc-hetmet.git] / compiler / typecheck / TcGadt.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 %************************************************************************
7 %*                                                                      *
8                 Type refinement for GADTs
9 %*                                                                      *
10 %************************************************************************
11
12 \begin{code}
13 module TcGadt (
14         Refinement, emptyRefinement, isEmptyRefinement, 
15         matchRefine, 
16         refineType, refinePred, refineResType,
17         tcUnifyTys, BindFlag(..)
18   ) where
19
20 #include "HsVersions.h"
21
22 import HsSyn
23 import Coercion
24 import Type
25
26 import TypeRep
27 import Var
28 import VarEnv
29 import VarSet
30 import ErrUtils
31 import Maybes
32 import Control.Monad
33 import Outputable
34 import TcType
35 import UniqFM
36 import FastString
37 \end{code}
38
39
40 %************************************************************************
41 %*                                                                      *
42                 What a refinement is
43 %*                                                                      *
44 %************************************************************************
45
46 \begin{code}
47 data Refinement = Reft InScopeSet InternalReft 
48
49 type InternalReft = TyVarEnv (Coercion, Type)
50 -- INVARIANT:   a->(co,ty)   then   co :: (a:=:ty)
51 -- Not necessarily idemopotent
52
53 instance Outputable Refinement where
54   ppr (Reft _in_scope env)
55     = ptext SLIT("Refinement") <+>
56         braces (ppr env)
57
58 emptyRefinement :: Refinement
59 emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
60
61 isEmptyRefinement :: Refinement -> Bool
62 isEmptyRefinement (Reft _ env) = isEmptyVarEnv env
63
64 refineType :: Refinement -> Type -> Maybe (Coercion, Type)
65 -- Apply the refinement to the type.
66 -- If (refineType r ty) = (co, ty')
67 -- Then co :: ty:=:ty'
68 -- Nothing => the refinement does nothing to this type
69 refineType (Reft in_scope env) ty
70   | not (isEmptyVarEnv env),            -- Common case
71     any (`elemVarEnv` env) (varSetElems (tyVarsOfType ty))
72   = Just (substTy co_subst ty, substTy tv_subst ty)
73   | otherwise
74   = Nothing     -- The type doesn't mention any refined type variables
75   where
76     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
77     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
78  
79 refinePred :: Refinement -> PredType -> Maybe (Coercion, PredType)
80 refinePred (Reft in_scope env) pred
81   | not (isEmptyVarEnv env),            -- Common case
82     any (`elemVarEnv` env) (varSetElems (tyVarsOfPred pred))
83   = Just (mkPredTy (substPred co_subst pred), substPred tv_subst pred)
84   | otherwise
85   = Nothing     -- The type doesn't mention any refined type variables
86   where
87     tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
88     co_subst = mkTvSubst in_scope (mapVarEnv fst env)
89  
90 refineResType :: Refinement -> Type -> (HsWrapper, Type)
91 -- Like refineType, but returns the 'sym' coercion
92 -- If (refineResType r ty) = (co, ty')
93 -- Then co :: ty':=:ty
94 -- It's convenient to return a HsWrapper here
95 refineResType reft ty
96   = case refineType reft ty of
97         Just (co, ty1) -> (WpCo (mkSymCoercion co), ty1)
98         Nothing        -> (idHsWrapper,             ty)
99 \end{code}
100
101
102 %************************************************************************
103 %*                                                                      *
104                 Simple generation of a type refinement
105 %*                                                                      *
106 %************************************************************************
107
108 \begin{code}
109 matchRefine :: [CoVar] -> Refinement
110 \end{code}
111
112 Given a list of coercions, where for each coercion c::(ty1~ty2), the type ty2
113 is a specialisation of ty1, produce a type refinement that maps the variables
114 of ty1 to the corresponding sub-terms of ty2 using appropriate coercions; eg,
115
116   matchRefine (co :: [(a, b)] ~ [(c, Maybe d)])
117     = { right (left (right co)) :: a ~ c
118       , right (right co)        :: b ~ Maybe d
119       }
120
121 Precondition: The rhs types must indeed be a specialisation of the lhs types;
122   i.e., some free variables of the lhs are replaced with either distinct free 
123   variables or proper type terms to obtain the rhs.  (We don't perform full
124   unification or type matching here!)
125
126 NB: matchRefine does *not* expand the type synonyms.
127
128 \begin{code}
129 matchRefine co_vars 
130   = Reft in_scope (foldr plusVarEnv emptyVarEnv (map refineOne co_vars))
131   where
132     in_scope = foldr extend emptyInScopeSet co_vars
133
134         -- For each co_var, add it *and* the tyvars it mentions, to in_scope
135     extend co_var in_scope
136       = extendInScopeSetSet in_scope $
137           extendVarSet (tyVarsOfType (tyVarKind co_var)) co_var
138
139     refineOne co_var = refine (TyVarTy co_var) ty1 ty2
140       where
141         (ty1, ty2) = splitCoercionKind (tyVarKind co_var)
142
143     refine co (TyVarTy tv) ty                     = unitVarEnv tv (co, ty)
144     refine co (TyConApp _ tys) (TyConApp _ tys')  = refineArgs co tys tys'
145     refine co (NoteTy _ ty) ty'                   = refine co ty ty'
146     refine co ty (NoteTy _ ty')                   = refine co ty ty'
147     refine _  (PredTy _) (PredTy _)               = 
148       error "TcGadt.matchRefine: PredTy"
149     refine co (FunTy arg res) (FunTy arg' res')   =
150       refine (mkRightCoercion (mkLeftCoercion co)) arg arg' 
151       `plusVarEnv` 
152       refine (mkRightCoercion co) res res'
153     refine co (AppTy fun arg) (AppTy fun' arg')   = 
154       refine (mkLeftCoercion co) fun fun' 
155       `plusVarEnv`
156       refine (mkRightCoercion co) arg arg'
157     refine co (ForAllTy tv ty) (ForAllTy _tv ty') =
158       refine (mkForAllCoercion tv co) ty ty' `delVarEnv` tv
159     refine _ _ _ = error "RcGadt.matchRefine: mismatch"
160
161     refineArgs :: Coercion -> [Type] -> [Type] -> InternalReft
162     refineArgs co tys tys' = 
163       fst $ foldr refineArg (emptyVarEnv, id) (zip tys tys')
164       where
165         refineArg (ty, ty') (reft, coWrapper) 
166           = (refine (mkRightCoercion (coWrapper co)) ty ty' `plusVarEnv` reft, 
167              mkLeftCoercion . coWrapper)
168 \end{code}
169
170
171 %************************************************************************
172 %*                                                                      *
173                 Unification
174 %*                                                                      *
175 %************************************************************************
176
177 \begin{code}
178 tcUnifyTys :: (TyVar -> BindFlag)
179            -> [Type] -> [Type]
180            -> Maybe TvSubst     -- A regular one-shot substitution
181 -- The two types may have common type variables, and indeed do so in the
182 -- second call to tcUnifyTys in FunDeps.checkClsFD
183 --
184 -- We implement tcUnifyTys using the evidence-generating 'unify' function
185 -- in this module, even though we don't need to generate any evidence.
186 -- This is simply to avoid replicating all all the code for unify
187 tcUnifyTys bind_fn tys1 tys2
188   = maybeErrToMaybe $ initUM bind_fn $
189     do { reft <- unifyList emptyInternalReft cos tys1 tys2
190
191         -- Find the fixed point of the resulting non-idempotent substitution
192         ; let in_scope = mkInScopeSet (tvs1 `unionVarSet` tvs2)
193               tv_env   = fixTvSubstEnv in_scope (mapVarEnv snd reft)
194
195         ; return (mkTvSubst in_scope tv_env) }
196   where
197     tvs1 = tyVarsOfTypes tys1
198     tvs2 = tyVarsOfTypes tys2
199     cos  = zipWith mkUnsafeCoercion tys1 tys2
200
201
202 ----------------------------
203 -- XXX Can we do this more nicely, by exploiting laziness?
204 -- Or avoid needing it in the first place?
205 fixTvSubstEnv :: InScopeSet -> TvSubstEnv -> TvSubstEnv
206 fixTvSubstEnv in_scope env = f env
207   where
208     f e = let e' = mapUFM (substTy (mkTvSubst in_scope e)) e
209           in if and $ eltsUFM $ intersectUFM_C tcEqType e e'
210              then e
211              else f e'
212
213 \end{code}
214
215
216 %************************************************************************
217 %*                                                                      *
218                 The workhorse
219 %*                                                                      *
220 %************************************************************************
221
222 \begin{code}
223 emptyInternalReft :: InternalReft
224 emptyInternalReft = emptyVarEnv
225
226 unify :: InternalReft           -- An existing substitution to extend
227       -> Coercion       -- Witness of their equality 
228       -> Type -> Type   -- Types to be unified, and witness of their equality
229       -> UM InternalReft                -- Just the extended substitution, 
230                                 -- Nothing if unification failed
231 -- We do not require the incoming substitution to be idempotent,
232 -- nor guarantee that the outgoing one is.  That's fixed up by
233 -- the wrappers.
234
235 -- PRE-CONDITION: in the call (unify r co ty1 ty2), we know that
236 --                      co :: (ty1:=:ty2)
237
238 -- Respects newtypes, PredTypes
239
240 unify subst co ty1 ty2 = -- pprTrace "unify" (ppr subst <+> pprParendType ty1 <+> pprParendType ty2) $
241                          unify_ subst co ty1 ty2
242
243 -- in unify_, any NewTcApps/Preds should be taken at face value
244 unify_ subst co (TyVarTy tv1) ty2  = uVar False subst co tv1 ty2
245 unify_ subst co ty1 (TyVarTy tv2)  = uVar True  subst co tv2 ty1
246
247 unify_ subst co ty1 ty2 | Just ty1' <- tcView ty1 = unify subst co ty1' ty2
248 unify_ subst co ty1 ty2 | Just ty2' <- tcView ty2 = unify subst co ty1 ty2'
249
250 unify_ subst co (PredTy p1) (PredTy p2) = unify_pred subst co p1 p2
251
252 unify_ subst co t1@(TyConApp tyc1 tys1) t2@(TyConApp tyc2 tys2) 
253   | tyc1 == tyc2 = unify_tys subst co tys1 tys2
254
255 unify_ subst co (FunTy ty1a ty1b) (FunTy ty2a ty2b) 
256   = do  { let [co1,co2] = decomposeCo 2 co
257         ; subst' <- unify subst co1 ty1a ty2a
258         ; unify subst' co2 ty1b ty2b }
259
260         -- Applications need a bit of care!
261         -- They can match FunTy and TyConApp, so use splitAppTy_maybe
262         -- NB: we've already dealt with type variables and Notes,
263         -- so if one type is an App the other one jolly well better be too
264 unify_ subst co (AppTy ty1a ty1b) ty2
265   | Just (ty2a, ty2b) <- repSplitAppTy_maybe ty2
266   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
267         ; unify subst' (mkRightCoercion co) ty1b ty2b }
268
269 unify_ subst co ty1 (AppTy ty2a ty2b)
270   | Just (ty1a, ty1b) <- repSplitAppTy_maybe ty1
271   = do  { subst' <- unify subst (mkLeftCoercion co) ty1a ty2a
272         ; unify subst' (mkRightCoercion co) ty1b ty2b }
273
274 unify_ subst co ty1 ty2 = failWith (misMatch ty1 ty2)
275         -- ForAlls??
276
277
278 ------------------------------
279 unify_pred subst co (ClassP c1 tys1) (ClassP c2 tys2)
280   | c1 == c2 = unify_tys subst co tys1 tys2
281 unify_pred subst co (IParam n1 t1) (IParam n2 t2)
282   | n1 == n2 = unify subst co t1 t2
283 unify_pred subst co p1 p2 = failWith (misMatch (PredTy p1) (PredTy p2))
284  
285 ------------------------------
286 unify_tys :: InternalReft -> Coercion -> [Type] -> [Type] -> UM InternalReft
287 unify_tys subst co xs ys
288   = unifyList subst (decomposeCo (length xs) co) xs ys
289
290 unifyList :: InternalReft -> [Coercion] -> [Type] -> [Type] -> UM InternalReft
291 unifyList subst orig_cos orig_xs orig_ys
292   = go subst orig_cos orig_xs orig_ys
293   where
294     go subst _        []     []     = return subst
295     go subst (co:cos) (x:xs) (y:ys) = do { subst' <- unify subst co x y
296                                          ; go subst' cos xs ys }
297     go subst _ _ _ = failWith (lengthMisMatch orig_xs orig_ys)
298
299 ---------------------------------
300 uVar :: Bool            -- Swapped
301      -> InternalReft    -- An existing substitution to extend
302      -> Coercion
303      -> TyVar           -- Type variable to be unified
304      -> Type            -- with this type
305      -> UM InternalReft
306
307 -- PRE-CONDITION: in the call (uVar swap r co tv1 ty), we know that
308 --      if swap=False   co :: (tv1:=:ty)
309 --      if swap=True    co :: (ty:=:tv1)
310
311 uVar swap subst co tv1 ty
312  = -- Check to see whether tv1 is refined by the substitution
313    case (lookupVarEnv subst tv1) of
314
315      -- Yes, call back into unify'
316      Just (co',ty')     -- co' :: (tv1:=:ty')
317         | swap          -- co :: (ty:=:tv1)
318         -> unify subst (mkTransCoercion co co') ty ty' 
319         | otherwise     -- co :: (tv1:=:ty)
320         -> unify subst (mkTransCoercion (mkSymCoercion co') co) ty' ty
321
322      -- No, continue
323      Nothing -> uUnrefined swap subst co
324                            tv1 ty ty
325
326
327 uUnrefined :: Bool                -- Whether the input is swapped
328            -> InternalReft        -- An existing substitution to extend
329            -> Coercion
330            -> TyVar               -- Type variable to be unified
331            -> Type                -- with this type
332            -> Type                -- (de-noted version)
333            -> UM InternalReft
334
335 -- We know that tv1 isn't refined
336 -- PRE-CONDITION: in the call (uUnrefined False r co tv1 ty2 ty2'), we know that
337 --      co :: tv1:=:ty2
338 -- and if the first argument is True instead, we know
339 --      co :: ty2:=:tv1
340
341 uUnrefined swap subst co tv1 ty2 ty2'
342   | Just ty2'' <- tcView ty2'
343   = uUnrefined swap subst co tv1 ty2 ty2''      -- Unwrap synonyms
344                 -- This is essential, in case we have
345                 --      type Foo a = a
346                 -- and then unify a :=: Foo a
347
348 uUnrefined swap subst co tv1 ty2 (TyVarTy tv2)
349   | tv1 == tv2          -- Same type variable
350   = return subst
351
352     -- Check to see whether tv2 is refined
353   | Just (co',ty') <- lookupVarEnv subst tv2    -- co' :: tv2:=:ty'
354   = uUnrefined False subst (mkTransCoercion (doSwap swap co) co') tv1 ty' ty'
355
356   -- So both are unrefined; next, see if the kinds force the direction
357   | eqKind k1 k2        -- Can update either; so check the bind-flags
358   = do  { b1 <- tvBindFlag tv1
359         ; b2 <- tvBindFlag tv2
360         ; case (b1,b2) of
361             (BindMe, _)          -> bind swap tv1 ty2
362
363             (AvoidMe, BindMe)    -> bind (not swap) tv2 ty1
364             (AvoidMe, _)         -> bind swap tv1 ty2
365
366             (WildCard, WildCard) -> return subst
367             (WildCard, Skolem)   -> return subst
368             (WildCard, _)        -> bind (not swap) tv2 ty1
369
370             (Skolem, WildCard)   -> return subst
371             (Skolem, Skolem)     -> failWith (misMatch ty1 ty2)
372             (Skolem, _)          -> bind (not swap) tv2 ty1
373         }
374
375   | k1 `isSubKind` k2 = bindTv (not swap) subst co tv2 ty1  -- Must update tv2
376   | k2 `isSubKind` k1 = bindTv swap subst co tv1 ty2        -- Must update tv1
377
378   | otherwise = failWith (kindMisMatch tv1 ty2)
379   where
380     ty1 = TyVarTy tv1
381     k1 = tyVarKind tv1
382     k2 = tyVarKind tv2
383     bind swap tv ty = extendReft swap subst tv co ty
384
385 uUnrefined swap subst co tv1 ty2 ty2'   -- ty2 is not a type variable
386   | tv1 `elemVarSet` substTvSet subst (tyVarsOfType ty2')
387   = failWith (occursCheck tv1 ty2)      -- Occurs check
388   | not (k2 `isSubKind` k1)
389   = failWith (kindMisMatch tv1 ty2)     -- Kind check
390   | otherwise
391   = bindTv swap subst co tv1 ty2                -- Bind tyvar to the synonym if poss
392   where
393     k1 = tyVarKind tv1
394     k2 = typeKind ty2'
395
396 substTvSet :: InternalReft -> TyVarSet -> TyVarSet
397 -- Apply the non-idempotent substitution to a set of type variables,
398 -- remembering that the substitution isn't necessarily idempotent
399 substTvSet subst tvs
400   = foldVarSet (unionVarSet . get) emptyVarSet tvs
401   where
402     get tv = case lookupVarEnv subst tv of
403                 Nothing     -> unitVarSet tv
404                 Just (_,ty) -> substTvSet subst (tyVarsOfType ty)
405
406 bindTv swap subst co tv ty      -- ty is not a type variable
407   = do  { b <- tvBindFlag tv
408         ; case b of
409             Skolem   -> failWith (misMatch (TyVarTy tv) ty)
410             WildCard -> return subst
411             other    -> extendReft swap subst tv co ty
412         }
413
414 doSwap :: Bool -> Coercion -> Coercion
415 doSwap swap co = if swap then mkSymCoercion co else co
416
417 extendReft :: Bool 
418            -> InternalReft 
419            -> TyVar 
420            -> Coercion 
421            -> Type 
422            -> UM InternalReft
423 extendReft swap subst tv  co ty
424   = ASSERT2( (coercionKindPredTy co1 `tcEqType` mkCoKind (mkTyVarTy tv) ty), 
425           (text "Refinement invariant failure: co = " <+> ppr co1  <+> ppr (coercionKindPredTy co1) $$ text "subst = " <+> ppr tv <+> ppr (mkCoKind (mkTyVarTy tv) ty)) )
426     return (extendVarEnv subst tv (co1, ty))
427   where
428     co1 = doSwap swap co
429
430 \end{code}
431
432 %************************************************************************
433 %*                                                                      *
434                 Unification monad
435 %*                                                                      *
436 %************************************************************************
437
438 \begin{code}
439 data BindFlag 
440   = BindMe      -- A regular type variable
441   | AvoidMe     -- Like BindMe but, given the choice, avoid binding it
442
443   | Skolem      -- This type variable is a skolem constant
444                 -- Don't bind it; it only matches itself
445
446   | WildCard    -- This type variable matches anything,
447                 -- and does not affect the substitution
448
449 newtype UM a = UM { unUM :: (TyVar -> BindFlag)
450                          -> MaybeErr Message a }
451
452 instance Monad UM where
453   return a = UM (\tvs -> Succeeded a)
454   fail s   = UM (\tvs -> Failed (text s))
455   m >>= k  = UM (\tvs -> case unUM m tvs of
456                            Failed err -> Failed err
457                            Succeeded v  -> unUM (k v) tvs)
458
459 initUM :: (TyVar -> BindFlag) -> UM a -> MaybeErr Message a
460 initUM badtvs um = unUM um badtvs
461
462 tvBindFlag :: TyVar -> UM BindFlag
463 tvBindFlag tv = UM (\tv_fn -> Succeeded (tv_fn tv))
464
465 failWith :: Message -> UM a
466 failWith msg = UM (\tv_fn -> Failed msg)
467
468 maybeErrToMaybe :: MaybeErr fail succ -> Maybe succ
469 maybeErrToMaybe (Succeeded a) = Just a
470 maybeErrToMaybe (Failed m)    = Nothing
471 \end{code}
472
473
474 %************************************************************************
475 %*                                                                      *
476                 Error reporting
477         We go to a lot more trouble to tidy the types
478         in TcUnify.  Maybe we'll end up having to do that
479         here too, but I'll leave it for now.
480 %*                                                                      *
481 %************************************************************************
482
483 \begin{code}
484 misMatch t1 t2
485   = ptext SLIT("Can't match types") <+> quotes (ppr t1) <+> 
486     ptext SLIT("and") <+> quotes (ppr t2)
487
488 lengthMisMatch tys1 tys2
489   = sep [ptext SLIT("Can't match unequal length lists"), 
490          nest 2 (ppr tys1), nest 2 (ppr tys2) ]
491
492 kindMisMatch tv1 t2
493   = vcat [ptext SLIT("Can't match kinds") <+> quotes (ppr (tyVarKind tv1)) <+> 
494             ptext SLIT("and") <+> quotes (ppr (typeKind t2)),
495           ptext SLIT("when matching") <+> quotes (ppr tv1) <+> 
496                 ptext SLIT("with") <+> quotes (ppr t2)]
497
498 occursCheck tv ty
499   = hang (ptext SLIT("Can't construct the infinite type"))
500        2 (ppr tv <+> equals <+> ppr ty)
501 \end{code}