Add bang patterns
[ghc-hetmet.git] / ghc / compiler / deSugar / Match.lhs
index ebe503a..bbc37b3 100644 (file)
@@ -8,14 +8,14 @@ module Match ( match, matchWrapper, matchSimply, matchSinglePat ) where
 
 #include "HsVersions.h"
 
-import CmdLineOpts     ( DynFlag(..), dopt )
+import DynFlags        ( DynFlag(..), dopt )
 import HsSyn           
-import TcHsSyn         ( hsPatType )
+import TcHsSyn         ( mkVanillaTuplePat )
 import Check            ( check, ExhaustivePat )
 import CoreSyn
 import CoreUtils       ( bindNonRec, exprType )
 import DsMonad
-import DsBinds         ( dsHsNestedBinds )
+import DsBinds         ( dsLHsBinds )
 import DsGRHSs         ( dsGRHSs )
 import DsUtils
 import Id              ( idName, idType, Id )
@@ -24,12 +24,12 @@ import MatchCon             ( matchConFamily )
 import MatchLit                ( matchLiterals, matchNPlusKPats, matchNPats, tidyLitPat, tidyNPat )
 import PrelInfo                ( pAT_ERROR_ID )
 import TcType          ( Type, tcTyConAppArgs )
-import Type            ( splitFunTysN )
-import TysWiredIn      ( consDataCon, mkTupleTy, mkListTy,
+import Type            ( splitFunTysN, mkTyVarTys )
+import TysWiredIn      ( consDataCon, mkListTy, unitTy,
                          tupleCon, parrFakeCon, mkPArrTy )
 import BasicTypes      ( Boxity(..) )
 import ListSetOps      ( runs )
-import SrcLoc          ( noSrcSpan, noLoc, unLoc, Located(..) )
+import SrcLoc          ( noLoc, unLoc, Located(..) )
 import Util             ( lengthExceeds, notNull )
 import Name            ( Name )
 import Outputable
@@ -90,19 +90,21 @@ The next two functions create the warning message.
 
 \begin{code}
 dsShadowWarn :: DsMatchContext -> [EquationInfo] -> DsM ()
-dsShadowWarn ctx@(DsMatchContext kind _ _) qs = dsWarn warn 
-       where
-         warn | qs `lengthExceeds` maximum_output
-               = pp_context ctx (ptext SLIT("are overlapped"))
-                           (\ f -> vcat (map (ppr_eqn f kind) (take maximum_output qs)) $$
-                           ptext SLIT("..."))
-              | otherwise
-               = pp_context ctx (ptext SLIT("are overlapped"))
-                           (\ f -> vcat $ map (ppr_eqn f kind) qs)
+dsShadowWarn ctx@(DsMatchContext kind _ loc) qs
+  = putSrcSpanDs loc (dsWarn warn)
+  where
+    warn | qs `lengthExceeds` maximum_output
+         = pp_context ctx (ptext SLIT("are overlapped"))
+                     (\ f -> vcat (map (ppr_eqn f kind) (take maximum_output qs)) $$
+                     ptext SLIT("..."))
+        | otherwise
+         = pp_context ctx (ptext SLIT("are overlapped"))
+                     (\ f -> vcat $ map (ppr_eqn f kind) qs)
 
 
 dsIncompleteWarn :: DsMatchContext -> [ExhaustivePat] -> DsM ()
-dsIncompleteWarn ctx@(DsMatchContext kind _ _) pats = dsWarn warn 
+dsIncompleteWarn ctx@(DsMatchContext kind _ loc) pats 
+  = putSrcSpanDs loc (dsWarn warn)
        where
          warn = pp_context ctx (ptext SLIT("are non-exhaustive"))
                            (\f -> hang (ptext SLIT("Patterns not matched:"))
@@ -113,12 +115,9 @@ dsIncompleteWarn ctx@(DsMatchContext kind _ _) pats = dsWarn warn
          dots | pats `lengthExceeds` maximum_output = ptext SLIT("...")
               | otherwise                           = empty
 
-pp_context NoMatchContext msg rest_of_msg_fun
-  = (noSrcSpan, ptext SLIT("Some match(es)") <+> hang msg 8 (rest_of_msg_fun id))
-
-pp_context (DsMatchContext kind pats loc) msg rest_of_msg_fun
-  = (loc, vcat [ptext SLIT("Pattern match(es)") <+> msg,
-               sep [ptext SLIT("In") <+> ppr_match <> char ':', nest 4 (rest_of_msg_fun pref)]])
+pp_context (DsMatchContext kind pats _loc) msg rest_of_msg_fun
+  = vcat [ptext SLIT("Pattern match(es)") <+> msg,
+         sep [ptext SLIT("In") <+> ppr_match <> char ':', nest 4 (rest_of_msg_fun pref)]]
   where
     (ppr_match, pref)
        = case kind of
@@ -248,7 +247,7 @@ match [] ty eqns_info
     returnDs (foldr1 combineMatchResults match_results)
   where
     match_results = [ ASSERT( null (eqn_pats eqn) ) 
-                     eqn_rhs eqn
+                     adjustMatchResult (eqn_wrap eqn) (eqn_rhs eqn)
                    | eqn <- eqns_info ]
 \end{code}
 
@@ -284,19 +283,19 @@ match vars@(v:_) ty eqns_info
  
     match_block eqns
       = case firstPat (head eqns) of
-         WildPat {}      -> matchVariables  vars ty eqns
-         ConPatOut {}    -> matchConFamily  vars ty eqns
-         NPlusKPatOut {} -> matchNPlusKPats vars ty eqns
-         NPatOut {}      -> matchNPats      vars ty eqns
-         LitPat {}       -> matchLiterals   vars ty eqns
+         WildPat {}   -> matchVariables  vars ty eqns
+         ConPatOut {} -> matchConFamily  vars ty eqns
+         NPlusKPat {} -> matchNPlusKPats vars ty eqns
+         NPat {}      -> matchNPats      vars ty eqns
+         LitPat {}    -> matchLiterals   vars ty eqns
 
 -- After tidying, there are only five kinds of patterns
-samePatFamily (WildPat {})     (WildPat {})      = True
-samePatFamily (ConPatOut {})   (ConPatOut {})    = True
-samePatFamily (NPlusKPatOut {}) (NPlusKPatOut {}) = True
-samePatFamily (NPatOut {})     (NPatOut {})      = True
-samePatFamily (LitPat {})       (LitPat {})      = True
-samePatFamily _                        _                 = False
+samePatFamily (WildPat {})   (WildPat {})   = True
+samePatFamily (ConPatOut {}) (ConPatOut {}) = True
+samePatFamily (NPlusKPat {}) (NPlusKPat {}) = True
+samePatFamily (NPat {})             (NPat {})      = True
+samePatFamily (LitPat {})    (LitPat {})    = True
+samePatFamily _                     _              = False
 
 matchVariables :: [Id] -> Type -> [EquationInfo] -> DsM MatchResult
 -- Real true variables, just like in matchVar, SLPJ p 94
@@ -344,7 +343,7 @@ Float,      Double, at least) are converted to unboxed form; e.g.,
 
 \begin{code}
 tidyEqnInfo :: Id -> EquationInfo -> DsM EquationInfo
-       -- DsM'd because of internal call to dsHsNestedBinds
+       -- DsM'd because of internal call to dsLHsBinds
        --      and mkSelectorBinds.
        -- "tidy1" does the interesting stuff, looking at
        -- one pattern and fiddling the list of bindings.
@@ -357,15 +356,15 @@ tidyEqnInfo :: Id -> EquationInfo -> DsM EquationInfo
        --      NPlusKPat
        -- but no other
 
-tidyEqnInfo v eqn@(EqnInfo { eqn_pats = pat : pats, eqn_rhs = rhs })
-  = tidy1 v pat rhs    `thenDs` \ (pat', rhs') ->
-    returnDs (eqn { eqn_pats = pat' : pats, eqn_rhs = rhs' })
+tidyEqnInfo v eqn@(EqnInfo { eqn_wrap = wrap, eqn_pats = pat : pats })
+  = tidy1 v wrap pat   `thenDs` \ (wrap', pat') ->
+    returnDs (eqn { eqn_wrap = wrap', eqn_pats = pat' : pats })
 
 tidy1 :: Id                    -- The Id being scrutinised
+      -> DsWrapper             -- Previous wrapping bindings
       -> Pat Id                -- The pattern against which it is to be matched
-      -> MatchResult           -- What to do afterwards
-      -> DsM (Pat Id,          -- Equivalent pattern
-             MatchResult)      -- Extra bindings around what to do afterwards
+      -> DsM (DsWrapper,       -- Extra bindings around what to do afterwards
+             Pat Id)           -- Equivalent pattern
 
 -- The extra bindings etc are all wrapped around the RHS of the match
 -- so they are only available when matching is complete.  But that's ok
@@ -392,26 +391,27 @@ tidy1 :: Id                       -- The Id being scrutinised
 --     NPat
 --     NPlusKPat
 
-tidy1 v (ParPat pat)      wrap = tidy1 v (unLoc pat) wrap 
-tidy1 v (SigPatOut pat _) wrap = tidy1 v (unLoc pat) wrap 
-tidy1 v (WildPat ty)      wrap = returnDs (WildPat ty, wrap)
+tidy1 v wrap (ParPat pat)      = tidy1 v wrap (unLoc pat) 
+tidy1 v wrap (SigPatOut pat _) = tidy1 v wrap (unLoc pat) 
+tidy1 v wrap (WildPat ty)      = returnDs (wrap, WildPat ty)
 
        -- case v of { x -> mr[] }
        -- = case v of { _ -> let x=v in mr[] }
-tidy1 v (VarPat var) rhs
-  = returnDs (WildPat (idType var), bindOneInMatchResult var v rhs)
+tidy1 v wrap (VarPat var)
+  = returnDs (wrap . wrapBind var v, WildPat (idType var)) 
 
-tidy1 v (VarPatOut var binds) rhs
-  = do { prs <- dsHsNestedBinds binds
-       ; return (WildPat (idType var), 
-                 bindOneInMatchResult var v $
-                 mkCoLetMatchResult (Rec prs) rhs) }
+tidy1 v wrap (VarPatOut var binds)
+  = do { prs <- dsLHsBinds binds
+       ; return (wrap . wrapBind var v . mkDsLet (Rec prs),
+                 WildPat (idType var)) }
 
        -- case v of { x@p -> mr[] }
        -- = case v of { p -> let x=v in mr[] }
-tidy1 v (AsPat (L _ var) pat) rhs
-  = tidy1 v (unLoc pat) (bindOneInMatchResult var v rhs)
+tidy1 v wrap (AsPat (L _ var) pat)
+  = tidy1 v (wrap . wrapBind var v) (unLoc pat)
 
+tidy1 v wrap (BangPat pat)
+  = tidy1 v (wrap . seqVar v) (unLoc pat)
 
 {- now, here we handle lazy patterns:
     tidy1 v ~p bs = (v, v1 = case v of p -> v1 :
@@ -424,23 +424,22 @@ tidy1 v (AsPat (L _ var) pat) rhs
     The case expr for v_i is just: match [v] [(p, [], \ x -> Var v_i)] any_expr
 -}
 
-tidy1 v (LazyPat pat) rhs
+tidy1 v wrap (LazyPat pat)
   = do { v' <- newSysLocalDs (idType v)
        ; sel_prs <- mkSelectorBinds pat (Var v)
        ; let sel_binds =  [NonRec b rhs | (b,rhs) <- sel_prs]
-       ; returnDs (WildPat (idType v), 
-                   bindOneInMatchResult v' v $
-                   mkCoLetsMatchResult sel_binds rhs) }
+       ; returnDs (wrap . wrapBind v' v . mkDsLets sel_binds,
+                   WildPat (idType v)) }
 
 -- re-express <con-something> as (ConPat ...) [directly]
 
-tidy1 v (ConPatOut (L loc con) ex_tvs dicts binds ps pat_ty) rhs
-  = returnDs (ConPatOut (L loc con) ex_tvs dicts binds tidy_ps pat_ty, rhs)
+tidy1 v wrap (ConPatOut (L loc con) ex_tvs dicts binds ps pat_ty)
+  = returnDs (wrap, ConPatOut (L loc con) ex_tvs dicts binds tidy_ps pat_ty)
   where
-    tidy_ps = PrefixCon (tidy_con con pat_ty ps)
+    tidy_ps = PrefixCon (tidy_con con ex_tvs pat_ty ps)
 
-tidy1 v (ListPat pats ty) rhs
-  = returnDs (unLoc list_ConPat, rhs)
+tidy1 v wrap (ListPat pats ty)
+  = returnDs (wrap, unLoc list_ConPat)
   where
     list_ty     = mkListTy ty
     list_ConPat = foldr (\ x y -> mkPrefixConPat consDataCon [x, y] list_ty)
@@ -449,45 +448,44 @@ tidy1 v (ListPat pats ty) rhs
 
 -- Introduce fake parallel array constructors to be able to handle parallel
 -- arrays with the existing machinery for constructor pattern
-tidy1 v (PArrPat pats ty) rhs
-  = returnDs (unLoc parrConPat, rhs)
+tidy1 v wrap (PArrPat pats ty)
+  = returnDs (wrap, unLoc parrConPat)
   where
     arity      = length pats
     parrConPat = mkPrefixConPat (parrFakeCon arity) pats (mkPArrTy ty)
 
-tidy1 v (TuplePat pats boxity) rhs
-  = returnDs (unLoc tuple_ConPat, rhs)
+tidy1 v wrap (TuplePat pats boxity ty)
+  = returnDs (wrap, unLoc tuple_ConPat)
   where
     arity = length pats
-    tuple_ConPat = mkPrefixConPat (tupleCon boxity arity) pats
-                                 (mkTupleTy boxity arity (map hsPatType pats))
+    tuple_ConPat = mkPrefixConPat (tupleCon boxity arity) pats ty
 
-tidy1 v (DictPat dicts methods) rhs
+tidy1 v wrap (DictPat dicts methods)
   = case num_of_d_and_ms of
-       0 -> tidy1 v (TuplePat [] Boxed) rhs
-       1 -> tidy1 v (unLoc (head dict_and_method_pats)) rhs
-       _ -> tidy1 v (TuplePat dict_and_method_pats Boxed) rhs
+       0 -> tidy1 v wrap (TuplePat [] Boxed unitTy) 
+       1 -> tidy1 v wrap (unLoc (head dict_and_method_pats))
+       _ -> tidy1 v wrap (mkVanillaTuplePat dict_and_method_pats Boxed)
   where
     num_of_d_and_ms     = length dicts + length methods
     dict_and_method_pats = map nlVarPat (dicts ++ methods)
 
 -- LitPats: we *might* be able to replace these w/ a simpler form
-tidy1 v pat@(LitPat lit) rhs
-  = returnDs (unLoc (tidyLitPat lit (noLoc pat)), rhs)
+tidy1 v wrap pat@(LitPat lit)
+  = returnDs (wrap, unLoc (tidyLitPat lit (noLoc pat)))
 
 -- NPats: we *might* be able to replace these w/ a simpler form
-tidy1 v pat@(NPatOut lit lit_ty _) rhs
-  = returnDs (unLoc (tidyNPat lit lit_ty (noLoc pat)), rhs)
+tidy1 v wrap pat@(NPat lit mb_neg _ lit_ty)
+  = returnDs (wrap, unLoc (tidyNPat lit mb_neg lit_ty (noLoc pat)))
 
 -- and everything else goes through unchanged...
 
-tidy1 v non_interesting_pat rhs
-  = returnDs (non_interesting_pat, rhs)
+tidy1 v wrap non_interesting_pat
+  = returnDs (wrap, non_interesting_pat)
 
 
-tidy_con data_con pat_ty (PrefixCon ps)   = ps
-tidy_con data_con pat_ty (InfixCon p1 p2) = [p1,p2]
-tidy_con data_con pat_ty (RecCon rpats)
+tidy_con data_con ex_tvs pat_ty (PrefixCon ps)   = ps
+tidy_con data_con ex_tvs pat_ty (InfixCon p1 p2) = [p1,p2]
+tidy_con data_con ex_tvs pat_ty (RecCon rpats)
   | null rpats
   =    -- Special case for C {}, which can be used for 
        -- a constructor that isn't declared to have
@@ -495,14 +493,13 @@ tidy_con data_con pat_ty (RecCon rpats)
     map (noLoc . WildPat) con_arg_tys'
 
   | otherwise
-  = ASSERT( isVanillaDataCon data_con )
-       -- We're in a record case, so the data con must be vanilla
-       -- and hence no existentials to worry about
-    map mk_pat tagged_arg_tys
+  = map mk_pat tagged_arg_tys
   where
        -- Boring stuff to find the arg-tys of the constructor
        
-    inst_tys         = tcTyConAppArgs pat_ty   -- Newtypes must be opaque
+    inst_tys | isVanillaDataCon data_con = tcTyConAppArgs pat_ty       -- Newtypes must be opaque
+            | otherwise                 = mkTyVarTys ex_tvs
+
     con_arg_tys'     = dataConInstOrigArgTys data_con inst_tys
     tagged_arg_tys   = con_arg_tys' `zip` dataConFieldLabels data_con
 
@@ -673,7 +670,8 @@ matchWrapper ctxt (MatchGroup matches match_ty)
     mk_eqn_info (L _ (Match pats _ grhss))
       = do { let upats = map unLoc pats
           ; match_result <- dsGRHSs ctxt upats grhss rhs_ty
-          ; return (EqnInfo { eqn_pats = upats, 
+          ; return (EqnInfo { eqn_wrap = idWrapper,
+                              eqn_pats = upats, 
                               eqn_rhs = match_result}) }
 
     match_fun dflags ds_ctxt
@@ -701,32 +699,35 @@ matchSimply :: CoreExpr                   -- Scrutinee
            -> CoreExpr                 -- Return this if it doesn't
            -> DsM CoreExpr
 
-matchSimply scrut kind pat result_expr fail_expr
-  = getSrcSpanDs                               `thenDs` \ locn ->
-    let
-      ctx         = DsMatchContext kind [unLoc pat] locn
+matchSimply scrut hs_ctx pat result_expr fail_expr
+  = let
       match_result = cantFailMatchResult result_expr
       rhs_ty      = exprType fail_expr
        -- Use exprType of fail_expr, because won't refine in the case of failure!
     in 
-    matchSinglePat scrut ctx pat rhs_ty match_result   `thenDs` \ match_result' ->
+    matchSinglePat scrut hs_ctx pat rhs_ty match_result        `thenDs` \ match_result' ->
     extractMatchResult match_result' fail_expr
 
 
-matchSinglePat :: CoreExpr -> DsMatchContext -> LPat Id
+matchSinglePat :: CoreExpr -> HsMatchContext Name -> LPat Id
               -> Type -> MatchResult -> DsM MatchResult
-matchSinglePat (Var var) ctx pat ty match_result
-  = getDOptsDs                                 `thenDs` \ dflags ->
-    match_fn dflags [var] ty [EqnInfo { eqn_pats = [unLoc pat],
+matchSinglePat (Var var) hs_ctx (L _ pat) ty match_result
+  = getDOptsDs                         `thenDs` \ dflags ->
+    getSrcSpanDs                       `thenDs` \ locn ->
+    let
+       match_fn dflags
+           | dopt Opt_WarnSimplePatterns dflags = matchCheck ds_ctx
+          | otherwise                          = match
+          where
+            ds_ctx = DsMatchContext hs_ctx [pat] locn
+    in
+    match_fn dflags [var] ty [EqnInfo { eqn_wrap = idWrapper,
+                                       eqn_pats = [pat],
                                        eqn_rhs  = match_result }]
-  where
-    match_fn dflags
-       | dopt Opt_WarnSimplePatterns dflags = matchCheck ctx
-       | otherwise                         = match
 
-matchSinglePat scrut ctx pat ty match_result
+matchSinglePat scrut hs_ctx pat ty match_result
   = selectSimpleMatchVarL pat                          `thenDs` \ var ->
-    matchSinglePat (Var var) ctx pat ty match_result   `thenDs` \ match_result' ->
+    matchSinglePat (Var var) hs_ctx pat ty match_result        `thenDs` \ match_result' ->
     returnDs (adjustMatchResult (bindNonRec var scrut) match_result')
 \end{code}