[project @ 2003-04-16 13:34:13 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcGenDeriv.lhs
index 41ba931..bafa008 100644 (file)
@@ -49,11 +49,11 @@ import Name         ( getOccString, getOccName, getSrcLoc, occNameString,
                        )
 
 import HscTypes                ( FixityEnv, lookupFixity )
-import PrelInfo                -- Lots of Names
+import PrelNames       -- Lots of Names
 import PrimOp          -- Lots of Names
 import SrcLoc          ( generatedSrcLoc, SrcLoc )
 import TyCon           ( TyCon, isNewTyCon, tyConDataCons, isEnumerationTyCon,
-                         maybeTyConSingleCon, tyConFamilySize
+                         maybeTyConSingleCon, tyConFamilySize, tyConTyVars
                        )
 import TcType          ( isUnLiftedType, tcEqType, Type )
 import TysPrim         ( charPrimTy, intPrimTy, wordPrimTy, addrPrimTy,
@@ -62,7 +62,6 @@ import TysPrim                ( charPrimTy, intPrimTy, wordPrimTy, addrPrimTy,
 import Util            ( zipWithEqual, isSingleton,
                          zipWith3Equal, nOfThem, zipEqual )
 import Panic           ( panic, assertPanic )
-import Maybes          ( maybeToBool )
 import Char            ( ord, isAlpha )
 import Constants
 import List            ( partition, intersperse )
@@ -319,55 +318,39 @@ gen_Ord_binds tycon
     tycon_loc = getSrcLoc tycon
     --------------------------------------------------------------------
     compare = mk_easy_FunMonoBind tycon_loc compare_RDR
-               [a_Pat, b_Pat]
-               [cmp_eq]
-           (if maybeToBool (maybeTyConSingleCon tycon) then
-
---             cmp_eq_Expr ltTag_Expr eqTag_Expr gtTag_Expr a_Expr b_Expr
--- Weird.  Was: case (cmp a b) of { LT -> LT; EQ -> EQ; GT -> GT }
-
-               cmp_eq_Expr a_Expr b_Expr
-            else
-               untag_Expr tycon [(a_RDR, ah_RDR), (b_RDR, bh_RDR)]
+                                 [a_Pat, b_Pat] [cmp_eq] compare_rhs
+    compare_rhs
+       | single_con_type = cmp_eq_Expr a_Expr b_Expr
+       | otherwise
+       = untag_Expr tycon [(a_RDR, ah_RDR), (b_RDR, bh_RDR)]
                  (cmp_tags_Expr eqInt_RDR ah_RDR bh_RDR
-                       -- True case; they are equal
-                       -- If an enumeration type we are done; else
-                       -- recursively compare their components
-                   (if isEnumerationTyCon tycon then
-                       eqTag_Expr
-                    else
---                     cmp_eq_Expr ltTag_Expr eqTag_Expr gtTag_Expr a_Expr b_Expr
--- Ditto
-                       cmp_eq_Expr a_Expr b_Expr
-                   )
+                       (cmp_eq_Expr a_Expr b_Expr)     -- True case
                        -- False case; they aren't equal
                        -- So we need to do a less-than comparison on the tags
-                   (cmp_tags_Expr ltInt_RDR ah_RDR bh_RDR ltTag_Expr gtTag_Expr)))
+                       (cmp_tags_Expr ltInt_RDR ah_RDR bh_RDR ltTag_Expr gtTag_Expr))
 
     tycon_data_cons = tyConDataCons tycon
+    single_con_type = isSingleton tycon_data_cons
     (nullary_cons, nonnullary_cons)
        | isNewTyCon tycon = ([], tyConDataCons tycon)
        | otherwise       = partition isNullaryDataCon tycon_data_cons
 
-    cmp_eq =
-       mk_FunMonoBind tycon_loc 
-                      cmp_eq_RDR 
-                      (if null nonnullary_cons && isSingleton nullary_cons then
-                          -- catch this specially to avoid warnings
-                          -- about overlapping patterns from the desugarer.
-                         let 
-                          data_con     = head nullary_cons
-                          data_con_RDR = getRdrName data_con
-                           pat          = mkNullaryConPat data_con_RDR
-                          in
-                         [([pat,pat], eqTag_Expr)]
-                      else
-                         map pats_etc nonnullary_cons ++
-                         -- leave out wildcards to silence desugarer.
-                         (if isSingleton tycon_data_cons then
-                             []
-                          else
-                              [([wildPat, wildPat], default_rhs)]))
+    cmp_eq = mk_FunMonoBind tycon_loc cmp_eq_RDR cmp_eq_match
+    cmp_eq_match
+      | isEnumerationTyCon tycon
+                          -- We know the tags are equal, so if it's an enumeration TyCon,
+                          -- then there is nothing left to do
+                          -- Catch this specially to avoid warnings
+                          -- about overlapping patterns from the desugarer,
+                          -- and to avoid unnecessary pattern-matching
+      = [([wildPat,wildPat], eqTag_Expr)]
+      | otherwise
+      = map pats_etc nonnullary_cons ++
+       (if single_con_type then        -- Omit wildcards when there's just one 
+             []                        -- constructor, to silence desugarer
+       else
+              [([wildPat, wildPat], default_rhs)])
+
       where
        pats_etc data_con
          = ([con1_pat, con2_pat],
@@ -383,11 +366,11 @@ gen_Ord_binds tycon
            tys_needed  = dataConOrigArgTys data_con
 
            nested_compare_expr [ty] [a] [b]
-             = careful_compare_Case ty ltTag_Expr eqTag_Expr gtTag_Expr (HsVar a) (HsVar b)
+             = careful_compare_Case ty eqTag_Expr (HsVar a) (HsVar b)
 
            nested_compare_expr (ty:tys) (a:as) (b:bs)
              = let eq_expr = nested_compare_expr tys as bs
-               in  careful_compare_Case ty ltTag_Expr eq_expr gtTag_Expr (HsVar a) (HsVar b)
+               in  careful_compare_Case ty eq_expr (HsVar a) (HsVar b)
 
        default_rhs | null nullary_cons = impossible_Expr       -- Keep desugarer from complaining about
                                                                -- inexhaustive patterns
@@ -871,10 +854,11 @@ gen_Read_binds get_fixity tycon
                        BindStmt (VarPat a) (mkHsVarApps reset_RDR [readPrec_RDR]) loc]
 
        -- When reading field labels we might encounter
-       --      a = 3
+       --      a  = 3
+       --      _a = 3
        -- or   (#) = 4
        -- Note the parens!
-    read_lbl lbl | isAlpha (head lbl_str) 
+    read_lbl lbl | is_id_start (head lbl_str) 
                 = [bindLex (ident_pat lbl_lit)]
                 | otherwise
                 = [read_punc "(", 
@@ -883,6 +867,7 @@ gen_Read_binds get_fixity tycon
                 where  
                   lbl_str = occNameUserString (getOccName (fieldLabelName lbl)) 
                   lbl_lit = mkHsString lbl_str
+                  is_id_start c = isAlpha c || c == '_'
 \end{code}
 
 
@@ -1046,13 +1031,29 @@ gen_tag_n_con_monobind
 
 gen_tag_n_con_monobind (rdr_name, tycon, GenCon2Tag)
   | lots_of_constructors
-  = mk_FunMonoBind (getSrcLoc tycon) rdr_name 
-       [([VarPat a_RDR], HsApp getTag_Expr a_Expr)]
+  = mk_FunMonoBind loc rdr_name [([], get_tag_rhs)]
 
   | otherwise
-  = mk_FunMonoBind (getSrcLoc tycon) rdr_name (map mk_stuff (tyConDataCons tycon))
+  = mk_FunMonoBind loc rdr_name (map mk_stuff (tyConDataCons tycon))
 
   where
+    loc = getSrcLoc tycon
+
+       -- Give a signature to the bound variable, so 
+       -- that the case expression generated by getTag is
+       -- monomorphic.  In the push-enter model we get better code.
+    get_tag_rhs = ExprWithTySig 
+                       (HsLam (mk_match loc [VarPat a_RDR] 
+                                            (HsApp getTag_Expr a_Expr) 
+                                            EmptyBinds))
+                       (HsForAllTy Nothing [] con2tag_ty)
+                               -- Nothing => implicit quantification
+
+    con2tag_ty = foldl HsAppTy (HsTyVar (getRdrName tycon)) 
+                    [HsTyVar (getRdrName tv) | tv <- tyConTyVars tycon]
+               `HsFunTy` 
+               HsTyVar (getRdrName intPrimTyConName)
+
     lots_of_constructors = tyConFamilySize tycon > mAX_FAMILY_SIZE_FOR_VEC_RETURNS
 
     mk_stuff :: DataCon -> ([RdrNamePat], RdrNameHsExpr)
@@ -1106,7 +1107,7 @@ mk_easy_FunMonoBind loc fun pats binds expr
   = FunMonoBind fun False{-not infix-} [mk_easy_Match loc pats binds expr] loc
 
 mk_easy_Match loc pats binds expr
-  = mk_match loc pats expr (mkMonoBind (andMonoBindList binds) [] Recursive)
+  = mk_match loc pats expr (mkMonoBind Recursive (andMonoBindList binds))
        -- The renamer expects everything in its input to be a
        -- "recursive" MonoBinds, and it is its job to sort things out
        -- from there.
@@ -1145,34 +1146,35 @@ ToDo: Better SrcLocs.
 
 \begin{code}
 compare_gen_Case ::
-         RdrName
-         -> RdrNameHsExpr -> RdrNameHsExpr -> RdrNameHsExpr
+         RdrNameHsExpr -- What to do for equality
          -> RdrNameHsExpr -> RdrNameHsExpr
          -> RdrNameHsExpr
 careful_compare_Case :: -- checks for primitive types...
          Type
-         -> RdrNameHsExpr -> RdrNameHsExpr -> RdrNameHsExpr
+         -> RdrNameHsExpr      -- What to do for equality
          -> RdrNameHsExpr -> RdrNameHsExpr
          -> RdrNameHsExpr
 
 cmp_eq_Expr a b = HsApp (HsApp (HsVar cmp_eq_RDR) a) b
        -- Was: compare_gen_Case cmp_eq_RDR
 
-compare_gen_Case fun lt eq gt a b
-  = HsCase (HsPar (HsApp (HsApp (HsVar fun) a) b)) {-of-}
-      [mkSimpleMatch [mkNullaryConPat ltTag_RDR] lt placeHolderType generatedSrcLoc,
+compare_gen_Case (HsVar eq_tag) a b | eq_tag == eqTag_RDR
+  = HsApp (HsApp (HsVar compare_RDR) a) b      -- Simple case 
+compare_gen_Case eq a b                                -- General case
+  = HsCase (HsPar (HsApp (HsApp (HsVar compare_RDR) a) b)) {-of-}
+      [mkSimpleMatch [mkNullaryConPat ltTag_RDR] ltTag_Expr placeHolderType generatedSrcLoc,
        mkSimpleMatch [mkNullaryConPat eqTag_RDR] eq placeHolderType generatedSrcLoc,
-       mkSimpleMatch [mkNullaryConPat gtTag_RDR] gt placeHolderType generatedSrcLoc]
+       mkSimpleMatch [mkNullaryConPat gtTag_RDR] gtTag_Expr placeHolderType generatedSrcLoc]
       generatedSrcLoc
 
-careful_compare_Case ty lt eq gt a b
+careful_compare_Case ty eq a b
   | not (isUnLiftedType ty) =
-       compare_gen_Case compare_RDR lt eq gt a b
+       compare_gen_Case eq a b
   | otherwise               =
          -- we have to do something special for primitive things...
        HsIf (genOpApp a relevant_eq_op b)
            eq
-           (HsIf (genOpApp a relevant_lt_op b) lt gt generatedSrcLoc)
+           (HsIf (genOpApp a relevant_lt_op b) ltTag_Expr gtTag_Expr generatedSrcLoc)
            generatedSrcLoc
   where
     relevant_eq_op = assoc_ty_id eq_op_tbl ty