From: simonpj Date: Tue, 1 Mar 2005 05:49:49 +0000 (+0000) Subject: [project @ 2005-03-01 05:49:43 by simonpj] X-Git-Tag: Initial_conversion_from_CVS_complete~1006 X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=6d36af4aff6e12afa50dae2fad3993c385f8081d;hp=b4dae163a4830e1984a656cdf66df957e840c77d [project @ 2005-03-01 05:49:43 by simonpj] Make desugaring of pattern-matching much more civilised. Before this change we wrapped new bindings around the right hand side; but that meant they ended up wrapped in reverse order. Now we accumulate the bindings separately, in the eqn_wrap field of an EqnInfo. This cures a desugaring bug encountered by Akos Korosmezey immortalised as ds055 --- diff --git a/ghc/compiler/deSugar/DsMonad.lhs b/ghc/compiler/deSugar/DsMonad.lhs index 8fecc81..b82a30a 100644 --- a/ghc/compiler/deSugar/DsMonad.lhs +++ b/ghc/compiler/deSugar/DsMonad.lhs @@ -25,7 +25,7 @@ module DsMonad ( -- Data types DsMatchContext(..), - EquationInfo(..), MatchResult(..), + EquationInfo(..), MatchResult(..), DsWrapper, idWrapper, CanItFail(..), orFail ) where @@ -74,9 +74,13 @@ data DsMatchContext deriving () data EquationInfo - = EqnInfo { eqn_pats :: [Pat Id], -- The patterns for an eqn + = EqnInfo { eqn_wrap :: DsWrapper, -- Bindings + eqn_pats :: [Pat Id], -- The patterns for an eqn eqn_rhs :: MatchResult } -- What to do after match +type DsWrapper = CoreExpr -> CoreExpr +idWrapper e = e + -- The semantics of (match vs (EqnInfo wrap pats rhs)) is the MatchResult -- \fail. wrap (case vs of { pats -> rhs fail }) -- where vs are not in the domain of wrap diff --git a/ghc/compiler/deSugar/DsUtils.lhs b/ghc/compiler/deSugar/DsUtils.lhs index 4105c88..671697b 100644 --- a/ghc/compiler/deSugar/DsUtils.lhs +++ b/ghc/compiler/deSugar/DsUtils.lhs @@ -10,16 +10,16 @@ module DsUtils ( EquationInfo(..), firstPat, shiftEqns, - mkDsLet, + mkDsLet, mkDsLets, MatchResult(..), CanItFail(..), cantFailMatchResult, alwaysFailMatchResult, extractMatchResult, combineMatchResults, adjustMatchResult, adjustMatchResultDs, - mkCoLetsMatchResult, mkCoLetMatchResult, + mkCoLetMatchResult, mkGuardedMatchResult, mkCoPrimCaseMatchResult, mkCoAlgCaseMatchResult, - bindInMatchResult, bindOneInMatchResult, + wrapBind, wrapBinds, mkErrorAppDs, mkNilExpr, mkConsExpr, mkListExpr, mkIntExpr, mkCharExpr, @@ -191,13 +191,8 @@ firstPat :: EquationInfo -> Pat Id firstPat eqn = head (eqn_pats eqn) shiftEqns :: [EquationInfo] -> [EquationInfo] --- Drop the outermost layer of the first pattern in each equation -shiftEqns eqns = [ eqn { eqn_pats = shiftPats (eqn_pats eqn) } - | eqn <- eqns ] - -shiftPats :: [Pat Id] -> [Pat Id] -shiftPats (ConPatOut _ _ _ _ (PrefixCon arg_pats) _ : pats) = map unLoc arg_pats ++ pats -shiftPats (pat_with_no_sub_pats : pats) = pats +-- Drop the first pattern in each equation +shiftEqns eqns = [ eqn { eqn_pats = tail (eqn_pats eqn) } | eqn <- eqns ] \end{code} Functions on MatchResults @@ -242,24 +237,16 @@ adjustMatchResultDs encl_fn (MatchResult can_it_fail body_fn) = MatchResult can_it_fail (\fail -> body_fn fail `thenDs` \ body -> encl_fn body) -bindInMatchResult :: [(Var,Var)] -> MatchResult -> MatchResult -bindInMatchResult binds = adjustMatchResult (\e -> foldr bind e binds) - where - bind (new,old) body = bindMR new old body - -bindOneInMatchResult :: Var -> Var -> MatchResult -> MatchResult -bindOneInMatchResult new old = adjustMatchResult (bindMR new old) +wrapBinds :: [(Var,Var)] -> CoreExpr -> CoreExpr +wrapBinds [] e = e +wrapBinds ((new,old):prs) e = wrapBind new old (wrapBinds prs e) -bindMR :: Var -> Var -> CoreExpr -> CoreExpr -bindMR new old body +wrapBind :: Var -> Var -> CoreExpr -> CoreExpr +wrapBind new old body | new==old = body | isTyVar new = App (Lam new body) (Type (mkTyVarTy old)) | otherwise = Let (NonRec new (Var old)) body -mkCoLetsMatchResult :: [CoreBind] -> MatchResult -> MatchResult -mkCoLetsMatchResult binds match_result - = adjustMatchResult (mkDsLets binds) match_result - mkCoLetMatchResult :: CoreBind -> MatchResult -> MatchResult mkCoLetMatchResult bind match_result = adjustMatchResult (mkDsLet bind) match_result @@ -292,7 +279,7 @@ mkCoAlgCaseMatchResult :: Id -- Scrutinee mkCoAlgCaseMatchResult var ty match_alts | isNewTyCon tycon -- Newtype case; use a let = ASSERT( null (tail match_alts) && null (tail arg_ids1) ) - mkCoLetsMatchResult [NonRec arg_id1 newtype_rhs] match_result1 + mkCoLetMatchResult (NonRec arg_id1 newtype_rhs) match_result1 | isPArrFakeAlts match_alts -- Sugared parallel array; use a literal case = MatchResult CanFail mk_parrCase diff --git a/ghc/compiler/deSugar/Match.lhs b/ghc/compiler/deSugar/Match.lhs index ebe503a..43471d8 100644 --- a/ghc/compiler/deSugar/Match.lhs +++ b/ghc/compiler/deSugar/Match.lhs @@ -248,7 +248,7 @@ match [] ty eqns_info returnDs (foldr1 combineMatchResults match_results) where match_results = [ ASSERT( null (eqn_pats eqn) ) - eqn_rhs eqn + adjustMatchResult (eqn_wrap eqn) (eqn_rhs eqn) | eqn <- eqns_info ] \end{code} @@ -357,15 +357,15 @@ tidyEqnInfo :: Id -> EquationInfo -> DsM EquationInfo -- NPlusKPat -- but no other -tidyEqnInfo v eqn@(EqnInfo { eqn_pats = pat : pats, eqn_rhs = rhs }) - = tidy1 v pat rhs `thenDs` \ (pat', rhs') -> - returnDs (eqn { eqn_pats = pat' : pats, eqn_rhs = rhs' }) +tidyEqnInfo v eqn@(EqnInfo { eqn_wrap = wrap, eqn_pats = pat : pats }) + = tidy1 v wrap pat `thenDs` \ (wrap', pat') -> + returnDs (eqn { eqn_wrap = wrap', eqn_pats = pat' : pats }) tidy1 :: Id -- The Id being scrutinised + -> DsWrapper -- Previous wrapping bindings -> Pat Id -- The pattern against which it is to be matched - -> MatchResult -- What to do afterwards - -> DsM (Pat Id, -- Equivalent pattern - MatchResult) -- Extra bindings around what to do afterwards + -> DsM (DsWrapper, -- Extra bindings around what to do afterwards + Pat Id) -- Equivalent pattern -- The extra bindings etc are all wrapped around the RHS of the match -- so they are only available when matching is complete. But that's ok @@ -392,25 +392,24 @@ tidy1 :: Id -- The Id being scrutinised -- NPat -- NPlusKPat -tidy1 v (ParPat pat) wrap = tidy1 v (unLoc pat) wrap -tidy1 v (SigPatOut pat _) wrap = tidy1 v (unLoc pat) wrap -tidy1 v (WildPat ty) wrap = returnDs (WildPat ty, wrap) +tidy1 v wrap (ParPat pat) = tidy1 v wrap (unLoc pat) +tidy1 v wrap (SigPatOut pat _) = tidy1 v wrap (unLoc pat) +tidy1 v wrap (WildPat ty) = returnDs (wrap, WildPat ty) -- case v of { x -> mr[] } -- = case v of { _ -> let x=v in mr[] } -tidy1 v (VarPat var) rhs - = returnDs (WildPat (idType var), bindOneInMatchResult var v rhs) +tidy1 v wrap (VarPat var) + = returnDs (wrap . wrapBind var v, WildPat (idType var)) -tidy1 v (VarPatOut var binds) rhs +tidy1 v wrap (VarPatOut var binds) = do { prs <- dsHsNestedBinds binds - ; return (WildPat (idType var), - bindOneInMatchResult var v $ - mkCoLetMatchResult (Rec prs) rhs) } + ; return (wrap . wrapBind var v . mkDsLet (Rec prs), + WildPat (idType var)) } -- case v of { x@p -> mr[] } -- = case v of { p -> let x=v in mr[] } -tidy1 v (AsPat (L _ var) pat) rhs - = tidy1 v (unLoc pat) (bindOneInMatchResult var v rhs) +tidy1 v wrap (AsPat (L _ var) pat) + = tidy1 v (wrap . wrapBind var v) (unLoc pat) {- now, here we handle lazy patterns: @@ -424,23 +423,22 @@ tidy1 v (AsPat (L _ var) pat) rhs The case expr for v_i is just: match [v] [(p, [], \ x -> Var v_i)] any_expr -} -tidy1 v (LazyPat pat) rhs +tidy1 v wrap (LazyPat pat) = do { v' <- newSysLocalDs (idType v) ; sel_prs <- mkSelectorBinds pat (Var v) ; let sel_binds = [NonRec b rhs | (b,rhs) <- sel_prs] - ; returnDs (WildPat (idType v), - bindOneInMatchResult v' v $ - mkCoLetsMatchResult sel_binds rhs) } + ; returnDs (wrap . wrapBind v' v . mkDsLets sel_binds, + WildPat (idType v)) } -- re-express as (ConPat ...) [directly] -tidy1 v (ConPatOut (L loc con) ex_tvs dicts binds ps pat_ty) rhs - = returnDs (ConPatOut (L loc con) ex_tvs dicts binds tidy_ps pat_ty, rhs) +tidy1 v wrap (ConPatOut (L loc con) ex_tvs dicts binds ps pat_ty) + = returnDs (wrap, ConPatOut (L loc con) ex_tvs dicts binds tidy_ps pat_ty) where tidy_ps = PrefixCon (tidy_con con pat_ty ps) -tidy1 v (ListPat pats ty) rhs - = returnDs (unLoc list_ConPat, rhs) +tidy1 v wrap (ListPat pats ty) + = returnDs (wrap, unLoc list_ConPat) where list_ty = mkListTy ty list_ConPat = foldr (\ x y -> mkPrefixConPat consDataCon [x, y] list_ty) @@ -449,40 +447,40 @@ tidy1 v (ListPat pats ty) rhs -- Introduce fake parallel array constructors to be able to handle parallel -- arrays with the existing machinery for constructor pattern -tidy1 v (PArrPat pats ty) rhs - = returnDs (unLoc parrConPat, rhs) +tidy1 v wrap (PArrPat pats ty) + = returnDs (wrap, unLoc parrConPat) where arity = length pats parrConPat = mkPrefixConPat (parrFakeCon arity) pats (mkPArrTy ty) -tidy1 v (TuplePat pats boxity) rhs - = returnDs (unLoc tuple_ConPat, rhs) +tidy1 v wrap (TuplePat pats boxity) + = returnDs (wrap, unLoc tuple_ConPat) where arity = length pats tuple_ConPat = mkPrefixConPat (tupleCon boxity arity) pats (mkTupleTy boxity arity (map hsPatType pats)) -tidy1 v (DictPat dicts methods) rhs +tidy1 v wrap (DictPat dicts methods) = case num_of_d_and_ms of - 0 -> tidy1 v (TuplePat [] Boxed) rhs - 1 -> tidy1 v (unLoc (head dict_and_method_pats)) rhs - _ -> tidy1 v (TuplePat dict_and_method_pats Boxed) rhs + 0 -> tidy1 v wrap (TuplePat [] Boxed) + 1 -> tidy1 v wrap (unLoc (head dict_and_method_pats)) + _ -> tidy1 v wrap (TuplePat dict_and_method_pats Boxed) where num_of_d_and_ms = length dicts + length methods dict_and_method_pats = map nlVarPat (dicts ++ methods) -- LitPats: we *might* be able to replace these w/ a simpler form -tidy1 v pat@(LitPat lit) rhs - = returnDs (unLoc (tidyLitPat lit (noLoc pat)), rhs) +tidy1 v wrap pat@(LitPat lit) + = returnDs (wrap, unLoc (tidyLitPat lit (noLoc pat))) -- NPats: we *might* be able to replace these w/ a simpler form -tidy1 v pat@(NPatOut lit lit_ty _) rhs - = returnDs (unLoc (tidyNPat lit lit_ty (noLoc pat)), rhs) +tidy1 v wrap pat@(NPatOut lit lit_ty _) + = returnDs (wrap, unLoc (tidyNPat lit lit_ty (noLoc pat))) -- and everything else goes through unchanged... -tidy1 v non_interesting_pat rhs - = returnDs (non_interesting_pat, rhs) +tidy1 v wrap non_interesting_pat + = returnDs (wrap, non_interesting_pat) tidy_con data_con pat_ty (PrefixCon ps) = ps @@ -673,7 +671,8 @@ matchWrapper ctxt (MatchGroup matches match_ty) mk_eqn_info (L _ (Match pats _ grhss)) = do { let upats = map unLoc pats ; match_result <- dsGRHSs ctxt upats grhss rhs_ty - ; return (EqnInfo { eqn_pats = upats, + ; return (EqnInfo { eqn_wrap = idWrapper, + eqn_pats = upats, eqn_rhs = match_result}) } match_fun dflags ds_ctxt @@ -717,7 +716,8 @@ matchSinglePat :: CoreExpr -> DsMatchContext -> LPat Id -> Type -> MatchResult -> DsM MatchResult matchSinglePat (Var var) ctx pat ty match_result = getDOptsDs `thenDs` \ dflags -> - match_fn dflags [var] ty [EqnInfo { eqn_pats = [unLoc pat], + match_fn dflags [var] ty [EqnInfo { eqn_wrap = idWrapper, + eqn_pats = [unLoc pat], eqn_rhs = match_result }] where match_fn dflags diff --git a/ghc/compiler/deSugar/MatchCon.lhs b/ghc/compiler/deSugar/MatchCon.lhs index c7e2b93..3787265 100644 --- a/ghc/compiler/deSugar/MatchCon.lhs +++ b/ghc/compiler/deSugar/MatchCon.lhs @@ -106,34 +106,27 @@ wouldn't). Cf.~@shift_lit_pats@ in @MatchLits@. match_con vars ty eqns = do { -- Make new vars for the con arguments; avoid new locals where possible arg_vars <- selectMatchVars (map unLoc arg_pats1) arg_tys - - ; match_result <- match (arg_vars ++ vars) ty (shiftEqns eqns) - - ; binds <- mapM ds_binds [ bind | ConPatOut _ _ _ bind _ _ <- pats, - not (isEmptyLHsBinds bind) ] - - ; let match_result' = bindInMatchResult (line_up other_pats) $ - mkCoLetsMatchResult binds match_result - - ; return (data_con, tvs1 ++ dicts1 ++ arg_vars, match_result') } + ; eqns' <- mapM shift eqns + ; match_result <- match (arg_vars ++ vars) ty eqns' + ; return (con, tvs1 ++ dicts1 ++ arg_vars, match_result) } where - pats@(pat1 : other_pats) = map firstPat eqns - ConPatOut (L _ data_con) tvs1 dicts1 _ (PrefixCon arg_pats1) pat_ty = pat1 - - ds_binds bind = do { prs <- dsHsNestedBinds bind; return (Rec prs) } + ConPatOut (L _ con) tvs1 dicts1 _ (PrefixCon arg_pats1) pat_ty = firstPat (head eqns) - line_up pats - | null tvs1 && null dicts1 = [] -- Common case - | otherwise = [ pr | ConPatOut _ ts ds _ _ _ <- pats, - pr <- (ts `zip` tvs1) ++ (ds `zip` dicts1)] + shift eqn@(EqnInfo { eqn_wrap = wrap, + eqn_pats = ConPatOut _ tvs ds bind (PrefixCon arg_pats) _ : pats }) + = do { prs <- dsHsNestedBinds bind + ; return (eqn { eqn_wrap = wrap . wrapBinds (tvs `zip` tvs1) + . wrapBinds (ds `zip` dicts1) + . mkDsLet (Rec prs), + eqn_pats = map unLoc arg_pats ++ pats }) } -- Get the arg types, which we use to type the new vars -- to match on, from the "outside"; the types of pats1 may -- be more refined, and hence won't do - arg_tys = substTys (zipTopTvSubst (dataConTyVars data_con) inst_tys) - (dataConOrigArgTys data_con) - inst_tys | isVanillaDataCon data_con = tcTyConAppArgs pat_ty -- Newtypes opaque! - | otherwise = mkTyVarTys tvs1 + arg_tys = substTys (zipTopTvSubst (dataConTyVars con) inst_tys) + (dataConOrigArgTys con) + inst_tys | isVanillaDataCon con = tcTyConAppArgs pat_ty -- Newtypes opaque! + | otherwise = mkTyVarTys tvs1 \end{code} Note [Existentials in shift_con_pat] diff --git a/ghc/compiler/deSugar/MatchLit.lhs b/ghc/compiler/deSugar/MatchLit.lhs index 75a0a62..5ca0569 100644 --- a/ghc/compiler/deSugar/MatchLit.lhs +++ b/ghc/compiler/deSugar/MatchLit.lhs @@ -167,12 +167,16 @@ matchNPats (var:vars) ty eqns return (foldr1 combineMatchResults match_results) } where match_group :: [EquationInfo] -> DsM MatchResult - match_group eqns + match_group (eqn1:eqns) = do { pred_expr <- dsExpr (HsApp (noLoc eq_chk) (nlHsVar var)) - ; match_result <- match vars ty (shiftEqns eqns) - ; return (mkGuardedMatchResult pred_expr match_result) } + ; match_result <- match vars ty (eqn1' : shiftEqns eqns) + ; return (adjustMatchResult (eqn_wrap eqn1) $ + -- Bring the eqn1 wrapper stuff into scope because + -- it may be used in pred_expr + mkGuardedMatchResult pred_expr match_result) } where - NPatOut _ _ eq_chk = firstPat (head eqns) + NPatOut _ _ eq_chk : pats1 = eqn_pats eqn1 + eqn1' = eqn1 { eqn_wrap = idWrapper, eqn_pats = pats1 } \end{code} @@ -216,17 +220,23 @@ matchNPlusKPats all_vars@(var:vars) ty eqns return (foldr1 combineMatchResults match_results) } where match_group :: [EquationInfo] -> DsM MatchResult - match_group eqns + match_group (eqn1:eqns) = do { ge_expr <- dsExpr (HsApp (noLoc ge) (nlHsVar var)) ; minusk_expr <- dsExpr (HsApp (noLoc sub) (nlHsVar var)) - ; match_result <- match vars ty (shiftEqns eqns) - ; return (mkGuardedMatchResult ge_expr $ - mkCoLetsMatchResult [NonRec n1 minusk_expr] $ - bindInMatchResult (map line_up other_pats) $ + ; match_result <- match vars ty (eqn1' : map shift eqns) + ; return (adjustMatchResult (eqn_wrap eqn1) $ + -- Bring the eqn1 wrapper stuff into scope because + -- it may be used in ge_expr, minusk_expr + mkGuardedMatchResult ge_expr $ + mkCoLetMatchResult (NonRec n1 minusk_expr) $ match_result) } where - (NPlusKPatOut (L _ n1) _ ge sub : other_pats) = map firstPat eqns - line_up (NPlusKPatOut (L _ n) _ _ _) = (n,n1) + NPlusKPatOut (L _ n1) _ ge sub : pats1 = eqn_pats eqn1 + eqn1' = eqn1 { eqn_wrap = idWrapper, eqn_pats = pats1 } + + shift eqn@(EqnInfo { eqn_wrap = wrap, + eqn_pats = NPlusKPatOut (L _ n) _ _ _ : pats }) + = eqn { eqn_wrap = wrap . wrapBind n n1, eqn_pats = pats } \end{code}