[project @ 2004-07-28 12:59:53 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcGenDeriv.lhs
index 83134d8..f812b20 100644 (file)
@@ -35,18 +35,13 @@ import BasicTypes   ( Fixity(..), maxPrecedence, Boxity(..) )
 import FieldLabel       ( fieldLabelName )
 import DataCon         ( isNullaryDataCon, dataConTag,
                          dataConOrigArgTys, dataConSourceArity, fIRST_TAG,
-                         DataCon, dataConName,
+                         DataCon, dataConName, dataConIsInfix,
                          dataConFieldLabels )
-import Name            ( getOccString, getOccName, getSrcLoc, occNameString, 
-                         occNameUserString, 
-                         Name, NamedThing(..), 
-                         isDataSymOcc, isSymOcc
-                       )
+import Name            ( getOccString, getSrcLoc, Name, NamedThing(..) )
 
 import HscTypes                ( FixityEnv, lookupFixity )
 import PrelInfo
 import PrelNames
-import TysWiredIn
 import MkId            ( eRROR_ID )
 import PrimOp          ( PrimOp(..) )
 import SrcLoc          ( Located(..), noLoc, srcLocSpan )
@@ -56,7 +51,8 @@ import TyCon          ( TyCon, isNewTyCon, tyConDataCons, isEnumerationTyCon, tyConArity
 import TcType          ( isUnLiftedType, tcEqType, Type )
 import TysPrim         ( charPrimTy, intPrimTy, wordPrimTy, addrPrimTy, floatPrimTy, doublePrimTy,
                          intPrimTyCon )
-import TysWiredIn      ( charDataCon, intDataCon, floatDataCon, doubleDataCon )
+import TysWiredIn      ( charDataCon, intDataCon, floatDataCon, doubleDataCon,
+                         intDataCon_RDR, true_RDR, false_RDR )
 import Util            ( zipWithEqual, isSingleton,
                          zipWith3Equal, nOfThem, zipEqual )
 import Char            ( isAlpha )
@@ -164,7 +160,7 @@ gen_Eq_binds tycon
                case maybeTyConSingleCon tycon of
                  Just _ -> []
                  Nothing -> -- if cons don't match, then False
-                    [([wildPat, wildPat], false_Expr)]
+                    [([nlWildPat, nlWildPat], false_Expr)]
            else -- calc. and compare the tags
                 [([a_Pat, b_Pat],
                    untag_Expr tycon [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
@@ -329,13 +325,13 @@ gen_Ord_binds tycon
                           -- Catch this specially to avoid warnings
                           -- about overlapping patterns from the desugarer,
                           -- and to avoid unnecessary pattern-matching
-      = [([wildPat,wildPat], eqTag_Expr)]
+      = [([nlWildPat,nlWildPat], eqTag_Expr)]
       | otherwise
       = map pats_etc nonnullary_cons ++
        (if single_con_type then        -- Omit wildcards when there's just one 
              []                        -- constructor, to silence desugarer
        else
-              [([wildPat, wildPat], default_rhs)])
+              [([nlWildPat, nlWildPat], default_rhs)])
 
       where
        pats_etc data_con
@@ -597,7 +593,7 @@ gen_Ix_binds tycon
     enum_index
       = mk_easy_FunBind tycon_loc index_RDR 
                [noLoc (AsPat (noLoc c_RDR) 
-                          (nlTuplePat [a_Pat, wildPat] Boxed)), 
+                          (nlTuplePat [a_Pat, nlWildPat] Boxed)), 
                                d_Pat] emptyBag (
        nlHsIf (nlHsPar (nlHsVarApps inRange_RDR [c_RDR, d_RDR])) (
           untag_Expr tycon [(a_RDR, ah_RDR)] (
@@ -784,7 +780,7 @@ gen_Read_binds get_fixity tycon
              | otherwise         = prefix_stmts
      
                prefix_stmts            -- T a b c
-                 = [bindLex (ident_pat (data_con_str data_con))]
+                 = [bindLex (ident_pat (data_con_str_w_parens data_con))]
                    ++ read_args
                    ++ [result_stmt data_con as_needed]
         
@@ -795,7 +791,7 @@ gen_Read_binds get_fixity tycon
             result_stmt data_con [a1,a2]]
      
                lbl_stmts               -- T { f1 = a, f2 = b }
-                 = [bindLex (ident_pat (data_con_str data_con)),
+                 = [bindLex (ident_pat (data_con_str_w_parens data_con)),
                     read_punc "{"]
                    ++ concat (intersperse [read_punc ","] field_stmts)
                    ++ [read_punc "}", result_stmt data_con as_needed]
@@ -805,7 +801,7 @@ gen_Read_binds get_fixity tycon
                con_arity    = dataConSourceArity data_con
                labels       = dataConFieldLabels data_con
                dc_nm        = getName data_con
-               is_infix     = isDataSymOcc (getOccName dc_nm)
+               is_infix     = dataConIsInfix data_con
                as_needed    = take con_arity as_RDRs
        read_args    = zipWithEqual "gen_Read_binds" read_arg as_needed (dataConOrigArgTys data_con)
                (read_a1:read_a2:_) = read_args
@@ -824,7 +820,8 @@ gen_Read_binds get_fixity tycon
     ident_pat s  = nlConPat ident_RDR [nlLitPat s]               -- Ident "foo"
     symbol_pat s = nlConPat symbol_RDR [nlLitPat s]              -- Symbol ">>"
     
-    data_con_str con = mkHsString (occNameUserString (getOccName con))
+    data_con_str          con = mkHsString (occNameUserString (getOccName con))
+    data_con_str_w_parens con = mkHsString (occNameUserString_with_parens (getOccName con))
     
     read_punc c = bindLex (punc_pat c)
     read_arg a ty 
@@ -898,7 +895,7 @@ gen_Show_binds get_fixity tycon
        pats_etc data_con
          | nullary_con =  -- skip the showParen junk...
             ASSERT(null bs_needed)
-            ([wildPat, con_pat], mk_showString_app con_str)
+            ([nlWildPat, con_pat], mk_showString_app con_str)
          | otherwise   =
             ([a_Pat, con_pat],
                  showParen_Expr (nlHsPar (genOpApp a_Expr ge_RDR (nlHsLit (HsInt con_prec_plus_one))))
@@ -917,24 +914,22 @@ gen_Show_binds get_fixity tycon
             dc_nm          = getName data_con
             dc_occ_nm      = getOccName data_con
              con_str        = occNameUserString dc_occ_nm
+            op_con_str     = occNameUserString_with_parens dc_occ_nm
 
             show_thingies 
                | is_infix      = [show_arg1, mk_showString_app (" " ++ con_str ++ " "), show_arg2]
-               | record_syntax = mk_showString_app (con_str ++ " {") : 
+               | record_syntax = mk_showString_app (op_con_str ++ " {") : 
                                  show_record_args ++ [mk_showString_app "}"]
-               | otherwise     = mk_showString_app (con_str ++ " ") : show_prefix_args
+               | otherwise     = mk_showString_app (op_con_str ++ " ") : show_prefix_args
                 
-            show_label l = mk_showString_app (the_name ++ " = ")
+            show_label l = mk_showString_app (nm ++ " = ")
                        -- Note the spaces around the "=" sign.  If we don't have them
                        -- then we get Foo { x=-1 } and the "=-" parses as a single
                        -- lexeme.  Only the space after the '=' is necessary, but
                        -- it seems tidier to have them both sides.
                 where
                   occ_nm   = getOccName (fieldLabelName l)
-                  nm       = occNameUserString occ_nm
-                  is_op    = isSymOcc occ_nm       -- Legal, but rare.
-                  the_name | is_op     = '(':nm ++ ")"
-                           | otherwise = nm
+                  nm       = occNameUserString_with_parens occ_nm
 
              show_args                      = zipWith show_arg bs_needed arg_tys
             (show_arg1:show_arg2:_) = show_args
@@ -955,11 +950,18 @@ gen_Show_binds get_fixity tycon
                                                         box_if_necy "Show" tycon (nlHsVar b) arg_ty]
 
                -- Fixity stuff
-            is_infix = isDataSymOcc dc_occ_nm
+            is_infix = dataConIsInfix data_con
              con_prec_plus_one = 1 + getPrec is_infix get_fixity dc_nm
             arg_prec | record_syntax = 0       -- Record fields don't need parens
                      | otherwise     = con_prec_plus_one
 
+occNameUserString_with_parens :: OccName -> String
+occNameUserString_with_parens occ
+  | isSymOcc occ = '(':nm ++ ")"
+  | otherwise    = nm
+  where
+   nm = occNameUserString occ
+
 mk_showString_app str = nlHsApp (nlHsVar showString_RDR) (nlHsLit (mkHsString str))
 \end{code}
 
@@ -1004,7 +1006,7 @@ gen_Typeable_binds tycon
   = unitBag $
        mk_easy_FunBind tycon_loc 
                (mk_typeOf_RDR tycon)   -- Name of appropriate type0f function
-               [wildPat] emptyBag
+               [nlWildPat] emptyBag
                (nlHsApps mkTypeRep_RDR [tycon_rep, nlList []])
   where
     tycon_loc = getSrcSpan tycon
@@ -1062,9 +1064,11 @@ gen_Data_binds fix_env tycon
                -- Auxiliary definitions: the data type and constructors
      datatype_bind `consBag` listToBag (map mk_con_bind data_cons))
   where
-    tycon_loc = getSrcSpan tycon
+    tycon_loc  = getSrcSpan tycon
     tycon_name = tyConName tycon
-    data_cons = tyConDataCons tycon
+    data_cons  = tyConDataCons tycon
+    n_cons     = length data_cons
+    one_constr = n_cons == 1
 
        ------------ gfoldl
     gfoldl_bind = mk_FunBind tycon_loc gfoldl_RDR (map gfoldl_eqn data_cons)
@@ -1079,19 +1083,25 @@ gen_Data_binds fix_env tycon
        ------------ gunfold
     gunfold_bind = mk_FunBind tycon_loc
                               gunfold_RDR
-                              [([k_Pat,z_Pat,c_Pat], gunfold_rhs)]
+                              [([k_Pat, z_Pat, if one_constr then nlWildPat else c_Pat], 
+                               gunfold_rhs)]
 
-    gunfold_rhs = nlHsCase (nlHsVar conIndex_RDR `nlHsApp` c_Expr) 
-                          (map gunfold_alt data_cons)
+    gunfold_rhs 
+       | one_constr = mk_unfold_rhs (head data_cons)   -- No need for case
+       | otherwise  = nlHsCase (nlHsVar conIndex_RDR `nlHsApp` c_Expr) 
+                               (map gunfold_alt data_cons)
 
-    gunfold_alt dc =
-      mkSimpleHsAlt (nlConPat intDataCon_RDR
-                            [nlLitPat (HsIntPrim (toInteger (dataConTag dc)))])
-                   (foldr nlHsApp
+    gunfold_alt dc = mkSimpleHsAlt (mk_unfold_pat dc) (mk_unfold_rhs dc)
+    mk_unfold_rhs dc = foldr nlHsApp
                            (nlHsVar z_RDR `nlHsApp` nlHsVar (getRdrName dc))
                            (replicate (dataConSourceArity dc) (nlHsVar k_RDR))
-                    )
 
+    mk_unfold_pat dc   -- Last one is a wild-pat, to avoid 
+                       -- redundant test, and annoying warning
+      | tag-fIRST_TAG == n_cons-1 = nlWildPat  -- Last constructor
+      | otherwise = nlConPat intDataCon_RDR [nlLitPat (HsIntPrim (toInteger tag))]
+      where 
+       tag = dataConTag dc
                          
        ------------ toConstr
     toCon_bind = mk_FunBind tycon_loc toConstr_RDR (map to_con_eqn data_cons)
@@ -1101,7 +1111,7 @@ gen_Data_binds fix_env tycon
     dataTypeOf_bind = mk_easy_FunBind
                         tycon_loc
                         dataTypeOf_RDR
-                       [wildPat]
+                       [nlWildPat]
                         emptyBag
                         (nlHsVar data_type_name)