Print infix function definitions correctly in HsSyn
[ghc-hetmet.git] / compiler / hsSyn / HsExpr.lhs
index 9161d46..8830155 100644 (file)
@@ -674,8 +674,8 @@ pprMatches ctxt (MatchGroup matches ty) = vcat (map (pprMatch ctxt) (map unLoc m
                                           -- a place-holder before typechecking
 
 -- Exported to HsBinds, which can't see the defn of HsMatchContext
-pprFunBind :: (OutputableBndr id) => id -> MatchGroup id -> SDoc
-pprFunBind fun matches = pprMatches (FunRhs fun) matches
+pprFunBind :: (OutputableBndr id) => id -> Bool -> MatchGroup id -> SDoc
+pprFunBind fun inf matches = pprMatches (FunRhs fun inf) matches
 
 -- Exported to HsBinds, which can't see the defn of HsMatchContext
 pprPatBind :: (OutputableBndr bndr, OutputableBndr id)
@@ -685,14 +685,29 @@ pprPatBind pat grhss = sep [ppr pat, nest 4 (pprGRHSs PatBindRhs grhss)]
 
 pprMatch :: OutputableBndr id => HsMatchContext id -> Match id -> SDoc
 pprMatch ctxt (Match pats maybe_ty grhss)
-  = pp_name ctxt <+> sep [sep (map ppr pats), 
-                    ppr_maybe_ty, 
-                    nest 2 (pprGRHSs ctxt grhss)]
+  = herald <+> sep [sep (map ppr other_pats), 
+                   ppr_maybe_ty, 
+                   nest 2 (pprGRHSs ctxt grhss)]
   where
-    pp_name (FunRhs fun) = ppr fun     -- Not pprBndr; the AbsBinds will
-                                       -- have printed the signature
-    pp_name LambdaExpr   = char '\\'
-    pp_name other       = empty
+    (herald, other_pats) 
+       = case ctxt of
+           FunRhs fun is_infix
+               | not is_infix -> (ppr fun, pats)
+                       -- f x y z = e
+                       -- Not pprBndr; the AbsBinds will
+                       -- have printed the signature
+
+               | null pats3 -> (pp_infix, [])
+                       -- x &&& y = e
+
+               | otherwise -> (parens pp_infix, pats3)
+                       -- (x &&& y) z = e
+               where
+                 (pat1:pat2:pats3) = pats
+                 pp_infix = ppr pat1 <+> ppr fun <+> ppr pat2
+
+           LambdaExpr -> (char '\\', pats)
+           other      -> (empty,     pats)
 
     ppr_maybe_ty = case maybe_ty of
                        Just ty -> dcolon <+> ppr ty
@@ -918,7 +933,7 @@ pp_dotdot = ptext SLIT(" .. ")
 
 \begin{code}
 data HsMatchContext id -- Context of a Match
-  = FunRhs id                  -- Function binding for f
+  = FunRhs id Bool             -- Function binding for f; True <=> written infix
   | CaseAlt                    -- Guard on a case alternative
   | LambdaExpr                 -- Pattern of a lambda
   | ProcExpr                   -- Pattern of a proc
@@ -952,7 +967,7 @@ isListCompExpr _        = False
 \end{code}
 
 \begin{code}
-matchSeparator (FunRhs _)   = ptext SLIT("=")
+matchSeparator (FunRhs {})  = ptext SLIT("=")
 matchSeparator CaseAlt      = ptext SLIT("->") 
 matchSeparator LambdaExpr   = ptext SLIT("->") 
 matchSeparator ProcExpr     = ptext SLIT("->") 
@@ -962,7 +977,7 @@ matchSeparator RecUpd       = panic "unused"
 \end{code}
 
 \begin{code}
-pprMatchContext (FunRhs fun)     = ptext SLIT("the definition of") <+> quotes (ppr fun)
+pprMatchContext (FunRhs fun _)           = ptext SLIT("the definition of") <+> quotes (ppr fun)
 pprMatchContext CaseAlt                  = ptext SLIT("a case alternative")
 pprMatchContext RecUpd           = ptext SLIT("a record-update construct")
 pprMatchContext PatBindRhs       = ptext SLIT("a pattern binding")
@@ -993,7 +1008,7 @@ pprStmtResultContext other      = ptext SLIT("the result of") <+> pprStmtContext
 -}
 
 -- Used to generate the string for a *runtime* error message
-matchContextErrString (FunRhs fun)              = "function " ++ showSDoc (ppr fun)
+matchContextErrString (FunRhs fun _)                    = "function " ++ showSDoc (ppr fun)
 matchContextErrString CaseAlt                   = "case"
 matchContextErrString PatBindRhs                = "pattern binding"
 matchContextErrString RecUpd                    = "record update"