Substantial improvements to coercion optimisation
[ghc-hetmet.git] / compiler / coreSyn / CoreLint.lhs
index 4893885..ee6541e 100644 (file)
@@ -633,10 +633,9 @@ lintCoercion ty@(FunTy ty1 ty2)
        ; return (FunTy s1 s2, FunTy t1 t2) }
 
 lintCoercion ty@(TyConApp tc tys) 
        ; return (FunTy s1 s2, FunTy t1 t2) }
 
 lintCoercion ty@(TyConApp tc tys) 
-  | Just (ar, rule) <- isCoercionTyCon_maybe tc
+  | Just (ar, desc) <- isCoercionTyCon_maybe tc
   = do { unless (tys `lengthAtLeast` ar) (badCo ty)
   = do { unless (tys `lengthAtLeast` ar) (badCo ty)
-       ; (s,t)   <- rule lintType lintCoercion 
-                         True (take ar tys)
+       ; (s,t) <- lintCoTyConApp ty desc (take ar tys)
        ; (ss,ts) <- mapAndUnzipM lintCoercion (drop ar tys)
        ; check_co_app ty (typeKind s) ss
        ; return (mkAppTys s ss, mkAppTys t ts) }
        ; (ss,ts) <- mapAndUnzipM lintCoercion (drop ar tys)
        ; check_co_app ty (typeKind s) ss
        ; return (mkAppTys s ss, mkAppTys t ts) }
@@ -677,6 +676,70 @@ lintCoercion (ForAllTy tv ty)
 badCo :: Coercion -> LintM a
 badCo co = failWithL (hang (ptext (sLit "Ill-kinded coercion term:")) 2 (ppr co))
 
 badCo :: Coercion -> LintM a
 badCo co = failWithL (hang (ptext (sLit "Ill-kinded coercion term:")) 2 (ppr co))
 
+---------------
+lintCoTyConApp :: Coercion -> CoTyConDesc -> [Coercion] -> LintM (Type,Type)
+-- Always called with correct number of coercion arguments
+-- First arg is just for error message
+lintCoTyConApp _ CoLeft  (co:_) = lintLR   fst             co 
+lintCoTyConApp _ CoRight (co:_) = lintLR   snd             co   
+lintCoTyConApp _ CoCsel1 (co:_) = lintCsel fstOf3   co 
+lintCoTyConApp _ CoCsel2 (co:_) = lintCsel sndOf3   co 
+lintCoTyConApp _ CoCselR (co:_) = lintCsel thirdOf3 co 
+
+lintCoTyConApp _ CoSym (co:_) 
+  = do { (ty1,ty2) <- lintCoercion co
+       ; return (ty2,ty1) }
+
+lintCoTyConApp co CoTrans (co1:co2:_) 
+  = do { (ty1a, ty1b) <- lintCoercion co1
+       ; (ty2a, ty2b) <- lintCoercion co2
+       ; checkL (ty1b `coreEqType` ty2a)
+                (hang (ptext (sLit "Trans coercion mis-match:") <+> ppr co)
+                    2 (vcat [ppr ty1a, ppr ty1b, ppr ty2a, ppr ty2b]))
+       ; return (ty1a, ty2b) }
+
+lintCoTyConApp _ CoInst (co:arg_ty:_) 
+  = do { co_tys <- lintCoercion co
+       ; arg_kind  <- lintType arg_ty
+       ; case decompInst_maybe co_tys of
+          Just ((tv1,tv2), (ty1,ty2)) 
+            | arg_kind `isSubKind` tyVarKind tv1
+            -> return (substTyWith [tv1] [arg_ty] ty1, 
+                       substTyWith [tv2] [arg_ty] ty2) 
+            | otherwise
+            -> failWithL (ptext (sLit "Kind mis-match in inst coercion"))
+         Nothing -> failWithL (ptext (sLit "Bad argument of inst")) }
+
+lintCoTyConApp _ (CoAxiom { co_ax_tvs = tvs 
+                          , co_ax_lhs = lhs_ty, co_ax_rhs = rhs_ty }) cos
+  = do { (tys1, tys2) <- mapAndUnzipM lintCoercion cos
+       ; sequence_ (zipWith checkKinds tvs tys1)
+       ; return (substTyWith tvs tys1 lhs_ty,
+                 substTyWith tvs tys2 rhs_ty) }
+
+lintCoTyConApp _ CoUnsafe (ty1:ty2:_) 
+  = do { _ <- lintType ty1
+       ; _ <- lintType ty2     -- Ignore kinds; it's unsafe!
+       ; return (ty1,ty2) } 
+
+lintCoTyConApp _ _ _ = panic "lintCoTyConApp"  -- Called with wrong number of coercion args
+
+----------
+lintLR :: (forall a. (a,a)->a) -> Coercion -> LintM (Type,Type)
+lintLR sel co
+  = do { (ty1,ty2) <- lintCoercion co
+       ; case decompLR_maybe (ty1,ty2) of
+           Just res -> return (sel res)
+           Nothing  -> failWithL (ptext (sLit "Bad argument of left/right")) }
+
+----------
+lintCsel :: (forall a. (a,a,a)->a) -> Coercion -> LintM (Type,Type)
+lintCsel sel co
+  = do { (ty1,ty2) <- lintCoercion co
+       ; case decompCsel_maybe (ty1,ty2) of
+           Just res -> return (sel res)
+           Nothing  -> failWithL (ptext (sLit "Bad argument of csel")) }
+
 -------------------
 lintType :: OutType -> LintM Kind
 lintType (TyVarTy tv)
 -------------------
 lintType :: OutType -> LintM Kind
 lintType (TyVarTy tv)