External Core typechecker - improve handling of coercions
[ghc-hetmet.git] / utils / ext-core / Check.hs
index 95c7281..cab3e62 100644 (file)
@@ -91,7 +91,8 @@ checkModule globalEnv (Module mn tdefs vdefgs) =
                               vdefgs
         return (eextend globalEnv 
             (mn,Envs{tcenv_=tcenv,tsenv_=tsenv,cenv_=cenv,venv_=e_venv})))
-    (mn, globalEnv)   
+         -- avoid name shadowing
+    (mn, eremove globalEnv mn)
 
 -- Like checkModule, but doesn't typecheck the code, instead just
 -- returning declared types for top-level defns.
@@ -229,51 +230,44 @@ mkErrMsg msg t = "wrong module name in " ++ msg ++ ":\n" ++ show t
 
 checkVdefg :: Bool -> (Tcenv,Tsenv,Tvenv,Cenv) -> (Venv,Venv) 
                -> Vdefg -> CheckResult (Venv,Venv)
-checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg =
+checkVdefg top_level (tcenv,tsenv,tvenv,cenv) (e_venv,l_venv) vdefg = do
+      mn <- getMname
       case vdefg of
        Rec vdefs ->
-           do e_venv' <- foldM extendVenv e_venv e_vts
-              l_venv' <- foldM extendVenv l_venv l_vts
+           do (e_venv', l_venv') <- makeEnv mn vdefs
               let env' = (tcenv,tsenv,tvenv,cenv,e_venv',l_venv')
-              mapM_ (\ (vdef@(Vdef ((m,_),t,e))) -> 
-                           do mn <- getMname
-                               requireModulesEq m mn "value definition" vdef True
-                              k <- checkTy (tcenv,tvenv) t
-                              require (k `eqKind` Klifted) ("unlifted kind in:\n" ++ show vdef)
-                              t' <- checkExp env' e
-                              requireM (equalTy tsenv t t') 
-                                       ("declared type doesn't match expression type in:\n"  ++ show vdef ++ "\n" ++  
-                                        "declared type: " ++ show t ++ "\n" ++
-                                        "expression type: " ++ show t')) vdefs
-              return (e_venv',l_venv')
-           where e_vts  = [ (v,t) | Vdef ((Just _,v),t,_) <- vdefs ]
-                 l_vts  = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
-       Nonrec (vdef@(Vdef ((m,v),t,e))) ->
-           do mn <- getMname
-               -- TODO: document this weirdness
-               let isZcMain = vdefIsMainWrapper mn m 
-               unless isZcMain $
-                    requireModulesEq m mn "value definition" vdef True
-              k <- checkTy (tcenv,tvenv) t 
-              require (not (k `eqKind` Kopen)) ("open kind in:\n" ++ show vdef)
-              require ((not top_level) || (not (k `eqKind` Kunlifted))) 
-                 ("top-level unlifted kind in:\n" ++ show vdef) 
-              t' <- checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) e
-              requireM (equalTy tsenv t t') 
-                       ("declared type doesn't match expression type in:\n" 
-                         ++ show vdef  ++ "\n"  ++
-                        "declared type: " ++ show t ++ "\n" ++
-                        "expression type: " ++ show t') 
-              if isNothing m then
-                 do l_venv' <- extendVenv l_venv (v,t)
-                    return (e_venv,l_venv')
-               else
-                              -- awful, but avoids name shadowing --
-                              -- otherwise we'd have two bindings for "main"
-                do e_venv' <- if isZcMain 
-                                 then return e_venv 
-                                 else extendVenv e_venv (v,t)
-                    return (e_venv',l_venv)
+               mapM_ (checkVdef (\ vdef k -> require (k `eqKind` Klifted) 
+                        ("unlifted kind in:\n" ++ show vdef)) env') 
+                     vdefs
+               return (e_venv', l_venv')
+       Nonrec vdef ->
+           do let env' = (tcenv, tsenv, tvenv, cenv, e_venv, l_venv)
+               checkVdef (\ vdef k -> do
+                     require (not (k `eqKind` Kopen)) ("open kind in:\n" ++ show vdef)
+                    require ((not top_level) || (not (k `eqKind` Kunlifted))) 
+                       ("top-level unlifted kind in:\n" ++ show vdef)) env' vdef
+               makeEnv mn [vdef]
+
+  where makeEnv mn vdefs = do
+             ev <- foldM extendVenv e_venv e_vts
+             lv <- foldM extendVenv l_venv l_vts
+             return (ev, lv)
+           where e_vts = [ (v,t) | Vdef ((Just m,v),t,_) <- vdefs,
+                                     not (vdefIsMainWrapper mn (Just m))]
+                 l_vts = [ (v,t) | Vdef ((Nothing,v),t,_) <- vdefs]
+        checkVdef checkKind env (vdef@(Vdef ((m,_),t,e))) = do
+          mn <- getMname
+          let isZcMain = vdefIsMainWrapper mn m
+          unless isZcMain $
+             requireModulesEq m mn "value definition" vdef True
+         k <- checkTy (tcenv,tvenv) t
+         checkKind vdef k
+         t' <- checkExp env e
+         requireM (equalTy tsenv t t') 
+                  ("declared type doesn't match expression type in:\n"  
+                    ++ show vdef ++ "\n" ++  
+                   "declared type: " ++ show t ++ "\n" ++
+                   "expression type: " ++ show t')
     
 vdefIsMainWrapper :: AnMname -> Mname -> Bool
 vdefIsMainWrapper enclosing defining = 
@@ -393,15 +387,11 @@ checkExp (tcenv,tsenv,tvenv,cenv,e_venv,l_venv) = ch
                 return t
            c@(Cast e t) -> 
              do eTy <- ch e 
-                k <- checkTy (tcenv,tvenv) t
-                 case k of
-                    (Keq fromTy toTy) -> do
-                        require (eTy == fromTy) ("Type mismatch in cast: c = "
-                             ++ show c ++ " and " ++ show eTy
-                             ++ " and " ++ show fromTy)
-                        return toTy
-                    _ -> fail ("Cast with non-equality kind: " ++ show e 
-                               ++ " and " ++ show t ++ " has kind " ++ show k)
+                (fromTy, toTy) <- checkTyCo (tcenv,tvenv) t
+                 require (eTy == fromTy) ("Type mismatch in cast: c = "
+                             ++ show c ++ "\nand eTy = " ++ show eTy
+                             ++ "\n and " ++ show fromTy)
+                 return toTy
            Note _ e -> 
              ch e
            External _ t -> 
@@ -466,7 +456,7 @@ checkAlt (env@(tcenv,tsenv,tvenv,cenv,e_venv,l_venv)) t0 = ch
              checkExp env e
     
 checkTy :: (Tcenv,Tvenv) -> Ty -> CheckResult Kind
-checkTy es@(tcenv,tvenv) t = ch t
+checkTy es@(tcenv,tvenv) = ch
      where
        ch (Tvar tv) = lookupM tvenv tv
        ch (Tcon qtc) = do
@@ -495,7 +485,7 @@ checkTy es@(tcenv,tvenv) t = ch t
                        -- :Co:TTypeable2 (->)
                        -- where in this case, :Co:TTypeable2 expects an
                        -- argument of kind (*->(*->*)) and (->) has kind
-                       -- (?->(?->*)). In general, I don't think it's
+                       -- (?->(?->*)). I'm not sure whether or not it's
                        -- sound to apply an arbitrary coercion to an
                        -- argument that's a subkind of what it expects.
                        then warn $ "Applying coercion " ++ show tc ++
@@ -507,23 +497,12 @@ checkTy es@(tcenv,tvenv) t = ch t
                      return $ Keq (substl tvs tys from) (substl tvs tys to)
                Nothing -> checkTapp t1 t2
             where checkTapp t1 t2 = do 
-                   k1 <- ch t1
+                    k1 <- ch t1
                    k2 <- ch t2
                    case k1 of
-                     Karrow k11 k12 ->
-                           case k2 of
-                               Keq eqTy1 eqTy2 -> do
-                                 -- Distribute the type operator over the
-                                 -- coercion
-                                 subK1 <- checkTy es eqTy1
-                                 subK2 <- checkTy es eqTy2
-                                 require (subK2 `subKindOf` k11 && 
-                                          subK1 `subKindOf` k11) $
-                                    kindError
-                                 return $ Keq (Tapp t1 eqTy1) (Tapp t1 eqTy2)
-                              _               -> do
-                                   require (k2 `subKindOf` k11) kindError
-                                   return k12
+                     Karrow k11 k12 -> do
+                         require (k2 `subKindOf` k11) kindError
+                         return k12
                             where kindError = 
                                     "kinds don't match in type application: "
                                     ++ show t ++ "\n" ++
@@ -533,39 +512,89 @@ checkTy es@(tcenv,tvenv) t = ch t
                            
        ch (Tforall tb t) = 
            do tvenv' <- extendTvenv tvenv tb 
-              k <- checkTy (tcenv,tvenv') t
-               return $ case k of
-                 -- distribute the forall
-                 Keq t1 t2 -> Keq (Tforall tb t1) (Tforall tb t2)
-                 _         -> k
+               checkTy (tcenv,tvenv') t
        ch (TransCoercion t1 t2) = do
-            k1 <- checkTy es t1
-            k2 <- checkTy es t2
-            case (k1, k2) of
-              (Keq ty1 ty2, Keq ty3 ty4) | ty2 == ty3 ->
-                  return $ Keq ty1 ty4
-              _ -> fail ("Bad kinds in trans. coercion: " ++
-                           show k1 ++ " " ++ show k2)
+            (ty1,ty2) <- checkTyCo es t1
+            (ty3,ty4) <- checkTyCo es t2
+            require (ty2 == ty3) ("Types don't match in trans. coercion: " ++
+                        show ty2 ++ " and " ++ show ty3)
+            return $ Keq ty1 ty4
        ch (SymCoercion t1) = do
-            k <- checkTy es t1
-            case k of
-               Keq ty1 ty2 -> return $ Keq ty2 ty1
-               _           -> fail ("Bad kind in sym. coercion: "
-                            ++ show k)
+            (ty1,ty2) <- checkTyCo es t1
+            return $ Keq ty2 ty1
        ch (UnsafeCoercion t1 t2) = do
             checkTy es t1
             checkTy es t2
             return $ Keq t1 t2
        ch (LeftCoercion t1) = do
-            k <- checkTy es t1
+            k <- checkTyCo es t1
             case k of
-              Keq (Tapp u _) (Tapp w _) -> return $ Keq u w
+              ((Tapp u _), (Tapp w _)) -> return $ Keq u w
               _ -> fail ("Bad coercion kind in operand of left: " ++ show k)
        ch (RightCoercion t1) = do
-            k <- checkTy es t1
+            k <- checkTyCo es t1
             case k of
-              Keq (Tapp _ v) (Tapp _ x) -> return $ Keq v x
+              ((Tapp _ v), (Tapp _ x)) -> return $ Keq v x
               _ -> fail ("Bad coercion kind in operand of left: " ++ show k)
+       ch (InstCoercion ty arg) = do
+            forallK <- checkTyCo es ty
+            case forallK of
+              ((Tforall (v1,k1) b1), (Tforall (v2,k2) b2)) -> do
+                 require (k1 `eqKind` k2) ("Kind mismatch in argument of inst: "
+                                            ++ show ty)
+                 argK <- checkTy es arg
+                 require (argK `eqKind` k1) ("Kind mismatch in type being "
+                           ++ "instantiated: " ++ show arg)
+                 let newLhs = substl [v1] [arg] b1
+                 let newRhs = substl [v2] [arg] b2
+                 return $ Keq newLhs newRhs
+              _ -> fail ("Non-forall-ty in argument to inst: " ++ show ty)
+
+checkTyCo :: (Tcenv, Tvenv) -> Ty -> CheckResult (Ty, Ty)
+checkTyCo es@(tcenv,_) t@(Tapp t1 t2) = 
+  (case splitTyConApp_maybe t of
+    Just (tc, tys) -> do
+       tcK <- qlookupM tcenv_ tcenv eempty tc
+       case tcK of
+ -- todo: avoid duplicating this code
+ -- blah, this almost calls for a different syntactic form
+ -- (for a defined-coercion app): (TCoercionApp Tcon [Ty])
+         Coercion (DefinedCoercion tbs (from, to)) -> do
+           require (length tys == length tbs) $ 
+            ("Arity mismatch in coercion app: " ++ show t)
+           let (tvs, tks) = unzip tbs
+           argKs <- mapM (checkTy es) tys
+           let kPairs = zip argKs tks
+           let kindsOk = all (uncurry eqKind) kPairs
+           if not kindsOk &&
+                        all (uncurry subKindOf) kPairs
+                       -- see comments in checkTy about this
+                       then warn $ "Applying coercion " ++ show tc ++
+                               " to arguments of kind " ++ show argKs
+                               ++ " when it expects: " ++ show tks
+                       else require kindsOk
+                        ("Kind mismatch in coercion app: " ++ show tks 
+                         ++ " and " ++ show argKs ++ " t = " ++ show t)
+           return (substl tvs tys from, substl tvs tys to)
+         _ -> checkTapp t1 t2
+    _ -> checkTapp t1 t2)
+       where checkTapp t1 t2 = do
+               (lhsRator, rhsRator) <- checkTyCo es t1
+               (lhs, rhs) <- checkTyCo es t2
+               -- Comp rule from paper
+               checkTy es (Tapp lhsRator lhs)
+               checkTy es (Tapp rhsRator rhs)
+               return (Tapp lhsRator lhs, Tapp rhsRator rhs)
+checkTyCo (tcenv, tvenv) (Tforall tb t) = do
+  tvenv' <- extendTvenv tvenv tb
+  (t1,t2) <- checkTyCo (tcenv, tvenv') t
+  return (Tforall tb t1, Tforall tb t2)
+checkTyCo es t = do
+  k <- checkTy es t
+  case k of
+    Keq t1 t2 -> return (t1, t2)
+    -- otherwise, expand by the "refl" rule
+    _          -> return (t, t)
 
 {- Type equality modulo newtype synonyms. -}
 equalTy :: Tsenv -> Ty -> Ty -> CheckResult Bool
@@ -598,6 +627,10 @@ equalTy tsenv t1 t2 =
             expand (RightCoercion t1) = do
                exp1 <- expand t1
                return $ RightCoercion exp1
+            expand (InstCoercion t1 t2) = do
+               exp1 <- expand t1
+               exp2 <- expand t2
+               return $ InstCoercion exp1 exp2
            expapp (t@(Tcon (m,tc))) ts = 
              do env <- mlookupM tsenv_ tsenv eempty m
                 case elookup env tc of 
@@ -685,6 +718,7 @@ substl tvs ts t = f (zip tvs ts) t
        UnsafeCoercion t1 t2 -> UnsafeCoercion (f env t1) (f env t2)
        LeftCoercion t1 -> LeftCoercion (f env t1)
        RightCoercion t1 -> RightCoercion (f env t1)
+       InstCoercion t1 t2 -> InstCoercion (f env t1) (f env t2)
      where free = foldr union [] (map (freeTvars.snd) env)
            t' = freshTvar free 
    
@@ -692,13 +726,14 @@ substl tvs ts t = f (zip tvs ts) t
 freeTvars :: Ty -> [Tvar]
 freeTvars (Tcon _) = []
 freeTvars (Tvar v) = [v]
-freeTvars (Tapp t1 t2) = (freeTvars t1) `union` (freeTvars t2)
+freeTvars (Tapp t1 t2) = freeTvars t1 `union` freeTvars t2
 freeTvars (Tforall (t,_) t1) = delete t (freeTvars t1) 
 freeTvars (TransCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
 freeTvars (SymCoercion t) = freeTvars t
 freeTvars (UnsafeCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
 freeTvars (LeftCoercion t) = freeTvars t
 freeTvars (RightCoercion t) = freeTvars t
+freeTvars (InstCoercion t1 t2) = freeTvars t1 `union` freeTvars t2
 
 {- Return any tvar *not* in the argument list. -}
 freshTvar :: [Tvar] -> Tvar