First cut at reviving the External Core tools
[ghc-hetmet.git] / utils / ext-core / Check.hs
index a9a3eac..8b928b0 100644 (file)
@@ -1,6 +1,8 @@
 module Check where
 
-import Monad
+import Maybe
+import Control.Monad.Reader
+
 import Core
 import Printer
 import List
@@ -10,9 +12,18 @@ import Env
    allowing errors to be captured, this makes it easy to guarantee
    that checking itself has been completed for an entire module. -}
 
-data CheckResult a = OkC a | FailC String
+{- We use the Reader monad transformer in order to thread the 
+   top-level module name throughout the computation simply.
+   This is so that checkExp can also be an entry point (we call it
+   from Prep.) -}
+data CheckRes a = OkC a | FailC String
+type CheckResult a = ReaderT (AnMname, Menv) CheckRes a
+getMname :: CheckResult AnMname
+getMname     = ask >>= (return . fst)
+getGlobalEnv :: CheckResult Menv
+getGlobalEnv = ask >>= (return . snd)
 
-instance Monad CheckResult where
+instance Monad CheckRes where
   OkC a >>= k = k a
   FailC s >>= k = fail s
   return = OkC
@@ -33,7 +44,7 @@ type Tcenv = Env Tcon Kind                    -- type constructors
 type Tsenv = Env Tcon ([Tvar],Ty)             -- type synonyms
 type Cenv = Env Dcon Ty                      -- data constructors
 type Venv = Env Var Ty                               -- values
-type Menv = Env Mname Envs                   -- modules
+type Menv = Env AnMname Envs                 -- modules
 data Envs = Envs {tcenv_::Tcenv,tsenv_::Tsenv,cenv_::Cenv,venv_::Venv} -- all the exportable envs
 
 {- Extend an environment, checking for illegal shadowing of identifiers. -}
@@ -50,24 +61,29 @@ lookupM env k =
      Nothing -> fail ("undefined identifier: " ++ show k)
             
 {- Main entry point. -}
-checkModule :: Menv -> Module -> CheckResult Menv
-checkModule globalEnv (Module mn tdefs vdefgs) = 
-  do (tcenv,tsenv) <- foldM checkTdef0 (eempty,eempty) tdefs
-     cenv <- foldM (checkTdef tcenv) eempty tdefs
-     (e_venv,l_venv) <- foldM (checkVdefg True (tcenv,tsenv,eempty,cenv)) (eempty,eempty) vdefgs
-     return (eextend globalEnv (mn,Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv}))
-  where 
+checkModule :: Menv -> Module -> CheckRes Menv
+checkModule globalEnv mod@(Module mn tdefs vdefgs) = 
+  runReaderT 
+    (do (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
+        (e_venv,l_venv) <- foldM (checkVdefg True (tcenv,tsenv,eempty,cenv))
+                              (eempty,eempty) 
+                              vdefgs
+        return (eextend globalEnv 
+            (mn,Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv})))
+    (mn, globalEnv)   
 
-    checkTdef0 :: (Tcenv,Tsenv) -> Tdef -> CheckResult (Tcenv,Tsenv)
-    checkTdef0 (tcenv,tsenv) tdef = ch tdef
+checkTdef0 :: (Tcenv,Tsenv) -> Tdef -> CheckResult (Tcenv,Tsenv)
+checkTdef0 (tcenv,tsenv) tdef = ch tdef
       where 
        ch (Data (m,c) tbs _) = 
-           do require (m == mn) ("wrong module name in data type declaration:\n" ++ show tdef)
+           do mn <- getMname
+               requireModulesEq m mn "data type declaration" tdef False
               tcenv' <- extendM tcenv (c,k)
               return (tcenv',tsenv)
            where k = foldr Karrow Klifted (map snd tbs)
        ch (Newtype (m,c) tbs rhs) = 
-           do require (m == mn) ("wrong module name in newtype declaration:\n" ++ show tdef)
+           do mn <- getMname
+               requireModulesEq m mn "newtype declaration" tdef False
               tcenv' <- extendM tcenv (c,k)
               tsenv' <- case rhs of
                           Nothing -> return tsenv
@@ -75,24 +91,26 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
               return (tcenv', tsenv')
            where k = foldr Karrow Klifted (map snd tbs)
     
-    checkTdef :: Tcenv -> Cenv -> Tdef -> CheckResult Cenv
-    checkTdef tcenv cenv = ch
+checkTdef :: Tcenv -> Cenv -> Tdef -> CheckResult Cenv
+checkTdef tcenv cenv = ch
        where 
         ch (Data (_,c) utbs cdefs) = 
            do cbinds <- mapM checkCdef cdefs
               foldM extendM cenv cbinds
            where checkCdef (cdef@(Constr (m,dcon) etbs ts)) =
-                   do require (m == mn) ("wrong module name in constructor declaration:\n" ++ show cdef)
+                   do mn <- getMname
+                       requireModulesEq m mn "constructor declaration" cdef 
+                         False 
                       tvenv <- foldM extendM eempty tbs 
                       ks <- mapM (checkTy (tcenv,tvenv)) ts
                       mapM_ (\k -> require (baseKind k)
                                            ("higher-order kind in:\n" ++ show cdef ++ "\n" ++
                                             "kind: " ++ show k) ) ks
-                      return (dcon,t) 
+                      return (dcon,t mn) 
                    where tbs = utbs ++ etbs
-                         t = foldr Tforall 
+                         t mn = foldr Tforall 
                                  (foldr tArrow
-                                         (foldl Tapp (Tcon (mn,c))
+                                         (foldl Tapp (Tcon (Just mn,c))
                                                 (map (Tvar . fst) utbs)) ts) tbs
         ch (tdef@(Newtype c tbs (Just t))) =  
            do tvenv <- foldM extendM eempty tbs
@@ -102,17 +120,32 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
         ch (tdef@(Newtype c tbs Nothing)) =
            {- should only occur for recursive Newtypes -}
            return cenv
-    
 
-    checkVdefg :: Bool -> (Tcenv,Tsenv,Tvenv,Cenv) -> (Venv,Venv) -> Vdefg -> CheckResult (Venv,Venv)
-    checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg =
+mkTypeEnvs :: [Tdef] -> CheckResult (Tcenv, Tsenv, Cenv)
+mkTypeEnvs tdefs = do
+  (tcenv, tsenv) <- foldM checkTdef0 (eempty,eempty) tdefs
+  cenv <- foldM (checkTdef tcenv) eempty tdefs
+  return (tcenv, tsenv, cenv)
+
+requireModulesEq :: Show a => Mname -> AnMname -> String -> a 
+                          -> Bool -> CheckResult ()
+requireModulesEq (Just mn) m msg t _      = require (mn == m) (mkErrMsg msg t)
+requireModulesEq Nothing m msg t emptyOk  = require emptyOk (mkErrMsg msg t)
+
+mkErrMsg :: Show a => String -> a -> String
+mkErrMsg msg t = "wrong module name in " ++ msg ++ ":\n" ++ show t    
+
+checkVdefg :: Bool -> (Tcenv,Tsenv,Tvenv,Cenv) -> (Venv,Venv) 
+               -> Vdefg -> CheckResult (Venv,Venv)
+checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg =
       case vdefg of
        Rec vdefs ->
            do e_venv' <- foldM extendM e_venv e_vts
               l_venv' <- foldM extendM l_venv l_vts
               let env' = (tcenv,tsenv,tvenv,cenv,e_venv',l_venv')
               mapM_ (\ (vdef@(Vdef ((m,v),t,e))) -> 
-                           do require (m == "" || m == mn) ("wrong module name in value definition:\n" ++ show vdef)
+                           do mn <- getMname
+                               requireModulesEq m mn "value definition" vdef True
                               k <- checkTy (tcenv,tvenv) t
                               require (k==Klifted) ("unlifted kind in:\n" ++ show vdef)
                               t' <- checkExp env' e
@@ -121,10 +154,11 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
                                         "declared type: " ++ show t ++ "\n" ++
                                         "expression type: " ++ show t')) vdefs
               return (e_venv',l_venv')
-           where e_vts  = [ (v,t) | Vdef ((m,v),t,_) <- vdefs, m /= "" ]
-                 l_vts  = [ (v,t) | Vdef (("",v),t,_) <- vdefs]
+           where e_vts  = [ (v,t) | Vdef ((Just _,v),t,_) <- vdefs ]
+                 l_vts  = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
        Nonrec (vdef@(Vdef ((m,v),t,e))) ->
-           do require (m == "" || m == mn) ("wrong module name in value definition:\n" ++ show vdef)
+           do mn <- getMname
+               requireModulesEq m mn "value definition" vdef True
               k <- checkTy (tcenv,tvenv) t 
               require (k /= Kopen) ("open kind in:\n" ++ show vdef)
               require ((not top_level) || (k /= Kunlifted)) ("top-level unlifted kind in:\n" ++ show vdef) 
@@ -133,15 +167,24 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
                        ("declared type doesn't match expression type in:\n" ++ show vdef  ++ "\n"  ++
                         "declared type: " ++ show t ++ "\n" ++
                         "expression type: " ++ show t') 
-              if m == "" then
+              if isNothing m then
                  do l_venv' <- extendM l_venv (v,t)
                     return (e_venv,l_venv')
                else
                 do e_venv' <- extendM e_venv (v,t)
                     return (e_venv',l_venv)
     
-    checkExp ::  (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Exp -> CheckResult Ty
-    checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) = ch 
+checkExpr :: AnMname -> Menv -> [Tdef] -> Venv -> Tvenv 
+               -> Exp -> Ty
+checkExpr mn menv tdefs venv tvenv e = case (runReaderT (do
+  (tcenv, tsenv, cenv) <- mkTypeEnvs tdefs
+  checkExp (tcenv, tsenv, tvenv, cenv, venv, eempty) e) 
+                            (mn, menv)) of
+    OkC t -> t
+    FailC s -> reportError s
+
+checkExp :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Exp -> CheckResult Ty
+checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) = ch 
       where 
        ch e0 = 
          case e0 of
@@ -189,9 +232,10 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
                 require (not (isUtupleTy vt)) ("lambda-bound unboxed tuple in:\n" ++ show e0) 
                 return (tArrow vt t)
            Let vdefg e ->
-             do (e_venv',l_venv') <- checkVdefg False (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg 
+             do (e_venv',l_venv') <- checkVdefg False (tcenv,tsenv,tvenv,cenv) 
+                                        (e_venv,l_venv) vdefg 
                 checkExp (tcenv,tsenv,tvenv,cenv,e_venv',l_venv') e
-           Case e (v,t) alts ->
+           Case e (v,t) resultTy alts ->
              do t' <- ch e 
                 checkTy (tcenv,tvenv) t
                 requireM (equalTy tsenv t t') 
@@ -225,8 +269,12 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
                 require (and bs)
                         ("alternative types don't match in:\n" ++ show e0 ++ "\n" ++
                          "types: " ++ show (t:ts))
+                 checkTy (tcenv,tvenv) resultTy
+                 require (t == resultTy) ("case alternative type doesn't " ++
+                   " match case return type in:\n" ++ show e0 ++ "\n" ++
+                   "alt type: " ++ show t ++ " return type: " ++ show resultTy)
                 return t
-           Coerce t e -> 
+           Cast e t -> 
              do ch e 
                 checkTy (tcenv,tvenv) t 
                 return t
@@ -236,8 +284,8 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
              do checkTy (tcenv,eempty) t {- external types must be closed -}
                 return t
     
-    checkAlt :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Ty -> Alt -> CheckResult Ty 
-    checkAlt (env@(tcenv,tsenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
+checkAlt :: (Tcenv,Tsenv,Tvenv,Cenv,Venv,Venv) -> Ty -> Alt -> CheckResult Ty 
+checkAlt (env@(tcenv,tsenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
       where 
        ch a0 = 
          case a0 of 
@@ -292,8 +340,8 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
            Adefault e ->
              checkExp env e
     
-    checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
-    checkTy (tcenv,tvenv) = ch
+checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
+checkTy (tcenv,tvenv) = ch
      where
        ch (Tvar tv) = lookupM tvenv tv
        ch (Tcon qtc) = qlookupM tcenv_ tcenv eempty qtc
@@ -312,9 +360,9 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
            do tvenv' <- extendM tvenv tb 
               checkTy (tcenv,tvenv') t
     
-    {- Type equality modulo newtype synonyms. -}
-    equalTy :: Tsenv -> Ty -> Ty -> CheckResult Bool
-    equalTy tsenv t1 t2 = 
+{- Type equality modulo newtype synonyms. -}
+equalTy :: Tsenv -> Ty -> Ty -> CheckResult Bool
+equalTy tsenv t1 t2 = 
            do t1' <- expand t1
               t2' <- expand t2
               return (t1' == t2')
@@ -339,19 +387,22 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
                 return (foldl Tapp t' ts)
     
 
-    mlookupM :: (Envs -> Env a b) -> Env a b -> Env a b -> Mname -> CheckResult (Env a b)
-    mlookupM selector external_env local_env m =   
-      if m == "" then
-        return local_env
-      else if m == mn then
-        return external_env
-      else 
-        case elookup globalEnv m of
-          Just env' -> return (selector env')
-          Nothing -> fail ("undefined module name: " ++ show m)
+mlookupM :: (Envs -> Env a b) -> Env a b -> Env a b -> Mname 
+          -> CheckResult (Env a b)
+mlookupM _ _ local_env    Nothing            = return local_env
+mlookupM selector external_env _ (Just m) = do
+  mn <- getMname
+  if m == mn
+     then return external_env
+     else do
+       globalEnv <- getGlobalEnv
+       case elookup globalEnv m of
+         Just env' -> return (selector env')
+         Nothing -> fail ("undefined module name: " ++ show m)
 
-    qlookupM :: (Ord a, Show a) => (Envs -> Env a b) -> Env a b -> Env a b -> (Mname,a) -> CheckResult b
-    qlookupM selector external_env local_env (m,k) =   
+qlookupM :: (Ord a, Show a) => (Envs -> Env a b) -> Env a b -> Env a b 
+                  -> Qual a -> CheckResult b
+qlookupM selector external_env local_env (m,k) =   
       do env <- mlookupM selector external_env local_env m
         lookupM env k
 
@@ -419,3 +470,5 @@ freeTvars (Tforall (t,_) t1) = delete t (freeTvars t1)
 freshTvar :: [Tvar] -> Tvar
 freshTvar tvs = maximum ("":tvs) ++ "x" -- one simple way!
 
+-- todo
+reportError s = error $ ("Core parser error: checkExpr failed with " ++ s)