[project @ 2002-09-09 12:50:26 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcBinds.lhs
index 35f3923..c6ca52d 100644 (file)
@@ -4,7 +4,7 @@
 \section[TcBinds]{TcBinds}
 
 \begin{code}
-module TcBinds ( tcBindsAndThen, tcTopBinds,
+module TcBinds ( tcBindsAndThen, tcTopBinds, tcMonoBinds,
                 tcSpecSigs, tcBindWithSigs ) where
 
 #include "HsVersions.h"
@@ -12,7 +12,7 @@ module TcBinds ( tcBindsAndThen, tcTopBinds,
 import {-# SOURCE #-} TcMatches ( tcGRHSs, tcMatchesFun )
 import {-# SOURCE #-} TcExpr  ( tcExpr )
 
-import CmdLineOpts     ( opt_NoMonomorphismRestriction )
+import CmdLineOpts     ( DynFlag(Opt_NoMonomorphismRestriction) )
 import HsSyn           ( HsExpr(..), HsBinds(..), MonoBinds(..), Sig(..), 
                          Match(..), HsMatchContext(..), 
                          collectMonoBinders, andMonoBinds,
@@ -25,18 +25,18 @@ import TcMonad
 import Inst            ( LIE, emptyLIE, mkLIE, plusLIE, InstOrigin(..),
                          newDicts, instToId
                        )
-import TcEnv           ( tcExtendLocalValEnv, newLocalName )
-import TcUnify         ( unifyTauTyLists, checkSigTyVars, sigCtxt )
+import TcEnv           ( tcExtendLocalValEnv, tcExtendLocalValEnv2, newLocalName )
+import TcUnify         ( unifyTauTyLists, checkSigTyVarsWrt, sigCtxt )
 import TcSimplify      ( tcSimplifyInfer, tcSimplifyInferCheck, tcSimplifyRestricted, tcSimplifyToDicts )
-import TcMonoType      ( tcHsSigType, UserTypeCtxt(..), 
-                         TcSigInfo(..), tcTySig, maybeSig, tcAddScopedTyVars
+import TcMonoType      ( tcHsSigType, UserTypeCtxt(..), TcSigInfo(..), 
+                         tcTySig, maybeSig, tcSigPolyId, tcSigMonoId, tcAddScopedTyVars
                        )
 import TcPat           ( tcPat, tcSubPat, tcMonoPatBndr )
 import TcSimplify      ( bindInstsOfLocalFuns )
 import TcMType         ( newTyVar, newTyVarTy, newHoleTyVarTy,
-                         zonkTcTyVarToTyVar
+                         zonkTcTyVarToTyVar, readHoleResult
                        )
-import TcType          ( mkTyVarTy, mkForAllTys, mkFunTys, tyVarsOfType, 
+import TcType          ( TcTyVar, mkTyVarTy, mkForAllTys, mkFunTys, tyVarsOfType, 
                          mkPredTy, mkForAllTy, isUnLiftedType, 
                          unliftedTypeKind, liftedTypeKind, openTypeKind, eqKind
                        )
@@ -131,7 +131,7 @@ tc_binds_and_then top_lvl combiner (MonoBind bind sigs is_rec) do_next
                     sigs is_rec                        `thenTc` \ (poly_binds, poly_lie, poly_ids) ->
   
          -- Extend the environment to bind the new polymorphic Ids
-      tcExtendLocalValEnv [(idName poly_id, poly_id) | poly_id <- poly_ids] $
+      tcExtendLocalValEnv poly_ids                     $
   
          -- Build bindings and IdInfos corresponding to user pragmas
       tcSpecSigs sigs          `thenTc` \ (prag_binds, prag_lie) ->
@@ -219,8 +219,8 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
           binder_names  = collectMonoBinders mbind
          poly_ids      = map mk_dummy binder_names
          mk_dummy name = case maybeSig tc_ty_sigs name of
-                           Just (TySigInfo _ poly_id _ _ _ _ _ _) -> poly_id   -- Signature
-                           Nothing -> mkLocalId name forall_a_a                -- No signature
+                           Just sig -> tcSigPolyId sig                 -- Signature
+                           Nothing  -> mkLocalId name forall_a_a       -- No signature
        in
        returnTc (EmptyMonoBinds, emptyLIE, poly_ids)
     )                                          $
@@ -262,18 +262,21 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
        dict_tys = map idType zonked_dict_ids
 
        inlines    = mkNameSet [name | InlineSig True name _ loc <- inline_sigs]
-        no_inlines = listToFM [(name, phase) | InlineSig _ name phase _ <- inline_sigs, 
-                                              not (isAlwaysActive phase)]
+                       -- Any INLINE sig (regardless of phase control) 
+                       -- makes the RHS look small
+        inline_phases = listToFM [(name, phase) | InlineSig _ name phase _ <- inline_sigs, 
+                                                 not (isAlwaysActive phase)]
+                       -- Set the IdInfo field to control the inline phase
                        -- AlwaysActive is the default, so don't bother with them
 
        mk_export binder_name zonked_mono_id
          = (tyvars, 
-            attachNoInlinePrag no_inlines poly_id,
+            attachInlinePhase inline_phases poly_id,
             zonked_mono_id)
          where
            (tyvars, poly_id) = 
                case maybeSig tc_ty_sigs binder_name of
-                 Just (TySigInfo _ sig_poly_id sig_tyvars _ _ _ _ _) -> 
+                 Just (TySigInfo sig_poly_id sig_tyvars _ _ _ _ _) -> 
                        (sig_tyvars, sig_poly_id)
                  Nothing -> (real_tyvars_to_gen, new_poly_id)
 
@@ -313,8 +316,8 @@ tcBindWithSigs top_lvl mbind tc_ty_sigs inline_sigs is_rec
        lie_free, poly_ids
     )
 
-attachNoInlinePrag no_inlines bndr
-  = case lookupFM no_inlines (idName bndr) of
+attachInlinePhase inline_phases bndr
+  = case lookupFM inline_phases (idName bndr) of
        Just prag -> bndr `setInlinePragma` prag
        Nothing   -> bndr
 
@@ -412,9 +415,16 @@ is doing.
 %************************************************************************
 
 \begin{code}
-generalise binder_names mbind tau_tvs lie_req sigs
-  | not is_unrestricted        -- RESTRICTED CASE
-  =    -- Check signature contexts are empty 
+generalise binder_names mbind tau_tvs lie_req sigs =
+
+  -- check for -fno-monomorphism-restriction
+  doptsTc Opt_NoMonomorphismRestriction                `thenTc` \ no_MR ->
+  let is_unrestricted | no_MR    = True
+                     | otherwise = isUnRestrictedGroup tysig_names mbind
+  in
+
+  if not is_unrestricted then  -- RESTRICTED CASE
+       -- Check signature contexts are empty 
     checkTc (all is_mono_sig sigs)
            (restrictedBindCtxtErr binder_names)        `thenTc_`
 
@@ -423,33 +433,30 @@ generalise binder_names mbind tau_tvs lie_req sigs
     tcSimplifyRestricted doc tau_tvs lie_req           `thenTc` \ (qtvs, lie_free, binds) ->
 
        -- Check that signature type variables are OK
-    checkSigsTyVars sigs                               `thenTc_`
+    checkSigsTyVars qtvs sigs                          `thenTc` \ final_qtvs ->
 
-    returnTc (qtvs, lie_free, binds, [])
+    returnTc (final_qtvs, lie_free, binds, [])
 
-  | null sigs                  -- UNRESTRICTED CASE, NO TYPE SIGS
-  = tcSimplifyInfer doc tau_tvs lie_req
+  else if null sigs then       -- UNRESTRICTED CASE, NO TYPE SIGS
+    tcSimplifyInfer doc tau_tvs lie_req
 
-  | otherwise                  -- UNRESTRICTED CASE, WITH TYPE SIGS
-  =    -- CHECKING CASE: Unrestricted group, there are type signatures
-       -- Check signature contexts are empty 
-    checkSigsCtxts sigs                                `thenTc` \ (sig_avails, sig_dicts) ->
+  else                                 -- UNRESTRICTED CASE, WITH TYPE SIGS
+       -- CHECKING CASE: Unrestricted group, there are type signatures
+       -- Check signature contexts are identical
+    checkSigsCtxts sigs                        `thenTc` \ (sig_avails, sig_dicts) ->
     
        -- Check that the needed dicts can be
        -- expressed in terms of the signature ones
     tcSimplifyInferCheck doc tau_tvs sig_avails lie_req        `thenTc` \ (forall_tvs, lie_free, dict_binds) ->
        
        -- Check that signature type variables are OK
-    checkSigsTyVars sigs                                       `thenTc_`
+    checkSigsTyVars forall_tvs sigs                    `thenTc` \ final_qtvs ->
 
-    returnTc (forall_tvs, lie_free, dict_binds, sig_dicts)
+    returnTc (final_qtvs, lie_free, dict_binds, sig_dicts)
 
   where
-    is_unrestricted | opt_NoMonomorphismRestriction = True
-                   | otherwise                     = isUnRestrictedGroup tysig_names mbind
-
-    tysig_names = [name | (TySigInfo name _ _ _ _ _ _ _) <- sigs]
-    is_mono_sig (TySigInfo _ _ _ theta _ _ _ _) = null theta
+    tysig_names = map (idName . tcSigPolyId) sigs
+    is_mono_sig (TySigInfo _ _ theta _ _ _ _) = null theta
 
     doc = ptext SLIT("type signature(s) for") <+> pprBinders binder_names
 
@@ -461,7 +468,7 @@ generalise binder_names mbind tau_tvs lie_req sigs
        -- We unify them because, with polymorphic recursion, their types
        -- might not otherwise be related.  This is a rather subtle issue.
        -- ToDo: amplify
-checkSigsCtxts sigs@(TySigInfo _ id1 sig_tvs theta1 _ _ _ src_loc : other_sigs)
+checkSigsCtxts sigs@(TySigInfo id1 sig_tvs theta1 _ _ _ src_loc : other_sigs)
   = tcAddSrcLoc src_loc                        $
     mapTc_ check_one other_sigs                `thenTc_` 
     if null theta1 then
@@ -477,21 +484,37 @@ checkSigsCtxts sigs@(TySigInfo _ id1 sig_tvs theta1 _ _ _ src_loc : other_sigs)
     returnTc (sig_avails, map instToId sig_dicts)
   where
     sig1_dict_tys = map mkPredTy theta1
-    sig_meths    = concat [insts | TySigInfo _ _ _ _ _ _ insts _ <- sigs]
+    sig_meths    = concat [insts | TySigInfo _ _ _ _ _ insts _ <- sigs]
 
-    check_one sig@(TySigInfo _ id _ theta _ _ _ src_loc)
+    check_one sig@(TySigInfo id _ theta _ _ _ _)
        = tcAddErrCtxt (sigContextsCtxt id1 id)                 $
         checkTc (equalLength theta theta1) sigContextsErr      `thenTc_`
         unifyTauTyLists sig1_dict_tys (map mkPredTy theta)
 
-checkSigsTyVars sigs = mapTc_ check_one sigs
+checkSigsTyVars :: [TcTyVar] -> [TcSigInfo] -> TcM [TcTyVar]
+checkSigsTyVars qtvs sigs 
+  = mapTc check_one sigs       `thenTc` \ sig_tvs_s ->
+    let
+       -- Sigh.  Make sure that all the tyvars in the type sigs
+       -- appear in the returned ty var list, which is what we are
+       -- going to generalise over.  Reason: we occasionally get
+       -- silly types like
+       --      type T a = () -> ()
+       --      f :: T a
+       --      f () = ()
+       -- Here, 'a' won't appear in qtvs, so we have to add it
+
+       sig_tvs = foldr (unionVarSet . mkVarSet) emptyVarSet sig_tvs_s
+       all_tvs = mkVarSet qtvs `unionVarSet` sig_tvs
+    in
+    returnTc (varSetElems all_tvs)
   where
-    check_one (TySigInfo _ id sig_tyvars sig_theta sig_tau _ _ src_loc)
+    check_one (TySigInfo id sig_tyvars sig_theta sig_tau _ _ src_loc)
       = tcAddSrcLoc src_loc                                            $
        tcAddErrCtxt (ptext SLIT("When checking the type signature for") 
                      <+> quotes (ppr id))                              $
-       tcAddErrCtxtM (sigCtxt sig_tyvars sig_theta sig_tau)            $
-       checkSigTyVars sig_tyvars (idFreeTyVars id)
+       tcAddErrCtxtM (sigCtxt id sig_tyvars sig_theta sig_tau)         $
+       checkSigTyVarsWrt (idFreeTyVars id) sig_tyvars
 \end{code}
 
 @getTyVarsToGen@ decides what type variables to generalise over.
@@ -608,8 +631,10 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
   where
 
     mk_bind (name, mono_id) = case maybeSig tc_ty_sigs name of
-                               Nothing                                   -> (name, mono_id)
-                               Just (TySigInfo name poly_id _ _ _ _ _ _) -> (name, poly_id)
+                               Nothing  -> (name, mono_id)
+                               Just sig -> (idName poly_id, poly_id)
+                                        where
+                                           poly_id = tcSigPolyId sig
 
     tc_mb_pats EmptyMonoBinds
       = returnTc (\ xve -> returnTc (EmptyMonoBinds, emptyLIE), emptyLIE, emptyBag, emptyBag, emptyLIE)
@@ -630,19 +655,18 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
 
     tc_mb_pats (FunMonoBind name inf matches locn)
       = (case maybeSig tc_ty_sigs name of
-           Just (TySigInfo _ _ _ _ _ mono_id _ _) 
-                   -> returnNF_Tc mono_id
-           Nothing -> newLocalName name        `thenNF_Tc` \ bndr_name ->
-                      newTyVarTy openTypeKind  `thenNF_Tc` \ bndr_ty -> 
+           Just sig -> returnNF_Tc (tcSigMonoId sig)
+           Nothing  -> newLocalName name       `thenNF_Tc` \ bndr_name ->
+                       newTyVarTy openTypeKind `thenNF_Tc` \ bndr_ty -> 
                        -- NB: not a 'hole' tyvar; since there is no type 
                        -- signature, we revert to ordinary H-M typechecking
                        -- which means the variable gets an inferred tau-type
-                      returnNF_Tc (mkLocalId bndr_name bndr_ty)
+                       returnNF_Tc (mkLocalId bndr_name bndr_ty)
        )                                       `thenNF_Tc` \ bndr_id ->
        let
           bndr_ty         = idType bndr_id
           complete_it xve = tcAddSrcLoc locn                           $
-                            tcMatchesFun xve name bndr_ty  matches     `thenTc` \ (matches', lie) ->
+                            tcMatchesFun xve name bndr_ty matches      `thenTc` \ (matches', lie) ->
                             returnTc (FunMonoBind bndr_id inf matches' locn, lie)
        in
        returnTc (complete_it, emptyLIE, emptyBag, unitBag (name, bndr_id), emptyLIE)
@@ -660,11 +684,12 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
                -- so we don't have to do anything here.
 
        tcPat tc_pat_bndr pat pat_ty            `thenTc` \ (pat', lie_req, tvs, ids, lie_avail) ->
+       readHoleResult pat_ty                   `thenTc` \ pat_ty' ->
        let
           complete_it xve = tcAddSrcLoc locn                           $
                             tcAddErrCtxt (patMonoBindsCtxt bind)       $
-                            tcExtendLocalValEnv xve                    $
-                            tcGRHSs PatBindRhs grhss pat_ty            `thenTc` \ (grhss', lie) ->
+                            tcExtendLocalValEnv2 xve                   $
+                            tcGRHSs PatBindRhs grhss pat_ty'           `thenTc` \ (grhss', lie) ->
                             returnTc (PatMonoBind pat' grhss' locn, lie)
        in
        returnTc (complete_it, lie_req, tvs, ids, lie_avail)
@@ -683,10 +708,11 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
                -> newLocalName name    `thenNF_Tc` \ bndr_name ->
                   tcMonoPatBndr bndr_name pat_ty
 
-           Just (TySigInfo _ _ _ _ _ mono_id _ _)
-               -> tcAddSrcLoc (getSrcLoc name)         $
-                  tcSubPat pat_ty (idType mono_id)     `thenTc` \ (co_fn, lie) ->
-                  returnTc (co_fn, lie, mono_id)
+           Just sig -> tcAddSrcLoc (getSrcLoc name)            $
+                       tcSubPat (idType mono_id) pat_ty        `thenTc` \ (co_fn, lie) ->
+                       returnTc (co_fn, lie, mono_id)
+                    where
+                       mono_id = tcSigMonoId sig
 \end{code}