af3bb3c55969d1bb866b4c7b4b4df9e43addae22
[ghc-hetmet.git] / utils / ext-core / Check.hs
1 {-# OPTIONS -Wall -fno-warn-name-shadowing #-}
2 module Check(
3   checkModule, envsModule,
4   checkExpr, checkType, 
5   primCoercionError, 
6   Menv, Venv, Tvenv, Envs(..),
7   CheckRes(..), splitTy, substl) where
8
9 import Maybe
10 import Control.Monad.Reader
11 -- just for printing warnings
12 import System.IO.Unsafe
13
14 import Core
15 import Printer()
16 import List
17 import Env
18 import PrimEnv
19
20 {- Checking is done in a simple error monad.  In addition to
21    allowing errors to be captured, this makes it easy to guarantee
22    that checking itself has been completed for an entire module. -}
23
24 {- We use the Reader monad transformer in order to thread the 
25    top-level module name throughout the computation simply.
26    This is so that checkExp can also be an entry point (we call it
27    from Prep.) -}
28 data CheckRes a = OkC a | FailC String
29 type CheckResult a = ReaderT (AnMname, Menv) CheckRes a
30 getMname :: CheckResult AnMname
31 getMname     = ask >>= (return . fst)
32 getGlobalEnv :: CheckResult Menv
33 getGlobalEnv = ask >>= (return . snd)
34
35 instance Monad CheckRes where
36   OkC a >>= k = k a
37   FailC s >>= _ = fail s
38   return = OkC
39   fail = FailC
40
41 require :: Bool -> String -> CheckResult ()
42 require False s = fail s
43 require True  _ = return ()
44
45 requireM :: CheckResult Bool -> String -> CheckResult ()
46 requireM cond s =
47   do b <- cond
48      require b s
49
50 {- Environments. -}
51 type Tvenv = Env Tvar Kind                    -- type variables  (local only)
52 type Tcenv = Env Tcon KindOrCoercion          -- type constructors
53 type Tsenv = Env Tcon ([Tvar],Ty)             -- type synonyms
54 type Cenv = Env Dcon Ty                       -- data constructors
55 type Venv = Env Var Ty                        -- values
56 type Menv = Env AnMname Envs                  -- modules
57 data Envs = Envs {tcenv_::Tcenv,tsenv_::Tsenv,cenv_::Cenv,venv_::Venv} -- all the exportable envs
58   deriving Show
59
60 {- Extend an environment, checking for illegal shadowing of identifiers (for term
61    variables -- shadowing type variables is allowed.) -}
62 data EnvType = Tv | NotTv
63   deriving Eq
64
65 extendM :: (Ord a, Show a) => EnvType -> Env a b -> (a,b) -> CheckResult (Env a b)
66 extendM envType env (k,d) = 
67    case elookup env k of
68      Just _ | envType == NotTv -> fail ("multiply-defined identifier: " 
69                                       ++ show k)
70      _ -> return (eextend env (k,d))
71
72 extendVenv :: (Ord a, Show a) => Env a b -> (a,b) -> CheckResult (Env a b)
73 extendVenv = extendM NotTv
74
75 extendTvenv :: (Ord a, Show a) => Env a b -> (a,b) -> CheckResult (Env a b)
76 extendTvenv = extendM Tv
77
78 lookupM :: (Ord a, Show a) => Env a b -> a -> CheckResult b
79 lookupM env k =   
80    case elookup env k of
81      Just v -> return v
82      Nothing -> fail ("undefined identifier: " ++ show k)
83             
84 {- Main entry point. -}
85 checkModule :: Menv -> Module -> CheckRes Menv
86 checkModule globalEnv (Module mn tdefs vdefgs) = 
87   runReaderT 
88     (do (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
89         (e_venv,_) <- foldM (checkVdefg True (tcenv,tsenv,eempty,cenv))
90                               (eempty,eempty) 
91                               vdefgs
92         return (eextend globalEnv 
93             (mn,Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv})))
94          -- avoid name shadowing
95     (mn, eremove globalEnv mn)
96
97 -- Like checkModule, but doesn't typecheck the code, instead just
98 -- returning declared types for top-level defns.
99 -- This is necessary in order to handle circular dependencies, but it's sort
100 -- of unpleasant.
101 envsModule :: Menv -> Module -> Menv
102 envsModule globalEnv (Module mn tdefs vdefgs) = 
103    let (tcenv, tsenv, cenv) = mkTypeEnvsNoChecking tdefs
104        e_venv               = foldr vdefgTypes eempty vdefgs in
105      eextend globalEnv (mn, 
106              (Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv}))
107         where vdefgTypes :: Vdefg -> Venv -> Venv
108               vdefgTypes (Nonrec (Vdef (v,t,_))) e =
109                              add [(v,t)] e
110               vdefgTypes (Rec vds) e = 
111                              add (map (\ (Vdef (v,t,_)) -> (v,t)) vds) e
112               add :: [(Qual Var,Ty)] -> Venv -> Venv
113               add pairs e = foldr addOne e pairs
114               addOne :: (Qual Var, Ty) -> Venv -> Venv
115               addOne ((Nothing,_),_) e = e
116               addOne ((Just _,v),t) e  = eextend e (v,t)
117
118 checkTdef0 :: (Tcenv,Tsenv) -> Tdef -> CheckResult (Tcenv,Tsenv)
119 checkTdef0 (tcenv,tsenv) tdef = ch tdef
120       where 
121         ch (Data (m,c) tbs _) = 
122             do mn <- getMname
123                requireModulesEq m mn "data type declaration" tdef False
124                tcenv' <- extendM NotTv tcenv (c, Kind k)
125                return (tcenv',tsenv)
126             where k = foldr Karrow Klifted (map snd tbs)
127         ch (Newtype (m,c) coVar tbs rhs) = 
128             do mn <- getMname
129                requireModulesEq m mn "newtype declaration" tdef False
130                tcenv' <- extendM NotTv tcenv (c, Kind k)
131                -- add newtype axiom to env
132                tcenv'' <- envPlusNewtype tcenv' (m,c) coVar tbs rhs
133                tsenv' <- case rhs of
134                            Nothing -> return tsenv
135                            Just rep -> extendM NotTv tsenv (c,(map fst tbs,rep))
136                return (tcenv'', tsenv')
137             where k = foldr Karrow Klifted (map snd tbs)
138
139 processTdef0NoChecking :: (Tcenv,Tsenv) -> Tdef -> (Tcenv,Tsenv)
140 processTdef0NoChecking (tcenv,tsenv) tdef = ch tdef
141       where 
142         ch (Data (_,c) tbs _) = (eextend tcenv (c, Kind k), tsenv)
143             where k = foldr Karrow Klifted (map snd tbs)
144         ch (Newtype tc@(_,c) coercion tbs rhs) = 
145             let tcenv' = eextend tcenv (c, Kind k)
146                 -- add newtype axiom to env
147                 tcenv'' = case rhs of
148                             Just rep ->
149                               eextend tcenv' 
150                                (snd coercion, Coercion $ DefinedCoercion tbs
151                                 (foldl Tapp (Tcon tc) (map Tvar (fst (unzip tbs))), rep))
152                             Nothing -> tcenv'
153                 tsenv' = maybe tsenv 
154                   (\ rep -> eextend tsenv (c, (map fst tbs, rep))) rhs in
155                (tcenv'', tsenv')
156             where k = foldr Karrow Klifted (map snd tbs)
157
158 envPlusNewtype :: Tcenv -> Qual Tcon -> Qual Tcon -> [Tbind] -> Maybe Ty 
159   -> CheckResult Tcenv
160 envPlusNewtype tcenv tyCon coVar tbs rhs =
161   case rhs of
162     Nothing -> return tcenv
163     Just rep -> extendM NotTv tcenv 
164                   (snd coVar, Coercion $ DefinedCoercion tbs
165                             (foldl Tapp (Tcon tyCon) 
166                                        (map Tvar (fst (unzip tbs))),
167                                        rep))
168     
169 checkTdef :: Tcenv -> Cenv -> Tdef -> CheckResult Cenv
170 checkTdef tcenv cenv = ch
171        where 
172          ch (Data (_,c) utbs cdefs) = 
173             do cbinds <- mapM checkCdef cdefs
174                foldM (extendM NotTv) cenv cbinds
175             where checkCdef (cdef@(Constr (m,dcon) etbs ts)) =
176                     do mn <- getMname
177                        requireModulesEq m mn "constructor declaration" cdef 
178                          False 
179                        tvenv <- foldM (extendM Tv) eempty tbs 
180                        ks <- mapM (checkTy (tcenv,tvenv)) ts
181                        mapM_ (\k -> require (baseKind k)
182                                             ("higher-order kind in:\n" ++ show cdef ++ "\n" ++
183                                              "kind: " ++ show k) ) ks
184                        return (dcon,t mn) 
185                     where tbs = utbs ++ etbs
186                           t mn = foldr Tforall 
187                                   (foldr tArrow
188                                           (foldl Tapp (Tcon (Just mn,c))
189                                                  (map (Tvar . fst) utbs)) ts) tbs
190          ch (tdef@(Newtype tc _ tbs (Just t))) =  
191             do tvenv <- foldM (extendM Tv) eempty tbs
192                kRhs <- checkTy (tcenv,tvenv) t
193                require (kRhs `eqKind` Klifted) ("bad kind:\n" ++ show tdef)
194                kLhs <- checkTy (tcenv,tvenv) 
195                          (foldl Tapp (Tcon tc) (map Tvar (fst (unzip tbs))))
196                require (kLhs `eqKind` kRhs) 
197                   ("Kind mismatch in newtype axiom types: " ++ show tdef 
198                     ++ " kinds: " ++
199                    (show kLhs) ++ " and " ++ (show kRhs))
200                return cenv
201          ch (Newtype _ _ _ Nothing) =
202             {- should only occur for recursive Newtypes -}
203             return cenv
204
205 processCdef :: Cenv -> Tdef -> Cenv
206 processCdef cenv = ch
207   where
208     ch (Data (_,c) utbs cdefs) = do 
209        let cbinds = map checkCdef cdefs
210        foldl eextend cenv cbinds
211      where checkCdef (Constr (mn,dcon) etbs ts) =
212              (dcon,t mn) 
213             where tbs = utbs ++ etbs
214                   t mn = foldr Tforall 
215                           (foldr tArrow
216                             (foldl Tapp (Tcon (mn,c))
217                                (map (Tvar . fst) utbs)) ts) tbs
218     ch _ = cenv
219
220 mkTypeEnvs :: [Tdef] -> CheckResult (Tcenv, Tsenv, Cenv)
221 mkTypeEnvs tdefs = do
222   (tcenv, tsenv) <- foldM checkTdef0 (eempty,eempty) tdefs
223   cenv <- foldM (checkTdef tcenv) eempty tdefs
224   return (tcenv, tsenv, cenv)
225
226 mkTypeEnvsNoChecking :: [Tdef] -> (Tcenv, Tsenv, Cenv)
227 mkTypeEnvsNoChecking tdefs = 
228   let (tcenv, tsenv) = foldl processTdef0NoChecking (eempty,eempty) tdefs
229       cenv           = foldl processCdef eempty tdefs in
230     (tcenv, tsenv, cenv)
231
232 requireModulesEq :: Show a => Mname -> AnMname -> String -> a 
233                           -> Bool -> CheckResult ()
234 requireModulesEq (Just mn) m msg t _      = require (mn == m) (mkErrMsg msg t)
235 requireModulesEq Nothing _ msg t emptyOk  = require emptyOk (mkErrMsg msg t)
236
237 mkErrMsg :: Show a => String -> a -> String
238 mkErrMsg msg t = "wrong module name in " ++ msg ++ ":\n" ++ show t    
239
240 checkVdefg :: Bool -> (Tcenv,Tsenv,Tvenv,Cenv) -> (Venv,Venv) 
241                -> Vdefg -> CheckResult (Venv,Venv)
242 checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg = do
243       mn <- getMname
244       case vdefg of
245         Rec vdefs ->
246             do (e_venv', l_venv') <- makeEnv mn vdefs
247                let env' = (tcenv,tsenv,tvenv,cenv,e_venv',l_venv')
248                mapM_ (checkVdef (\ vdef k -> require (k `eqKind` Klifted) 
249                         ("unlifted kind in:\n" ++ show vdef)) env') 
250                      vdefs
251                return (e_venv', l_venv')
252         Nonrec vdef ->
253             do let env' = (tcenv, tsenv, tvenv, cenv, e_venv, l_venv)
254                checkVdef (\ vdef k -> do
255                      require (not (k `eqKind` Kopen)) ("open kind in:\n" ++ show vdef)
256                      require ((not top_level) || (not (k `eqKind` Kunlifted))) 
257                        ("top-level unlifted kind in:\n" ++ show vdef)) env' vdef
258                makeEnv mn [vdef]
259
260   where makeEnv mn vdefs = do
261              ev <- foldM extendVenv e_venv e_vts
262              lv <- foldM extendVenv l_venv l_vts
263              return (ev, lv)
264            where e_vts = [ (v,t) | Vdef ((Just m,v),t,_) <- vdefs,
265                                      not (vdefIsMainWrapper mn (Just m))]
266                  l_vts = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
267         checkVdef checkKind env (vdef@(Vdef ((m,_),t,e))) = do
268           mn <- getMname
269           let isZcMain = vdefIsMainWrapper mn m
270           unless isZcMain $
271              requireModulesEq m mn "value definition" vdef True
272           k <- checkTy (tcenv,tvenv) t
273           checkKind vdef k
274           t' <- checkExp env e
275           requireM (equalTy tsenv t t') 
276                    ("declared type doesn't match expression type in:\n"  
277                     ++ show vdef ++ "\n" ++  
278                     "declared type: " ++ show t ++ "\n" ++
279                     "expression type: " ++ show t')
280     
281 vdefIsMainWrapper :: AnMname -> Mname -> Bool
282 vdefIsMainWrapper enclosing defining = 
283    enclosing == mainMname && defining == wrapperMainMname
284
285 checkExpr :: AnMname -> Menv -> [Tdef] -> Venv -> Tvenv 
286                -> Exp -> Ty
287 checkExpr mn menv tdefs venv tvenv e = case runReaderT (do
288   (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
289   -- Since the preprocessor calls checkExpr after code has been
290   -- typechecked, we expect to find the external env in the Menv.
291   case (elookup menv mn) of
292      Just thisEnv ->
293        checkExp (tcenv, tsenv, tvenv, cenv, (venv_ thisEnv), venv) e 
294      Nothing -> reportError e ("checkExpr: Environment for " ++ 
295                   show mn ++ " not found")) (mn,menv) of
296          OkC t -> t
297          FailC s -> reportError e s
298
299 checkType :: AnMname -> Menv -> [Tdef] -> Tvenv -> Ty -> Kind
300 checkType mn menv tdefs tvenv t = case runReaderT (do
301   (tcenv, _, _) <- mkTypeEnvs tdefs
302   checkTy (tcenv, tvenv) t) (mn, menv) of
303       OkC k -> k
304       FailC s -> reportError tvenv s
305
306 checkExp :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Exp -> CheckResult Ty
307 checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) = ch
308       where 
309         ch e0 =
310           case e0 of
311             Var qv -> 
312               qlookupM venv_ e_venv l_venv qv
313             Dcon qc ->
314               qlookupM cenv_ cenv eempty qc
315             Lit l -> 
316               checkLit l
317             Appt e t -> 
318               do t' <- ch e
319                  k' <- checkTy (tcenv,tvenv) t
320                  case t' of
321                    Tforall (tv,k) t0 ->
322                      do require (k' `subKindOf` k) 
323                                 ("kind doesn't match at type application in:\n" ++ show e0 ++ "\n" ++
324                                  "operator kind: " ++ show k ++ "\n" ++
325                                  "operand kind: " ++ show k') 
326                         return (substl [tv] [t] t0)
327                    _ -> fail ("bad operator type in type application:\n" ++ show e0 ++ "\n" ++
328                                "operator type: " ++ show t')
329             App e1 e2 -> 
330               do t1 <- ch e1
331                  t2 <- ch e2
332                  case t1 of
333                    Tapp(Tapp(Tcon tc) t') t0 | tc == tcArrow ->
334                         do requireM (equalTy tsenv t2 t') 
335                                     ("type doesn't match at application in:\n" ++ show e0 ++ "\n" ++ 
336                                      "operator type: " ++ show t' ++ "\n" ++ 
337                                      "operand type: " ++ show t2) 
338                            return t0
339                    _ -> fail ("bad operator type at application in:\n" ++ show e0 ++ "\n" ++
340                                "operator type: " ++ show t1)
341             Lam (Tb tb) e ->
342               do tvenv' <- extendTvenv tvenv tb 
343                  t <- checkExp (tcenv,tsenv,tvenv',cenv,e_venv,l_venv) e 
344                  return (Tforall tb t)
345             Lam (Vb (vb@(_,vt))) e ->
346               do k <- checkTy (tcenv,tvenv) vt
347                  require (baseKind k)   
348                          ("higher-order kind in:\n" ++ show e0 ++ "\n" ++
349                           "kind: " ++ show k) 
350                  l_venv' <- extendVenv l_venv vb
351                  t <- checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv') e
352                  require (not (isUtupleTy vt)) ("lambda-bound unboxed tuple in:\n" ++ show e0) 
353                  return (tArrow vt t)
354             Let vdefg e ->
355               do (e_venv',l_venv') <- checkVdefg False (tcenv,tsenv,tvenv,cenv) 
356                                         (e_venv,l_venv) vdefg 
357                  checkExp (tcenv,tsenv,tvenv,cenv,e_venv',l_venv') e
358             Case e (v,t) resultTy alts ->
359               do t' <- ch e 
360                  checkTy (tcenv,tvenv) t
361                  requireM (equalTy tsenv t t') 
362                           ("scrutinee declared type doesn't match expression type in:\n" ++ show e0 ++ "\n" ++
363                            "declared type: " ++ show t ++ "\n" ++
364                            "expression type: " ++ show t') 
365                  case (reverse alts) of
366                    (Acon c _ _ _):as ->
367                       let ok ((Acon c _ _ _):as) cs = do require (notElem c cs)
368                                                                  ("duplicate alternative in case:\n" ++ show e0) 
369                                                          ok as (c:cs)
370                           ok ((Alit _ _):_)      _  = fail ("invalid alternative in constructor case:\n" ++ show e0)
371                           ok [Adefault _]        _  = return ()
372                           ok (Adefault _:_)      _  = fail ("misplaced default alternative in case:\n" ++ show e0)
373                           ok []                  _  = return () 
374                       in ok as [c] 
375                    (Alit l _):as -> 
376                       let ok ((Acon _ _ _ _):_) _  = fail ("invalid alternative in literal case:\n" ++ show e0)
377                           ok ((Alit l _):as)    ls = do require (notElem l ls)
378                                                                 ("duplicate alternative in case:\n" ++ show e0) 
379                                                         ok as (l:ls)
380                           ok [Adefault _]       _  = return ()
381                           ok (Adefault _:_)     _  = fail ("misplaced default alternative in case:\n" ++ show e0)
382                           ok []                 _  = fail ("missing default alternative in literal case:\n" ++ show e0)
383                       in ok as [l] 
384                    [Adefault _] -> return ()
385                    _ -> fail ("no alternatives in case:\n" ++ show e0) 
386                  l_venv' <- extendVenv l_venv (v,t)
387                  t:ts <- mapM (checkAlt (tcenv,tsenv,tvenv,cenv,e_venv,l_venv') t) alts  
388                  bs <- mapM (equalTy tsenv t) ts
389                  require (and bs)
390                          ("alternative types don't match in:\n" ++ show e0 ++ "\n" ++
391                           "types: " ++ show (t:ts))
392                  checkTy (tcenv,tvenv) resultTy
393                  require (t == resultTy) ("case alternative type doesn't " ++
394                    " match case return type in:\n" ++ show e0 ++ "\n" ++
395                    "alt type: " ++ show t ++ " return type: " ++ show resultTy)
396                  return t
397             c@(Cast e t) -> 
398               do eTy <- ch e 
399                  (fromTy, toTy) <- checkTyCo (tcenv,tvenv) t
400                  require (eTy == fromTy) ("Type mismatch in cast: c = "
401                              ++ show c ++ "\nand eTy = " ++ show eTy
402                              ++ "\n and " ++ show fromTy)
403                  return toTy
404             Note _ e -> 
405               ch e
406             External _ t -> 
407               do checkTy (tcenv,eempty) t {- external types must be closed -}
408                  return t
409     
410 checkAlt :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Ty -> Alt -> CheckResult Ty 
411 checkAlt (env@(tcenv,tsenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
412       where 
413         ch a0 = 
414           case a0 of 
415             Acon qc etbs vbs e ->
416               do let uts = f t0                                      
417                        where f (Tapp t0 t) = f t0 ++ [t]
418                              f _ = []
419                  ct <- qlookupM cenv_ cenv eempty qc
420                  let (tbs,ct_args0,ct_res0) = splitTy ct
421                  {- get universals -}
422                  let (utbs,etbs') = splitAt (length uts) tbs
423                  let utvs = map fst utbs
424                  {- check existentials -}
425                  let (etvs,eks) = unzip etbs
426                  let (etvs',eks') = unzip etbs'
427                  require (all (uncurry eqKind)
428                             (zip eks eks'))  
429                          ("existential kinds don't match in:\n" ++ show a0 ++ "\n" ++
430                           "kinds declared in data constructor: " ++ show eks ++
431                           "kinds declared in case alternative: " ++ show eks') 
432                  tvenv' <- foldM extendTvenv tvenv etbs
433                  {- check term variables -}
434                  let vts = map snd vbs
435                  mapM_ (\vt -> require ((not . isUtupleTy) vt)
436                                        ("pattern-bound unboxed tuple in:\n" ++ show a0 ++ "\n" ++
437                                         "pattern type: " ++ show vt)) vts
438                  vks <- mapM (checkTy (tcenv,tvenv')) vts
439                  mapM_ (\vk -> require (baseKind vk)
440                                        ("higher-order kind in:\n" ++ show a0 ++ "\n" ++
441                                         "kind: " ++ show vk)) vks 
442                  let (ct_res:ct_args) = map (substl (utvs++etvs') (uts++(map Tvar etvs))) (ct_res0:ct_args0)
443                  zipWithM_ 
444                     (\ct_arg vt -> 
445                         requireM (equalTy tsenv ct_arg vt)
446                                  ("pattern variable type doesn't match constructor argument type in:\n" ++ show a0 ++ "\n" ++
447                                   "pattern variable type: " ++ show ct_arg ++ "\n" ++
448                                   "constructor argument type: " ++ show vt)) ct_args vts
449                  requireM (equalTy tsenv ct_res t0)
450                           ("pattern constructor type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
451                            "pattern constructor type: " ++ show ct_res ++ "\n" ++
452                            "scrutinee type: " ++ show t0) 
453                  l_venv' <- foldM extendVenv l_venv vbs
454                  t <- checkExp (tcenv,tsenv,tvenv',cenv,e_venv,l_venv') e 
455                  checkTy (tcenv,tvenv) t  {- check that existentials don't escape in result type -}
456                  return t
457             Alit l e ->
458               do t <- checkLit l
459                  requireM (equalTy tsenv t t0)
460                          ("pattern type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
461                           "pattern type: " ++ show t ++ "\n" ++
462                           "scrutinee type: " ++ show t0) 
463                  checkExp env e
464             Adefault e ->
465               checkExp env e
466     
467 checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
468 checkTy es@(tcenv,tvenv) = ch
469      where
470        ch (Tvar tv) = lookupM tvenv tv
471        ch (Tcon qtc) = do
472          kOrC <- qlookupM tcenv_ tcenv eempty qtc
473          case kOrC of
474             Kind k -> return k
475             Coercion (DefinedCoercion [] (t1,t2)) -> return $ Keq t1 t2
476             Coercion _ -> fail ("Unsaturated coercion app: " ++ show qtc)
477        ch (t@(Tapp t1 t2)) = 
478              case splitTyConApp_maybe t of
479                Just (tc, tys) -> do
480                  tcK <- qlookupM tcenv_ tcenv eempty tc 
481                  case tcK of
482                    Kind _ -> checkTapp t1 t2
483                    Coercion (DefinedCoercion tbs (from,to)) -> do
484                      -- makes sure coercion is fully applied
485                      require (length tys == length tbs) $
486                         ("Arity mismatch in coercion app: " ++ show t)
487                      let (tvs, tks) = unzip tbs
488                      argKs <- mapM (checkTy es) tys
489                      let kPairs = zip argKs tks
490                      let kindsOk = all (uncurry eqKind) kPairs
491                      if not kindsOk &&
492                         all (uncurry subKindOf) kPairs
493                        -- GHC occasionally generates code like:
494                        -- :Co:TTypeable2 (->)
495                        -- where in this case, :Co:TTypeable2 expects an
496                        -- argument of kind (*->(*->*)) and (->) has kind
497                        -- (?->(?->*)). I'm not sure whether or not it's
498                        -- sound to apply an arbitrary coercion to an
499                        -- argument that's a subkind of what it expects.
500                        then warn $ "Applying coercion " ++ show tc ++
501                                " to arguments of kind " ++ show argKs
502                                ++ " when it expects: " ++ show tks
503                        else require kindsOk
504                         ("Kind mismatch in coercion app: " ++ show tks 
505                          ++ " and " ++ show argKs ++ " t = " ++ show t)
506                      return $ Keq (substl tvs tys from) (substl tvs tys to)
507                Nothing -> checkTapp t1 t2
508             where checkTapp t1 t2 = do 
509                     k1 <- ch t1
510                     k2 <- ch t2
511                     case k1 of
512                       Karrow k11 k12 -> do
513                          require (k2 `subKindOf` k11) kindError
514                          return k12
515                             where kindError = 
516                                     "kinds don't match in type application: "
517                                     ++ show t ++ "\n" ++
518                                     "operator kind: " ++ show k11 ++ "\n" ++
519                                     "operand kind: " ++ show k2 
520                       _ -> fail ("applied type has non-arrow kind: " ++ show t)
521                            
522        ch (Tforall tb t) = 
523             do tvenv' <- extendTvenv tvenv tb 
524                checkTy (tcenv,tvenv') t
525        ch (TransCoercion t1 t2) = do
526             (ty1,ty2) <- checkTyCo es t1
527             (ty3,ty4) <- checkTyCo es t2
528             require (ty2 == ty3) ("Types don't match in trans. coercion: " ++
529                         show ty2 ++ " and " ++ show ty3)
530             return $ Keq ty1 ty4
531        ch (SymCoercion t1) = do
532             (ty1,ty2) <- checkTyCo es t1
533             return $ Keq ty2 ty1
534        ch (UnsafeCoercion t1 t2) = do
535             checkTy es t1
536             checkTy es t2
537             return $ Keq t1 t2
538        ch (LeftCoercion t1) = do
539             k <- checkTyCo es t1
540             case k of
541               ((Tapp u _), (Tapp w _)) -> return $ Keq u w
542               _ -> fail ("Bad coercion kind in operand of left: " ++ show k)
543        ch (RightCoercion t1) = do
544             k <- checkTyCo es t1
545             case k of
546               ((Tapp _ v), (Tapp _ x)) -> return $ Keq v x
547               _ -> fail ("Bad coercion kind in operand of left: " ++ show k)
548        ch (InstCoercion ty arg) = do
549             forallK <- checkTyCo es ty
550             case forallK of
551               ((Tforall (v1,k1) b1), (Tforall (v2,k2) b2)) -> do
552                  require (k1 `eqKind` k2) ("Kind mismatch in argument of inst: "
553                                             ++ show ty)
554                  argK <- checkTy es arg
555                  require (argK `eqKind` k1) ("Kind mismatch in type being "
556                            ++ "instantiated: " ++ show arg)
557                  let newLhs = substl [v1] [arg] b1
558                  let newRhs = substl [v2] [arg] b2
559                  return $ Keq newLhs newRhs
560               _ -> fail ("Non-forall-ty in argument to inst: " ++ show ty)
561
562 checkTyCo :: (Tcenv, Tvenv) -> Ty -> CheckResult (Ty, Ty)
563 checkTyCo es@(tcenv,_) t@(Tapp t1 t2) = 
564   (case splitTyConApp_maybe t of
565     Just (tc, tys) -> do
566        tcK <- qlookupM tcenv_ tcenv eempty tc
567        case tcK of
568  -- todo: avoid duplicating this code
569  -- blah, this almost calls for a different syntactic form
570  -- (for a defined-coercion app): (TCoercionApp Tcon [Ty])
571          Coercion (DefinedCoercion tbs (from, to)) -> do
572            require (length tys == length tbs) $ 
573             ("Arity mismatch in coercion app: " ++ show t)
574            let (tvs, tks) = unzip tbs
575            argKs <- mapM (checkTy es) tys
576            let kPairs = zip argKs tks
577            let kindsOk = all (uncurry eqKind) kPairs
578            if not kindsOk &&
579                         all (uncurry subKindOf) kPairs
580                        -- see comments in checkTy about this
581                        then warn $ "Applying coercion " ++ show tc ++
582                                " to arguments of kind " ++ show argKs
583                                ++ " when it expects: " ++ show tks
584                        else require kindsOk
585                         ("Kind mismatch in coercion app: " ++ show tks 
586                          ++ " and " ++ show argKs ++ " t = " ++ show t)
587            return (substl tvs tys from, substl tvs tys to)
588          _ -> checkTapp t1 t2
589     _ -> checkTapp t1 t2)
590        where checkTapp t1 t2 = do
591                (lhsRator, rhsRator) <- checkTyCo es t1
592                (lhs, rhs) <- checkTyCo es t2
593                -- Comp rule from paper
594                checkTy es (Tapp lhsRator lhs)
595                checkTy es (Tapp rhsRator rhs)
596                return (Tapp lhsRator lhs, Tapp rhsRator rhs)
597 checkTyCo (tcenv, tvenv) (Tforall tb t) = do
598   tvenv' <- extendTvenv tvenv tb
599   (t1,t2) <- checkTyCo (tcenv, tvenv') t
600   return (Tforall tb t1, Tforall tb t2)
601 checkTyCo es t = do
602   k <- checkTy es t
603   case k of
604     Keq t1 t2 -> return (t1, t2)
605     -- otherwise, expand by the "refl" rule
606     _          -> return (t, t)
607
608 {- Type equality modulo newtype synonyms. -}
609 equalTy :: Tsenv -> Ty -> Ty -> CheckResult Bool
610 equalTy tsenv t1 t2 = 
611             do t1' <- expand t1
612                t2' <- expand t2
613                return (t1' == t2')
614       where expand (Tvar v) = return (Tvar v)
615             expand (Tcon qtc) = return (Tcon qtc)
616             expand (Tapp t1 t2) = 
617               do t2' <- expand t2
618                  expapp t1 [t2']
619             expand (Tforall tb t) = 
620               do t' <- expand t
621                  return (Tforall tb t')
622             expand (TransCoercion t1 t2) = do
623                exp1 <- expand t1
624                exp2 <- expand t2
625                return $ TransCoercion exp1 exp2
626             expand (SymCoercion t) = do
627                exp <- expand t
628                return $ SymCoercion exp
629             expand (UnsafeCoercion t1 t2) = do
630                exp1 <- expand t1
631                exp2 <- expand t2
632                return $ UnsafeCoercion exp1 exp2
633             expand (LeftCoercion t1) = do
634                exp1 <- expand t1
635                return $ LeftCoercion exp1
636             expand (RightCoercion t1) = do
637                exp1 <- expand t1
638                return $ RightCoercion exp1
639             expand (InstCoercion t1 t2) = do
640                exp1 <- expand t1
641                exp2 <- expand t2
642                return $ InstCoercion exp1 exp2
643             expapp (t@(Tcon (m,tc))) ts = 
644               do env <- mlookupM tsenv_ tsenv eempty m
645                  case elookup env tc of 
646                     Just (formals,rhs) | (length formals) == (length ts) ->
647                          return (substl formals ts rhs)
648                     _ -> return (foldl Tapp t ts)
649             expapp (Tapp t1 t2) ts = 
650               do t2' <- expand t2
651                  expapp t1 (t2':ts)
652             expapp t ts = 
653               do t' <- expand t
654                  return (foldl Tapp t' ts)
655
656 mlookupM :: (Eq a, Show a) => (Envs -> Env a b) -> Env a b -> Env a b -> Mname
657           -> CheckResult (Env a b)
658 mlookupM _ _ local_env    Nothing            = return local_env
659 mlookupM selector external_env local_env (Just m) = do
660   mn <- getMname
661   if m == mn
662      then return external_env
663      else do
664        globalEnv <- getGlobalEnv
665        case elookup globalEnv m of
666          Just env' -> return (selector env')
667          Nothing -> fail ("Check: undefined module name: "
668                       ++ show m ++ show (edomain local_env))
669
670 qlookupM :: (Ord a, Show a,Show b) => (Envs -> Env a b) -> Env a b -> Env a b
671                   -> Qual a -> CheckResult b
672 qlookupM selector external_env local_env (m,k) =
673       do env <- mlookupM selector external_env local_env m
674          lookupM env k
675
676 checkLit :: Lit -> CheckResult Ty
677 checkLit (Literal lit t) =
678   case lit of
679     Lint _ -> 
680           do require (t `elem` intLitTypes)
681                      ("invalid int literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
682              return t
683     Lrational _ ->
684           do require (t `elem` ratLitTypes)
685                      ("invalid rational literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
686              return t
687     Lchar _ -> 
688           do require (t `elem` charLitTypes)
689                      ("invalid char literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
690              return t   
691     Lstring _ ->
692           do require (t `elem` stringLitTypes)
693                      ("invalid string literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
694              return t
695
696 {- Utilities -}
697
698 {- Split off tbs, arguments and result of a (possibly abstracted)  arrow type -}
699 splitTy :: Ty -> ([Tbind],[Ty],Ty)
700 splitTy (Tforall tb t) = (tb:tbs,ts,tr) 
701                 where (tbs,ts,tr) = splitTy t
702 splitTy (Tapp(Tapp(Tcon tc) t0) t) | tc == tcArrow = (tbs,t0:ts,tr)
703                 where (tbs,ts,tr) = splitTy t
704 splitTy t = ([],[],t)
705
706
707 {- Simultaneous substitution on types for type variables,
708    renaming as neceessary to avoid capture.
709    No checks for correct kindedness. -}
710 substl :: [Tvar] -> [Ty] -> Ty -> Ty
711 substl tvs ts t = f (zip tvs ts) t
712   where 
713     f env t0 =
714      case t0 of
715        Tcon _ -> t0
716        Tvar v -> case lookup v env of
717                    Just t1 -> t1
718                    Nothing -> t0
719        Tapp t1 t2 -> Tapp (f env t1) (f env t2)
720        Tforall (t,k) t1 -> 
721          if t `elem` free then
722            Tforall (t',k) (f ((t,Tvar t'):env) t1)
723          else 
724            Tforall (t,k) (f (filter ((/=t).fst) env) t1)
725        TransCoercion t1 t2 -> TransCoercion (f env t1) (f env t2)
726        SymCoercion t1 -> SymCoercion (f env t1)
727        UnsafeCoercion t1 t2 -> UnsafeCoercion (f env t1) (f env t2)
728        LeftCoercion t1 -> LeftCoercion (f env t1)
729        RightCoercion t1 -> RightCoercion (f env t1)
730        InstCoercion t1 t2 -> InstCoercion (f env t1) (f env t2)
731      where free = foldr union [] (map (freeTvars.snd) env)
732            t' = freshTvar free 
733    
734 {- Return free tvars in a type -}
735 freeTvars :: Ty -> [Tvar]
736 freeTvars (Tcon _) = []
737 freeTvars (Tvar v) = [v]
738 freeTvars (Tapp t1 t2) = freeTvars t1 `union` freeTvars t2
739 freeTvars (Tforall (t,_) t1) = delete t (freeTvars t1) 
740 freeTvars (TransCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
741 freeTvars (SymCoercion t) = freeTvars t
742 freeTvars (UnsafeCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
743 freeTvars (LeftCoercion t) = freeTvars t
744 freeTvars (RightCoercion t) = freeTvars t
745 freeTvars (InstCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
746
747 {- Return any tvar *not* in the argument list. -}
748 freshTvar :: [Tvar] -> Tvar
749 freshTvar tvs = maximum ("":tvs) ++ "x" -- one simple way!
750
751 primCoercionError :: Show a => a -> b
752 primCoercionError s = error $ "Bad coercion application: " ++ show s
753
754 -- todo
755 reportError :: Show a => a -> String -> b
756 reportError e s = error $ ("Core type error: checkExpr failed with "
757                    ++ s ++ " and " ++ show e)
758
759 warn :: String -> CheckResult ()
760 warn s = (unsafePerformIO $ putStrLn ("WARNING: " ++ s)) `seq` return ()