[project @ 2005-03-01 05:49:43 by simonpj]
authorsimonpj <unknown>
Tue, 1 Mar 2005 05:49:49 +0000 (05:49 +0000)
committersimonpj <unknown>
Tue, 1 Mar 2005 05:49:49 +0000 (05:49 +0000)
Make desugaring of pattern-matching much more civilised.

Before this change we wrapped new bindings around the
right hand side; but that meant they ended up wrapped
in reverse order.  Now we accumulate the bindings
separately, in the eqn_wrap field of an EqnInfo.

This cures a desugaring bug encountered by Akos Korosmezey
immortalised as ds055

ghc/compiler/deSugar/DsMonad.lhs
ghc/compiler/deSugar/DsUtils.lhs
ghc/compiler/deSugar/Match.lhs
ghc/compiler/deSugar/MatchCon.lhs
ghc/compiler/deSugar/MatchLit.lhs

index 8fecc81..b82a30a 100644 (file)
@@ -25,7 +25,7 @@ module DsMonad (
 
        -- Data types
        DsMatchContext(..),
-       EquationInfo(..), MatchResult(..), 
+       EquationInfo(..), MatchResult(..), DsWrapper, idWrapper,
        CanItFail(..), orFail
     ) where
 
@@ -74,9 +74,13 @@ data DsMatchContext
   deriving ()
 
 data EquationInfo
-  = EqnInfo { eqn_pats :: [Pat Id],            -- The patterns for an eqn
+  = EqnInfo { eqn_wrap :: DsWrapper,   -- Bindings
+             eqn_pats :: [Pat Id],     -- The patterns for an eqn
              eqn_rhs  :: MatchResult } -- What to do after match
 
+type DsWrapper = CoreExpr -> CoreExpr
+idWrapper e = e
+
 -- The semantics of (match vs (EqnInfo wrap pats rhs)) is the MatchResult
 --     \fail. wrap (case vs of { pats -> rhs fail })
 -- where vs are not in the domain of wrap
index 4105c88..671697b 100644 (file)
@@ -10,16 +10,16 @@ module DsUtils (
        EquationInfo(..), 
        firstPat, shiftEqns,
 
-       mkDsLet,
+       mkDsLet, mkDsLets,
 
        MatchResult(..), CanItFail(..), 
        cantFailMatchResult, alwaysFailMatchResult,
        extractMatchResult, combineMatchResults, 
        adjustMatchResult,  adjustMatchResultDs,
-       mkCoLetsMatchResult, mkCoLetMatchResult,
+       mkCoLetMatchResult,
        mkGuardedMatchResult, 
        mkCoPrimCaseMatchResult, mkCoAlgCaseMatchResult,
-       bindInMatchResult, bindOneInMatchResult,
+       wrapBind, wrapBinds,
 
        mkErrorAppDs, mkNilExpr, mkConsExpr, mkListExpr,
        mkIntExpr, mkCharExpr,
@@ -191,13 +191,8 @@ firstPat :: EquationInfo -> Pat Id
 firstPat eqn = head (eqn_pats eqn)
 
 shiftEqns :: [EquationInfo] -> [EquationInfo]
--- Drop the outermost layer of the first pattern in each equation
-shiftEqns eqns = [ eqn { eqn_pats = shiftPats (eqn_pats eqn) }
-                | eqn <- eqns ]
-
-shiftPats :: [Pat Id] -> [Pat Id]
-shiftPats (ConPatOut _ _ _ _ (PrefixCon arg_pats) _ : pats) = map unLoc arg_pats ++ pats
-shiftPats (pat_with_no_sub_pats                            : pats) = pats
+-- Drop the first pattern in each equation
+shiftEqns eqns = [ eqn { eqn_pats = tail (eqn_pats eqn) } | eqn <- eqns ]
 \end{code}
 
 Functions on MatchResults
@@ -242,24 +237,16 @@ adjustMatchResultDs encl_fn (MatchResult can_it_fail body_fn)
   = MatchResult can_it_fail (\fail -> body_fn fail     `thenDs` \ body ->
                                      encl_fn body)
 
-bindInMatchResult :: [(Var,Var)] -> MatchResult -> MatchResult
-bindInMatchResult binds = adjustMatchResult (\e -> foldr bind e binds)
-  where
-    bind (new,old) body = bindMR new old body
-
-bindOneInMatchResult :: Var -> Var -> MatchResult -> MatchResult
-bindOneInMatchResult new old = adjustMatchResult (bindMR new old)
+wrapBinds :: [(Var,Var)] -> CoreExpr -> CoreExpr
+wrapBinds [] e = e
+wrapBinds ((new,old):prs) e = wrapBind new old (wrapBinds prs e)
 
-bindMR :: Var -> Var -> CoreExpr -> CoreExpr
-bindMR new old body
+wrapBind :: Var -> Var -> CoreExpr -> CoreExpr
+wrapBind new old body
   | new==old    = body
   | isTyVar new = App (Lam new body) (Type (mkTyVarTy old))
   | otherwise   = Let (NonRec new (Var old)) body
 
-mkCoLetsMatchResult :: [CoreBind] -> MatchResult -> MatchResult
-mkCoLetsMatchResult binds match_result
-  = adjustMatchResult (mkDsLets binds) match_result
-
 mkCoLetMatchResult :: CoreBind -> MatchResult -> MatchResult
 mkCoLetMatchResult bind match_result
   = adjustMatchResult (mkDsLet bind) match_result
@@ -292,7 +279,7 @@ mkCoAlgCaseMatchResult :: Id                                        -- Scrutinee
 mkCoAlgCaseMatchResult var ty match_alts 
   | isNewTyCon tycon           -- Newtype case; use a let
   = ASSERT( null (tail match_alts) && null (tail arg_ids1) )
-    mkCoLetsMatchResult [NonRec arg_id1 newtype_rhs] match_result1
+    mkCoLetMatchResult (NonRec arg_id1 newtype_rhs) match_result1
 
   | isPArrFakeAlts match_alts  -- Sugared parallel array; use a literal case 
   = MatchResult CanFail mk_parrCase
index ebe503a..43471d8 100644 (file)
@@ -248,7 +248,7 @@ match [] ty eqns_info
     returnDs (foldr1 combineMatchResults match_results)
   where
     match_results = [ ASSERT( null (eqn_pats eqn) ) 
-                     eqn_rhs eqn
+                     adjustMatchResult (eqn_wrap eqn) (eqn_rhs eqn)
                    | eqn <- eqns_info ]
 \end{code}
 
@@ -357,15 +357,15 @@ tidyEqnInfo :: Id -> EquationInfo -> DsM EquationInfo
        --      NPlusKPat
        -- but no other
 
-tidyEqnInfo v eqn@(EqnInfo { eqn_pats = pat : pats, eqn_rhs = rhs })
-  = tidy1 v pat rhs    `thenDs` \ (pat', rhs') ->
-    returnDs (eqn { eqn_pats = pat' : pats, eqn_rhs = rhs' })
+tidyEqnInfo v eqn@(EqnInfo { eqn_wrap = wrap, eqn_pats = pat : pats })
+  = tidy1 v wrap pat   `thenDs` \ (wrap', pat') ->
+    returnDs (eqn { eqn_wrap = wrap', eqn_pats = pat' : pats })
 
 tidy1 :: Id                    -- The Id being scrutinised
+      -> DsWrapper             -- Previous wrapping bindings
       -> Pat Id                -- The pattern against which it is to be matched
-      -> MatchResult           -- What to do afterwards
-      -> DsM (Pat Id,          -- Equivalent pattern
-             MatchResult)      -- Extra bindings around what to do afterwards
+      -> DsM (DsWrapper,       -- Extra bindings around what to do afterwards
+             Pat Id)           -- Equivalent pattern
 
 -- The extra bindings etc are all wrapped around the RHS of the match
 -- so they are only available when matching is complete.  But that's ok
@@ -392,25 +392,24 @@ tidy1 :: Id                       -- The Id being scrutinised
 --     NPat
 --     NPlusKPat
 
-tidy1 v (ParPat pat)      wrap = tidy1 v (unLoc pat) wrap 
-tidy1 v (SigPatOut pat _) wrap = tidy1 v (unLoc pat) wrap 
-tidy1 v (WildPat ty)      wrap = returnDs (WildPat ty, wrap)
+tidy1 v wrap (ParPat pat)      = tidy1 v wrap (unLoc pat) 
+tidy1 v wrap (SigPatOut pat _) = tidy1 v wrap (unLoc pat) 
+tidy1 v wrap (WildPat ty)      = returnDs (wrap, WildPat ty)
 
        -- case v of { x -> mr[] }
        -- = case v of { _ -> let x=v in mr[] }
-tidy1 v (VarPat var) rhs
-  = returnDs (WildPat (idType var), bindOneInMatchResult var v rhs)
+tidy1 v wrap (VarPat var)
+  = returnDs (wrap . wrapBind var v, WildPat (idType var)) 
 
-tidy1 v (VarPatOut var binds) rhs
+tidy1 v wrap (VarPatOut var binds)
   = do { prs <- dsHsNestedBinds binds
-       ; return (WildPat (idType var), 
-                 bindOneInMatchResult var v $
-                 mkCoLetMatchResult (Rec prs) rhs) }
+       ; return (wrap . wrapBind var v . mkDsLet (Rec prs),
+                 WildPat (idType var)) }
 
        -- case v of { x@p -> mr[] }
        -- = case v of { p -> let x=v in mr[] }
-tidy1 v (AsPat (L _ var) pat) rhs
-  = tidy1 v (unLoc pat) (bindOneInMatchResult var v rhs)
+tidy1 v wrap (AsPat (L _ var) pat)
+  = tidy1 v (wrap . wrapBind var v) (unLoc pat)
 
 
 {- now, here we handle lazy patterns:
@@ -424,23 +423,22 @@ tidy1 v (AsPat (L _ var) pat) rhs
     The case expr for v_i is just: match [v] [(p, [], \ x -> Var v_i)] any_expr
 -}
 
-tidy1 v (LazyPat pat) rhs
+tidy1 v wrap (LazyPat pat)
   = do { v' <- newSysLocalDs (idType v)
        ; sel_prs <- mkSelectorBinds pat (Var v)
        ; let sel_binds =  [NonRec b rhs | (b,rhs) <- sel_prs]
-       ; returnDs (WildPat (idType v), 
-                   bindOneInMatchResult v' v $
-                   mkCoLetsMatchResult sel_binds rhs) }
+       ; returnDs (wrap . wrapBind v' v . mkDsLets sel_binds,
+                   WildPat (idType v)) }
 
 -- re-express <con-something> as (ConPat ...) [directly]
 
-tidy1 v (ConPatOut (L loc con) ex_tvs dicts binds ps pat_ty) rhs
-  = returnDs (ConPatOut (L loc con) ex_tvs dicts binds tidy_ps pat_ty, rhs)
+tidy1 v wrap (ConPatOut (L loc con) ex_tvs dicts binds ps pat_ty)
+  = returnDs (wrap, ConPatOut (L loc con) ex_tvs dicts binds tidy_ps pat_ty)
   where
     tidy_ps = PrefixCon (tidy_con con pat_ty ps)
 
-tidy1 v (ListPat pats ty) rhs
-  = returnDs (unLoc list_ConPat, rhs)
+tidy1 v wrap (ListPat pats ty)
+  = returnDs (wrap, unLoc list_ConPat)
   where
     list_ty     = mkListTy ty
     list_ConPat = foldr (\ x y -> mkPrefixConPat consDataCon [x, y] list_ty)
@@ -449,40 +447,40 @@ tidy1 v (ListPat pats ty) rhs
 
 -- Introduce fake parallel array constructors to be able to handle parallel
 -- arrays with the existing machinery for constructor pattern
-tidy1 v (PArrPat pats ty) rhs
-  = returnDs (unLoc parrConPat, rhs)
+tidy1 v wrap (PArrPat pats ty)
+  = returnDs (wrap, unLoc parrConPat)
   where
     arity      = length pats
     parrConPat = mkPrefixConPat (parrFakeCon arity) pats (mkPArrTy ty)
 
-tidy1 v (TuplePat pats boxity) rhs
-  = returnDs (unLoc tuple_ConPat, rhs)
+tidy1 v wrap (TuplePat pats boxity)
+  = returnDs (wrap, unLoc tuple_ConPat)
   where
     arity = length pats
     tuple_ConPat = mkPrefixConPat (tupleCon boxity arity) pats
                                  (mkTupleTy boxity arity (map hsPatType pats))
 
-tidy1 v (DictPat dicts methods) rhs
+tidy1 v wrap (DictPat dicts methods)
   = case num_of_d_and_ms of
-       0 -> tidy1 v (TuplePat [] Boxed) rhs
-       1 -> tidy1 v (unLoc (head dict_and_method_pats)) rhs
-       _ -> tidy1 v (TuplePat dict_and_method_pats Boxed) rhs
+       0 -> tidy1 v wrap (TuplePat [] Boxed) 
+       1 -> tidy1 v wrap (unLoc (head dict_and_method_pats))
+       _ -> tidy1 v wrap (TuplePat dict_and_method_pats Boxed)
   where
     num_of_d_and_ms     = length dicts + length methods
     dict_and_method_pats = map nlVarPat (dicts ++ methods)
 
 -- LitPats: we *might* be able to replace these w/ a simpler form
-tidy1 v pat@(LitPat lit) rhs
-  = returnDs (unLoc (tidyLitPat lit (noLoc pat)), rhs)
+tidy1 v wrap pat@(LitPat lit)
+  = returnDs (wrap, unLoc (tidyLitPat lit (noLoc pat)))
 
 -- NPats: we *might* be able to replace these w/ a simpler form
-tidy1 v pat@(NPatOut lit lit_ty _) rhs
-  = returnDs (unLoc (tidyNPat lit lit_ty (noLoc pat)), rhs)
+tidy1 v wrap pat@(NPatOut lit lit_ty _)
+  = returnDs (wrap, unLoc (tidyNPat lit lit_ty (noLoc pat)))
 
 -- and everything else goes through unchanged...
 
-tidy1 v non_interesting_pat rhs
-  = returnDs (non_interesting_pat, rhs)
+tidy1 v wrap non_interesting_pat
+  = returnDs (wrap, non_interesting_pat)
 
 
 tidy_con data_con pat_ty (PrefixCon ps)   = ps
@@ -673,7 +671,8 @@ matchWrapper ctxt (MatchGroup matches match_ty)
     mk_eqn_info (L _ (Match pats _ grhss))
       = do { let upats = map unLoc pats
           ; match_result <- dsGRHSs ctxt upats grhss rhs_ty
-          ; return (EqnInfo { eqn_pats = upats, 
+          ; return (EqnInfo { eqn_wrap = idWrapper,
+                              eqn_pats = upats, 
                               eqn_rhs = match_result}) }
 
     match_fun dflags ds_ctxt
@@ -717,7 +716,8 @@ matchSinglePat :: CoreExpr -> DsMatchContext -> LPat Id
               -> Type -> MatchResult -> DsM MatchResult
 matchSinglePat (Var var) ctx pat ty match_result
   = getDOptsDs                                 `thenDs` \ dflags ->
-    match_fn dflags [var] ty [EqnInfo { eqn_pats = [unLoc pat],
+    match_fn dflags [var] ty [EqnInfo { eqn_wrap = idWrapper,
+                                       eqn_pats = [unLoc pat],
                                        eqn_rhs  = match_result }]
   where
     match_fn dflags
index c7e2b93..3787265 100644 (file)
@@ -106,34 +106,27 @@ wouldn't).  Cf.~@shift_lit_pats@ in @MatchLits@.
 match_con vars ty eqns
   = do { -- Make new vars for the con arguments; avoid new locals where possible
          arg_vars <- selectMatchVars (map unLoc arg_pats1) arg_tys
-
-       ; match_result <- match (arg_vars ++ vars) ty (shiftEqns eqns)
-
-       ; binds <- mapM ds_binds [ bind | ConPatOut _ _ _ bind _ _ <- pats,
-                                         not (isEmptyLHsBinds bind) ]
-
-       ; let match_result' = bindInMatchResult (line_up other_pats) $
-                             mkCoLetsMatchResult binds match_result
-       
-       ; return (data_con, tvs1 ++ dicts1 ++ arg_vars, match_result') }
+       ; eqns' <- mapM shift eqns 
+       ; match_result <- match (arg_vars ++ vars) ty eqns'
+       ; return (con, tvs1 ++ dicts1 ++ arg_vars, match_result) }
   where
-    pats@(pat1 : other_pats) = map firstPat eqns
-    ConPatOut (L _ data_con) tvs1 dicts1 _ (PrefixCon arg_pats1) pat_ty = pat1
-
-    ds_binds bind = do { prs <- dsHsNestedBinds bind; return (Rec prs) }
+    ConPatOut (L _ con) tvs1 dicts1 _ (PrefixCon arg_pats1) pat_ty = firstPat (head eqns)
 
-    line_up pats 
-       | null tvs1 && null dicts1 = []         -- Common case
-       | otherwise = [ pr | ConPatOut _ ts ds _ _ _ <- pats,
-                            pr <- (ts `zip` tvs1) ++ (ds `zip` dicts1)]
+    shift eqn@(EqnInfo { eqn_wrap = wrap, 
+                        eqn_pats = ConPatOut _ tvs ds bind (PrefixCon arg_pats) _ : pats })
+       = do { prs <- dsHsNestedBinds bind
+            ; return (eqn { eqn_wrap = wrap . wrapBinds (tvs `zip` tvs1) 
+                                            . wrapBinds (ds  `zip` dicts1)
+                                            . mkDsLet (Rec prs),
+                            eqn_pats = map unLoc arg_pats ++ pats }) }
 
        -- Get the arg types, which we use to type the new vars
        -- to match on, from the "outside"; the types of pats1 may 
        -- be more refined, and hence won't do
-    arg_tys = substTys (zipTopTvSubst (dataConTyVars data_con) inst_tys)
-                      (dataConOrigArgTys data_con)
-    inst_tys | isVanillaDataCon data_con = tcTyConAppArgs pat_ty       -- Newtypes opaque!
-            | otherwise                 = mkTyVarTys tvs1
+    arg_tys = substTys (zipTopTvSubst (dataConTyVars con) inst_tys)
+                      (dataConOrigArgTys con)
+    inst_tys | isVanillaDataCon con = tcTyConAppArgs pat_ty    -- Newtypes opaque!
+            | otherwise            = mkTyVarTys tvs1
 \end{code}
 
 Note [Existentials in shift_con_pat]
index 75a0a62..5ca0569 100644 (file)
@@ -167,12 +167,16 @@ matchNPats (var:vars) ty eqns
          return (foldr1 combineMatchResults match_results) }
   where
     match_group :: [EquationInfo] -> DsM MatchResult
-    match_group eqns
+    match_group (eqn1:eqns)
        = do { pred_expr <- dsExpr (HsApp (noLoc eq_chk) (nlHsVar var))
-            ; match_result <- match vars ty (shiftEqns eqns)
-            ; return (mkGuardedMatchResult pred_expr match_result) }
+            ; match_result <- match vars ty (eqn1' : shiftEqns eqns)
+            ; return (adjustMatchResult (eqn_wrap eqn1) $
+                       -- Bring the eqn1 wrapper stuff into scope because
+                       -- it may be used in pred_expr
+                      mkGuardedMatchResult pred_expr match_result) }
        where
-         NPatOut _ _ eq_chk = firstPat (head eqns)
+         NPatOut _ _ eq_chk : pats1 = eqn_pats eqn1
+         eqn1' = eqn1 { eqn_wrap = idWrapper, eqn_pats = pats1 }
 \end{code}
 
 
@@ -216,17 +220,23 @@ matchNPlusKPats all_vars@(var:vars) ty eqns
          return (foldr1 combineMatchResults match_results) }
   where
     match_group :: [EquationInfo] -> DsM MatchResult
-    match_group eqns
+    match_group (eqn1:eqns)
        = do { ge_expr      <- dsExpr (HsApp (noLoc ge)  (nlHsVar var))
             ; minusk_expr  <- dsExpr (HsApp (noLoc sub) (nlHsVar var))
-            ; match_result <- match vars ty (shiftEqns eqns)
-            ; return  (mkGuardedMatchResult ge_expr                 $
-                       mkCoLetsMatchResult [NonRec n1 minusk_expr]  $
-                       bindInMatchResult (map line_up other_pats)   $
+            ; match_result <- match vars ty (eqn1' : map shift eqns)
+            ; return  (adjustMatchResult (eqn_wrap eqn1)            $
+                       -- Bring the eqn1 wrapper stuff into scope because
+                       -- it may be used in ge_expr, minusk_expr
+                       mkGuardedMatchResult ge_expr                $
+                       mkCoLetMatchResult (NonRec n1 minusk_expr)  $
                        match_result) }
        where
-         (NPlusKPatOut (L _ n1) _ ge sub : other_pats) = map firstPat eqns 
-         line_up (NPlusKPatOut (L _ n) _ _ _) = (n,n1)
+         NPlusKPatOut (L _ n1) _ ge sub : pats1 = eqn_pats eqn1
+         eqn1' = eqn1 { eqn_wrap = idWrapper, eqn_pats = pats1 }
+
+         shift eqn@(EqnInfo { eqn_wrap = wrap,
+                              eqn_pats = NPlusKPatOut (L _ n) _ _ _ : pats })
+           = eqn { eqn_wrap = wrap . wrapBind n n1, eqn_pats = pats }  
 \end{code}