Fix Trac #1470: improve handling of recursive instances (needed for SYB3)
[ghc-hetmet.git] / compiler / typecheck / TcTyFuns.lhs
index 41432e6..113ea43 100644 (file)
@@ -30,7 +30,6 @@ import Type
 import TypeRep         ( Type(..) )
 import TyCon
 import HsSyn
-import Id
 import VarEnv
 import VarSet
 import Var
@@ -333,35 +332,26 @@ propagateEqs eqCfg@(EqConfig {eqs = todoEqs})
 -- set of instances are the locals (without equalities) and the second set are
 -- all residual wanteds, including equalities. 
 --
--- Remove all identity dictinary bindings (i.e., those whose source and target
--- dictionary are the same).  This is important for termination, as
--- TcSimplify.reduceContext takes the presence of dictionary bindings as an
--- indicator that there was some improvement.
---
 finaliseEqsAndDicts :: EqConfig 
                     -> TcM ([Inst], [Inst], TcDictBinds, Bool)
 finaliseEqsAndDicts (EqConfig { eqs     = eqs
                               , locals  = locals
                               , wanteds = wanteds
                               , binds   = binds
+                              , skolems = skolems
                               })
   = do { traceTc $ ptext (sLit "finaliseEqsAndDicts")
        ; (eqs', subst_binds, locals', wanteds') <- substitute eqs locals wanteds
-       ; (eqs'', improved) <- instantiateAndExtract eqs'
-       ; final_binds <- filterM nonTrivialDictBind $
-                          bagToList (subst_binds `unionBags` binds)
+       ; (eqs'', improved) <- instantiateAndExtract eqs' (null locals) skolems
+       ; let final_binds = subst_binds `unionBags` binds
 
+         -- Assert that all cotvs of wanted equalities are still unfilled, and
+         -- zonk all final insts, to make any improvement visible
        ; ASSERTM2( allM isValidWantedEqInst eqs'', ppr eqs'' )
-       ; return (locals', eqs'' ++ wanteds', listToBag final_binds, improved)
+       ; zonked_locals  <- zonkInsts locals'
+       ; zonked_wanteds <- zonkInsts (eqs'' ++ wanteds')
+       ; return (zonked_locals, zonked_wanteds, final_binds, improved)
        }
-  where
-    nonTrivialDictBind (L _ (VarBind { var_id = ide1
-                                     , var_rhs = L _ (HsWrap _ (HsVar ide2))}))
-      = do { ty1 <- zonkTcType (idType ide1)
-           ; ty2 <- zonkTcType (idType ide2)
-           ; return $ not (ty1 `tcEqType` ty2)
-           }
-    nonTrivialDictBind _ = return True
 \end{code}
 
 
@@ -442,17 +432,21 @@ deriveEqInst rewrite ty1 ty2 co
 
 instance Outputable RewriteInst where
   ppr (RewriteFam {rwi_fam = fam, rwi_args = args, rwi_right = rhs, rwi_co =co})
-    = hsep [ ppr co <+> text "::" 
+    = hsep [ pprEqInstCo co <+> text "::" 
            , ppr (mkTyConApp fam args)
            , text "~>"
            , ppr rhs
            ]
   ppr (RewriteVar {rwi_var = tv, rwi_right = rhs, rwi_co =co})
-    = hsep [ ppr co <+> text "::" 
+    = hsep [ pprEqInstCo co <+> text "::" 
            , ppr tv
            , text "~>"
            , ppr rhs
            ]
+
+pprEqInstCo :: EqInstCo -> SDoc
+pprEqInstCo (Left cotv) = ptext (sLit "Wanted") <+> ppr cotv
+pprEqInstCo (Right co)  = ptext (sLit "Local") <+> ppr co
 \end{code}
 
 The following functions turn an arbitrary equality into a set of normal
@@ -579,7 +573,13 @@ checkOrientation :: Type -> Type -> EqInstCo -> Inst -> TcM [RewriteInst]
 -- NB: We cannot assume that the two types already have outermost type
 --     synonyms expanded due to the recursion in the case of type applications.
 checkOrientation ty1 ty2 co inst
-  = go ty1 ty2
+  = do { traceTc $ ptext (sLit "checkOrientation of ") <+> 
+                   pprEqInstCo co <+> text "::" <+> 
+                   ppr ty1 <+> text "~" <+> ppr ty2
+       ; eqs <- go ty1 ty2
+       ; traceTc $ ptext (sLit "checkOrientation returns") <+> ppr eqs
+       ; return eqs
+       }
   where
       -- look through synonyms
     go ty1 ty2 | Just ty1' <- tcView ty1 = go ty1' ty2
@@ -656,7 +656,14 @@ flattenType inst ty
   = go ty
   where
       -- look through synonyms
-    go ty | Just ty' <- tcView ty = go ty'
+    go ty | Just ty' <- tcView ty 
+      = do { (ty_flat, co, eqs, skolems) <- go ty'
+           ; if null eqs
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (ty_flat, co, eqs, skolems)
+           }
 
       -- type variable => nothing to do
     go ty@(TyVarTy _)
@@ -687,34 +694,45 @@ flattenType inst ty
 
       -- data constructor application => flatten subtypes
       -- NB: Special cased for efficiency - could be handled as type application
-    go (TyConApp con args)
+    go ty@(TyConApp con args)
       = do { (args', cargs, args_eqss, args_skolemss) <- mapAndUnzip4M go args
-           ; return (mkTyConApp con args', 
-                     mkTyConApp con cargs,
-                     concat args_eqss,
-                     unionVarSets args_skolemss)
+           ; if null args_eqss
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (mkTyConApp con args', 
+                       mkTyConApp con cargs,
+                       concat args_eqss,
+                       unionVarSets args_skolemss)
            }
 
       -- function type => flatten subtypes
       -- NB: Special cased for efficiency - could be handled as type application
-    go (FunTy ty_l ty_r)
+    go ty@(FunTy ty_l ty_r)
       = do { (ty_l', co_l, eqs_l, skolems_l) <- go ty_l
            ; (ty_r', co_r, eqs_r, skolems_r) <- go ty_r
-           ; return (mkFunTy ty_l' ty_r', 
-                     mkFunTy co_l co_r,
-                     eqs_l ++ eqs_r, 
-                     skolems_l `unionVarSet` skolems_r)
+           ; if null eqs_l && null eqs_r
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (mkFunTy ty_l' ty_r', 
+                       mkFunTy co_l co_r,
+                       eqs_l ++ eqs_r, 
+                       skolems_l `unionVarSet` skolems_r)
            }
 
       -- type application => flatten subtypes
-    go (AppTy ty_l ty_r)
---      | Just (ty_l, ty_r) <- repSplitAppTy_maybe ty
+    go ty@(AppTy ty_l ty_r)
       = do { (ty_l', co_l, eqs_l, skolems_l) <- go ty_l
            ; (ty_r', co_r, eqs_r, skolems_r) <- go ty_r
-           ; return (mkAppTy ty_l' ty_r', 
-                     mkAppTy co_l co_r, 
-                     eqs_l ++ eqs_r, 
-                     skolems_l `unionVarSet` skolems_r)
+           ; if null eqs_l && null eqs_r
+             then     -- unchanged, keep the old type with folded synonyms
+               return (ty, ty, [], emptyVarSet)
+             else 
+               return (mkAppTy ty_l' ty_r', 
+                       mkAppTy co_l co_r, 
+                       eqs_l ++ eqs_r, 
+                       skolems_l `unionVarSet` skolems_r)
            }
 
       -- forall type => panic if the body contains a type family
@@ -1013,10 +1031,26 @@ implied by one variable equality exhaustively before turning to the next and
 We also apply the same substitutions to the local and wanted class and IP
 dictionaries.
 
-NB: Given that we apply the substitution corresponding to a single equality
-exhaustively, before turning to the next, and because we eliminate recursive
-equalities, all opportunities for subtitution will have been exhausted after
-we have considered each equality once.
+The treatment of flexibles in wanteds is quite subtle.  We absolutely want to
+substitute them into right-hand sides of equalities, to avoid getting two
+competing instantiations for a type variables; e.g., consider
+
+  F s ~ alpha, alpha ~ t
+
+If we don't substitute `alpha ~ t', we may instantiate t with `F s' instead.
+This would be bad as `F s' is less useful, eg, as an argument to a class
+constraint.
+
+However, there is no reason why we would want to *substitute* `alpha ~ t' into a
+class constraint.  We rather wait until `alpha' is instantiated to `t` and
+save the extra dictionary binding that substitution would introduce.
+Moreover, we may substitute wanted equalities only into wanted dictionaries.
+
+NB: 
+* Given that we apply the substitution corresponding to a single equality
+  exhaustively, before turning to the next, and because we eliminate recursive
+  equalities, all opportunities for subtitution will have been exhausted after
+  we have considered each equality once.
 
 \begin{code}
 substitute :: [RewriteInst]       -- equalities
@@ -1030,20 +1064,35 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
   where
     subst [] res binds locals wanteds 
       = return (res, binds, locals, wanteds)
+
     subst (eq@(RewriteVar {rwi_var = tv, rwi_right = ty, rwi_co = co}):eqs) 
           res binds locals wanteds
-      = do { traceTc $ ptext (sLit "TcTyFuns.substitute:") <+> ppr tv <+>
-                       ptext (sLit "->") <+> ppr ty
+      = do { traceTc $ ptext (sLit "TcTyFuns.substitute:") <+> ppr eq
+
            ; let coSubst = zipOpenTvSubst [tv] [eqInstCoType co]
                  tySubst = zipOpenTvSubst [tv] [ty]
-           ; eqs'               <- mapM (substEq eq coSubst tySubst) eqs
-           ; res'               <- mapM (substEq eq coSubst tySubst) res
-           ; (lbinds, locals')  <- mapAndUnzipM 
-                                     (substDict eq coSubst tySubst False) 
-                                     locals
-           ; (wbinds, wanteds') <- mapAndUnzipM 
-                                     (substDict eq coSubst tySubst True) 
-                                     wanteds
+           ; eqs' <- mapM (substEq eq coSubst tySubst) eqs
+           ; res' <- mapM (substEq eq coSubst tySubst) res
+
+             -- only susbtitute local equalities into local dictionaries
+           ; (lbinds, locals')  <- if not (isWantedCo co)
+                                   then 
+                                     mapAndUnzipM 
+                                       (substDict eq coSubst tySubst False) 
+                                       locals
+                                   else
+                                     return ([], locals)
+
+              -- flexible tvs in wanteds will be instantiated anyway, there is
+              -- no need to substitute them into dictionaries
+           ; (wbinds, wanteds') <- if not (isMetaTyVar tv && isWantedCo co)
+                                   then
+                                     mapAndUnzipM 
+                                       (substDict eq coSubst tySubst True) 
+                                       wanteds
+                                   else
+                                     return ([], wanteds)
+
            ; let binds' = unionManyBags $ binds : lbinds ++ wbinds
            ; subst eqs' (eq:res') binds' locals' wanteds'
            }
@@ -1076,8 +1125,7 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
       -- We have, co :: tv ~ ty 
       -- => apply [ty/tv] to dictionary predicate
       --    (but only if tv actually occurs in the predicate)
-    substDict (RewriteVar {rwi_var = tv}) 
-              coSubst tySubst isWanted dict
+    substDict (RewriteVar {rwi_var = tv}) coSubst tySubst isWanted dict
       | isClassDict dict
       , tv `elemVarSet` tyVarsOfPred (tci_pred dict)
       = do { let co1Subst = PredTy (substPred coSubst (tci_pred dict))
@@ -1095,19 +1143,23 @@ substitute eqs locals wanteds = subst eqs [] emptyBag locals wanteds
 For any *wanted* variable equality of the form co :: alpha ~ t or co :: a ~
 alpha, we instantiate alpha with t or a, respectively, and set co := id.
 Return all remaining wanted equalities.  The Boolean result component is True
-if at least one instantiation of a flexible was performed.
+if at least one instantiation of a flexible that is *not* a skolem from
+flattening was performed.
 
 \begin{code}
-instantiateAndExtract :: [RewriteInst] -> TcM ([Inst], Bool)
-instantiateAndExtract eqs
-  = do { let wanteds = filter (isWantedCo . rwi_co) eqs
-       ; wanteds' <- mapM inst wanteds
-       ; let residuals = catMaybes wanteds'
-             improved  = length wanteds /= length residuals
+instantiateAndExtract :: [RewriteInst] -> Bool -> TyVarSet -> TcM ([Inst], Bool)
+instantiateAndExtract eqs localsEmpty skolems
+  = do { results <- mapM inst wanteds
+       ; let residuals    = [eq | Left eq <- results]
+             only_skolems = and [tv `elemVarSet` skolems | Right tv <- results]
        ; residuals' <- mapM rewriteInstToInst residuals
-       ; return (residuals', improved)
+       ; return (residuals', not only_skolems)
        }
   where
+    wanteds      = filter (isWantedCo . rwi_co) eqs
+    checkingMode = length eqs > length wanteds || not localsEmpty
+                     -- no local equalities or dicts => checking mode
+
     inst eq@(RewriteVar {rwi_var = tv1, rwi_right = ty2, rwi_co = co})
 
         -- co :: alpha ~ t
@@ -1119,7 +1171,19 @@ instantiateAndExtract eqs
       , isMetaTyVar tv2
       = doInst (not $ rwi_swapped eq) tv2 (mkTyVarTy tv1) co eq
 
-    inst eq = return $ Just eq
+        -- co :: F args ~ alpha, and we are in checking mode (ie, no locals)
+    inst eq@(RewriteFam {rwi_fam = fam, rwi_args = args, rwi_right = ty2, 
+                         rwi_co = co})
+      | Just tv2 <- tcGetTyVar_maybe ty2
+      , isMetaTyVar tv2
+      , checkingMode || tv2 `elemVarSet` skolems
+                        -- !!!TODO: this is too liberal, even if tv2 is in 
+                        -- skolems we shouldn't instantiate if tvs occurs 
+                        -- in other equalities that may propagate it into the
+                        -- environment
+      = doInst (not $ rwi_swapped eq) tv2 (mkTyConApp fam args) co eq
+
+    inst eq = return $ Left eq
 
     doInst _swapped _tv _ty (Right ty) _eq 
       = pprPanic "TcTyFuns.doInst: local eq: " (ppr ty)
@@ -1129,9 +1193,14 @@ instantiateAndExtract eqs
            }
       where
         -- meta variable has been filled already
-        -- => ignore (must be a skolem that was introduced by flattening locals)
-        uMeta _swapped _tv (IndirectTv _) _ty _cotv
-          = return Nothing
+        -- => keep the equality
+        uMeta _swapped tv (IndirectTv fill_ty) ty _cotv
+          = do { traceTc $ 
+                   ptext (sLit "flexible") <+> ppr tv <+>
+                   ptext (sLit "already filled with") <+> ppr fill_ty <+>
+                   ptext (sLit "meant to fill with") <+> ppr ty
+               ; return $ Left eq
+               }
 
         -- type variable meets type variable
         -- => check that tv2 hasn't been updated yet and choose which to update
@@ -1153,7 +1222,7 @@ instantiateAndExtract eqs
         -- signature skolem meets non-variable type
         -- => cannot update (retain the equality)!
         uMeta _swapped _tv (DoneTv (MetaTv (SigTv _) _)) _non_tv_ty _cotv
-          = return $ Just eq
+          = return $ Left eq
 
         -- updatable meta variable meets non-variable type
         -- => occurs check, monotype check, and kinds match check, then update
@@ -1168,7 +1237,7 @@ instantiateAndExtract eqs
                    Just ty' ->
                      do { checkUpdateMeta swapped tv ref ty'  -- update meta var
                         ; writeMetaTyVar cotv ty'             -- update co var
-                        ; return Nothing
+                        ; return $ Right tv
                         }
                }
 
@@ -1180,35 +1249,38 @@ instantiateAndExtract eqs
         uMetaVar swapped tv1 (MetaTv _ ref) tv2 (SkolemTv _) cotv
           = do { checkUpdateMeta swapped tv1 ref (mkTyVarTy tv2)
                ; writeMetaTyVar cotv (mkTyVarTy tv2)
-               ; return Nothing
+               ; return $ Right tv1
                }
 
         -- meta variable meets meta variable 
         -- => be clever about which of the two to update 
         --   (from TcUnify.uUnfilledVars minus boxy stuff)
         uMetaVar swapped tv1 (MetaTv info1 ref1) tv2 (MetaTv info2 ref2) cotv
-          = do { case (info1, info2) of
-                   -- Avoid SigTvs if poss
-                   (SigTv _, _      ) | k1_sub_k2 -> update_tv2
-                   (_,       SigTv _) | k2_sub_k1 -> update_tv1
-
-                   (_,   _) | k1_sub_k2 -> if k2_sub_k1 && nicer_to_update_tv1
-                                           then update_tv1     -- Same kinds
-                                           else update_tv2
-                            | k2_sub_k1 -> update_tv1
-                            | otherwise -> kind_err
+          = do { tv <- case (info1, info2) of
+                         -- Avoid SigTvs if poss
+                         (SigTv _, _      ) | k1_sub_k2 -> update_tv2
+                         (_,       SigTv _) | k2_sub_k1 -> update_tv1
+
+                         (_,   _) | k1_sub_k2 -> if k2_sub_k1 && 
+                                                    nicer_to_update_tv1
+                                                 then update_tv1  -- Same kinds
+                                                 else update_tv2
+                                  | k2_sub_k1 -> update_tv1
+                                  | otherwise -> kind_err >> return tv1
               -- Update the variable with least kind info
               -- See notes on type inference in Kind.lhs
               -- The "nicer to" part only applies if the two kinds are the same,
               -- so we can choose which to do.
 
                ; writeMetaTyVar cotv (mkTyVarTy tv2)
-               ; return Nothing
+               ; return $ Right tv
                }
           where
                 -- Kinds should be guaranteed ok at this point
             update_tv1 = updateMeta tv1 ref1 (mkTyVarTy tv2)
+                         >> return tv1
             update_tv2 = updateMeta tv2 ref2 (mkTyVarTy tv1)
+                         >> return tv2
 
             kind_err = addErrCtxtM (unifyKindCtxt swapped tv1 (mkTyVarTy tv2)) $
                        unifyKindMisMatch k1 k2