[project @ 2006-01-06 16:30:17 by simonmar]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcGenDeriv.lhs
index 8f8168b..94bb152 100644 (file)
@@ -54,7 +54,6 @@ import TysWiredIn     ( charDataCon, intDataCon, floatDataCon, doubleDataCon,
                          intDataCon_RDR, true_RDR, false_RDR )
 import Util            ( zipWithEqual, isSingleton,
                          zipWith3Equal, nOfThem, zipEqual )
-import Char            ( isAlpha )
 import Constants
 import List            ( partition, intersperse )
 import Outputable
@@ -167,7 +166,7 @@ gen_Eq_binds tycon
     in
     listToBag [
       mk_FunBind tycon_loc eq_RDR ((map pats_etc nonnullary_cons) ++ rest),
-      mk_easy_FunBind tycon_loc ne_RDR [a_Pat, b_Pat] emptyLHsBinds (
+      mk_easy_FunBind tycon_loc ne_RDR [a_Pat, b_Pat] (
        nlHsApp (nlHsVar not_RDR) (nlHsPar (nlHsVarApps eq_RDR [a_RDR, b_RDR])))
     ]
   where
@@ -298,8 +297,10 @@ gen_Ord_binds tycon
     tycon_loc = getSrcSpan tycon
     --------------------------------------------------------------------
 
-    compare = mk_easy_FunBind tycon_loc compare_RDR
-                                 [a_Pat, b_Pat] (unitBag cmp_eq) compare_rhs
+    compare = L tycon_loc (FunBind (L tycon_loc compare_RDR) False compare_matches placeHolderNames)
+    compare_matches = mkMatchGroup [mkMatch [a_Pat, b_Pat] compare_rhs cmp_eq_binds]
+    cmp_eq_binds    = HsValBinds (ValBindsIn (unitBag cmp_eq) [])
+
     compare_rhs
        | single_con_type = cmp_eq_Expr a_Expr b_Expr
        | otherwise
@@ -417,7 +418,7 @@ gen_Enum_binds tycon
     occ_nm    = getOccString tycon
 
     succ_enum
-      = mk_easy_FunBind tycon_loc succ_RDR [a_Pat] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc succ_RDR [a_Pat] $
        untag_Expr tycon [(a_RDR, ah_RDR)] $
        nlHsIf (nlHsApps eq_RDR [nlHsVar (maxtag_RDR tycon),
                               nlHsVarApps intDataCon_RDR [ah_RDR]])
@@ -427,7 +428,7 @@ gen_Enum_binds tycon
                                        nlHsIntLit 1]))
                    
     pred_enum
-      = mk_easy_FunBind tycon_loc pred_RDR [a_Pat] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc pred_RDR [a_Pat] $
        untag_Expr tycon [(a_RDR, ah_RDR)] $
        nlHsIf (nlHsApps eq_RDR [nlHsIntLit 0,
                               nlHsVarApps intDataCon_RDR [ah_RDR]])
@@ -437,7 +438,7 @@ gen_Enum_binds tycon
                                               nlHsLit (HsInt (-1))]))
 
     to_enum
-      = mk_easy_FunBind tycon_loc toEnum_RDR [a_Pat] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc toEnum_RDR [a_Pat] $
        nlHsIf (nlHsApps and_RDR
                [nlHsApps ge_RDR [nlHsVar a_RDR, nlHsIntLit 0],
                  nlHsApps le_RDR [nlHsVar a_RDR, nlHsVar (maxtag_RDR tycon)]])
@@ -445,7 +446,7 @@ gen_Enum_binds tycon
             (illegal_toEnum_tag occ_nm (maxtag_RDR tycon))
 
     enum_from
-      = mk_easy_FunBind tycon_loc enumFrom_RDR [a_Pat] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc enumFrom_RDR [a_Pat] $
          untag_Expr tycon [(a_RDR, ah_RDR)] $
          nlHsApps map_RDR 
                [nlHsVar (tag2con_RDR tycon),
@@ -454,7 +455,7 @@ gen_Enum_binds tycon
                            (nlHsVar (maxtag_RDR tycon)))]
 
     enum_from_then
-      = mk_easy_FunBind tycon_loc enumFromThen_RDR [a_Pat, b_Pat] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc enumFromThen_RDR [a_Pat, b_Pat] $
          untag_Expr tycon [(a_RDR, ah_RDR), (b_RDR, bh_RDR)] $
          nlHsApp (nlHsVarApps map_RDR [tag2con_RDR tycon]) $
            nlHsPar (enum_from_then_to_Expr
@@ -467,7 +468,7 @@ gen_Enum_binds tycon
                           ))
 
     from_enum
-      = mk_easy_FunBind tycon_loc fromEnum_RDR [a_Pat] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc fromEnum_RDR [a_Pat] $
          untag_Expr tycon [(a_RDR, ah_RDR)] $
          (nlHsVarApps intDataCon_RDR [ah_RDR])
 \end{code}
@@ -533,11 +534,11 @@ instance ... Ix (Foo ...) where
        map tag2con_Foo (enumFromTo (I# a#) (I# b#))
        }}
 
-    index c@(a, b) d
-      = if inRange c d
-       then case (con2tag_Foo d -# con2tag_Foo a) of
+    -- Generate code for unsafeIndex, becuase using index leads
+    -- to lots of redundant range tests
+    unsafeIndex c@(a, b) d
+      = case (con2tag_Foo d -# con2tag_Foo a) of
               r# -> I# r#
-       else error "Ix.Foo.index: out of range"
 
     inRange (a, b) c
       = let
@@ -573,15 +574,13 @@ gen_Ix_binds tycon
     then enum_ixes
     else single_con_ixes
   where
-    tycon_str = getOccString tycon
     tycon_loc = getSrcSpan tycon
 
     --------------------------------------------------------------
     enum_ixes = listToBag [ enum_range, enum_index, enum_inRange ]
 
     enum_range
-      = mk_easy_FunBind tycon_loc range_RDR 
-               [nlTuplePat [a_Pat, b_Pat] Boxed] emptyLHsBinds $
+      = mk_easy_FunBind tycon_loc range_RDR [nlTuplePat [a_Pat, b_Pat] Boxed] $
          untag_Expr tycon [(a_RDR, ah_RDR)] $
          untag_Expr tycon [(b_RDR, bh_RDR)] $
          nlHsApp (nlHsVarApps map_RDR [tag2con_RDR tycon]) $
@@ -590,11 +589,10 @@ gen_Ix_binds tycon
                        (nlHsVarApps intDataCon_RDR [bh_RDR]))
 
     enum_index
-      = mk_easy_FunBind tycon_loc index_RDR 
+      = mk_easy_FunBind tycon_loc unsafeIndex_RDR 
                [noLoc (AsPat (noLoc c_RDR) 
                           (nlTuplePat [a_Pat, nlWildPat] Boxed)), 
-                               d_Pat] emptyLHsBinds (
-       nlHsIf (nlHsPar (nlHsVarApps inRange_RDR [c_RDR, d_RDR])) (
+                               d_Pat] (
           untag_Expr tycon [(a_RDR, ah_RDR)] (
           untag_Expr tycon [(d_RDR, dh_RDR)] (
           let
@@ -604,13 +602,10 @@ gen_Ix_binds tycon
             (genOpApp (nlHsVar dh_RDR) minusInt_RDR (nlHsVar ah_RDR))
             [mkSimpleHsAlt (nlVarPat c_RDR) rhs]
           ))
-       ) {-else-} (
-          nlHsApp (nlHsVar error_RDR) (nlHsLit (HsString (mkFastString ("Ix."++tycon_str++".index: out of range\n"))))
-       ))
+       )
 
     enum_inRange
-      = mk_easy_FunBind tycon_loc inRange_RDR 
-         [nlTuplePat [a_Pat, b_Pat] Boxed, c_Pat] emptyLHsBinds (
+      = mk_easy_FunBind tycon_loc inRange_RDR [nlTuplePat [a_Pat, b_Pat] Boxed, c_Pat] $
          untag_Expr tycon [(a_RDR, ah_RDR)] (
          untag_Expr tycon [(b_RDR, bh_RDR)] (
          untag_Expr tycon [(c_RDR, ch_RDR)] (
@@ -618,7 +613,7 @@ gen_Ix_binds tycon
             (genOpApp (nlHsVar ch_RDR) leInt_RDR (nlHsVar bh_RDR))
          ) {-else-} (
             false_Expr
-         )))))
+         ))))
 
     --------------------------------------------------------------
     single_con_ixes 
@@ -644,50 +639,43 @@ gen_Ix_binds tycon
     --------------------------------------------------------------
     single_con_range
       = mk_easy_FunBind tycon_loc range_RDR 
-         [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed] emptyLHsBinds $
-       nlHsDo ListComp stmts
+         [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed] $
+       nlHsDo ListComp stmts con_expr
       where
        stmts = zipWith3Equal "single_con_range" mk_qual as_needed bs_needed cs_needed
-               ++
-               [nlResultStmt con_expr]
 
-       mk_qual a b c = nlBindStmt (nlVarPat c)
+       mk_qual a b c = noLoc $ mkBindStmt (nlVarPat c)
                                 (nlHsApp (nlHsVar range_RDR) 
                                        (nlTuple [nlHsVar a, nlHsVar b] Boxed))
 
     ----------------
     single_con_index
-      = mk_easy_FunBind tycon_loc index_RDR 
+      = mk_easy_FunBind tycon_loc unsafeIndex_RDR 
                [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed, 
-                con_pat cs_needed] (unitBag range_size) (
-       foldl mk_index (nlHsIntLit 0) (zip3 as_needed bs_needed cs_needed))
+                con_pat cs_needed] 
+               (mk_index (zip3 as_needed bs_needed cs_needed))
       where
-       mk_index multiply_by (l, u, i)
+       -- index (l1,u1) i1 + rangeSize (l1,u1) * (index (l2,u2) i2 + ...)
+       mk_index []        = nlHsIntLit 0
+       mk_index [(l,u,i)] = mk_one l u i
+       mk_index ((l,u,i) : rest)
          = genOpApp (
-              (nlHsApps index_RDR [nlTuple [nlHsVar l, nlHsVar u] Boxed,  
-                                   nlHsVar i])
-          ) plus_RDR (
+               mk_one l u i
+           ) plus_RDR (
                genOpApp (
-                   (nlHsApp (nlHsVar rangeSize_RDR) 
+                   (nlHsApp (nlHsVar unsafeRangeSize_RDR) 
                           (nlTuple [nlHsVar l, nlHsVar u] Boxed))
-               ) times_RDR multiply_by
+               ) times_RDR (mk_index rest)
           )
-
-       range_size
-         = mk_easy_FunBind tycon_loc rangeSize_RDR 
-                       [nlTuplePat [a_Pat, b_Pat] Boxed] emptyLHsBinds (
-               genOpApp (
-                   (nlHsApps index_RDR [nlTuple [a_Expr, b_Expr] Boxed,
-                                        b_Expr])
-               ) plus_RDR (nlHsIntLit 1))
+       mk_one l u i
+         = nlHsApps unsafeIndex_RDR [nlTuple [nlHsVar l, nlHsVar u] Boxed, nlHsVar i]
 
     ------------------
     single_con_inRange
       = mk_easy_FunBind tycon_loc inRange_RDR 
                [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed, 
-                con_pat cs_needed]
-                          emptyLHsBinds (
-         foldl1 and_Expr (zipWith3Equal "single_con_inRange" in_range as_needed bs_needed cs_needed))
+                con_pat cs_needed] $
+         foldl1 and_Expr (zipWith3Equal "single_con_inRange" in_range as_needed bs_needed cs_needed)
       where
        in_range a b c = nlHsApps inRange_RDR [nlTuple [nlHsVar a, nlHsVar b] Boxed,
                                               nlHsVar c]
@@ -762,38 +750,41 @@ gen_Read_binds get_fixity tycon
     read_nullary_cons 
       = case nullary_cons of
            []    -> []
-           [con] -> [nlHsDo DoExpr [bindLex (ident_pat (data_con_str con)),
-                                    result_stmt con []]]
+           [con] -> [nlHsDo DoExpr [bindLex (ident_pat (data_con_str con))]
+                                   (result_expr con [])]
             _     -> [nlHsApp (nlHsVar choose_RDR) 
                            (nlList (map mk_pair nullary_cons))]
     
-    mk_pair con = nlTuple [nlHsLit (data_con_str con),
-                                nlHsApp (nlHsVar returnM_RDR) (nlHsVar (getRdrName con))]
-                               Boxed
+    mk_pair con = nlTuple [nlHsLit (mkHsString (data_con_str con)),
+                                  nlHsApp (nlHsVar returnM_RDR) (nlHsVar (getRdrName con))]
+                                  Boxed
     
     read_non_nullary_con data_con
-      = nlHsApps prec_RDR [nlHsIntLit prec, nlHsDo DoExpr stmts]
+      = nlHsApps prec_RDR [nlHsIntLit prec, nlHsDo DoExpr stmts body]
       where
                stmts | is_infix          = infix_stmts
              | length labels > 0 = lbl_stmts
              | otherwise         = prefix_stmts
      
+       body = result_expr data_con as_needed
+       con_str = data_con_str data_con
+       
                prefix_stmts            -- T a b c
-                 = [bindLex (ident_pat (data_con_str_w_parens data_con))]
+                 = [bindLex (ident_pat (wrapOpParens con_str))]
                    ++ read_args
-                   ++ [result_stmt data_con as_needed]
         
-               infix_stmts             -- a %% b
-                 = [read_a1, 
-            bindLex (symbol_pat (data_con_str data_con)),
-            read_a2,
-            result_stmt data_con [a1,a2]]
+               infix_stmts             -- a %% b, or  a `T` b 
+                 = [read_a1]
+           ++  (if isSym con_str
+                then [bindLex (symbol_pat con_str)]
+                else [read_punc "`", bindLex (ident_pat con_str), read_punc "`"])
+           ++ [read_a2]
      
                lbl_stmts               -- T { f1 = a, f2 = b }
-                 = [bindLex (ident_pat (data_con_str_w_parens data_con)),
+                 = [bindLex (ident_pat (wrapOpParens con_str)),
                     read_punc "{"]
                    ++ concat (intersperse [read_punc ","] field_stmts)
-                   ++ [read_punc "}", result_stmt data_con as_needed]
+                   ++ [read_punc "}"]
      
                field_stmts  = zipWithEqual "lbl_stmts" read_field labels as_needed
      
@@ -804,48 +795,44 @@ gen_Read_binds get_fixity tycon
                as_needed    = take con_arity as_RDRs
        read_args    = zipWithEqual "gen_Read_binds" read_arg as_needed (dataConOrigArgTys data_con)
                (read_a1:read_a2:_) = read_args
-       (a1:a2:_)           = as_needed
                prec         = getPrec is_infix get_fixity dc_nm
 
     ------------------------------------------------------------------------
     --         Helpers
     ------------------------------------------------------------------------
     mk_alt e1 e2     = genOpApp e1 alt_RDR e2
-    bindLex pat             = nlBindStmt pat (nlHsVar lexP_RDR)
-    result_stmt c as = nlResultStmt (nlHsApp (nlHsVar returnM_RDR) (con_app c as))
+    bindLex pat             = noLoc (mkBindStmt pat (nlHsVar lexP_RDR))
     con_app c as     = nlHsVarApps (getRdrName c) as
+    result_expr c as = nlHsApp (nlHsVar returnM_RDR) (con_app c as)
     
-    punc_pat s   = nlConPat punc_RDR  [nlLitPat (mkHsString s)]          -- Punc 'c'
-    ident_pat s  = nlConPat ident_RDR [nlLitPat s]               -- Ident "foo"
-    symbol_pat s = nlConPat symbol_RDR [nlLitPat s]              -- Symbol ">>"
+    punc_pat s   = nlConPat punc_RDR   [nlLitPat (mkHsString s)]  -- Punc 'c'
+    ident_pat s  = nlConPat ident_RDR  [nlLitPat (mkHsString s)]  -- Ident "foo"
+    symbol_pat s = nlConPat symbol_RDR [nlLitPat (mkHsString s)]  -- Symbol ">>"
     
-    data_con_str          con = mkHsString (occNameUserString (getOccName con))
-    data_con_str_w_parens con = mkHsString (occNameUserString_with_parens (getOccName con))
+    data_con_str con = occNameString (getOccName con)
     
     read_punc c = bindLex (punc_pat c)
     read_arg a ty 
        | isUnLiftedType ty = pprPanic "Error in deriving:" (text "Can't read unlifted types yet:" <+> ppr ty)
-       | otherwise = nlBindStmt (nlVarPat a) (nlHsVarApps step_RDR [readPrec_RDR])
+       | otherwise = noLoc (mkBindStmt (nlVarPat a) (nlHsVarApps step_RDR [readPrec_RDR]))
     
     read_field lbl a = read_lbl lbl ++
                       [read_punc "=",
-                       nlBindStmt (nlVarPat a) (nlHsVarApps reset_RDR [readPrec_RDR])]
+                       noLoc (mkBindStmt (nlVarPat a) (nlHsVarApps reset_RDR [readPrec_RDR]))]
 
        -- When reading field labels we might encounter
        --      a  = 3
        --      _a = 3
        -- or   (#) = 4
        -- Note the parens!
-    read_lbl lbl | is_id_start (head lbl_str) 
-                = [bindLex (ident_pat lbl_lit)]
-                | otherwise
+    read_lbl lbl | isSym lbl_str 
                 = [read_punc "(", 
-                   bindLex (symbol_pat lbl_lit),
+                   bindLex (symbol_pat lbl_str),
                    read_punc ")"]
+                | otherwise
+                = [bindLex (ident_pat lbl_str)]
                 where  
-                  lbl_str = occNameUserString (getOccName lbl) 
-                  lbl_lit = mkHsString lbl_str
-                  is_id_start c = isAlpha c || c == '_'
+                  lbl_str = occNameString (getOccName lbl) 
 \end{code}
 
 
@@ -912,11 +899,12 @@ gen_Show_binds get_fixity tycon
 
             dc_nm          = getName data_con
             dc_occ_nm      = getOccName data_con
-             con_str        = occNameUserString dc_occ_nm
-            op_con_str     = occNameUserString_with_parens dc_occ_nm
+             con_str        = occNameString dc_occ_nm
+            op_con_str     = wrapOpParens con_str
+            backquote_str  = wrapOpBackquotes con_str
 
             show_thingies 
-               | is_infix      = [show_arg1, mk_showString_app (" " ++ con_str ++ " "), show_arg2]
+               | is_infix      = [show_arg1, mk_showString_app (" " ++ backquote_str ++ " "), show_arg2]
                | record_syntax = mk_showString_app (op_con_str ++ " {") : 
                                  show_record_args ++ [mk_showString_app "}"]
                | otherwise     = mk_showString_app (op_con_str ++ " ") : show_prefix_args
@@ -928,7 +916,7 @@ gen_Show_binds get_fixity tycon
                        -- it seems tidier to have them both sides.
                 where
                   occ_nm   = getOccName l
-                  nm       = occNameUserString_with_parens occ_nm
+                  nm       = wrapOpParens (occNameString occ_nm)
 
              show_args                      = zipWith show_arg bs_needed arg_tys
             (show_arg1:show_arg2:_) = show_args
@@ -954,12 +942,17 @@ gen_Show_binds get_fixity tycon
             arg_prec | record_syntax = 0       -- Record fields don't need parens
                      | otherwise     = con_prec_plus_one
 
-occNameUserString_with_parens :: OccName -> String
-occNameUserString_with_parens occ
-  | isSymOcc occ = '(':nm ++ ")"
-  | otherwise    = nm
-  where
-   nm = occNameUserString occ
+wrapOpParens :: String -> String
+wrapOpParens s | isSym s   = '(' : s ++ ")"
+              | otherwise = s
+
+wrapOpBackquotes :: String -> String
+wrapOpBackquotes s | isSym s   = s
+                  | otherwise = '`' : s ++ "`"
+
+isSym :: String -> Bool
+isSym ""     = False
+isSym (c:cs) = startsVarSym c || startsConSym c
 
 mk_showString_app str = nlHsApp (nlHsVar showString_RDR) (nlHsLit (mkHsString str))
 \end{code}
@@ -1005,7 +998,7 @@ gen_Typeable_binds tycon
   = unitBag $
        mk_easy_FunBind tycon_loc 
                (mk_typeOf_RDR tycon)   -- Name of appropriate type0f function
-               [nlWildPat] emptyLHsBinds
+               [nlWildPat] 
                (nlHsApps mkTypeRep_RDR [tycon_rep, nlList []])
   where
     tycon_loc = getSrcSpan tycon
@@ -1111,10 +1104,9 @@ gen_Data_binds fix_env tycon
                         tycon_loc
                         dataTypeOf_RDR
                        [nlWildPat]
-                        emptyLHsBinds
                         (nlHsVar data_type_name)
 
-       ------------ $dT
+       ------------  $dT
 
     data_type_name = mkDerivedRdrName tycon_name mkDataTOcc
     datatype_bind  = mkVarBind
@@ -1127,7 +1119,7 @@ gen_Data_binds fix_env tycon
     constrs = [nlHsVar (mk_constr_name con) | con <- data_cons]
 
 
-       ------------ $cT1 etc
+       ------------  $cT1 etc
     mk_constr_name con = mkDerivedRdrName (dataConName con) mkDataCOcc
     mk_con_bind dc = mkVarBind
                        tycon_loc
@@ -1136,7 +1128,7 @@ gen_Data_binds fix_env tycon
     constr_args dc =
         [ -- nlHsIntLit (toInteger (dataConTag dc)),           -- Tag
           nlHsVar data_type_name,                              -- DataType
-          nlHsLit (mkHsString (occNameUserString dc_occ)),     -- String name
+          nlHsLit (mkHsString (occNameString dc_occ)), -- String name
            nlList  labels,                                     -- Field labels
           nlHsVar fixity]                                      -- Fixity
        where
@@ -1435,7 +1427,6 @@ bh_RDR            = mkVarUnqual FSLIT("b#")
 ch_RDR         = mkVarUnqual FSLIT("c#")
 dh_RDR         = mkVarUnqual FSLIT("d#")
 cmp_eq_RDR     = mkVarUnqual FSLIT("cmp_eq")
-rangeSize_RDR  = mkVarUnqual FSLIT("rangeSize")
 
 as_RDRs                = [ mkVarUnqual (mkFastString ("a"++show i)) | i <- [(1::Int) .. ] ]
 bs_RDRs                = [ mkVarUnqual (mkFastString ("b"++show i)) | i <- [(1::Int) .. ] ]
@@ -1467,7 +1458,7 @@ mk_tc_deriv_name tycon str
   = mkDerivedRdrName tc_name mk_occ
   where
     tc_name = tyConName tycon
-    mk_occ tc_occ = mkOccFS varName (mkFastString new_str)
+    mk_occ tc_occ = mkVarOccFS (mkFastString new_str)
                  where
                    new_str = str ++ occNameString tc_occ ++ "#"
 \end{code}