Use implication constraints to improve type inference
[ghc-hetmet.git] / compiler / typecheck / TcBinds.lhs
index 33c8ddb..9e0b583 100644 (file)
@@ -1,4 +1,5 @@
 %
+% (c) The University of Glasgow 2006
 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
 %
 \section[TcBinds]{TcBinds}
@@ -15,58 +16,36 @@ module TcBinds ( tcLocalBinds, tcTopBinds,
 import {-# SOURCE #-} TcMatches ( tcGRHSsPat, tcMatchesFun )
 import {-# SOURCE #-} TcExpr  ( tcMonoExpr )
 
-import DynFlags                ( dopt, DynFlags,
-                         DynFlag(Opt_MonomorphismRestriction, Opt_MonoPatBinds, Opt_GlasgowExts) )
-import HsSyn           ( HsExpr(..), HsBind(..), LHsBinds, LHsBind, Sig(..),
-                         HsLocalBinds(..), HsValBinds(..), HsIPBinds(..),
-                         LSig, Match(..), IPBind(..), Prag(..),
-                         HsType(..), LHsType, HsExplicitForAll(..), hsLTyVarNames, 
-                         isVanillaLSig, sigName, placeHolderNames, isPragLSig,
-                         LPat, GRHSs, MatchGroup(..), pprLHsBinds, mkHsCoerce,
-                         collectHsBindBinders, collectPatBinders, pprPatBind, isBangHsBind
-                       )
-import TcHsSyn         ( zonkId )
+import DynFlags
+import HsSyn
+import TcHsSyn
 
 import TcRnMonad
-import Inst            ( newDictsAtLoc, newIPDict, instToId )
-import TcEnv           ( tcExtendIdEnv, tcExtendIdEnv2, tcExtendTyVarEnv2, 
-                         pprBinders, tcLookupLocalId_maybe, tcLookupId,
-                         tcGetGlobalTyVars )
-import TcUnify         ( tcInfer, tcSubExp, unifyTheta, 
-                         bleatEscapedTvs, sigCtxt )
-import TcSimplify      ( tcSimplifyInfer, tcSimplifyInferCheck, 
-                         tcSimplifyRestricted, tcSimplifyIPs )
-import TcHsType                ( tcHsSigType, UserTypeCtxt(..) )
-import TcPat           ( tcPat, PatCtxt(..) )
-import TcSimplify      ( bindInstsOfLocalFuns )
-import TcMType         ( newFlexiTyVarTy, zonkQuantifiedTyVar, zonkSigTyVar,
-                         tcInstSigTyVars, tcInstSkolTyVars, tcInstType, 
-                         zonkTcType, zonkTcTypes, zonkTcTyVars )
-import TcType          ( TcType, TcTyVar, TcThetaType, 
-                         SkolemInfo(SigSkol), UserTypeCtxt(FunSigCtxt), 
-                         TcTauType, TcSigmaType, isUnboxedTupleType,
-                         mkTyVarTy, mkForAllTys, mkFunTys, exactTyVarsOfType, 
-                         mkForAllTy, isUnLiftedType, tcGetTyVar, 
-                         mkTyVarTys, tidyOpenTyVar )
-import Kind            ( argTypeKind )
-import VarEnv          ( TyVarEnv, emptyVarEnv, lookupVarEnv, extendVarEnv ) 
-import TysWiredIn      ( unitTy )
-import TysPrim         ( alphaTyVar )
-import Id              ( Id, mkLocalId, mkVanillaGlobal )
-import IdInfo          ( vanillaIdInfo )
-import Var             ( TyVar, idType, idName )
-import Name            ( Name )
+import Inst
+import TcEnv
+import TcUnify
+import TcSimplify
+import TcHsType
+import TcPat
+import TcMType
+import TcType
+import {- Kind parts of -} Type
+import VarEnv
+import TysPrim
+import Id
+import IdInfo
+import Var ( TyVar )
+import Name
 import NameSet
 import NameEnv
 import VarSet
-import SrcLoc          ( Located(..), unLoc, getLoc )
+import SrcLoc
 import Bag
-import ErrUtils                ( Message )
-import Digraph         ( SCC(..), stronglyConnComp )
-import Maybes          ( expectJust, isJust, isNothing, orElse )
-import Util            ( singleton )
-import BasicTypes      ( TopLevelFlag(..), isTopLevel, isNotTopLevel,
-                         RecFlag(..), isNonRec, InlineSpec, defaultInlineSpec )
+import ErrUtils
+import Digraph
+import Maybes
+import Util
+import BasicTypes
 import Outputable
 \end{code}
 
@@ -323,7 +302,7 @@ tcPolyBinds top_lvl sig_fn prag_fn rec_group rec_tc binds
     in
        -- SET UP THE MAIN RECOVERY; take advantage of any type sigs
     setSrcSpan loc                             $
-    recoverM (recoveryCode binder_names)       $ do 
+    recoverM (recoveryCode binder_names sig_fn)        $ do 
 
   { traceTc (ptext SLIT("------------------------------------------------"))
   ; traceTc (ptext SLIT("Bindings for") <+> ppr binder_names)
@@ -364,43 +343,47 @@ tcPolyBinds top_lvl sig_fn prag_fn rec_group rec_tc binds
   ; exports <- mapM (mkExport prag_fn tyvars_to_gen' (map idType dict_ids))
                    mono_bind_infos
 
-       -- ZONK THE poly_ids, because they are used to extend the type 
-       -- environment; see the invariant on TcEnv.tcExtendIdEnv 
   ; let        poly_ids = [poly_id | (_, poly_id, _, _) <- exports]
-  ; zonked_poly_ids <- mappM zonkId poly_ids
-
-  ; traceTc (text "binding:" <+> ppr (zonked_poly_ids `zip` map idType zonked_poly_ids))
+  ; traceTc (text "binding:" <+> ppr (poly_ids `zip` map idType poly_ids))
 
   ; let abs_bind = L loc $ AbsBinds tyvars_to_gen'
                                    dict_ids exports
                                    (dict_binds `unionBags` binds')
 
-  ; return ([unitBag abs_bind], zonked_poly_ids)
+  ; return ([unitBag abs_bind], poly_ids)      -- poly_ids are guaranteed zonked by mkExport
   } }
 
 
 --------------
 mkExport :: TcPragFun -> [TyVar] -> [TcType] -> MonoBindInfo
         -> TcM ([TyVar], Id, Id, [Prag])
+-- mkExport generates exports with 
+--     zonked type variables, 
+--     zonked poly_ids
+-- The former is just because no further unifications will change
+-- the quantified type variables, so we can fix their final form
+-- right now.
+-- The latter is needed because the poly_ids are used to extend the
+-- type environment; see the invariant on TcEnv.tcExtendIdEnv 
+
+-- Pre-condition: the inferred_tvs are already zonked
+
 mkExport prag_fn inferred_tvs dict_tys (poly_name, mb_sig, mono_id)
-  = case mb_sig of
-      Nothing  -> do { prags <- tcPrags poly_id (prag_fn poly_name)
-                    ; return (inferred_tvs, poly_id, mono_id, prags) }
-         where
-           poly_id = mkLocalId poly_name poly_ty
-           poly_ty = mkForAllTys inferred_tvs
-                                      $ mkFunTys dict_tys 
-                                      $ idType mono_id
-
-      Just sig -> do { let poly_id = sig_id sig
-                    ; prags <- tcPrags poly_id (prag_fn poly_name)
-                    ; sig_tys <- zonkTcTyVars (sig_tvs sig)
-                    ; let sig_tvs' = map (tcGetTyVar "mkExport") sig_tys
-                    ; return (sig_tvs', poly_id, mono_id, prags) }
-               -- We zonk the sig_tvs here so that the export triple
-               -- always has zonked type variables; 
-               -- a convenient invariant
+  = do { (tvs, poly_id) <- mk_poly_id mb_sig
+
+       ; poly_id' <- zonkId poly_id
+       ; prags <- tcPrags poly_id' (prag_fn poly_name)
+               -- tcPrags requires a zonked poly_id
+
+       ; return (tvs, poly_id', mono_id, prags) }
+  where
+    poly_ty = mkForAllTys inferred_tvs (mkFunTys dict_tys (idType mono_id))
 
+    mk_poly_id Nothing    = return (inferred_tvs, mkLocalId poly_name poly_ty)
+    mk_poly_id (Just sig) = do { tvs <- mapM zonk_tv (sig_tvs sig)
+                              ; return (tvs,  sig_id sig) }
+
+    zonk_tv tv = do { ty <- zonkTcTyVar tv; return (tcGetTyVar "mkExport" ty) }
 
 ------------------------
 type TcPragFun = Name -> [LSig Name]
@@ -423,6 +406,8 @@ tcPrags poly_id prags = mapM tc_prag prags
 pragSigCtxt prag = hang (ptext SLIT("In the pragma")) 2 (ppr prag)
 
 tcPrag :: TcId -> Sig Name -> TcM Prag
+-- Pre-condition: the poly_id is zonked
+-- Reason: required by tcSubExp
 tcPrag poly_id (SpecSig orig_name hs_ty inl) = tcSpecPrag poly_id hs_ty inl
 tcPrag poly_id (SpecInstSig hs_ty)          = tcSpecPrag poly_id hs_ty defaultInlineSpec
 tcPrag poly_id (InlineSig v inl)             = return (InlinePrag inl)
@@ -434,7 +419,7 @@ tcSpecPrag poly_id hs_ty inl
        ; (co_fn, lie) <- getLIE (tcSubExp (idType poly_id) spec_ty)
        ; extendLIEs lie
        ; let const_dicts = map instToId lie
-       ; return (SpecPrag (mkHsCoerce co_fn (HsVar poly_id)) spec_ty const_dicts inl) }
+       ; return (SpecPrag (mkHsWrap co_fn (HsVar poly_id)) spec_ty const_dicts inl) }
        -- Most of the work of specialisation is done by 
        -- the desugarer, guided by the SpecPrag
   
@@ -442,15 +427,14 @@ tcSpecPrag poly_id hs_ty inl
 -- If typechecking the binds fails, then return with each
 -- signature-less binder given type (forall a.a), to minimise 
 -- subsequent error messages
-recoveryCode binder_names
+recoveryCode binder_names sig_fn
   = do { traceTc (text "tcBindsWithSigs: error recovery" <+> ppr binder_names)
        ; poly_ids <- mapM mk_dummy binder_names
        ; return ([], poly_ids) }
   where
-    mk_dummy name = do { mb_id <- tcLookupLocalId_maybe name
-                       ; case mb_id of
-                             Just id -> return id              -- Had signature, was in envt
-                             Nothing -> return (mkLocalId name forall_a_a) }    -- No signature
+    mk_dummy name 
+       | isJust (sig_fn name) = tcLookupId name        -- Had signature; look it up
+       | otherwise            = return (mkLocalId name forall_a_a)    -- No signature
 
 forall_a_a :: TcType
 forall_a_a = mkForAllTy alphaTyVar (mkTyVarTy alphaTyVar)
@@ -542,7 +526,7 @@ tcMonoBinds [L b_loc (FunBind { fun_id = L nm_loc name, fun_infix = inf,
        ; let mono_id = mkLocalId mono_name zonked_rhs_ty
        ; return (unitBag (L b_loc (FunBind { fun_id = L nm_loc mono_id, fun_infix = inf,
                                              fun_matches = matches', bind_fvs = fvs,
-                                             fun_co_fn = co_fn })),
+                                             fun_co_fn = co_fn, fun_tick = Nothing })),
                  [(name, Nothing, mono_id)]) }
 
 tcMonoBinds [L b_loc (FunBind { fun_id = L nm_loc name, fun_infix = inf, 
@@ -566,7 +550,8 @@ tcMonoBinds [L b_loc (FunBind { fun_id = L nm_loc name, fun_infix = inf,
 
        ; let fun_bind' = FunBind { fun_id = L nm_loc mono_id, 
                                    fun_infix = inf, fun_matches = matches',
-                                   bind_fvs = placeHolderNames, fun_co_fn = co_fn }
+                                   bind_fvs = placeHolderNames, fun_co_fn = co_fn, 
+                                   fun_tick = Nothing }
        ; return (unitBag (L b_loc fun_bind'),
                  [(name, Just tc_sig, mono_id)]) }
 
@@ -645,9 +630,8 @@ tcLhs sig_fn bind@(PatBind { pat_lhs = pat, pat_rhs = grhss })
                                      | (name, Just sig) <- nm_sig_prs]
              sig_tau_fn  = lookupNameEnv tau_sig_env
 
-             tc_pat exp_ty = tcPat (LetPat sig_tau_fn) pat exp_ty unitTy $ \ _ ->
+             tc_pat exp_ty = tcLetPat sig_tau_fn pat exp_ty $
                              mapM lookup_info nm_sig_prs
-               -- The unitTy is a bit bogus; it's the "result type" for lookup_info.  
 
                -- After typechecking the pattern, look up the binder
                -- names, which the pattern has brought into scope.
@@ -672,7 +656,8 @@ tcRhs (TcFunBind info fun'@(L _ mono_id) inf matches)
   = do { (co_fn, matches') <- tcMatchesFun (idName mono_id) matches 
                                            (idType mono_id)
        ; return (FunBind { fun_id = fun', fun_infix = inf, fun_matches = matches',
-                           bind_fvs = placeHolderNames, fun_co_fn = co_fn }) }
+                           bind_fvs = placeHolderNames, fun_co_fn = co_fn,
+                           fun_tick = Nothing }) }
 
 tcRhs bind@(TcPatBind _ pat' grhss pat_ty)
   = do { grhss' <- addErrCtxt (patMonoBindsCtxt pat' grhss) $
@@ -725,16 +710,17 @@ generalise dflags top_lvl bind_list sig_fn mono_infos lie_req
   = tcSimplifyInfer doc tau_tvs lie_req
 
   | otherwise  -- UNRESTRICTED CASE, WITH TYPE SIGS
-  = do { sig_lie <- unifyCtxts sigs    -- sigs is non-empty
+  = do { sig_lie <- unifyCtxts sigs    -- sigs is non-empty; sig_lie is zonked
        ; let   -- The "sig_avails" is the stuff available.  We get that from
                -- the context of the type signature, BUT ALSO the lie_avail
                -- so that polymorphic recursion works right (see Note [Polymorphic recursion])
                local_meths = [mkMethInst sig mono_id | (_, Just sig, mono_id) <- mono_infos]
                sig_avails = sig_lie ++ local_meths
+               loc = sig_loc (head sigs)
 
        -- Check that the needed dicts can be
        -- expressed in terms of the signature ones
-       ; (forall_tvs, dict_binds) <- tcSimplifyInferCheck doc tau_tvs sig_avails lie_req
+       ; (forall_tvs, dict_binds) <- tcSimplifyInferCheck loc tau_tvs sig_avails lie_req
        
        -- Check that signature type variables are OK
        ; final_qtvs <- checkSigsTyVars forall_tvs sigs
@@ -751,7 +737,8 @@ generalise dflags top_lvl bind_list sig_fn mono_infos lie_req
 
     mkMethInst (TcSigInfo { sig_id = poly_id, sig_tvs = tvs, 
                            sig_theta = theta, sig_loc = loc }) mono_id
-      = Method mono_id poly_id (mkTyVarTys tvs) theta loc
+      = Method {tci_id = mono_id, tci_oid = poly_id, tci_tys = mkTyVarTys tvs,
+               tci_theta = theta, tci_loc = loc}
 \end{code}
 
 unifyCtxts checks that all the signature contexts are the same
@@ -768,14 +755,16 @@ might not otherwise be related.  This is a rather subtle issue.
 
 \begin{code}
 unifyCtxts :: [TcSigInfo] -> TcM [Inst]
+-- Post-condition: the returned Insts are full zonked
 unifyCtxts (sig1 : sigs)       -- Argument is always non-empty
   = do { mapM unify_ctxt sigs
-       ; newDictsAtLoc (sig_loc sig1) (sig_theta sig1) }
+       ; theta <- zonkTcThetaType (sig_theta sig1)
+       ; newDictBndrs (sig_loc sig1) theta }
   where
     theta1 = sig_theta sig1
     unify_ctxt :: TcSigInfo -> TcM ()
     unify_ctxt sig@(TcSigInfo { sig_theta = theta })
-       = setSrcSpan (instLocSrcSpan (sig_loc sig))     $
+       = setSrcSpan (instLocSpan (sig_loc sig))        $
          addErrCtxt (sigContextsCtxt sig1 sig)         $
          unifyTheta theta1 theta
 
@@ -970,13 +959,12 @@ mkTcSigFun :: [LSig Name] -> TcSigFun
 -- Precondition: no duplicates
 mkTcSigFun sigs = lookupNameEnv env
   where
-    env = mkNameEnv [(name, scoped_tyvars hs_ty)
-                   | L span (TypeSig (L _ name) (L _ hs_ty)) <- sigs]
-    scoped_tyvars (HsForAllTy Explicit tvs _ _) = hsLTyVarNames tvs
-    scoped_tyvars other                                = []
+    env = mkNameEnv [(name, hsExplicitTvs lhs_ty)
+                   | L span (TypeSig (L _ name) lhs_ty) <- sigs]
        -- The scoped names are the ones explicitly mentioned
        -- in the HsForAll.  (There may be more in sigma_ty, because
        -- of nested type synonyms.  See Note [Scoped] with TcSigInfo.)
+       -- See Note [Only scoped tyvars are in the TyVarEnv]
 
 ---------------
 data TcSigInfo
@@ -995,6 +983,19 @@ data TcSigInfo
        sig_loc    :: InstLoc           -- The location of the signature
     }
 
+
+--     Note [Only scoped tyvars are in the TyVarEnv]
+-- We are careful to keep only the *lexically scoped* type variables in
+-- the type environment.  Why?  After all, the renamer has ensured
+-- that only legal occurrences occur, so we could put all type variables
+-- into the type env.
+--
+-- But we want to check that two distinct lexically scoped type variables
+-- do not map to the same internal type variable.  So we need to know which
+-- the lexically-scoped ones are... and at the moment we do that by putting
+-- only the lexically scoped ones into the environment.
+
+
 --     Note [Scoped]
 -- There may be more instantiated type variables than scoped 
 -- ones.  For example:
@@ -1007,7 +1008,7 @@ data TcSigInfo
 -- and remember the names from the original HsForAllTy in sig_scoped
 
 --     Note [Instantiate sig]
--- It's vital to instantiate a type signature with fresh variable.
+-- It's vital to instantiate a type signature with fresh variables.
 -- For example:
 --     type S = forall a. a->a
 --     f,g :: S
@@ -1041,9 +1042,15 @@ tcInstSig_maybe sig_fn name
 
 tcInstSig :: Bool -> Name -> [Name] -> TcM TcSigInfo
 -- Instantiate the signature, with either skolems or meta-type variables
--- depending on the use_skols boolean
+-- depending on the use_skols boolean.  This variable is set True
+-- when we are typechecking a single function binding; and False for
+-- pattern bindings and a group of several function bindings.
+-- Reason: in the latter cases, the "skolems" can be unified together, 
+--        so they aren't properly rigid in the type-refinement sense.
+-- NB: unless we are doing H98, each function with a sig will be done
+--     separately, even if it's mutually recursive, so use_skols will be True
 --
--- We always instantiate with freshs uniques,
+-- We always instantiate with fresh uniques,
 -- although we keep the same print-name
 --     
 --     type T = forall a. [a] -> [a]
@@ -1056,8 +1063,7 @@ tcInstSig use_skols name scoped_names
   = do { poly_id <- tcLookupId name    -- Cannot fail; the poly ids are put into 
                                        -- scope when starting the binding group
        ; let skol_info = SigSkol (FunSigCtxt name)
-             inst_tyvars | use_skols = tcInstSkolTyVars skol_info
-                         | otherwise = tcInstSigTyVars  skol_info
+             inst_tyvars = tcInstSigTyVars use_skols skol_info
        ; (tvs, theta, tau) <- tcInstType inst_tyvars (idType poly_id)
        ; loc <- getInstLoc (SigOrigin skol_info)
        ; return (TcSigInfo { sig_id = poly_id,