Type families: unify with family apps in checking mode
[ghc-hetmet.git] / compiler / typecheck / TcTyFuns.lhs
index 41432e6..ba73891 100644 (file)
@@ -347,7 +347,7 @@ finaliseEqsAndDicts (EqConfig { eqs     = eqs
                               })
   = do { traceTc $ ptext (sLit "finaliseEqsAndDicts")
        ; (eqs', subst_binds, locals', wanteds') <- substitute eqs locals wanteds
-       ; (eqs'', improved) <- instantiateAndExtract eqs'
+       ; (eqs'', improved) <- instantiateAndExtract eqs' (null locals)
        ; final_binds <- filterM nonTrivialDictBind $
                           bagToList (subst_binds `unionBags` binds)
 
@@ -442,17 +442,21 @@ deriveEqInst rewrite ty1 ty2 co
 
 instance Outputable RewriteInst where
   ppr (RewriteFam {rwi_fam = fam, rwi_args = args, rwi_right = rhs, rwi_co =co})
-    = hsep [ ppr co <+> text "::" 
+    = hsep [ pprEqInstCo co <+> text "::" 
            , ppr (mkTyConApp fam args)
            , text "~>"
            , ppr rhs
            ]
   ppr (RewriteVar {rwi_var = tv, rwi_right = rhs, rwi_co =co})
-    = hsep [ ppr co <+> text "::" 
+    = hsep [ pprEqInstCo co <+> text "::" 
            , ppr tv
            , text "~>"
            , ppr rhs
            ]
+
+pprEqInstCo :: EqInstCo -> SDoc
+pprEqInstCo (Left cotv) = ptext (sLit "Wanted") <+> ppr cotv
+pprEqInstCo (Right co)  = ptext (sLit "Local") <+> ppr co
 \end{code}
 
 The following functions turn an arbitrary equality into a set of normal
@@ -579,7 +583,13 @@ checkOrientation :: Type -> Type -> EqInstCo -> Inst -> TcM [RewriteInst]
 -- NB: We cannot assume that the two types already have outermost type
 --     synonyms expanded due to the recursion in the case of type applications.
 checkOrientation ty1 ty2 co inst
-  = go ty1 ty2
+  = do { traceTc $ ptext (sLit "checkOrientation of ") <+> 
+                   pprEqInstCo co <+> text "::" <+> 
+                   ppr ty1 <+> text "~" <+> ppr ty2
+       ; eqs <- go ty1 ty2
+       ; traceTc $ ptext (sLit "checkOrientation returns") <+> ppr eqs
+       ; return eqs
+       }
   where
       -- look through synonyms
     go ty1 ty2 | Just ty1' <- tcView ty1 = go ty1' ty2
@@ -656,7 +666,14 @@ flattenType inst ty
   = go ty
   where
       -- look through synonyms
-    go ty | Just ty' <- tcView ty = go ty'
+    go ty | Just ty' <- tcView ty 
+      = do { (ty_flat, co, eqs, skolems) <- go ty'
+           ; if null eqs
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (ty_flat, co, eqs, skolems)
+           }
 
       -- type variable => nothing to do
     go ty@(TyVarTy _)
@@ -687,34 +704,45 @@ flattenType inst ty
 
       -- data constructor application => flatten subtypes
       -- NB: Special cased for efficiency - could be handled as type application
-    go (TyConApp con args)
+    go ty@(TyConApp con args)
       = do { (args', cargs, args_eqss, args_skolemss) <- mapAndUnzip4M go args
-           ; return (mkTyConApp con args', 
-                     mkTyConApp con cargs,
-                     concat args_eqss,
-                     unionVarSets args_skolemss)
+           ; if null args_eqss
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (mkTyConApp con args', 
+                       mkTyConApp con cargs,
+                       concat args_eqss,
+                       unionVarSets args_skolemss)
            }
 
       -- function type => flatten subtypes
       -- NB: Special cased for efficiency - could be handled as type application
-    go (FunTy ty_l ty_r)
+    go ty@(FunTy ty_l ty_r)
       = do { (ty_l', co_l, eqs_l, skolems_l) <- go ty_l
            ; (ty_r', co_r, eqs_r, skolems_r) <- go ty_r
-           ; return (mkFunTy ty_l' ty_r', 
-                     mkFunTy co_l co_r,
-                     eqs_l ++ eqs_r, 
-                     skolems_l `unionVarSet` skolems_r)
+           ; if null eqs_l && null eqs_r
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (mkFunTy ty_l' ty_r', 
+                       mkFunTy co_l co_r,
+                       eqs_l ++ eqs_r, 
+                       skolems_l `unionVarSet` skolems_r)
            }
 
       -- type application => flatten subtypes
-    go (AppTy ty_l ty_r)
---      | Just (ty_l, ty_r) <- repSplitAppTy_maybe ty
+    go ty@(AppTy ty_l ty_r)
       = do { (ty_l', co_l, eqs_l, skolems_l) <- go ty_l
            ; (ty_r', co_r, eqs_r, skolems_r) <- go ty_r
-           ; return (mkAppTy ty_l' ty_r', 
-                     mkAppTy co_l co_r, 
-                     eqs_l ++ eqs_r, 
-                     skolems_l `unionVarSet` skolems_r)
+           ; if null eqs_l && null eqs_r
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (mkAppTy ty_l' ty_r', 
+                       mkAppTy co_l co_r, 
+                       eqs_l ++ eqs_r, 
+                       skolems_l `unionVarSet` skolems_r)
            }
 
       -- forall type => panic if the body contains a type family
@@ -1032,8 +1060,7 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
       = return (res, binds, locals, wanteds)
     subst (eq@(RewriteVar {rwi_var = tv, rwi_right = ty, rwi_co = co}):eqs) 
           res binds locals wanteds
-      = do { traceTc $ ptext (sLit "TcTyFuns.substitute:") <+> ppr tv <+>
-                       ptext (sLit "->") <+> ppr ty
+      = do { traceTc $ ptext (sLit "TcTyFuns.substitute:") <+> ppr eq
            ; let coSubst = zipOpenTvSubst [tv] [eqInstCoType co]
                  tySubst = zipOpenTvSubst [tv] [ty]
            ; eqs'               <- mapM (substEq eq coSubst tySubst) eqs
@@ -1098,16 +1125,19 @@ Return all remaining wanted equalities.  The Boolean result component is True
 if at least one instantiation of a flexible was performed.
 
 \begin{code}
-instantiateAndExtract :: [RewriteInst] -> TcM ([Inst], Bool)
-instantiateAndExtract eqs
-  = do { let wanteds = filter (isWantedCo . rwi_co) eqs
-       ; wanteds' <- mapM inst wanteds
+instantiateAndExtract :: [RewriteInst] -> Bool -> TcM ([Inst], Bool)
+instantiateAndExtract eqs localsEmpty
+  = do { wanteds' <- mapM inst wanteds
        ; let residuals = catMaybes wanteds'
              improved  = length wanteds /= length residuals
        ; residuals' <- mapM rewriteInstToInst residuals
        ; return (residuals', improved)
        }
   where
+    wanteds      = filter (isWantedCo . rwi_co) eqs
+    checkingMode = length eqs > length wanteds || not localsEmpty
+                     -- no local equalities or dicts => checking mode
+
     inst eq@(RewriteVar {rwi_var = tv1, rwi_right = ty2, rwi_co = co})
 
         -- co :: alpha ~ t
@@ -1119,6 +1149,14 @@ instantiateAndExtract eqs
       , isMetaTyVar tv2
       = doInst (not $ rwi_swapped eq) tv2 (mkTyVarTy tv1) co eq
 
+        -- co :: F args ~ alpha, and we are in checking mode (ie, no locals)
+    inst eq@(RewriteFam {rwi_fam = fam, rwi_args = args, rwi_right = ty2, 
+                         rwi_co = co})
+      | checkingMode
+      , Just tv2 <- tcGetTyVar_maybe ty2
+      , isMetaTyVar tv2
+      = doInst (not $ rwi_swapped eq) tv2 (mkTyConApp fam args) co eq
+
     inst eq = return $ Just eq
 
     doInst _swapped _tv _ty (Right ty) _eq