75470d5a564f9af24067dffbc996f9ad124e0d43
[ghc-hetmet.git] / utils / ext-core / Check.hs
1 module Check where
2
3 import Maybe
4 import Control.Monad.Reader
5
6 import Core
7 import Printer
8 import List
9 import Env
10
11 {- Checking is done in a simple error monad.  In addition to
12    allowing errors to be captured, this makes it easy to guarantee
13    that checking itself has been completed for an entire module. -}
14
15 {- We use the Reader monad transformer in order to thread the 
16    top-level module name throughout the computation simply.
17    This is so that checkExp can also be an entry point (we call it
18    from Prep.) -}
19 data CheckRes a = OkC a | FailC String
20 type CheckResult a = ReaderT (AnMname, Menv) CheckRes a
21 getMname :: CheckResult AnMname
22 getMname     = ask >>= (return . fst)
23 getGlobalEnv :: CheckResult Menv
24 getGlobalEnv = ask >>= (return . snd)
25
26 instance Monad CheckRes where
27   OkC a >>= k = k a
28   FailC s >>= k = fail s
29   return = OkC
30   fail = FailC
31
32 require :: Bool -> String -> CheckResult ()
33 require False s = fail s
34 require True  _ = return ()
35
36 requireM :: CheckResult Bool -> String -> CheckResult ()
37 requireM cond s =
38   do b <- cond
39      require b s
40
41 {- Environments. -}
42 type Tvenv = Env Tvar Kind                    -- type variables  (local only)
43 type Tcenv = Env Tcon Kind                    -- type constructors
44 type Tsenv = Env Tcon ([Tvar],Ty)             -- type synonyms
45 type Cenv = Env Dcon Ty                       -- data constructors
46 type Venv = Env Var Ty                        -- values
47 type Menv = Env AnMname Envs                  -- modules
48 data Envs = Envs {tcenv_::Tcenv,tsenv_::Tsenv,cenv_::Cenv,venv_::Venv} -- all the exportable envs
49
50 {- Extend an environment, checking for illegal shadowing of identifiers. -}
51 extendM :: (Ord a, Show a) => Env a b -> (a,b) -> CheckResult (Env a b)
52 extendM env (k,d) = 
53    case elookup env k of
54      Just _ -> fail ("multiply-defined identifier: " ++ show k)
55      Nothing -> return (eextend env (k,d))
56
57 lookupM :: (Ord a, Show a) => Env a b -> a -> CheckResult b
58 lookupM env k =   
59    case elookup env k of
60      Just v -> return v
61      Nothing -> fail ("undefined identifier: " ++ show k)
62             
63 {- Main entry point. -}
64 checkModule :: Menv -> Module -> CheckRes Menv
65 checkModule globalEnv mod@(Module mn tdefs vdefgs) = 
66   runReaderT 
67     (do (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
68         (e_venv,l_venv) <- foldM (checkVdefg True (tcenv,tsenv,eempty,cenv))
69                               (eempty,eempty) 
70                               vdefgs
71         return (eextend globalEnv 
72             (mn,Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv})))
73     (mn, globalEnv)   
74
75 checkTdef0 :: (Tcenv,Tsenv) -> Tdef -> CheckResult (Tcenv,Tsenv)
76 checkTdef0 (tcenv,tsenv) tdef = ch tdef
77       where 
78         ch (Data (m,c) tbs _) = 
79             do mn <- getMname
80                requireModulesEq m mn "data type declaration" tdef False
81                tcenv' <- extendM tcenv (c,k)
82                return (tcenv',tsenv)
83             where k = foldr Karrow Klifted (map snd tbs)
84         ch (Newtype (m,c) tbs rhs) = 
85             do mn <- getMname
86                requireModulesEq m mn "newtype declaration" tdef False
87                tcenv' <- extendM tcenv (c,k)
88                tsenv' <- case rhs of
89                            Nothing -> return tsenv
90                            Just rep -> extendM tsenv (c,(map fst tbs,rep))
91                return (tcenv', tsenv')
92             where k = foldr Karrow Klifted (map snd tbs)
93     
94 checkTdef :: Tcenv -> Cenv -> Tdef -> CheckResult Cenv
95 checkTdef tcenv cenv = ch
96        where 
97          ch (Data (_,c) utbs cdefs) = 
98             do cbinds <- mapM checkCdef cdefs
99                foldM extendM cenv cbinds
100             where checkCdef (cdef@(Constr (m,dcon) etbs ts)) =
101                     do mn <- getMname
102                        requireModulesEq m mn "constructor declaration" cdef 
103                          False 
104                        tvenv <- foldM extendM eempty tbs 
105                        ks <- mapM (checkTy (tcenv,tvenv)) ts
106                        mapM_ (\k -> require (baseKind k)
107                                             ("higher-order kind in:\n" ++ show cdef ++ "\n" ++
108                                              "kind: " ++ show k) ) ks
109                        return (dcon,t mn) 
110                     where tbs = utbs ++ etbs
111                           t mn = foldr Tforall 
112                                   (foldr tArrow
113                                           (foldl Tapp (Tcon (Just mn,c))
114                                                  (map (Tvar . fst) utbs)) ts) tbs
115          ch (tdef@(Newtype c tbs (Just t))) =  
116             do tvenv <- foldM extendM eempty tbs
117                k <- checkTy (tcenv,tvenv) t
118                require (k==Klifted) ("bad kind:\n" ++ show tdef) 
119                return cenv
120          ch (tdef@(Newtype c tbs Nothing)) =
121             {- should only occur for recursive Newtypes -}
122             return cenv
123
124 mkTypeEnvs :: [Tdef] -> CheckResult (Tcenv, Tsenv, Cenv)
125 mkTypeEnvs tdefs = do
126   (tcenv, tsenv) <- foldM checkTdef0 (eempty,eempty) tdefs
127   cenv <- foldM (checkTdef tcenv) eempty tdefs
128   return (tcenv, tsenv, cenv)
129
130 requireModulesEq :: Show a => Mname -> AnMname -> String -> a 
131                           -> Bool -> CheckResult ()
132 requireModulesEq (Just mn) m msg t _      = require (mn == m) (mkErrMsg msg t)
133 requireModulesEq Nothing m msg t emptyOk  = require emptyOk (mkErrMsg msg t)
134
135 mkErrMsg :: Show a => String -> a -> String
136 mkErrMsg msg t = "wrong module name in " ++ msg ++ ":\n" ++ show t    
137
138 checkVdefg :: Bool -> (Tcenv,Tsenv,Tvenv,Cenv) -> (Venv,Venv) 
139                -> Vdefg -> CheckResult (Venv,Venv)
140 checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg =
141       case vdefg of
142         Rec vdefs ->
143             do e_venv' <- foldM extendM e_venv e_vts
144                l_venv' <- foldM extendM l_venv l_vts
145                let env' = (tcenv,tsenv,tvenv,cenv,e_venv',l_venv')
146                mapM_ (\ (vdef@(Vdef ((m,v),t,e))) -> 
147                             do mn <- getMname
148                                requireModulesEq m mn "value definition" vdef True
149                                k <- checkTy (tcenv,tvenv) t
150                                require (k==Klifted) ("unlifted kind in:\n" ++ show vdef)
151                                t' <- checkExp env' e
152                                requireM (equalTy tsenv t t') 
153                                         ("declared type doesn't match expression type in:\n"  ++ show vdef ++ "\n" ++  
154                                          "declared type: " ++ show t ++ "\n" ++
155                                          "expression type: " ++ show t')) vdefs
156                return (e_venv',l_venv')
157             where e_vts  = [ (v,t) | Vdef ((Just _,v),t,_) <- vdefs ]
158                   l_vts  = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
159         Nonrec (vdef@(Vdef ((m,v),t,e))) ->
160             do mn <- getMname
161                requireModulesEq m mn "value definition" vdef True
162                k <- checkTy (tcenv,tvenv) t 
163                require (k /= Kopen) ("open kind in:\n" ++ show vdef)
164                require ((not top_level) || (k /= Kunlifted)) ("top-level unlifted kind in:\n" ++ show vdef) 
165                t' <- checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) e
166                requireM (equalTy tsenv t t') 
167                         ("declared type doesn't match expression type in:\n" ++ show vdef  ++ "\n"  ++
168                          "declared type: " ++ show t ++ "\n" ++
169                          "expression type: " ++ show t') 
170                if isNothing m then
171                  do l_venv' <- extendM l_venv (v,t)
172                     return (e_venv,l_venv')
173                 else
174                  do e_venv' <- extendM e_venv (v,t)
175                     return (e_venv',l_venv)
176     
177 checkExpr :: AnMname -> Menv -> [Tdef] -> Venv -> Tvenv 
178                -> Exp -> Ty
179 checkExpr mn menv tdefs venv tvenv e = case (runReaderT (do
180   (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
181   checkExp (tcenv, tsenv, tvenv, cenv, venv, eempty) e) 
182                             (mn, menv)) of
183     OkC t -> t
184     FailC s -> reportError s
185
186 checkExp :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Exp -> CheckResult Ty
187 checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) = ch 
188       where 
189         ch e0 = 
190           case e0 of
191             Var qv -> 
192               qlookupM venv_ e_venv l_venv qv
193             Dcon qc ->
194               qlookupM cenv_ cenv eempty qc
195             Lit l -> 
196               checkLit l
197             Appt e t -> 
198               do t' <- ch e
199                  k' <- checkTy (tcenv,tvenv) t
200                  case t' of
201                    Tforall (tv,k) t0 ->
202                      do require (k' <= k) 
203                                 ("kind doesn't match at type application in:\n" ++ show e0 ++ "\n" ++
204                                  "operator kind: " ++ show k ++ "\n" ++
205                                  "operand kind: " ++ show k') 
206                         return (substl [tv] [t] t0)
207                    _ -> fail ("bad operator type in type application:\n" ++ show e0 ++ "\n" ++
208                                "operator type: " ++ show t')
209             App e1 e2 -> 
210               do t1 <- ch e1
211                  t2 <- ch e2
212                  case t1 of
213                    Tapp(Tapp(Tcon tc) t') t0 | tc == tcArrow ->
214                         do requireM (equalTy tsenv t2 t') 
215                                     ("type doesn't match at application in:\n" ++ show e0 ++ "\n" ++ 
216                                      "operator type: " ++ show t' ++ "\n" ++ 
217                                      "operand type: " ++ show t2) 
218                            return t0
219                    _ -> fail ("bad operator type at application in:\n" ++ show e0 ++ "\n" ++
220                                "operator type: " ++ show t1)
221             Lam (Tb tb) e ->
222               do tvenv' <- extendM tvenv tb 
223                  t <- checkExp (tcenv,tsenv,tvenv',cenv,e_venv,l_venv) e 
224                  return (Tforall tb t)
225             Lam (Vb (vb@(_,vt))) e ->
226               do k <- checkTy (tcenv,tvenv) vt
227                  require (baseKind k)   
228                          ("higher-order kind in:\n" ++ show e0 ++ "\n" ++
229                           "kind: " ++ show k) 
230                  l_venv' <- extendM l_venv vb
231                  t <- checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv') e
232                  require (not (isUtupleTy vt)) ("lambda-bound unboxed tuple in:\n" ++ show e0) 
233                  return (tArrow vt t)
234             Let vdefg e ->
235               do (e_venv',l_venv') <- checkVdefg False (tcenv,tsenv,tvenv,cenv) 
236                                         (e_venv,l_venv) vdefg 
237                  checkExp (tcenv,tsenv,tvenv,cenv,e_venv',l_venv') e
238             Case e (v,t) resultTy alts ->
239               do t' <- ch e 
240                  checkTy (tcenv,tvenv) t
241                  requireM (equalTy tsenv t t') 
242                           ("scrutinee declared type doesn't match expression type in:\n" ++ show e0 ++ "\n" ++
243                            "declared type: " ++ show t ++ "\n" ++
244                            "expression type: " ++ show t') 
245                  case (reverse alts) of
246                    (Acon c _ _ _):as ->
247                       let ok ((Acon c _ _ _):as) cs = do require (notElem c cs)
248                                                                  ("duplicate alternative in case:\n" ++ show e0) 
249                                                          ok as (c:cs)
250                           ok ((Alit _ _):_)      _  = fail ("invalid alternative in constructor case:\n" ++ show e0)
251                           ok [Adefault _]        _  = return ()
252                           ok (Adefault _:_)      _  = fail ("misplaced default alternative in case:\n" ++ show e0)
253                           ok []                  _  = return () 
254                       in ok as [c] 
255                    (Alit l _):as -> 
256                       let ok ((Acon _ _ _ _):_) _  = fail ("invalid alternative in literal case:\n" ++ show e0)
257                           ok ((Alit l _):as)    ls = do require (notElem l ls)
258                                                                 ("duplicate alternative in case:\n" ++ show e0) 
259                                                         ok as (l:ls)
260                           ok [Adefault _]       _  = return ()
261                           ok (Adefault _:_)     _  = fail ("misplaced default alternative in case:\n" ++ show e0)
262                           ok []                 _  = fail ("missing default alternative in literal case:\n" ++ show e0)
263                       in ok as [l] 
264                    [Adefault _] -> return ()
265                    [] -> fail ("no alternatives in case:\n" ++ show e0) 
266                  l_venv' <- extendM l_venv (v,t)
267                  t:ts <- mapM (checkAlt (tcenv,tsenv,tvenv,cenv,e_venv,l_venv') t) alts  
268                  bs <- mapM (equalTy tsenv t) ts
269                  require (and bs)
270                          ("alternative types don't match in:\n" ++ show e0 ++ "\n" ++
271                           "types: " ++ show (t:ts))
272                  checkTy (tcenv,tvenv) resultTy
273                  require (t == resultTy) ("case alternative type doesn't " ++
274                    " match case return type in:\n" ++ show e0 ++ "\n" ++
275                    "alt type: " ++ show t ++ " return type: " ++ show resultTy)
276                  return t
277             Cast e t -> 
278               do ch e 
279                  checkTy (tcenv,tvenv) t 
280                  return t
281             Note s e -> 
282               ch e
283             External _ t -> 
284               do checkTy (tcenv,eempty) t {- external types must be closed -}
285                  return t
286     
287 checkAlt :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Ty -> Alt -> CheckResult Ty 
288 checkAlt (env@(tcenv,tsenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
289       where 
290         ch a0 = 
291           case a0 of 
292             Acon qc etbs vbs e ->
293               do let uts = f t0                                      
294                        where f (Tapp t0 t) = f t0 ++ [t]
295                              f _ = []
296                  ct <- qlookupM cenv_ cenv eempty qc
297                  let (tbs,ct_args0,ct_res0) = splitTy ct
298                  {- get universals -}
299                  let (utbs,etbs') = splitAt (length uts) tbs
300                  let utvs = map fst utbs
301                  {- check existentials -}
302                  let (etvs,eks) = unzip etbs
303                  let (etvs',eks') = unzip etbs'
304                  require (eks == eks')  
305                          ("existential kinds don't match in:\n" ++ show a0 ++ "\n" ++
306                           "kinds declared in data constructor: " ++ show eks ++
307                           "kinds declared in case alternative: " ++ show eks') 
308                  tvenv' <- foldM extendM tvenv etbs
309                  {- check term variables -}
310                  let vts = map snd vbs
311                  mapM_ (\vt -> require ((not . isUtupleTy) vt)
312                                        ("pattern-bound unboxed tuple in:\n" ++ show a0 ++ "\n" ++
313                                         "pattern type: " ++ show vt)) vts
314                  vks <- mapM (checkTy (tcenv,tvenv')) vts
315                  mapM_ (\vk -> require (baseKind vk)
316                                        ("higher-order kind in:\n" ++ show a0 ++ "\n" ++
317                                         "kind: " ++ show vk)) vks 
318                  let (ct_res:ct_args) = map (substl (utvs++etvs') (uts++(map Tvar etvs))) (ct_res0:ct_args0)
319                  zipWithM_ 
320                     (\ct_arg vt -> 
321                         requireM (equalTy tsenv ct_arg vt)
322                                  ("pattern variable type doesn't match constructor argument type in:\n" ++ show a0 ++ "\n" ++
323                                   "pattern variable type: " ++ show ct_arg ++ "\n" ++
324                                   "constructor argument type: " ++ show vt)) ct_args vts
325                  requireM (equalTy tsenv ct_res t0)
326                           ("pattern constructor type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
327                            "pattern constructor type: " ++ show ct_res ++ "\n" ++
328                            "scrutinee type: " ++ show t0) 
329                  l_venv' <- foldM extendM l_venv vbs
330                  t <- checkExp (tcenv,tsenv,tvenv',cenv,e_venv,l_venv') e 
331                  checkTy (tcenv,tvenv) t  {- check that existentials don't escape in result type -}
332                  return t
333             Alit l e ->
334               do t <- checkLit l
335                  requireM (equalTy tsenv t t0)
336                          ("pattern type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
337                           "pattern type: " ++ show t ++ "\n" ++
338                           "scrutinee type: " ++ show t0) 
339                  checkExp env e
340             Adefault e ->
341               checkExp env e
342     
343 checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
344 checkTy (tcenv,tvenv) = ch
345      where
346        ch (Tvar tv) = lookupM tvenv tv
347        ch (Tcon qtc) = qlookupM tcenv_ tcenv eempty qtc
348        ch (t@(Tapp t1 t2)) = 
349            do k1 <- ch t1
350               k2 <- ch t2
351               case k1 of
352                  Karrow k11 k12 ->
353                    do require (k2 <= k11) 
354                              ("kinds don't match in type application: " ++ show t ++ "\n" ++
355                               "operator kind: " ++ show k11 ++ "\n" ++
356                               "operand kind: " ++ show k2)              
357                       return k12
358                  _ -> fail ("applied type has non-arrow kind: " ++ show t)
359        ch (Tforall tb t) = 
360             do tvenv' <- extendM tvenv tb 
361                checkTy (tcenv,tvenv') t
362     
363 {- Type equality modulo newtype synonyms. -}
364 equalTy :: Tsenv -> Ty -> Ty -> CheckResult Bool
365 equalTy tsenv t1 t2 = 
366             do t1' <- expand t1
367                t2' <- expand t2
368                return (t1' == t2')
369       where expand (Tvar v) = return (Tvar v)
370             expand (Tcon qtc) = return (Tcon qtc)
371             expand (Tapp t1 t2) = 
372               do t2' <- expand t2
373                  expapp t1 [t2']
374             expand (Tforall tb t) = 
375               do t' <- expand t
376                  return (Tforall tb t')
377             expapp (t@(Tcon (m,tc))) ts = 
378               do env <- mlookupM tsenv_ tsenv eempty m
379                  case elookup env tc of 
380                     Just (formals,rhs) | (length formals) == (length ts) -> return (substl formals ts rhs)
381                     _ -> return (foldl Tapp t ts)
382             expapp (Tapp t1 t2) ts = 
383               do t2' <- expand t2
384                  expapp t1 (t2':ts)
385             expapp t ts = 
386               do t' <- expand t
387                  return (foldl Tapp t' ts)
388     
389
390 mlookupM :: (Envs -> Env a b) -> Env a b -> Env a b -> Mname 
391           -> CheckResult (Env a b)
392 mlookupM _ _ local_env    Nothing            = return local_env
393 mlookupM selector external_env _ (Just m) = do
394   mn <- getMname
395   if m == mn
396      then return external_env
397      else do
398        globalEnv <- getGlobalEnv
399        case elookup globalEnv m of
400          Just env' -> return (selector env')
401          Nothing -> fail ("Check: undefined module name: " ++ show m)
402
403 qlookupM :: (Ord a, Show a) => (Envs -> Env a b) -> Env a b -> Env a b 
404                   -> Qual a -> CheckResult b
405 qlookupM selector external_env local_env (m,k) =   
406       do env <- mlookupM selector external_env local_env m
407          lookupM env k
408
409
410 checkLit :: Lit -> CheckResult Ty
411 checkLit lit =
412   case lit of
413     Lint _ t -> 
414           do {- require (elem t [tIntzh, {- tInt32zh,tInt64zh, -} tWordzh, {- tWord32zh,tWord64zh, -} tAddrzh, tCharzh]) 
415                      ("invalid int literal: " ++ show lit ++ "\n" ++ "type: " ++ show t) -}
416              return t
417     Lrational _ t ->
418           do {- require (elem t [tFloatzh,tDoublezh]) 
419                      ("invalid rational literal: " ++ show lit ++ "\n" ++ "type: " ++ show t) -}
420              return t
421     Lchar _ t -> 
422           do {- require (t == tCharzh) 
423                      ("invalid char literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)  -}
424              return t   
425     Lstring _ t ->
426           do {- require (t == tAddrzh) 
427                      ("invalid string literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)  -}
428              return t
429
430 {- Utilities -}
431
432 {- Split off tbs, arguments and result of a (possibly abstracted)  arrow type -}
433 splitTy :: Ty -> ([Tbind],[Ty],Ty)
434 splitTy (Tforall tb t) = (tb:tbs,ts,tr) 
435                 where (tbs,ts,tr) = splitTy t
436 splitTy (Tapp(Tapp(Tcon tc) t0) t) | tc == tcArrow = (tbs,t0:ts,tr)
437                 where (tbs,ts,tr) = splitTy t
438 splitTy t = ([],[],t)
439
440
441 {- Simultaneous substitution on types for type variables,
442    renaming as neceessary to avoid capture.
443    No checks for correct kindedness. -}
444 substl :: [Tvar] -> [Ty] -> Ty -> Ty
445 substl tvs ts t = f (zip tvs ts) t
446   where 
447     f env t0 =
448      case t0 of
449        Tcon _ -> t0
450        Tvar v -> case lookup v env of
451                    Just t1 -> t1
452                    Nothing -> t0
453        Tapp t1 t2 -> Tapp (f env t1) (f env t2)
454        Tforall (t,k) t1 -> 
455          if t `elem` free then
456            Tforall (t',k) (f ((t,Tvar t'):env) t1)
457          else 
458            Tforall (t,k) (f (filter ((/=t).fst) env) t1)
459      where free = foldr union [] (map (freeTvars.snd) env)
460            t' = freshTvar free 
461    
462 {- Return free tvars in a type -}
463 freeTvars :: Ty -> [Tvar]
464 freeTvars (Tcon _) = []
465 freeTvars (Tvar v) = [v]
466 freeTvars (Tapp t1 t2) = (freeTvars t1) `union` (freeTvars t2)
467 freeTvars (Tforall (t,_) t1) = delete t (freeTvars t1) 
468
469 {- Return any tvar *not* in the argument list. -}
470 freshTvar :: [Tvar] -> Tvar
471 freshTvar tvs = maximum ("":tvs) ++ "x" -- one simple way!
472
473 -- todo
474 reportError s = error $ ("Core parser error: checkExpr failed with " ++ s)