[project @ 2000-01-28 20:52:37 by lewie]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcBinds.lhs
index 8c0ac2a..ec5a592 100644 (file)
@@ -21,6 +21,7 @@ import TcHsSyn                ( TcHsBinds, TcMonoBinds, TcId, zonkId, mkHsLet )
 import TcMonad
 import Inst            ( Inst, LIE, emptyLIE, mkLIE, plusLIE, plusLIEs, InstOrigin(..),
                          newDicts, tyVarsOfInst, instToId,
+                         getAllFunDepsOfLIE, getIPsOfLIE, zonkFunDeps
                        )
 import TcEnv           ( tcExtendLocalValEnv,
                          newSpecPragmaId, newLocalId,
@@ -28,6 +29,7 @@ import TcEnv          ( tcExtendLocalValEnv,
                          tcGetGlobalTyVars, tcExtendGlobalTyVars
                        )
 import TcSimplify      ( tcSimplify, tcSimplifyAndCheck, tcSimplifyToDicts )
+import TcImprove       ( tcImprove )
 import TcMonoType      ( tcHsType, checkSigTyVars,
                          TcSigInfo(..), tcTySig, maybeSig, sigCtxt
                        )
@@ -44,14 +46,15 @@ import PrelInfo             ( main_NAME, ioTyCon_NAME )
 
 import Id              ( Id, mkVanillaId, setInlinePragma )
 import Var             ( idType, idName )
-import IdInfo          ( IdInfo, vanillaIdInfo, setInlinePragInfo, InlinePragInfo(..) )
+import IdInfo          ( setInlinePragInfo, InlinePragInfo(..) )
 import Name            ( Name, getName, getOccName, getSrcLoc )
 import NameSet
 import Type            ( mkTyVarTy, tyVarsOfTypes, mkTyConApp,
                          splitSigmaTy, mkForAllTys, mkFunTys, getTyVar, 
-                         mkDictTy, splitRhoTy, mkForAllTy, isUnLiftedType, 
+                         mkPredTy, splitRhoTy, mkForAllTy, isUnLiftedType, 
                          isUnboxedType, unboxedTypeKind, boxedTypeKind
                        )
+import FunDeps         ( tyVarFunDep, oclose )
 import Var             ( TyVar, tyVarKind )
 import VarSet
 import Bag
@@ -250,6 +253,14 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
        -- (must do this before getTyVarsToGen)
     checkSigMatch top_lvl binder_names mono_ids tc_ty_sigs     `thenTc` \ maybe_sig_theta ->   
 
+       -- IMPROVE the LIE
+       -- Force any unifications dictated by functional dependencies.
+       -- Because unification may happen, it's important that this step
+       -- come before:
+       --   - computing vars over which to quantify
+       --   - zonking the generalized type vars
+    tcImprove lie_req `thenTc_`
+
        -- COMPUTE VARIABLES OVER WHICH TO QUANTIFY, namely tyvars_to_gen
        -- The tyvars_not_to_gen are free in the environment, and hence
        -- candidates for generalisation, but sometimes the monomorphism
@@ -279,8 +290,9 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
 
        -- SIMPLIFY THE LIE
     tcExtendGlobalTyVars tyvars_not_to_gen (
-       if null real_tyvars_to_gen_list then
-               -- No polymorphism, so no need to simplify context
+       let ips = getIPsOfLIE lie_req in
+       if null real_tyvars_to_gen_list && null ips then
+               -- No polymorphism, and no IPs, so no need to simplify context
            returnTc (lie_req, EmptyMonoBinds, [])
        else
        case maybe_sig_theta of
@@ -289,7 +301,7 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
                -- NB: no signatures => no polymorphic recursion, so no
                -- need to use lie_avail (which will be empty anyway)
            tcSimplify (text "tcBinds1" <+> ppr binder_names)
-                      top_lvl real_tyvars_to_gen lie_req       `thenTc` \ (lie_free, dict_binds, lie_bound) ->
+                      real_tyvars_to_gen lie_req       `thenTc` \ (lie_free, dict_binds, lie_bound) ->
            returnTc (lie_free, dict_binds, map instToId (bagToList lie_bound))
 
          Just (sig_theta, lie_avail) ->
@@ -397,6 +409,7 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
 
         -- BUILD RESULTS
     returnTc (
+        -- pprTrace "binding.." (ppr ((dicts_bound, dict_binds), exports, [idType poly_id | (_, poly_id, _) <- exports])) $
         AbsBinds real_tyvars_to_gen_list
                  dicts_bound
                  exports
@@ -482,7 +495,7 @@ is doing.
 %*                                                                     *
 %************************************************************************
 
-@getTyVarsToGen@ decides what type variables generalise over.
+@getTyVarsToGen@ decides what type variables to generalise over.
 
 For a "restricted group" -- see the monomorphism restriction
 for a definition -- we bind no dictionaries, and
@@ -524,22 +537,27 @@ getTyVarsToGen is_unrestricted mono_id_tys lie
   = tcGetGlobalTyVars                  `thenNF_Tc` \ free_tyvars ->
     zonkTcTypes mono_id_tys            `thenNF_Tc` \ zonked_mono_id_tys ->
     let
-       tyvars_to_gen = tyVarsOfTypes zonked_mono_id_tys `minusVarSet` free_tyvars
+       body_tyvars = tyVarsOfTypes zonked_mono_id_tys `minusVarSet` free_tyvars
     in
     if is_unrestricted
     then
-       returnNF_Tc (emptyVarSet, tyvars_to_gen)
+       let fds = getAllFunDepsOfLIE lie in
+       zonkFunDeps fds         `thenNF_Tc` \ fds' ->
+       let tvFundep = tyVarFunDep fds'
+           extended_tyvars = oclose tvFundep body_tyvars in
+       -- pprTrace "gTVTG" (ppr (lie, body_tyvars, extended_tyvars)) $
+       returnNF_Tc (emptyVarSet, extended_tyvars)
     else
        -- This recover and discard-errs is to avoid duplicate error
        -- messages; this, after all, is an "extra" call to tcSimplify
-       recoverNF_Tc (returnNF_Tc (emptyVarSet, tyvars_to_gen))         $
+       recoverNF_Tc (returnNF_Tc (emptyVarSet, body_tyvars))           $
        discardErrsTc                                                   $
 
-       tcSimplify (text "getTVG") NotTopLevel tyvars_to_gen lie    `thenTc` \ (_, _, constrained_dicts) ->
+       tcSimplify (text "getTVG") body_tyvars lie    `thenTc` \ (_, _, constrained_dicts) ->
        let
          -- ASSERT: dicts_sig is already zonked!
            constrained_tyvars    = foldrBag (unionVarSet . tyVarsOfInst) emptyVarSet constrained_dicts
-           reduced_tyvars_to_gen = tyvars_to_gen `minusVarSet` constrained_tyvars
+           reduced_tyvars_to_gen = body_tyvars `minusVarSet` constrained_tyvars
         in
         returnTc (constrained_tyvars, reduced_tyvars_to_gen)
 \end{code}
@@ -776,7 +794,7 @@ checkSigMatch top_lvl binder_names mono_ids sigs
        = tcAddSrcLoc src_loc   $
          checkTc (null theta) (mainContextsErr id)
 
-    mk_dict_tys theta = [mkDictTy c ts | (c,ts) <- theta]
+    mk_dict_tys theta = map mkPredTy theta
 
     sig_msg id tidy_ty = sep [ptext SLIT("When checking the type signature"),
                              nest 4 (ppr id <+> dcolon <+> ppr tidy_ty)]