remove empty dir
[ghc-hetmet.git] / ghc / compiler / typecheck / TcBinds.lhs
index 2040f53..cffcb9c 100644 (file)
@@ -21,8 +21,8 @@ import HsSyn          ( HsExpr(..), HsBind(..), LHsBinds, LHsBind, Sig(..),
                          LSig, Match(..), IPBind(..), Prag(..),
                          HsType(..), LHsType, HsExplicitForAll(..), hsLTyVarNames, 
                          isVanillaLSig, sigName, placeHolderNames, isPragLSig,
-                         LPat, GRHSs, MatchGroup(..), pprLHsBinds,
-                         collectHsBindBinders, collectPatBinders, pprPatBind
+                         LPat, GRHSs, MatchGroup(..), pprLHsBinds, mkHsCoerce,
+                         collectHsBindBinders, collectPatBinders, pprPatBind, isBangHsBind
                        )
 import TcHsSyn         ( zonkId )
 
@@ -54,7 +54,7 @@ import TysPrim                ( alphaTyVar )
 import Id              ( Id, mkLocalId, mkVanillaGlobal )
 import IdInfo          ( vanillaIdInfo )
 import Var             ( TyVar, idType, idName )
-import Name            ( Name, getSrcLoc )
+import Name            ( Name )
 import NameSet
 import NameEnv
 import VarSet
@@ -62,7 +62,7 @@ import SrcLoc         ( Located(..), unLoc, getLoc )
 import Bag
 import ErrUtils                ( Message )
 import Digraph         ( SCC(..), stronglyConnComp )
-import Maybes          ( fromJust, isJust, isNothing, orElse )
+import Maybes          ( expectJust, isJust, isNothing, orElse )
 import Util            ( singleton )
 import BasicTypes      ( TopLevelFlag(..), isTopLevel, isNotTopLevel,
                          RecFlag(..), isNonRec, InlineSpec, defaultInlineSpec )
@@ -251,10 +251,8 @@ mkEdges :: TcSigFun -> LHsBinds Name
 type BKey  = Int -- Just number off the bindings
 
 mkEdges sig_fn binds
-  = [ (bind, key, [fromJust mb_key | n <- nameSetToList (bind_fvs (unLoc bind)),
-                                    let mb_key = lookupNameEnv key_map n,
-                                    isJust mb_key,
-                                    no_sig n ])
+  = [ (bind, key, [key | n <- nameSetToList (bind_fvs (unLoc bind)),
+                        Just key <- [lookupNameEnv key_map n], no_sig n ])
     | (bind, key) <- keyd_binds
     ]
   where
@@ -347,11 +345,11 @@ tc_poly_binds top_lvl rec_group rec_tc sig_fn prag_fn binds
        -- These must be non-recursive etc, and are not generalised
        -- They desugar to a case expression in the end
   ; zonked_mono_tys <- zonkTcTypes (map getMonoType mono_bind_infos)
-  ; if any isUnLiftedType zonked_mono_tys then
-    do {       -- Unlifted bindings
-         checkUnliftedBinds top_lvl rec_group binds' mono_bind_infos
-       ; extendLIEs lie_req
-       ; let exports  = zipWith mk_export mono_bind_infos zonked_mono_tys
+  ; is_strict <- checkStrictBinds top_lvl rec_group binds' 
+                                 zonked_mono_tys mono_bind_infos
+  ; if is_strict then
+    do { extendLIEs lie_req
+       ; let exports = zipWith mk_export mono_bind_infos zonked_mono_tys
              mk_export (name, Nothing,  mono_id) mono_ty = ([], mkLocalId name mono_ty, mono_id, [])
              mk_export (name, Just sig, mono_id) mono_ty = ([], sig_id sig,             mono_id, [])
                        -- ToDo: prags for unlifted bindings
@@ -419,7 +417,8 @@ type TcPragFun = Name -> [LSig Name]
 mkPragFun :: [LSig Name] -> TcPragFun
 mkPragFun sigs = \n -> lookupNameEnv env n `orElse` []
        where
-         prs = [(fromJust (sigName sig), sig) | sig <- sigs, isPragLSig sig]
+         prs = [(expectJust "mkPragFun" (sigName sig), sig) 
+               | sig <- sigs, isPragLSig sig]
          env = foldl add emptyNameEnv prs
          add env (n,p) = extendNameEnv_Acc (:) singleton env n p
 
@@ -444,7 +443,7 @@ tcSpecPrag poly_id hs_ty inl
        ; (co_fn, lie) <- getLIE (tcSubExp (idType poly_id) spec_ty)
        ; extendLIEs lie
        ; let const_dicts = map instToId lie
-       ; return (SpecPrag (HsCoerce co_fn (HsVar poly_id)) spec_ty const_dicts inl) }
+       ; return (SpecPrag (mkHsCoerce co_fn (HsVar poly_id)) spec_ty const_dicts inl) }
   
 --------------
 -- If typechecking the binds fails, then return with each
@@ -469,20 +468,40 @@ forall_a_a = mkForAllTy alphaTyVar (mkTyVarTy alphaTyVar)
 --     b) not top level, 
 --     c) not a multiple-binding group (more or less implied by (a))
 
-checkUnliftedBinds :: TopLevelFlag -> RecFlag
-                  -> LHsBinds TcId -> [MonoBindInfo] -> TcM ()
-checkUnliftedBinds top_lvl rec_group mbind infos
+checkStrictBinds :: TopLevelFlag -> RecFlag
+                -> LHsBinds TcId -> [TcType] -> [MonoBindInfo]
+                -> TcM Bool
+checkStrictBinds top_lvl rec_group mbind mono_tys infos
+  | unlifted || bang_pat
   = do         { checkTc (isNotTopLevel top_lvl)
-                 (unliftedBindErr "Top-level" mbind)
+                 (strictBindErr "Top-level" unlifted mbind)
        ; checkTc (isNonRec rec_group)
-                 (unliftedBindErr "Recursive" mbind)
+                 (strictBindErr "Recursive" unlifted mbind)
        ; checkTc (isSingletonBag mbind)
-                 (unliftedBindErr "Multiple" mbind) 
-       ; mapM_ check_sig infos }
+                 (strictBindErr "Multiple" unlifted mbind) 
+       ; mapM_ check_sig infos
+       ; return True }
+  | otherwise
+  = return False
   where
+    unlifted = any isUnLiftedType mono_tys
+    bang_pat = anyBag (isBangHsBind . unLoc) mbind
     check_sig (_, Just sig, _) = checkTc (null (sig_tvs sig) && null (sig_theta sig))
-                                        (badUnliftedSig sig)
+                                        (badStrictSig unlifted sig)
     check_sig other           = return ()
+
+strictBindErr flavour unlifted mbind
+  = hang (text flavour <+> msg <+> ptext SLIT("aren't allowed:")) 4 (ppr mbind)
+  where
+    msg | unlifted  = ptext SLIT("bindings for unlifted types")
+       | otherwise = ptext SLIT("bang-pattern bindings")
+
+badStrictSig unlifted sig
+  = hang (ptext SLIT("Illegal polymorphic signature in") <+> msg)
+        4 (ppr sig)
+  where
+    msg | unlifted  = ptext SLIT("an unlifted binding")
+       | otherwise = ptext SLIT("a bang-pattern binding")
 \end{code}
 
 
@@ -498,9 +517,9 @@ The signatures have been dealt with already.
 \begin{code}
 tcMonoBinds :: [LHsBind Name]
            -> TcSigFun
-           -> RecFlag  -- True <=> the binding is recursive for typechecking purposes
-                       --          i.e. the binders are mentioned in their RHSs, and
-                       --               we are not resuced by a type signature
+           -> RecFlag  -- Whether the binding is recursive for typechecking purposes
+                       -- i.e. the binders are mentioned in their RHSs, and
+                       --      we are not resuced by a type signature
            -> TcM (LHsBinds TcId, [MonoBindInfo])
 
 tcMonoBinds [L b_loc (FunBind { fun_id = L nm_loc name, fun_infix = inf, 
@@ -938,7 +957,7 @@ mkSigFun :: [LSig Name] -> TcSigFun
 -- Precondition: no duplicates
 mkSigFun sigs = lookupNameEnv env
   where
-    env = mkNameEnv [(fromJust (sigName sig), sig) | sig <- sigs]
+    env = mkNameEnv [(expectJust "mkSigFun" (sigName sig), sig) | sig <- sigs]
 
 ---------------
 data TcSigInfo
@@ -1083,15 +1102,6 @@ sigContextsCtxt sig1 sig2
 
 
 -----------------------------------------------
-unliftedBindErr flavour mbind
-  = hang (text flavour <+> ptext SLIT("bindings for unlifted types aren't allowed:"))
-        4 (ppr mbind)
-
-badUnliftedSig sig
-  = hang (ptext SLIT("Illegal polymorphic signature in an unlifted binding"))
-        4 (ppr sig)
-
------------------------------------------------
 unboxedTupleErr name ty
   = hang (ptext SLIT("Illegal binding of unboxed tuple"))
         4 (ppr name <+> dcolon <+> ppr ty)