[project @ 2004-12-20 17:16:24 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcBinds.lhs
index f9bcc6d..ad5060b 100644 (file)
@@ -11,7 +11,7 @@ module TcBinds ( tcBindsAndThen, tcTopBinds, tcMonoBinds, tcSpecSigs ) where
 import {-# SOURCE #-} TcMatches ( tcGRHSsPat, tcMatchesFun )
 import {-# SOURCE #-} TcExpr  ( tcCheckSigma, tcCheckRho )
 
-import CmdLineOpts     ( DynFlag(Opt_NoMonomorphismRestriction) )
+import CmdLineOpts     ( DynFlag(Opt_MonomorphismRestriction) )
 import HsSyn           ( HsExpr(..), HsBind(..), LHsBinds, Sig(..),
                          LSig, Match(..), HsBindGroup(..), IPBind(..),
                          LPat, GRHSs, MatchGroup(..), emptyLHsBinds, isEmptyLHsBinds,
@@ -30,7 +30,7 @@ import TcHsType               ( tcHsSigType, UserTypeCtxt(..), tcAddLetBoundTyVars,
                        )
 import TcPat           ( tcPat, PatCtxt(..) )
 import TcSimplify      ( bindInstsOfLocalFuns )
-import TcMType         ( newTyFlexiVarTy, tcSkolType, zonkQuantifiedTyVar )
+import TcMType         ( newTyFlexiVarTy, tcSkolType, zonkQuantifiedTyVar, zonkTcTypes )
 import TcType          ( TcTyVar, SkolemInfo(SigSkol), 
                          TcTauType, TcSigmaType, 
                          TvSubstEnv, mkTvSubst, substTheta, substTy, 
@@ -50,6 +50,7 @@ import VarSet
 import SrcLoc          ( Located(..), unLoc, noLoc, getLoc )
 import Bag
 import Util            ( isIn )
+import Maybes          ( orElse )
 import BasicTypes      ( TopLevelFlag(..), RecFlag(..), isNonRec, isRec, 
                          isNotTopLevel, isAlwaysActive )
 import FiniteMap       ( listToFM, lookupFM )
@@ -237,6 +238,7 @@ tcBindWithSigs      :: TopLevelFlag
                -> [LSig Name]
                -> RecFlag
                -> TcM (LHsBinds TcId, [TcId])
+       -- The returned TcIds are guaranteed zonked
 
 tcBindWithSigs top_lvl mbind sigs is_rec = do  
   {    -- TYPECHECK THE SIGNATURES
@@ -254,8 +256,23 @@ tcBindWithSigs top_lvl mbind sigs is_rec = do
   ; ((mbind', mono_bind_infos), lie_req) 
        <- getLIE (tcMonoBinds mbind lookup_sig is_rec)
 
-       -- GENERALISE
-  ; is_unres <- isUnRestrictedGroup mbind tc_ty_sigs
+       -- CHECK FOR UNLIFTED BINDINGS
+       -- These must be non-recursive etc, and are not generalised
+       -- They desugar to a case expression in the end
+  ; zonked_mono_tys <- zonkTcTypes (map getMonoType mono_bind_infos)
+  ; if any isUnLiftedType zonked_mono_tys then
+    do {       -- Unlifted bindings
+         checkUnliftedBinds top_lvl is_rec mbind
+       ; extendLIEs lie_req
+       ; let exports  = zipWith mk_export mono_bind_infos zonked_mono_tys
+             mk_export (name, Nothing,  mono_id) mono_ty = ([], mkLocalId name mono_ty, mono_id)
+             mk_export (name, Just sig, mono_id) mono_ty = ([], sig_id sig,             mono_id)
+
+       ; return ( unitBag $ noLoc $ AbsBinds [] [] exports emptyNameSet mbind',
+                  [poly_id | (_, poly_id, _) <- exports]) }    -- Guaranteed zonked
+
+    else do    -- The normal lifted case: GENERALISE
+  { is_unres <- isUnRestrictedGroup mbind tc_ty_sigs
   ; (tyvars_to_gen, dict_binds, dict_ids)
        <- setSrcSpan (getLoc (head (bagToList mbind)))     $
                -- TODO: location a bit awkward, but the mbinds have been
@@ -303,28 +320,16 @@ tcBindWithSigs top_lvl mbind sigs is_rec = do
   ; traceTc (text "binding:" <+> ppr ((dict_ids, dict_binds),
                                      exports, map idType zonked_poly_ids))
 
-       -- Check for an unlifted, non-overloaded group
-       -- In that case we must make extra checks
-  ; if any (isUnLiftedType . idType) zonked_poly_ids
-    then       -- Some bindings are unlifted
-       do { checkUnliftedBinds top_lvl is_rec tyvars_to_gen' mbind
-          ; return (
-                   unitBag $ noLoc $
-                   AbsBinds [] [] exports inlines mbind',
-                       -- Do not generate even any x=y bindings
-                   zonked_poly_ids )}
-
-    else       -- The normal case
-       return (
+  ; return (
            unitBag $ noLoc $
            AbsBinds tyvars_to_gen'
-                dict_ids
-                exports
-                inlines
-                (dict_binds `unionBags` mbind'),
+                    dict_ids
+                    exports
+                    inlines
+                    (dict_binds `unionBags` mbind'),
            zonked_poly_ids
         )
-  } }
+  } } }
 
 -- If typechecking the binds fails, then return with each
 -- signature-less binder given type (forall a.a), to minimise 
@@ -348,26 +353,15 @@ attachInlinePhase inline_phases bndr
 -- Check that non-overloaded unlifted bindings are
 --     a) non-recursive,
 --     b) not top level, 
---     c) non-polymorphic
---     d) not a multiple-binding group (more or less implied by (a))
-
-checkUnliftedBinds top_lvl is_rec tyvars_to_gen mbind
-  = ASSERT( not (any (isUnliftedTypeKind . tyVarKind) tyvars_to_gen) )
-               -- The instCantBeGeneralised stuff in tcSimplify should have
-               -- already raised an error if we're trying to generalise an 
-               -- unboxed tyvar (NB: unboxed tyvars are always introduced 
-               -- along with a class constraint) and it's better done there 
-               -- because we have more precise origin information.
-               -- That's why we just use an ASSERT here.
-
-    checkTc (isNotTopLevel top_lvl)
+--     c) not a multiple-binding group (more or less implied by (a))
+
+checkUnliftedBinds top_lvl is_rec mbind
+  = checkTc (isNotTopLevel top_lvl)
            (unliftedBindErr "Top-level" mbind)         `thenM_`
     checkTc (isNonRec is_rec)
            (unliftedBindErr "Recursive" mbind)         `thenM_`
     checkTc (isSingletonBag mbind)
-           (unliftedBindErr "Multiple" mbind)          `thenM_`
-    checkTc (null tyvars_to_gen)
-           (unliftedBindErr "Polymorphic" mbind)
+           (unliftedBindErr "Multiple" mbind)
 \end{code}
 
 
@@ -574,14 +568,14 @@ tcTySig sig1 (L span (Sig (L _ name) ty))
 
        -- Try to match the context of this signature with 
        -- that of the first signature
-       ; case tcMatchPreds tvs (sig_theta sig1) theta of { 
+       ; case tcMatchPreds tvs theta (sig_theta sig1) of { 
            Nothing   -> bale_out
        ;   Just tenv -> do
        ; case check_tvs tenv tvs of
            Nothing   -> bale_out
-           Just tvs' -> do 
+           Just tvs' -> do {
 
-       { let subst  = mkTvSubst tenv
+         let subst  = mkTvSubst tenv
              theta' = substTheta subst theta
              tau'   = substTy subst tau
        ; loc <- getInstLoc (SigOrigin rigid_info)
@@ -600,15 +594,12 @@ tcTySig sig1 (L span (Sig (L _ name) ty))
     check_tvs :: TvSubstEnv -> [TcTyVar] -> Maybe [TcTyVar]
     check_tvs tenv [] = Just []
     check_tvs tenv (tv:tvs) 
-       | Just ty <- lookupVarEnv tenv tv
-       = do { tv' <- tcGetTyVar_maybe ty
+       = do { let ty = lookupVarEnv tenv tv `orElse` mkTyVarTy tv
+            ; tv'  <- tcGetTyVar_maybe ty
             ; tvs' <- check_tvs tenv tvs
             ; if tv' `elem` tvs'
               then Nothing
               else Just (tv':tvs') }
-       | otherwise
-       = do { tvs' <- check_tvs tenv tvs
-            ; Just (tv:tvs') }
 \end{code}
 
 \begin{code}
@@ -727,8 +718,8 @@ find which tyvars are constrained.
 \begin{code}
 isUnRestrictedGroup :: LHsBinds Name -> [TcSigInfo] -> TcM Bool
 isUnRestrictedGroup binds sigs
-  = do { no_MR <- doptM Opt_NoMonomorphismRestriction
-       ; return (no_MR || all_unrestricted) }
+  = do { mono_restriction <- doptM Opt_MonomorphismRestriction
+       ; return (not mono_restriction || all_unrestricted) }
   where 
     all_unrestricted = all (unrestricted . unLoc) (bagToList binds)
     tysig_names      = map (idName . sig_id) sigs