dedc0f41088ab458a8202d80cfb7445c699d9146
[ghc-hetmet.git] / utils / ext-core / Check.hs
1 module Check where
2
3 import Maybe
4 import Control.Monad.Reader
5
6 import Core
7 import Printer
8 import List
9 import Env
10
11 {- Checking is done in a simple error monad.  In addition to
12    allowing errors to be captured, this makes it easy to guarantee
13    that checking itself has been completed for an entire module. -}
14
15 {- We use the Reader monad transformer in order to thread the 
16    top-level module name throughout the computation simply.
17    This is so that checkExp can also be an entry point (we call it
18    from Prep.) -}
19 data CheckRes a = OkC a | FailC String
20 type CheckResult a = ReaderT (AnMname, Menv) CheckRes a
21 getMname :: CheckResult AnMname
22 getMname     = ask >>= (return . fst)
23 getGlobalEnv :: CheckResult Menv
24 getGlobalEnv = ask >>= (return . snd)
25
26 instance Monad CheckRes where
27   OkC a >>= k = k a
28   FailC s >>= k = fail s
29   return = OkC
30   fail = FailC
31
32 require :: Bool -> String -> CheckResult ()
33 require False s = fail s
34 require True  _ = return ()
35
36 requireM :: CheckResult Bool -> String -> CheckResult ()
37 requireM cond s =
38   do b <- cond
39      require b s
40
41 {- Environments. -}
42 type Tvenv = Env Tvar Kind                    -- type variables  (local only)
43 type Tcenv = Env Tcon Kind                    -- type constructors
44 type Tsenv = Env Tcon ([Tvar],Ty)             -- type synonyms
45 type Cenv = Env Dcon Ty                       -- data constructors
46 type Venv = Env Var Ty                        -- values
47 type Menv = Env AnMname Envs                  -- modules
48 data Envs = Envs {tcenv_::Tcenv,tsenv_::Tsenv,cenv_::Cenv,venv_::Venv} -- all the exportable envs
49
50 {- Extend an environment, checking for illegal shadowing of identifiers. -}
51 extendM :: (Ord a, Show a) => Env a b -> (a,b) -> CheckResult (Env a b)
52 extendM env (k,d) = 
53    case elookup env k of
54      Just _ -> fail ("multiply-defined identifier: " ++ show k)
55      Nothing -> return (eextend env (k,d))
56
57 lookupM :: (Ord a, Show a) => Env a b -> a -> CheckResult b
58 lookupM env k =   
59    case elookup env k of
60      Just v -> return v
61      Nothing -> fail ("undefined identifier: " ++ show k)
62             
63 {- Main entry point. -}
64 checkModule :: Menv -> Module -> CheckRes Menv
65 checkModule globalEnv mod@(Module mn tdefs vdefgs) = 
66   runReaderT 
67     (do (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
68         (e_venv,l_venv) <- foldM (checkVdefg True (tcenv,tsenv,eempty,cenv))
69                               (eempty,eempty) 
70                               vdefgs
71         return (eextend globalEnv 
72             (mn,Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv})))
73     (mn, globalEnv)   
74
75 checkTdef0 :: (Tcenv,Tsenv) -> Tdef -> CheckResult (Tcenv,Tsenv)
76 checkTdef0 (tcenv,tsenv) tdef = ch tdef
77       where 
78         ch (Data (m,c) tbs _) = 
79             do mn <- getMname
80                requireModulesEq m mn "data type declaration" tdef False
81                tcenv' <- extendM tcenv (c,k)
82                return (tcenv',tsenv)
83             where k = foldr Karrow Klifted (map snd tbs)
84         -- TODO
85         ch (Newtype (m,c) tbs axiom rhs) = 
86             do mn <- getMname
87                requireModulesEq m mn "newtype declaration" tdef False
88                tcenv' <- extendM tcenv (c,k)
89                tsenv' <- case rhs of
90                            Nothing -> return tsenv
91                            Just rep -> extendM tsenv (c,(map fst tbs,rep))
92                return (tcenv', tsenv')
93             where k = foldr Karrow Klifted (map snd tbs)
94     
95 checkTdef :: Tcenv -> Cenv -> Tdef -> CheckResult Cenv
96 checkTdef tcenv cenv = ch
97        where 
98          ch (Data (_,c) utbs cdefs) = 
99             do cbinds <- mapM checkCdef cdefs
100                foldM extendM cenv cbinds
101             where checkCdef (cdef@(Constr (m,dcon) etbs ts)) =
102                     do mn <- getMname
103                        requireModulesEq m mn "constructor declaration" cdef 
104                          False 
105                        tvenv <- foldM extendM eempty tbs 
106                        ks <- mapM (checkTy (tcenv,tvenv)) ts
107                        mapM_ (\k -> require (baseKind k)
108                                             ("higher-order kind in:\n" ++ show cdef ++ "\n" ++
109                                              "kind: " ++ show k) ) ks
110                        return (dcon,t mn) 
111                     where tbs = utbs ++ etbs
112                           t mn = foldr Tforall 
113                                   (foldr tArrow
114                                           (foldl Tapp (Tcon (Just mn,c))
115                                                  (map (Tvar . fst) utbs)) ts) tbs
116          -- TODO
117          ch (tdef@(Newtype c tbs axiom (Just t))) =  
118             do tvenv <- foldM extendM eempty tbs
119                k <- checkTy (tcenv,tvenv) t
120                require (k `eqKind` Klifted) ("bad kind:\n" ++ show tdef) 
121                return cenv
122          ch (tdef@(Newtype c tbs axiom Nothing)) =
123             {- should only occur for recursive Newtypes -}
124             return cenv
125
126 mkTypeEnvs :: [Tdef] -> CheckResult (Tcenv, Tsenv, Cenv)
127 mkTypeEnvs tdefs = do
128   (tcenv, tsenv) <- foldM checkTdef0 (eempty,eempty) tdefs
129   cenv <- foldM (checkTdef tcenv) eempty tdefs
130   return (tcenv, tsenv, cenv)
131
132 requireModulesEq :: Show a => Mname -> AnMname -> String -> a 
133                           -> Bool -> CheckResult ()
134 requireModulesEq (Just mn) m msg t _      = require (mn == m) (mkErrMsg msg t)
135 requireModulesEq Nothing m msg t emptyOk  = require emptyOk (mkErrMsg msg t)
136
137 mkErrMsg :: Show a => String -> a -> String
138 mkErrMsg msg t = "wrong module name in " ++ msg ++ ":\n" ++ show t    
139
140 checkVdefg :: Bool -> (Tcenv,Tsenv,Tvenv,Cenv) -> (Venv,Venv) 
141                -> Vdefg -> CheckResult (Venv,Venv)
142 checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg =
143       case vdefg of
144         Rec vdefs ->
145             do e_venv' <- foldM extendM e_venv e_vts
146                l_venv' <- foldM extendM l_venv l_vts
147                let env' = (tcenv,tsenv,tvenv,cenv,e_venv',l_venv')
148                mapM_ (\ (vdef@(Vdef ((m,v),t,e))) -> 
149                             do mn <- getMname
150                                requireModulesEq m mn "value definition" vdef True
151                                k <- checkTy (tcenv,tvenv) t
152                                require (k `eqKind` Klifted) ("unlifted kind in:\n" ++ show vdef)
153                                t' <- checkExp env' e
154                                requireM (equalTy tsenv t t') 
155                                         ("declared type doesn't match expression type in:\n"  ++ show vdef ++ "\n" ++  
156                                          "declared type: " ++ show t ++ "\n" ++
157                                          "expression type: " ++ show t')) vdefs
158                return (e_venv',l_venv')
159             where e_vts  = [ (v,t) | Vdef ((Just _,v),t,_) <- vdefs ]
160                   l_vts  = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
161         Nonrec (vdef@(Vdef ((m,v),t,e))) ->
162             do mn <- getMname
163                requireModulesEq m mn "value definition" vdef True
164                k <- checkTy (tcenv,tvenv) t 
165                require (not (k `eqKind` Kopen)) ("open kind in:\n" ++ show vdef)
166                require ((not top_level) || (not (k `eqKind` Kunlifted))) ("top-level unlifted kind in:\n" ++ show vdef) 
167                t' <- checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) e
168                requireM (equalTy tsenv t t') 
169                         ("declared type doesn't match expression type in:\n" ++ show vdef  ++ "\n"  ++
170                          "declared type: " ++ show t ++ "\n" ++
171                          "expression type: " ++ show t') 
172                if isNothing m then
173                  do l_venv' <- extendM l_venv (v,t)
174                     return (e_venv,l_venv')
175                 else
176                  do e_venv' <- extendM e_venv (v,t)
177                     return (e_venv',l_venv)
178     
179 checkExpr :: AnMname -> Menv -> [Tdef] -> Venv -> Tvenv 
180                -> Exp -> Ty
181 checkExpr mn menv tdefs venv tvenv e = case (runReaderT (do
182   (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
183   checkExp (tcenv, tsenv, tvenv, cenv, venv, eempty) e) 
184                             (mn, menv)) of
185     OkC t -> t
186     FailC s -> reportError s
187
188 checkExp :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Exp -> CheckResult Ty
189 checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) = ch 
190       where 
191         ch e0 = 
192           case e0 of
193             Var qv -> 
194               qlookupM venv_ e_venv l_venv qv
195             Dcon qc ->
196               qlookupM cenv_ cenv eempty qc
197             Lit l -> 
198               checkLit l
199             Appt e t -> 
200               do t' <- ch e
201                  k' <- checkTy (tcenv,tvenv) t
202                  case t' of
203                    Tforall (tv,k) t0 ->
204                      do require (k' `subKindOf` k) 
205                                 ("kind doesn't match at type application in:\n" ++ show e0 ++ "\n" ++
206                                  "operator kind: " ++ show k ++ "\n" ++
207                                  "operand kind: " ++ show k') 
208                         return (substl [tv] [t] t0)
209                    _ -> fail ("bad operator type in type application:\n" ++ show e0 ++ "\n" ++
210                                "operator type: " ++ show t')
211             App e1 e2 -> 
212               do t1 <- ch e1
213                  t2 <- ch e2
214                  case t1 of
215                    Tapp(Tapp(Tcon tc) t') t0 | tc == tcArrow ->
216                         do requireM (equalTy tsenv t2 t') 
217                                     ("type doesn't match at application in:\n" ++ show e0 ++ "\n" ++ 
218                                      "operator type: " ++ show t' ++ "\n" ++ 
219                                      "operand type: " ++ show t2) 
220                            return t0
221                    _ -> fail ("bad operator type at application in:\n" ++ show e0 ++ "\n" ++
222                                "operator type: " ++ show t1)
223             Lam (Tb tb) e ->
224               do tvenv' <- extendM tvenv tb 
225                  t <- checkExp (tcenv,tsenv,tvenv',cenv,e_venv,l_venv) e 
226                  return (Tforall tb t)
227             Lam (Vb (vb@(_,vt))) e ->
228               do k <- checkTy (tcenv,tvenv) vt
229                  require (baseKind k)   
230                          ("higher-order kind in:\n" ++ show e0 ++ "\n" ++
231                           "kind: " ++ show k) 
232                  l_venv' <- extendM l_venv vb
233                  t <- checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv') e
234                  require (not (isUtupleTy vt)) ("lambda-bound unboxed tuple in:\n" ++ show e0) 
235                  return (tArrow vt t)
236             Let vdefg e ->
237               do (e_venv',l_venv') <- checkVdefg False (tcenv,tsenv,tvenv,cenv) 
238                                         (e_venv,l_venv) vdefg 
239                  checkExp (tcenv,tsenv,tvenv,cenv,e_venv',l_venv') e
240             Case e (v,t) resultTy alts ->
241               do t' <- ch e 
242                  checkTy (tcenv,tvenv) t
243                  requireM (equalTy tsenv t t') 
244                           ("scrutinee declared type doesn't match expression type in:\n" ++ show e0 ++ "\n" ++
245                            "declared type: " ++ show t ++ "\n" ++
246                            "expression type: " ++ show t') 
247                  case (reverse alts) of
248                    (Acon c _ _ _):as ->
249                       let ok ((Acon c _ _ _):as) cs = do require (notElem c cs)
250                                                                  ("duplicate alternative in case:\n" ++ show e0) 
251                                                          ok as (c:cs)
252                           ok ((Alit _ _):_)      _  = fail ("invalid alternative in constructor case:\n" ++ show e0)
253                           ok [Adefault _]        _  = return ()
254                           ok (Adefault _:_)      _  = fail ("misplaced default alternative in case:\n" ++ show e0)
255                           ok []                  _  = return () 
256                       in ok as [c] 
257                    (Alit l _):as -> 
258                       let ok ((Acon _ _ _ _):_) _  = fail ("invalid alternative in literal case:\n" ++ show e0)
259                           ok ((Alit l _):as)    ls = do require (notElem l ls)
260                                                                 ("duplicate alternative in case:\n" ++ show e0) 
261                                                         ok as (l:ls)
262                           ok [Adefault _]       _  = return ()
263                           ok (Adefault _:_)     _  = fail ("misplaced default alternative in case:\n" ++ show e0)
264                           ok []                 _  = fail ("missing default alternative in literal case:\n" ++ show e0)
265                       in ok as [l] 
266                    [Adefault _] -> return ()
267                    [] -> fail ("no alternatives in case:\n" ++ show e0) 
268                  l_venv' <- extendM l_venv (v,t)
269                  t:ts <- mapM (checkAlt (tcenv,tsenv,tvenv,cenv,e_venv,l_venv') t) alts  
270                  bs <- mapM (equalTy tsenv t) ts
271                  require (and bs)
272                          ("alternative types don't match in:\n" ++ show e0 ++ "\n" ++
273                           "types: " ++ show (t:ts))
274                  checkTy (tcenv,tvenv) resultTy
275                  require (t == resultTy) ("case alternative type doesn't " ++
276                    " match case return type in:\n" ++ show e0 ++ "\n" ++
277                    "alt type: " ++ show t ++ " return type: " ++ show resultTy)
278                  return t
279             Cast e t -> 
280               do ch e 
281                  checkTy (tcenv,tvenv) t 
282                  return t
283             Note s e -> 
284               ch e
285             External _ t -> 
286               do checkTy (tcenv,eempty) t {- external types must be closed -}
287                  return t
288     
289 checkAlt :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Ty -> Alt -> CheckResult Ty 
290 checkAlt (env@(tcenv,tsenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
291       where 
292         ch a0 = 
293           case a0 of 
294             Acon qc etbs vbs e ->
295               do let uts = f t0                                      
296                        where f (Tapp t0 t) = f t0 ++ [t]
297                              f _ = []
298                  ct <- qlookupM cenv_ cenv eempty qc
299                  let (tbs,ct_args0,ct_res0) = splitTy ct
300                  {- get universals -}
301                  let (utbs,etbs') = splitAt (length uts) tbs
302                  let utvs = map fst utbs
303                  {- check existentials -}
304                  let (etvs,eks) = unzip etbs
305                  let (etvs',eks') = unzip etbs'
306                  require (all (uncurry eqKind)
307                             (zip eks eks'))  
308                          ("existential kinds don't match in:\n" ++ show a0 ++ "\n" ++
309                           "kinds declared in data constructor: " ++ show eks ++
310                           "kinds declared in case alternative: " ++ show eks') 
311                  tvenv' <- foldM extendM tvenv etbs
312                  {- check term variables -}
313                  let vts = map snd vbs
314                  mapM_ (\vt -> require ((not . isUtupleTy) vt)
315                                        ("pattern-bound unboxed tuple in:\n" ++ show a0 ++ "\n" ++
316                                         "pattern type: " ++ show vt)) vts
317                  vks <- mapM (checkTy (tcenv,tvenv')) vts
318                  mapM_ (\vk -> require (baseKind vk)
319                                        ("higher-order kind in:\n" ++ show a0 ++ "\n" ++
320                                         "kind: " ++ show vk)) vks 
321                  let (ct_res:ct_args) = map (substl (utvs++etvs') (uts++(map Tvar etvs))) (ct_res0:ct_args0)
322                  zipWithM_ 
323                     (\ct_arg vt -> 
324                         requireM (equalTy tsenv ct_arg vt)
325                                  ("pattern variable type doesn't match constructor argument type in:\n" ++ show a0 ++ "\n" ++
326                                   "pattern variable type: " ++ show ct_arg ++ "\n" ++
327                                   "constructor argument type: " ++ show vt)) ct_args vts
328                  requireM (equalTy tsenv ct_res t0)
329                           ("pattern constructor type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
330                            "pattern constructor type: " ++ show ct_res ++ "\n" ++
331                            "scrutinee type: " ++ show t0) 
332                  l_venv' <- foldM extendM l_venv vbs
333                  t <- checkExp (tcenv,tsenv,tvenv',cenv,e_venv,l_venv') e 
334                  checkTy (tcenv,tvenv) t  {- check that existentials don't escape in result type -}
335                  return t
336             Alit l e ->
337               do t <- checkLit l
338                  requireM (equalTy tsenv t t0)
339                          ("pattern type doesn't match scrutinee type in:\n" ++ show a0 ++ "\n" ++
340                           "pattern type: " ++ show t ++ "\n" ++
341                           "scrutinee type: " ++ show t0) 
342                  checkExp env e
343             Adefault e ->
344               checkExp env e
345     
346 checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
347 checkTy (tcenv,tvenv) = ch
348      where
349        ch (Tvar tv) = lookupM tvenv tv
350        ch (Tcon qtc) = qlookupM tcenv_ tcenv eempty qtc
351        ch (t@(Tapp t1 t2)) = 
352            do k1 <- ch t1
353               k2 <- ch t2
354               case k1 of
355                  Karrow k11 k12 ->
356                    do require (k2 `subKindOf` k11) 
357                              ("kinds don't match in type application: " ++ show t ++ "\n" ++
358                               "operator kind: " ++ show k11 ++ "\n" ++
359                               "operand kind: " ++ show k2)              
360                       return k12
361                  _ -> fail ("applied type has non-arrow kind: " ++ show t)
362        ch (Tforall tb t) = 
363             do tvenv' <- extendM tvenv tb 
364                checkTy (tcenv,tvenv') t
365     
366 {- Type equality modulo newtype synonyms. -}
367 equalTy :: Tsenv -> Ty -> Ty -> CheckResult Bool
368 equalTy tsenv t1 t2 = 
369             do t1' <- expand t1
370                t2' <- expand t2
371                return (t1' == t2')
372       where expand (Tvar v) = return (Tvar v)
373             expand (Tcon qtc) = return (Tcon qtc)
374             expand (Tapp t1 t2) = 
375               do t2' <- expand t2
376                  expapp t1 [t2']
377             expand (Tforall tb t) = 
378               do t' <- expand t
379                  return (Tforall tb t')
380             expapp (t@(Tcon (m,tc))) ts = 
381               do env <- mlookupM tsenv_ tsenv eempty m
382                  case elookup env tc of 
383                     Just (formals,rhs) | (length formals) == (length ts) -> return (substl formals ts rhs)
384                     _ -> return (foldl Tapp t ts)
385             expapp (Tapp t1 t2) ts = 
386               do t2' <- expand t2
387                  expapp t1 (t2':ts)
388             expapp t ts = 
389               do t' <- expand t
390                  return (foldl Tapp t' ts)
391     
392
393 mlookupM :: (Envs -> Env a b) -> Env a b -> Env a b -> Mname 
394           -> CheckResult (Env a b)
395 mlookupM _ _ local_env    Nothing            = return local_env
396 mlookupM selector external_env _ (Just m) = do
397   mn <- getMname
398   if m == mn
399      then return external_env
400      else do
401        globalEnv <- getGlobalEnv
402        case elookup globalEnv m of
403          Just env' -> return (selector env')
404          Nothing -> fail ("Check: undefined module name: " ++ show m)
405
406 qlookupM :: (Ord a, Show a) => (Envs -> Env a b) -> Env a b -> Env a b 
407                   -> Qual a -> CheckResult b
408 qlookupM selector external_env local_env (m,k) =   
409       do env <- mlookupM selector external_env local_env m
410          lookupM env k
411
412
413 checkLit :: Lit -> CheckResult Ty
414 checkLit (Literal lit t) =
415   case lit of
416       -- TODO: restore commented-out stuff.
417     Lint _ -> 
418           do {- require (elem t [tIntzh, {- tInt32zh,tInt64zh, -} tWordzh, {- tWord32zh,tWord64zh, -} tAddrzh, tCharzh]) 
419                      ("invalid int literal: " ++ show lit ++ "\n" ++ "type: " ++ show t) -}
420              return t
421     Lrational _ ->
422           do {- require (elem t [tFloatzh,tDoublezh]) 
423                      ("invalid rational literal: " ++ show lit ++ "\n" ++ "type: " ++ show t) -}
424              return t
425     Lchar _ -> 
426           do {- require (t == tCharzh) 
427                      ("invalid char literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)  -}
428              return t   
429     Lstring _ ->
430           do {- require (t == tAddrzh) 
431                      ("invalid string literal: " ++ show lit ++ "\n" ++ "type: " ++ show t)  -}
432              return t
433
434 {- Utilities -}
435
436 {- Split off tbs, arguments and result of a (possibly abstracted)  arrow type -}
437 splitTy :: Ty -> ([Tbind],[Ty],Ty)
438 splitTy (Tforall tb t) = (tb:tbs,ts,tr) 
439                 where (tbs,ts,tr) = splitTy t
440 splitTy (Tapp(Tapp(Tcon tc) t0) t) | tc == tcArrow = (tbs,t0:ts,tr)
441                 where (tbs,ts,tr) = splitTy t
442 splitTy t = ([],[],t)
443
444
445 {- Simultaneous substitution on types for type variables,
446    renaming as neceessary to avoid capture.
447    No checks for correct kindedness. -}
448 substl :: [Tvar] -> [Ty] -> Ty -> Ty
449 substl tvs ts t = f (zip tvs ts) t
450   where 
451     f env t0 =
452      case t0 of
453        Tcon _ -> t0
454        Tvar v -> case lookup v env of
455                    Just t1 -> t1
456                    Nothing -> t0
457        Tapp t1 t2 -> Tapp (f env t1) (f env t2)
458        Tforall (t,k) t1 -> 
459          if t `elem` free then
460            Tforall (t',k) (f ((t,Tvar t'):env) t1)
461          else 
462            Tforall (t,k) (f (filter ((/=t).fst) env) t1)
463      where free = foldr union [] (map (freeTvars.snd) env)
464            t' = freshTvar free 
465    
466 {- Return free tvars in a type -}
467 freeTvars :: Ty -> [Tvar]
468 freeTvars (Tcon _) = []
469 freeTvars (Tvar v) = [v]
470 freeTvars (Tapp t1 t2) = (freeTvars t1) `union` (freeTvars t2)
471 freeTvars (Tforall (t,_) t1) = delete t (freeTvars t1) 
472
473 {- Return any tvar *not* in the argument list. -}
474 freshTvar :: [Tvar] -> Tvar
475 freshTvar tvs = maximum ("":tvs) ++ "x" -- one simple way!
476
477 -- todo
478 reportError s = error $ ("Core parser error: checkExpr failed with " ++ s)