Use "on the spot" solving for fundeps
[ghc-hetmet.git] / compiler / typecheck / TcCanonical.lhs
index 861b262..8668d90 100644 (file)
@@ -1,7 +1,8 @@
 \begin{code}
 module TcCanonical(
     mkCanonical, mkCanonicals, mkCanonicalFEV, canWanteds, canGivens,
-    canOccursCheck, canEq
+    canOccursCheck, canEq,
+    rewriteWithFunDeps
  ) where
 
 #include "HsVersions.h"
@@ -9,7 +10,8 @@ module TcCanonical(
 import BasicTypes
 import Type
 import TcRnTypes
-
+import FunDeps
+import qualified TcMType as TcM
 import TcType
 import TcErrors
 import Coercion
@@ -18,6 +20,7 @@ import TyCon
 import TypeRep
 import Name
 import Var
+import VarEnv          ( TidyEnv )
 import Outputable
 import Control.Monad    ( unless, when, zipWithM, zipWithM_ )
 import MonadUtils
@@ -28,6 +31,7 @@ import Bag
 
 import HsBinds
 import TcSMonad
+import FastString
 \end{code}
 
 Note [Canonicalisation]
@@ -991,4 +995,75 @@ a.  If this turns out to be impossible, we next try expanding F
 itself, and so on.
 
 
+%************************************************************************
+%*                                                                      *
+%*          Functional dependencies, instantiation of equations
+%*                                                                      *
+%************************************************************************
 
+\begin{code}
+rewriteWithFunDeps :: [Equation]
+                   -> [Xi] -> CtFlavor
+                   -> TcS (Maybe ([Xi], [Coercion], CanonicalCts))
+rewriteWithFunDeps eqn_pred_locs xis fl
+ = do { fd_ev_poss <- mapM (instFunDepEqn fl) eqn_pred_locs
+      ; let fd_ev_pos :: [(Int,FlavoredEvVar)]
+            fd_ev_pos = concat fd_ev_poss
+            (rewritten_xis, cos) = unzip (rewriteDictParams fd_ev_pos xis)
+      ; fds <- mapM (\(_,fev) -> mkCanonicalFEV fev) fd_ev_pos
+      ; let fd_work = unionManyBags fds
+      ; if isEmptyBag fd_work 
+        then return Nothing
+        else return (Just (rewritten_xis, cos, fd_work)) }
+
+instFunDepEqn :: CtFlavor -- Precondition: Only Wanted or Derived
+              -> Equation
+              -> TcS [(Int, FlavoredEvVar)]
+-- Post: Returns the position index as well as the corresponding FunDep equality
+instFunDepEqn fl (FDEqn { fd_qtvs = qtvs, fd_eqs = eqs
+                        , fd_pred1 = d1, fd_pred2 = d2 })
+  = do { let tvs = varSetElems qtvs
+       ; tvs' <- mapM instFlexiTcS tvs
+       ; let subst = zipTopTvSubst tvs (mkTyVarTys tvs')
+       ; mapM (do_one subst) eqs }
+  where 
+    fl' = case fl of 
+             Given _     -> panic "mkFunDepEqns"
+             Wanted  loc -> Wanted  (push_ctx loc)
+             Derived loc -> Derived (push_ctx loc)
+
+    push_ctx loc = pushErrCtxt FunDepOrigin (False, mkEqnMsg d1 d2) loc
+
+    do_one subst (FDEq { fd_pos = i, fd_ty_left = ty1, fd_ty_right = ty2 })
+       = do { let sty1 = substTy subst ty1
+                  sty2 = substTy subst ty2
+            ; ev <- newCoVar sty1 sty2
+            ; return (i, mkEvVarX ev fl') }
+
+rewriteDictParams :: [(Int,FlavoredEvVar)] -- A set of coercions : (pos, ty' ~ ty)
+                  -> [Type]                -- A sequence of types: tys
+                  -> [(Type,Coercion)]     -- Returns            : [(ty', co : ty' ~ ty)]
+rewriteDictParams param_eqs tys
+  = zipWith do_one tys [0..]
+  where
+    do_one :: Type -> Int -> (Type,Coercion)
+    do_one ty n = case lookup n param_eqs of
+                    Just wev -> (get_fst_ty wev, mkCoVarCoercion (evVarOf wev))
+                    Nothing  -> (ty,ty)                -- Identity
+
+    get_fst_ty wev = case evVarOfPred wev of
+                          EqPred ty1 _ -> ty1
+                          _ -> panic "rewriteDictParams: non equality fundep"
+
+mkEqnMsg :: (TcPredType, SDoc) -> (TcPredType, SDoc) -> TidyEnv
+         -> TcM (TidyEnv, SDoc)
+mkEqnMsg (pred1,from1) (pred2,from2) tidy_env
+  = do  { zpred1 <- TcM.zonkTcPredType pred1
+        ; zpred2 <- TcM.zonkTcPredType pred2
+       ; let { tpred1 = tidyPred tidy_env zpred1
+              ; tpred2 = tidyPred tidy_env zpred2 }
+       ; let msg = vcat [ptext (sLit "When using functional dependencies to combine"),
+                         nest 2 (sep [ppr tpred1 <> comma, nest 2 from1]), 
+                         nest 2 (sep [ppr tpred2 <> comma, nest 2 from2])]
+       ; return (tidy_env, msg) }
+\end{code}
\ No newline at end of file