[project @ 2005-04-04 11:55:11 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcGenDeriv.lhs
index 1788cf6..b184513 100644 (file)
@@ -533,11 +533,11 @@ instance ... Ix (Foo ...) where
        map tag2con_Foo (enumFromTo (I# a#) (I# b#))
        }}
 
-    index c@(a, b) d
-      = if inRange c d
-       then case (con2tag_Foo d -# con2tag_Foo a) of
+    -- Generate code for unsafeIndex, becuase using index leads
+    -- to lots of redundant range tests
+    unsafeIndex c@(a, b) d
+      = case (con2tag_Foo d -# con2tag_Foo a) of
               r# -> I# r#
-       else error "Ix.Foo.index: out of range"
 
     inRange (a, b) c
       = let
@@ -573,7 +573,6 @@ gen_Ix_binds tycon
     then enum_ixes
     else single_con_ixes
   where
-    tycon_str = getOccString tycon
     tycon_loc = getSrcSpan tycon
 
     --------------------------------------------------------------
@@ -590,11 +589,10 @@ gen_Ix_binds tycon
                        (nlHsVarApps intDataCon_RDR [bh_RDR]))
 
     enum_index
-      = mk_easy_FunBind tycon_loc index_RDR 
+      = mk_easy_FunBind tycon_loc unsafeIndex_RDR 
                [noLoc (AsPat (noLoc c_RDR) 
                           (nlTuplePat [a_Pat, nlWildPat] Boxed)), 
                                d_Pat] emptyLHsBinds (
-       nlHsIf (nlHsPar (nlHsVarApps inRange_RDR [c_RDR, d_RDR])) (
           untag_Expr tycon [(a_RDR, ah_RDR)] (
           untag_Expr tycon [(d_RDR, dh_RDR)] (
           let
@@ -604,9 +602,7 @@ gen_Ix_binds tycon
             (genOpApp (nlHsVar dh_RDR) minusInt_RDR (nlHsVar ah_RDR))
             [mkSimpleHsAlt (nlVarPat c_RDR) rhs]
           ))
-       ) {-else-} (
-          nlHsApp (nlHsVar error_RDR) (nlHsLit (HsString (mkFastString ("Ix."++tycon_str++".index: out of range\n"))))
-       ))
+       )
 
     enum_inRange
       = mk_easy_FunBind tycon_loc inRange_RDR 
@@ -645,41 +641,35 @@ gen_Ix_binds tycon
     single_con_range
       = mk_easy_FunBind tycon_loc range_RDR 
          [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed] emptyLHsBinds $
-       nlHsDo ListComp stmts
+       nlHsDo ListComp stmts con_expr
       where
        stmts = zipWith3Equal "single_con_range" mk_qual as_needed bs_needed cs_needed
-               ++
-               [nlResultStmt con_expr]
 
-       mk_qual a b c = nlBindStmt (nlVarPat c)
+       mk_qual a b c = noLoc $ mkBindStmt (nlVarPat c)
                                 (nlHsApp (nlHsVar range_RDR) 
                                        (nlTuple [nlHsVar a, nlHsVar b] Boxed))
 
     ----------------
     single_con_index
-      = mk_easy_FunBind tycon_loc index_RDR 
+      = mk_easy_FunBind tycon_loc unsafeIndex_RDR 
                [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed, 
-                con_pat cs_needed] (unitBag range_size) (
-       foldl mk_index (nlHsIntLit 0) (zip3 as_needed bs_needed cs_needed))
+                con_pat cs_needed] emptyBag
+               (mk_index (zip3 as_needed bs_needed cs_needed))
       where
-       mk_index multiply_by (l, u, i)
+       -- index (l1,u1) i1 + rangeSize (l1,u1) * (index (l2,u2) i2 + ...)
+       mk_index []        = nlHsIntLit 0
+       mk_index [(l,u,i)] = mk_one l u i
+       mk_index ((l,u,i) : rest)
          = genOpApp (
-              (nlHsApps index_RDR [nlTuple [nlHsVar l, nlHsVar u] Boxed,  
-                                   nlHsVar i])
-          ) plus_RDR (
+               mk_one l u i
+           ) plus_RDR (
                genOpApp (
-                   (nlHsApp (nlHsVar rangeSize_RDR) 
+                   (nlHsApp (nlHsVar unsafeRangeSize_RDR) 
                           (nlTuple [nlHsVar l, nlHsVar u] Boxed))
-               ) times_RDR multiply_by
+               ) times_RDR (mk_index rest)
           )
-
-       range_size
-         = mk_easy_FunBind tycon_loc rangeSize_RDR 
-                       [nlTuplePat [a_Pat, b_Pat] Boxed] emptyLHsBinds (
-               genOpApp (
-                   (nlHsApps index_RDR [nlTuple [a_Expr, b_Expr] Boxed,
-                                        b_Expr])
-               ) plus_RDR (nlHsIntLit 1))
+       mk_one l u i
+         = nlHsApps unsafeIndex_RDR [nlTuple [nlHsVar l, nlHsVar u] Boxed, nlHsVar i]
 
     ------------------
     single_con_inRange
@@ -762,8 +752,8 @@ gen_Read_binds get_fixity tycon
     read_nullary_cons 
       = case nullary_cons of
            []    -> []
-           [con] -> [nlHsDo DoExpr [bindLex (ident_pat (data_con_str con)),
-                                    result_stmt con []]]
+           [con] -> [nlHsDo DoExpr [bindLex (ident_pat (data_con_str con))]
+                                   (result_expr con [])]
             _     -> [nlHsApp (nlHsVar choose_RDR) 
                            (nlList (map mk_pair nullary_cons))]
     
@@ -772,28 +762,28 @@ gen_Read_binds get_fixity tycon
                                Boxed
     
     read_non_nullary_con data_con
-      = nlHsApps prec_RDR [nlHsIntLit prec, nlHsDo DoExpr stmts]
+      = nlHsApps prec_RDR [nlHsIntLit prec, nlHsDo DoExpr stmts body]
       where
                stmts | is_infix          = infix_stmts
              | length labels > 0 = lbl_stmts
              | otherwise         = prefix_stmts
      
+       body = result_expr data_con as_needed
+       
                prefix_stmts            -- T a b c
                  = [bindLex (ident_pat (data_con_str_w_parens data_con))]
                    ++ read_args
-                   ++ [result_stmt data_con as_needed]
         
                infix_stmts             -- a %% b
                  = [read_a1, 
             bindLex (symbol_pat (data_con_str data_con)),
-            read_a2,
-            result_stmt data_con [a1,a2]]
+            read_a2]
      
                lbl_stmts               -- T { f1 = a, f2 = b }
                  = [bindLex (ident_pat (data_con_str_w_parens data_con)),
                     read_punc "{"]
                    ++ concat (intersperse [read_punc ","] field_stmts)
-                   ++ [read_punc "}", result_stmt data_con as_needed]
+                   ++ [read_punc "}"]
      
                field_stmts  = zipWithEqual "lbl_stmts" read_field labels as_needed
      
@@ -804,16 +794,15 @@ gen_Read_binds get_fixity tycon
                as_needed    = take con_arity as_RDRs
        read_args    = zipWithEqual "gen_Read_binds" read_arg as_needed (dataConOrigArgTys data_con)
                (read_a1:read_a2:_) = read_args
-       (a1:a2:_)           = as_needed
                prec         = getPrec is_infix get_fixity dc_nm
 
     ------------------------------------------------------------------------
     --         Helpers
     ------------------------------------------------------------------------
     mk_alt e1 e2     = genOpApp e1 alt_RDR e2
-    bindLex pat             = nlBindStmt pat (nlHsVar lexP_RDR)
-    result_stmt c as = nlResultStmt (nlHsApp (nlHsVar returnM_RDR) (con_app c as))
+    bindLex pat             = noLoc (mkBindStmt pat (nlHsVar lexP_RDR))
     con_app c as     = nlHsVarApps (getRdrName c) as
+    result_expr c as = nlHsApp (nlHsVar returnM_RDR) (con_app c as)
     
     punc_pat s   = nlConPat punc_RDR  [nlLitPat (mkHsString s)]          -- Punc 'c'
     ident_pat s  = nlConPat ident_RDR [nlLitPat s]               -- Ident "foo"
@@ -825,11 +814,11 @@ gen_Read_binds get_fixity tycon
     read_punc c = bindLex (punc_pat c)
     read_arg a ty 
        | isUnLiftedType ty = pprPanic "Error in deriving:" (text "Can't read unlifted types yet:" <+> ppr ty)
-       | otherwise = nlBindStmt (nlVarPat a) (nlHsVarApps step_RDR [readPrec_RDR])
+       | otherwise = noLoc (mkBindStmt (nlVarPat a) (nlHsVarApps step_RDR [readPrec_RDR]))
     
     read_field lbl a = read_lbl lbl ++
                       [read_punc "=",
-                       nlBindStmt (nlVarPat a) (nlHsVarApps reset_RDR [readPrec_RDR])]
+                       noLoc (mkBindStmt (nlVarPat a) (nlHsVarApps reset_RDR [readPrec_RDR]))]
 
        -- When reading field labels we might encounter
        --      a  = 3
@@ -1435,7 +1424,6 @@ bh_RDR            = mkVarUnqual FSLIT("b#")
 ch_RDR         = mkVarUnqual FSLIT("c#")
 dh_RDR         = mkVarUnqual FSLIT("d#")
 cmp_eq_RDR     = mkVarUnqual FSLIT("cmp_eq")
-rangeSize_RDR  = mkVarUnqual FSLIT("rangeSize")
 
 as_RDRs                = [ mkVarUnqual (mkFastString ("a"++show i)) | i <- [(1::Int) .. ] ]
 bs_RDRs                = [ mkVarUnqual (mkFastString ("b"++show i)) | i <- [(1::Int) .. ] ]