Fixed performance bug in ext-core preprocessor
[ghc-hetmet.git] / utils / ext-core / Language / Core / Check.hs
1 {-# OPTIONS -Wall -fno-warn-name-shadowing #-}
2 module Language.Core.Check(
3   checkModule, envsModule,
4   checkExpr, checkType, 
5   primCoercionError, 
6   Menv, Venv, Tvenv, Envs(..),
7   CheckRes(..), splitTy, substl,
8   mkTypeEnvsNoChecking) where
9
10 import Language.Core.Core
11 import Language.Core.Printer()
12 import Language.Core.PrimEnv
13 import Language.Core.Env
14 import Language.Core.Environments
15
16 import Control.Monad.Reader
17 import Data.List
18 import Data.Maybe
19
20 {- Checking is done in a simple error monad.  In addition to
21    allowing errors to be captured, this makes it easy to guarantee
22    that checking itself has been completed for an entire module. -}
23
24 {- We use the Reader monad transformer in order to thread the 
25    top-level module name throughout the computation simply.
26    This is so that checkExp can also be an entry point (we call it
27    from Prep.) -}
28 data CheckRes a = OkC a | FailC String
29 type CheckResult a = ReaderT (AnMname, Menv) CheckRes a
30 getMname :: CheckResult AnMname
31 getMname     = ask >>= (return . fst)
32 getGlobalEnv :: CheckResult Menv
33 getGlobalEnv = ask >>= (return . snd)
34
35 instance Monad CheckRes where
36   OkC a >>= k = k a
37   FailC s >>= _ = fail s
38   return = OkC
39   fail = FailC
40
41 require :: Bool -> String -> CheckResult ()
42 require False s = fail s
43 require True  _ = return ()
44
45
46 extendM :: (Ord a, Show a) => EnvType -> Env a b -> (a,b) -> CheckResult (Env a b)
47 extendM envType env (k,d) = 
48    case elookup env k of
49      Just _ | envType == NotTv -> fail ("multiply-defined identifier: " 
50                                       ++ show k)
51      _ -> return (eextend env (k,d))
52
53 extendVenv :: (Ord a, Show a) => Env a b -> (a,b) -> CheckResult (Env a b)
54 extendVenv = extendM NotTv
55
56 extendTvenv :: (Ord a, Show a) => Env a b -> (a,b) -> CheckResult (Env a b)
57 extendTvenv = extendM Tv
58
59 lookupM :: (Ord a, Show a) => Env a b -> a -> CheckResult b
60 lookupM env k =   
61    case elookup env k of
62      Just v -> return v
63      Nothing -> fail ("undefined identifier: " ++ show k ++ " e = " ++ show (edomain env))
64             
65 {- Main entry point. -}
66 checkModule :: Menv -> Module -> CheckRes Menv
67 checkModule globalEnv (Module mn tdefs vdefgs) = 
68   runReaderT 
69     (do (tcenv, cenv) <- mkTypeEnvs tdefs
70         (e_venv,_) <- foldM (checkVdefg True (tcenv,eempty,cenv))
71                               (eempty,eempty) 
72                               vdefgs
73         return (eextend globalEnv 
74             (mn,Envs{tcenv_=tcenv,cenv_=cenv,venv_=e_venv})))
75          -- avoid name shadowing
76     (mn, eremove globalEnv mn)
77
78 -- Like checkModule, but doesn't typecheck the code, instead just
79 -- returning declared types for top-level defns.
80 -- This is necessary in order to handle circular dependencies, but it's sort
81 -- of unpleasant.
82 envsModule :: Menv -> Module -> Menv
83 envsModule globalEnv (Module mn tdefs vdefgs) = 
84    let (tcenv, cenv) = mkTypeEnvsNoChecking tdefs
85        e_venv               = foldr vdefgTypes eempty vdefgs in
86      eextend globalEnv (mn, 
87              (Envs{tcenv_=tcenv,cenv_=cenv,venv_=e_venv}))
88         where vdefgTypes :: Vdefg -> Venv -> Venv
89               vdefgTypes (Nonrec (Vdef (v,t,_))) e =
90                              add [(v,t)] e
91               vdefgTypes (Rec vds) e = 
92                              add (map (\ (Vdef (v,t,_)) -> (v,t)) vds) e
93               add :: [(Qual Var,Ty)] -> Venv -> Venv
94               add pairs e = foldr addOne e pairs
95               addOne :: (Qual Var, Ty) -> Venv -> Venv
96               addOne ((Nothing,_),_) e = e
97               addOne ((Just _,v),t) e  = eextend e (v,t)
98
99 checkTdef0 :: Tcenv -> Tdef -> CheckResult Tcenv
100 checkTdef0 tcenv tdef = ch tdef
101       where 
102         ch (Data (m,c) tbs _) = 
103             do mn <- getMname
104                requireModulesEq m mn "data type declaration" tdef False
105                extendM NotTv tcenv (c, Kind k)
106             where k = foldr Karrow Klifted (map snd tbs)
107         ch (Newtype (m,c) coVar tbs rhs) = 
108             do mn <- getMname
109                requireModulesEq m mn "newtype declaration" tdef False
110                tcenv' <- extendM NotTv tcenv (c, Kind k)
111                -- add newtype axiom to env
112                tcenv'' <- envPlusNewtype tcenv' (m,c) coVar tbs rhs
113                return tcenv''
114             where k = foldr Karrow Klifted (map snd tbs)
115
116 processTdef0NoChecking :: Tcenv -> Tdef -> Tcenv
117 processTdef0NoChecking tcenv tdef = ch tdef
118       where 
119         ch (Data (_,c) tbs _) = eextend tcenv (c, Kind k)
120             where k = foldr Karrow Klifted (map snd tbs)
121         ch (Newtype tc@(_,c) coercion tbs rhs) = 
122             let tcenv' = eextend tcenv (c, Kind k) in
123                 -- add newtype axiom to env
124                 eextend tcenv'
125                   (snd coercion, Coercion $ DefinedCoercion tbs
126                     (foldl Tapp (Tcon tc) (map Tvar (fst (unzip tbs))), rhs))
127             where k = foldr Karrow Klifted (map snd tbs)
128
129 envPlusNewtype :: Tcenv -> Qual Tcon -> Qual Tcon -> [Tbind] -> Ty
130   -> CheckResult Tcenv
131 envPlusNewtype tcenv tyCon coVar tbs rep = extendM NotTv tcenv
132                   (snd coVar, Coercion $ DefinedCoercion tbs
133                             (foldl Tapp (Tcon tyCon) 
134                                        (map Tvar (fst (unzip tbs))),
135                                        rep))
136     
137 checkTdef :: Tcenv -> Cenv -> Tdef -> CheckResult Cenv
138 checkTdef tcenv cenv = ch
139        where 
140          ch (Data (_,c) utbs cdefs) = 
141             do cbinds <- mapM checkCdef cdefs
142                foldM (extendM NotTv) cenv cbinds
143             where checkCdef (cdef@(Constr (m,dcon) etbs ts)) =
144                     do mn <- getMname
145                        requireModulesEq m mn "constructor declaration" cdef 
146                          False 
147                        tvenv <- foldM (extendM Tv) eempty tbs 
148                        ks <- mapM (checkTy (tcenv,tvenv)) ts
149                        mapM_ (\k -> require (baseKind k)
150                                             ("higher-order kind in:\n" ++ show cdef ++ "\n" ++
151                                              "kind: " ++ show k) ) ks
152                        return (dcon,t mn) 
153                     where tbs = utbs ++ etbs
154                           t mn = foldr Tforall 
155                                   (foldr tArrow
156                                           (foldl Tapp (Tcon (Just mn,c))
157                                                  (map (Tvar . fst) utbs)) ts) tbs
158          ch (tdef@(Newtype tc _ tbs t)) =  
159             do tvenv <- foldM (extendM Tv) eempty tbs
160                kRhs <- checkTy (tcenv,tvenv) t
161                require (kRhs `eqKind` Klifted) ("bad kind:\n" ++ show tdef)
162                kLhs <- checkTy (tcenv,tvenv) 
163                          (foldl Tapp (Tcon tc) (map Tvar (fst (unzip tbs))))
164                require (kLhs `eqKind` kRhs) 
165                   ("Kind mismatch in newtype axiom types: " ++ show tdef 
166                     ++ " kinds: " ++
167                    (show kLhs) ++ " and " ++ (show kRhs))
168                return cenv
169
170 processCdef :: Cenv -> Tdef -> Cenv
171 processCdef cenv = ch
172   where
173     ch (Data (_,c) utbs cdefs) = do 
174        let cbinds = map checkCdef cdefs
175        foldl eextend cenv cbinds
176      where checkCdef (Constr (mn,dcon) etbs ts) =
177              (dcon,t mn) 
178             where tbs = utbs ++ etbs
179                   t mn = foldr Tforall 
180                           (foldr tArrow
181                             (foldl Tapp (Tcon (mn,c))
182                                (map (Tvar . fst) utbs)) ts) tbs
183     ch _ = cenv
184
185 mkTypeEnvs :: [Tdef] -> CheckResult (Tcenv, Cenv)
186 mkTypeEnvs tdefs = do
187   tcenv <- foldM checkTdef0 eempty tdefs
188   cenv <- foldM (checkTdef tcenv) eempty tdefs
189   return (tcenv, cenv)
190
191 mkTypeEnvsNoChecking :: [Tdef] -> (Tcenv, Cenv)
192 mkTypeEnvsNoChecking tdefs = 
193   let tcenv = foldl processTdef0NoChecking eempty tdefs
194       cenv  = foldl processCdef eempty tdefs in
195     (tcenv, cenv)
196
197 requireModulesEq :: Show a => Mname -> AnMname -> String -> a 
198                           -> Bool -> CheckResult ()
199 requireModulesEq (Just mn) m msg t _      = require (mn == m) (mkErrMsg msg t)
200 requireModulesEq Nothing _ msg t emptyOk  = require emptyOk (mkErrMsg msg t)
201
202 mkErrMsg :: Show a => String -> a -> String
203 mkErrMsg msg t = "wrong module name in " ++ msg ++ ":\n" ++ show t    
204
205 checkVdefg :: Bool -> (Tcenv,Tvenv,Cenv) -> (Venv,Venv)
206                -> Vdefg -> CheckResult (Venv,Venv)
207 checkVdefg top_level (tcenv,tvenv,cenv) (e_venv,l_venv) vdefg = do
208       mn <- getMname
209       case vdefg of
210         Rec vdefs ->
211             do (e_venv', l_venv') <- makeEnv mn vdefs
212                let env' = (tcenv,tvenv,cenv,e_venv',l_venv')
213                mapM_ (checkVdef (\ vdef k -> require (k `eqKind` Klifted) 
214                         ("unlifted kind in:\n" ++ show vdef)) env') 
215                      vdefs
216                return (e_venv', l_venv')
217         Nonrec vdef ->
218             do let env' = (tcenv, tvenv, cenv, e_venv, l_venv)
219                checkVdef (\ vdef k -> do
220                      require (not (k `eqKind` Kopen)) ("open kind in:\n" ++ show vdef)
221                      require ((not top_level) || (not (k `eqKind` Kunlifted))) 
222                        ("top-level unlifted kind in:\n" ++ show vdef)) env' vdef
223                makeEnv mn [vdef]
224
225   where makeEnv mn vdefs = do
226              ev <- foldM extendVenv e_venv e_vts
227              lv <- foldM extendVenv l_venv l_vts
228              return (ev, lv)
229            where e_vts = [ (v,t) | Vdef ((Just m,v),t,_) <- vdefs,
230                                      not (vdefIsMainWrapper mn (Just m))]
231                  l_vts = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
232         checkVdef checkKind env (vdef@(Vdef ((m,_),t,e))) = do
233           mn <- getMname
234           let isZcMain = vdefIsMainWrapper mn m
235           unless isZcMain $
236              requireModulesEq m mn "value definition" vdef True
237           k <- checkTy (tcenv,tvenv) t
238           checkKind vdef k
239           t' <- checkExp env e
240           require (t == t')
241                    ("declared type doesn't match expression type in:\n"  
242                     ++ show vdef ++ "\n" ++  
243                     "declared type: " ++ show t ++ "\n" ++
244                     "expression type: " ++ show t')
245     
246 vdefIsMainWrapper :: AnMname -> Mname -> Bool
247 vdefIsMainWrapper enclosing defining = 
248    enclosing == mainMname && defining == wrapperMainAnMname
249
250 checkExpr :: AnMname -> Menv -> Tcenv -> Cenv -> Venv -> Tvenv 
251                -> Exp -> Ty
252 checkExpr mn menv _tcenv _cenv venv tvenv e = case runReaderT (do
253   --(tcenv, cenv) <- mkTypeEnvs tdefs
254   -- Since the preprocessor calls checkExpr after code has been
255   -- typechecked, we expect to find the external env in the Menv.
256   case (elookup menv mn) of
257      Just thisEnv ->
258        checkExp ({-tcenv-}tcenv_ thisEnv, tvenv, {-cenv-}cenv_ thisEnv, (venv_ thisEnv), venv) e
259      Nothing -> reportError e ("checkExpr: Environment for " ++ 
260                   show mn ++ " not found")) (mn,menv) of
261          OkC t -> t
262          FailC s -> reportError e s
263
264 checkType :: AnMname -> Menv -> Tcenv -> Tvenv -> Ty -> Kind
265 checkType mn menv _tcenv tvenv t = 
266  case runReaderT (checkTy (tcenv_ (fromMaybe (error "checkType") (elookup menv mn)), tvenv) t) (mn, menv) of
267       OkC k -> k
268       FailC s -> reportError tvenv (s ++ "\n " ++ show menv ++ "\n mname =" ++ show mn)
269
270 checkExp :: (Tcenv,Tvenv,Cenv,Venv,Venv) -> Exp -> CheckResult Ty
271 checkExp (tcenv,tvenv,cenv,e_venv,l_venv) = ch
272       where 
273         ch e0 =
274           case e0 of
275             Var qv -> 
276               qlookupM venv_ e_venv l_venv qv
277             Dcon qc ->
278               qlookupM cenv_ cenv eempty qc
279             Lit l -> 
280               checkLit l
281             Appt e t -> 
282               do t' <- ch e
283                  k' <- checkTy (tcenv,tvenv) t
284                  case t' of
285                    Tforall (tv,k) t0 ->
286                      do require (k' `subKindOf` k) 
287                                 ("kind doesn't match at type application in:\n" ++ show e0 ++ "\n" ++
288                                  "operator kind: " ++ show k ++ "\n" ++
289                                  "operand kind: " ++ show k') 
290                         return (substl [tv] [t] t0)
291                    _ -> fail ("bad operator type in type application:\n" ++ show e0 ++ "\n" ++
292                                "operator type: " ++ show t')
293             App e1 e2 -> 
294               do t1 <- ch e1
295                  t2 <- ch e2
296                  case t1 of
297                    Tapp(Tapp(Tcon tc) t') t0 | tc == tcArrow ->
298                         do require (t2 == t')
299                                     ("type doesn't match at application in:\n" ++ show e0 ++ "\n" ++ 
300                                      "operator type: " ++ show t' ++ "\n" ++ 
301                                      "operand type: " ++ show t2) 
302                            return t0
303                    _ -> fail ("bad operator type at application in:\n" ++ show e0 ++ "\n" ++
304                                "operator type: " ++ show t1)
305             Lam (Tb tb) e ->
306               do tvenv' <- extendTvenv tvenv tb 
307                  t <- checkExp (tcenv,tvenv',cenv,e_venv,l_venv) e 
308                  return (Tforall tb t)
309             Lam (Vb (vb@(_,vt))) e ->
310               do k <- checkTy (tcenv,tvenv) vt
311                  require (baseKind k)   
312                          ("higher-order kind in:\n" ++ show e0 ++ "\n" ++
313                           "kind: " ++ show k) 
314                  l_venv' <- extendVenv l_venv vb
315                  t <- checkExp (tcenv,tvenv,cenv,e_venv,l_venv') e
316                  require (not (isUtupleTy vt)) ("lambda-bound unboxed tuple in:\n" ++ show e0) 
317                  return (tArrow vt t)
318             Let vdefg e ->
319               do (e_venv',l_venv') <- checkVdefg False (tcenv,tvenv,cenv)
320                                         (e_venv,l_venv) vdefg
321                  checkExp (tcenv,tvenv,cenv,e_venv',l_venv') e
322             Case e (v,t) resultTy alts ->
323               do t' <- ch e 
324                  checkTy (tcenv,tvenv) t
325                  require (t == t')
326                           ("scrutinee declared type doesn't match expression type in:\n" ++ show e0 ++ "\n" ++
327                            "declared type: " ++ show t ++ "\n" ++
328                            "expression type: " ++ show t') 
329                  case (reverse alts) of
330                    (Acon c _ _ _):as ->
331                       let ok ((Acon c _ _ _):as) cs = do require (notElem c cs)
332                                                                  ("duplicate alternative in case:\n" ++ show e0) 
333                                                          ok as (c:cs)
334                           ok ((Alit _ _):_)      _  = fail ("invalid alternative in constructor case:\n" ++ show e0)
335                           ok [Adefault _]        _  = return ()
336                           ok (Adefault _:_)      _  = fail ("misplaced default alternative in case:\n" ++ show e0)
337                           ok []                  _  = return () 
338                       in ok as [c] 
339                    (Alit l _):as -> 
340                       let ok ((Acon _ _ _ _):_) _  = fail ("invalid alternative in literal case:\n" ++ show e0)
341                           ok ((Alit l _):as)    ls = do require (notElem l ls)
342                                                                 ("duplicate alternative in case:\n" ++ show e0) 
343                                                         ok as (l:ls)
344                           ok [Adefault _]       _  = return ()
345                           ok (Adefault _:_)     _  = fail ("misplaced default alternative in case:\n" ++ show e0)
346                           ok []                 _  = fail ("missing default alternative in literal case:\n" ++ show e0)
347                       in ok as [l] 
348                    [Adefault _] -> return ()
349                    _ -> fail ("no alternatives in case:\n" ++ show e0) 
350                  l_venv' <- extendVenv l_venv (v,t)
351                  t:ts <- mapM (checkAlt (tcenv,tvenv,cenv,e_venv,l_venv') t) alts
352                  require (all (== t) ts)
353                          ("alternative types don't match in:\n" ++ show e0 ++ "\n" ++
354                           "types: " ++ show (t:ts))
355                  checkTy (tcenv,tvenv) resultTy
356                  require (t == resultTy) ("case alternative type doesn't " ++
357                    " match case return type in:\n" ++ show e0 ++ "\n" ++
358                    "alt type: " ++ show t ++ " return type: " ++ show resultTy)
359                  return t
360             c@(Cast e t) -> 
361               do eTy <- ch e 
362                  (fromTy, toTy) <- checkTyCo (tcenv,tvenv) t
363                  require (eTy == fromTy) ("Type mismatch in cast: c = "
364                              ++ show c ++ "\nand eTy = " ++ show eTy
365                              ++ "\n and " ++ show fromTy)
366                  return toTy
367             Note _ e -> 
368               ch e
369             External _ t -> 
370               do checkTy (tcenv,eempty) t {- external types must be closed -}
371                  return t
372     
373 checkAlt :: (Tcenv,Tvenv,Cenv,Venv,Venv) -> Ty -> Alt -> CheckResult Ty
374 checkAlt (env@(tcenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
375       where 
376         ch a0 = 
377           case a0 of 
378             Acon qc etbs vbs e ->
379               do let uts = f t0                                      
380                        where f (Tapp t0 t) = f t0 ++ [t]
381                              f _ = []
382                  ct <- qlookupM cenv_ cenv eempty qc
383                  let (tbs,ct_args0,ct_res0) = splitTy ct
384                  {- get universals -}
385                  let (utbs,etbs') = splitAt (length uts) tbs
386                  let utvs = map fst utbs
387                  {- check existentials -}
388                  let (etvs,eks) = unzip etbs
389                  let (etvs',eks') = unzip etbs'
390                  require (all (uncurry eqKind)
391                             (zip eks eks'))  
392                          ("existential kinds don't match in:\n" ++ show a0 ++ "\n" ++
393                           "kinds declared in data constructor: " ++ show eks ++
394                           "kinds declared in case alternative: " ++ show eks') 
395                  tvenv' <- foldM extendTvenv tvenv etbs
396                  {- check term variables -}
397                  let vts = map snd vbs
398                  mapM_ (\vt -> require ((not . isUtupleTy) vt)
399                                        ("pattern-bound unboxed tuple in:\n" ++ show a0 ++ "\n" ++
400                                         "pattern type: " ++ show vt)) vts
401                  vks <- mapM (checkTy (tcenv,tvenv')) vts
402                  mapM_ (\vk -> require (baseKind vk)
403                                        ("higher-order kind in:\n" ++ show a0 ++ "\n" ++
404                                         "kind: " ++ show vk)) vks 
405                  let (ct_res:ct_args) = map (substl (utvs++etvs') (uts++(map Tvar etvs))) (ct_res0:ct_args0)
406                  zipWithM_ 
407                     (\ct_arg vt -> 
408                         require (ct_arg == vt)
409                                  ("pattern variable type doesn't match constructor argument type in:\n" ++ show a0 ++ "\n" ++
410                                   "pattern variable type: " ++ show ct_arg ++ "\n" ++
411                                   "constructor argument type: " ++ show vt)) ct_args vts
412                  require (ct_res == t0)
413                           ("pattern constructor type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
414                            "pattern constructor type: " ++ show ct_res ++ "\n" ++
415                            "scrutinee type: " ++ show t0) 
416                  l_venv' <- foldM extendVenv l_venv vbs
417                  t <- checkExp (tcenv,tvenv',cenv,e_venv,l_venv') e
418                  checkTy (tcenv,tvenv) t  {- check that existentials don't escape in result type -}
419                  return t
420             Alit l e ->
421               do t <- checkLit l
422                  require (t == t0)
423                          ("pattern type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
424                           "pattern type: " ++ show t ++ "\n" ++
425                           "scrutinee type: " ++ show t0) 
426                  checkExp env e
427             Adefault e ->
428               checkExp env e
429     
430 checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
431 checkTy es@(tcenv,tvenv) = ch
432      where
433        ch (Tvar tv) = lookupM tvenv tv
434        ch (Tcon qtc) = do
435          kOrC <- qlookupM tcenv_ tcenv eempty qtc
436          case kOrC of
437             Kind k -> return k
438             Coercion (DefinedCoercion [] (t1,t2)) -> return $ Keq t1 t2
439             Coercion _ -> fail ("Unsaturated coercion app: " ++ show qtc)
440        ch (t@(Tapp t1 t2)) = 
441              case splitTyConApp_maybe t of
442                Just (tc, tys) -> do
443                  tcK <- qlookupM tcenv_ tcenv eempty tc 
444                  case tcK of
445                    Kind _ -> checkTapp t1 t2
446                    Coercion (DefinedCoercion tbs (from,to)) -> do
447                      -- makes sure coercion is fully applied
448                      require (length tys == length tbs) $
449                         ("Arity mismatch in coercion app: " ++ show t)
450                      let (tvs, tks) = unzip tbs
451                      argKs <- mapM (checkTy es) tys
452                      let kPairs = zip argKs tks
453                          -- Simon says it's okay for these to be
454                          -- subkinds
455                      let kindsOk = all (uncurry subKindOf) kPairs
456                      require kindsOk
457                         ("Kind mismatch in coercion app: " ++ show tks 
458                          ++ " and " ++ show argKs ++ " t = " ++ show t)
459                      return $ Keq (substl tvs tys from) (substl tvs tys to)
460                Nothing -> checkTapp t1 t2
461             where checkTapp t1 t2 = do 
462                     k1 <- ch t1
463                     k2 <- ch t2
464                     case k1 of
465                       Karrow k11 k12 -> do
466                          require (k2 `subKindOf` k11) kindError
467                          return k12
468                             where kindError = 
469                                     "kinds don't match in type application: "
470                                     ++ show t ++ "\n" ++
471                                     "operator kind: " ++ show k11 ++ "\n" ++
472                                     "operand kind: " ++ show k2 
473                       _ -> fail ("applied type has non-arrow kind: " ++ show t)
474                            
475        ch (Tforall tb t) = 
476             do tvenv' <- extendTvenv tvenv tb 
477                checkTy (tcenv,tvenv') t
478        ch (TransCoercion t1 t2) = do
479             (ty1,ty2) <- checkTyCo es t1
480             (ty3,ty4) <- checkTyCo es t2
481             require (ty2 == ty3) ("Types don't match in trans. coercion: " ++
482                         show ty2 ++ " and " ++ show ty3)
483             return $ Keq ty1 ty4
484        ch (SymCoercion t1) = do
485             (ty1,ty2) <- checkTyCo es t1
486             return $ Keq ty2 ty1
487        ch (UnsafeCoercion t1 t2) = do
488             checkTy es t1
489             checkTy es t2
490             return $ Keq t1 t2
491        ch (LeftCoercion t1) = do
492             k <- checkTyCo es t1
493             case k of
494               ((Tapp u _), (Tapp w _)) -> return $ Keq u w
495               _ -> fail ("Bad coercion kind in operand of left: " ++ show k)
496        ch (RightCoercion t1) = do
497             k <- checkTyCo es t1
498             case k of
499               ((Tapp _ v), (Tapp _ x)) -> return $ Keq v x
500               _ -> fail ("Bad coercion kind in operand of left: " ++ show k)
501        ch (InstCoercion ty arg) = do
502             forallK <- checkTyCo es ty
503             case forallK of
504               ((Tforall (v1,k1) b1), (Tforall (v2,k2) b2)) -> do
505                  require (k1 `eqKind` k2) ("Kind mismatch in argument of inst: "
506                                             ++ show ty)
507                  argK <- checkTy es arg
508                  require (argK `eqKind` k1) ("Kind mismatch in type being "
509                            ++ "instantiated: " ++ show arg)
510                  let newLhs = substl [v1] [arg] b1
511                  let newRhs = substl [v2] [arg] b2
512                  return $ Keq newLhs newRhs
513               _ -> fail ("Non-forall-ty in argument to inst: " ++ show ty)
514
515 checkTyCo :: (Tcenv, Tvenv) -> Ty -> CheckResult (Ty, Ty)
516 checkTyCo es@(tcenv,_) t@(Tapp t1 t2) = 
517   (case splitTyConApp_maybe t of
518     Just (tc, tys) -> do
519        tcK <- qlookupM tcenv_ tcenv eempty tc
520        case tcK of
521  -- todo: avoid duplicating this code
522  -- blah, this almost calls for a different syntactic form
523  -- (for a defined-coercion app): (TCoercionApp Tcon [Ty])
524          Coercion (DefinedCoercion tbs (from, to)) -> do
525            require (length tys == length tbs) $ 
526             ("Arity mismatch in coercion app: " ++ show t)
527            let (tvs, tks) = unzip tbs
528            argKs <- mapM (checkTy es) tys
529            let kPairs = zip argKs tks
530            let kindsOk = all (uncurry subKindOf) kPairs
531            require kindsOk
532               ("Kind mismatch in coercion app: " ++ show tks 
533                  ++ " and " ++ show argKs ++ " t = " ++ show t)
534            return (substl tvs tys from, substl tvs tys to)
535          _ -> checkTapp t1 t2
536     _ -> checkTapp t1 t2)
537        where checkTapp t1 t2 = do
538                (lhsRator, rhsRator) <- checkTyCo es t1
539                (lhs, rhs) <- checkTyCo es t2
540                -- Comp rule from paper
541                checkTy es (Tapp lhsRator lhs)
542                checkTy es (Tapp rhsRator rhs)
543                return (Tapp lhsRator lhs, Tapp rhsRator rhs)
544 checkTyCo (tcenv, tvenv) (Tforall tb t) = do
545   tvenv' <- extendTvenv tvenv tb
546   (t1,t2) <- checkTyCo (tcenv, tvenv') t
547   return (Tforall tb t1, Tforall tb t2)
548 checkTyCo es t = do
549   k <- checkTy es t
550   case k of
551     Keq t1 t2 -> return (t1, t2)
552     -- otherwise, expand by the "refl" rule
553     _          -> return (t, t)
554
555 mlookupM :: (Eq a, Show a) => (Envs -> Env a b) -> Env a b -> Env a b -> Mname
556           -> CheckResult (Env a b)
557 mlookupM _ _ local_env    Nothing            = return local_env
558 mlookupM selector external_env local_env (Just m) = do
559   mn <- getMname
560   if m == mn
561      then return external_env
562      else do
563        globalEnv <- getGlobalEnv
564        case elookup globalEnv m of
565          Just env' -> return (selector env')
566          Nothing -> fail ("Check: undefined module name: "
567                       ++ show m ++ show (edomain local_env))
568
569 qlookupM :: (Ord a, Show a,Show b) => (Envs -> Env a b) -> Env a b -> Env a b
570                   -> Qual a -> CheckResult b
571 qlookupM selector external_env local_env (m,k) =
572       do env <- mlookupM selector external_env local_env m
573          lookupM env k
574
575 checkLit :: Lit -> CheckResult Ty
576 checkLit (Literal lit t) =
577   case lit of
578     Lint _ -> 
579           do require (t `elem` intLitTypes)
580                      ("invalid int literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
581              return t
582     Lrational _ ->
583           do require (t `elem` ratLitTypes)
584                      ("invalid rational literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
585              return t
586     Lchar _ -> 
587           do require (t `elem` charLitTypes)
588                      ("invalid char literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
589              return t   
590     Lstring _ ->
591           do require (t `elem` stringLitTypes)
592                      ("invalid string literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)
593              return t
594
595 {- Utilities -}
596
597 {- Split off tbs, arguments and result of a (possibly abstracted)  arrow type -}
598 splitTy :: Ty -> ([Tbind],[Ty],Ty)
599 splitTy (Tforall tb t) = (tb:tbs,ts,tr) 
600                 where (tbs,ts,tr) = splitTy t
601 splitTy (Tapp(Tapp(Tcon tc) t0) t) | tc == tcArrow = (tbs,t0:ts,tr)
602                 where (tbs,ts,tr) = splitTy t
603 splitTy t = ([],[],t)
604
605
606 {- Simultaneous substitution on types for type variables,
607    renaming as neceessary to avoid capture.
608    No checks for correct kindedness. -}
609 substl :: [Tvar] -> [Ty] -> Ty -> Ty
610 substl tvs ts t = f (zip tvs ts) t
611   where 
612     f env t0 =
613      case t0 of
614        Tcon _ -> t0
615        Tvar v -> case lookup v env of
616                    Just t1 -> t1
617                    Nothing -> t0
618        Tapp t1 t2 -> Tapp (f env t1) (f env t2)
619        Tforall (t,k) t1 -> 
620          if t `elem` free then
621            Tforall (t',k) (f ((t,Tvar t'):env) t1)
622          else 
623            Tforall (t,k) (f (filter ((/=t).fst) env) t1)
624        TransCoercion t1 t2 -> TransCoercion (f env t1) (f env t2)
625        SymCoercion t1 -> SymCoercion (f env t1)
626        UnsafeCoercion t1 t2 -> UnsafeCoercion (f env t1) (f env t2)
627        LeftCoercion t1 -> LeftCoercion (f env t1)
628        RightCoercion t1 -> RightCoercion (f env t1)
629        InstCoercion t1 t2 -> InstCoercion (f env t1) (f env t2)
630      where free = foldr union [] (map (freeTvars.snd) env)
631            t' = freshTvar free 
632    
633 {- Return free tvars in a type -}
634 freeTvars :: Ty -> [Tvar]
635 freeTvars (Tcon _) = []
636 freeTvars (Tvar v) = [v]
637 freeTvars (Tapp t1 t2) = freeTvars t1 `union` freeTvars t2
638 freeTvars (Tforall (t,_) t1) = delete t (freeTvars t1) 
639 freeTvars (TransCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
640 freeTvars (SymCoercion t) = freeTvars t
641 freeTvars (UnsafeCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
642 freeTvars (LeftCoercion t) = freeTvars t
643 freeTvars (RightCoercion t) = freeTvars t
644 freeTvars (InstCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
645
646 {- Return any tvar *not* in the argument list. -}
647 freshTvar :: [Tvar] -> Tvar
648 freshTvar tvs = maximum ("":tvs) ++ "x" -- one simple way!
649
650 primCoercionError :: Show a => a -> b
651 primCoercionError s = error $ "Bad coercion application: " ++ show s
652
653 -- todo
654 reportError :: Show a => a -> String -> b
655 reportError e s = error $ ("Core type error: checkExpr failed with "
656                    ++ s ++ " and " ++ show e)