Fix Trac #246: order of matching in record patterns
authorsimonpj@microsoft.com <unknown>
Mon, 30 Mar 2009 08:37:36 +0000 (08:37 +0000)
committersimonpj@microsoft.com <unknown>
Mon, 30 Mar 2009 08:37:36 +0000 (08:37 +0000)
While I was looking at the desugaring of pattern matching (fixing
Trac #3126) I finally got around to fixing another long-standing bug:
when matching in a record pattern, GHC should match left-to-right in
the programmer-specfied order, *not* left-to-right positionally in
the original record declaration.

Needless to say, that requires a little more code.
See Note [Record patterns] in MatchCon.lhs

compiler/deSugar/MatchCon.lhs

index bba9d42..2912e29 100644 (file)
@@ -28,9 +28,11 @@ import CoreSyn
 import MkCore
 import DsMonad
 import DsUtils
-import Util    ( takeList )
+import Util    ( all2, takeList, zipEqual )
+import ListSetOps ( runs )
 import Id
-import Var      (TyVar)
+import Var      ( Var )
+import NameEnv
 import SrcLoc
 import Outputable
 \end{code}
@@ -95,22 +97,31 @@ matchConFamily (var:vars) ty groups
   = do { alts <- mapM (matchOneCon vars ty) groups
        ; return (mkCoAlgCaseMatchResult var ty alts) }
 
+type ConArgPats = HsConDetails (LPat Id) (HsRecFields Id (LPat Id))
+
 matchOneCon :: [Id]
             -> Type
             -> [EquationInfo]
-            -> DsM (DataCon, [TyVar], MatchResult)
+            -> DsM (DataCon, [Var], MatchResult)
 matchOneCon vars ty (eqn1 : eqns)      -- All eqns for a single constructor
-  = do { (wraps, eqns') <- mapAndUnzipM shift (eqn1:eqns)
-       ; arg_vars <- selectMatchVars (take (dataConSourceArity con1) 
-                                           (eqn_pats (head eqns')))
-               -- Use the new argument patterns as a source of 
+  = do { arg_vars <- selectConMatchVars arg_tys args1
+               -- Use the first equation as a source of 
                -- suggestions for the new variables
-       ; match_result <- match (arg_vars ++ vars) ty eqns'
+
+       -- Divide into sub-groups; see Note [Record patterns]
+        ; let groups :: [[(ConArgPats, EquationInfo)]]
+             groups = runs compatible_pats [ (pat_args (firstPat eqn), eqn) 
+                                           | eqn <- eqn1:eqns ]
+
+       ; match_results <- mapM (match_group arg_vars) groups
+
        ; return (con1, tvs1 ++ dicts1 ++ arg_vars, 
-                 adjustMatchResult (foldr1 (.) wraps) match_result) }
+                 foldr1 combineMatchResults match_results) }
   where
     ConPatOut { pat_con = L _ con1, pat_ty = pat_ty1,
-               pat_tvs = tvs1, pat_dicts = dicts1 } = firstPat eqn1
+               pat_tvs = tvs1, pat_dicts = dicts1, pat_args = args1 }
+             = firstPat eqn1
+    fields1 = dataConFieldLabels con1
        
     arg_tys  = dataConInstOrigArgTys con1 inst_tys
     inst_tys = tcTyConAppArgs pat_ty1 ++ 
@@ -119,41 +130,102 @@ matchOneCon vars ty (eqn1 : eqns)        -- All eqns for a single constructor
        -- dataConInstOrigArgTys takes the univ and existential tyvars
        -- and returns the types of the *value* args, which is what we want
 
-    shift eqn@(EqnInfo { eqn_pats = ConPatOut{ pat_tvs = tvs, pat_dicts = ds, 
-                                              pat_binds = bind, pat_args = args
-                                             } : pats })
-       = do { prs <- dsLHsBinds bind
-            ; return (wrapBinds (tvs `zip` tvs1) 
-                      . wrapBinds (ds  `zip` dicts1)
-                      . mkCoreLet (Rec prs),
-                      eqn { eqn_pats = conArgPats con1 arg_tys args ++ pats }) }
-
-conArgPats :: DataCon 
-          -> [Type]    -- Instantiated argument types 
+    match_group :: [Id] -> [(ConArgPats, EquationInfo)] -> DsM MatchResult
+    -- All members of the group have compatible ConArgPats
+    match_group arg_vars arg_eqn_prs
+      = do { (wraps, eqns') <- mapAndUnzipM shift arg_eqn_prs
+          ; let group_arg_vars = select_arg_vars arg_vars arg_eqn_prs
+          ; match_result <- match (group_arg_vars ++ vars) ty eqns'
+          ; return (adjustMatchResult (foldr1 (.) wraps) match_result) }
+
+    shift (_, eqn@(EqnInfo { eqn_pats = ConPatOut{ pat_tvs = tvs, pat_dicts = ds, 
+                                                  pat_binds = bind, pat_args = args
+                                       } : pats }))
+      = do { prs <- dsLHsBinds bind
+          ; return (wrapBinds (tvs `zip` tvs1) 
+                   . wrapBinds (ds  `zip` dicts1)
+                   . mkCoreLet (Rec prs),
+                   eqn { eqn_pats = conArgPats arg_tys args ++ pats }) }
+
+    -- Choose the right arg_vars in the right order for this group
+    -- Note [Record patterns]
+    select_arg_vars arg_vars ((arg_pats, _) : _)
+      | RecCon flds <- arg_pats
+      , let rpats = rec_flds flds  
+      , not (null rpats)     -- Treated specially; cf conArgPats
+      = ASSERT2( length fields1 == length arg_vars, 
+                 ppr con1 $$ ppr fields1 $$ ppr arg_vars )
+        map lookup_fld rpats
+      | otherwise
+      = arg_vars
+      where
+        fld_var_env = mkNameEnv $ zipEqual "get_arg_vars" fields1 arg_vars
+       lookup_fld rpat = lookupNameEnv_NF fld_var_env 
+                                          (idName (unLoc (hsRecFieldId rpat)))
+
+-----------------
+compatible_pats :: (ConArgPats,a) -> (ConArgPats,a) -> Bool
+-- Two constructors have compatible argument patterns if the number
+-- and order of sub-matches is the same in both cases
+compatible_pats (RecCon flds1, _) (RecCon flds2, _) = same_fields flds1 flds2
+compatible_pats (RecCon flds1, _) _                 = null (rec_flds flds1)
+compatible_pats _                 (RecCon flds2, _) = null (rec_flds flds2)
+compatible_pats _                 _                 = True -- Prefix or infix con
+
+same_fields :: HsRecFields Id (LPat Id) -> HsRecFields Id (LPat Id) -> Bool
+same_fields flds1 flds2 
+  = all2 (\f1 f2 -> unLoc (hsRecFieldId f1) == unLoc (hsRecFieldId f2))
+        (rec_flds flds1) (rec_flds flds2)
+
+
+-----------------
+selectConMatchVars :: [Type] -> ConArgPats -> DsM [Id]
+selectConMatchVars arg_tys (RecCon {})      = newSysLocalsDs arg_tys
+selectConMatchVars _       (PrefixCon ps)   = selectMatchVars (map unLoc ps)
+selectConMatchVars _       (InfixCon p1 p2) = selectMatchVars [unLoc p1, unLoc p2]
+
+conArgPats :: [Type]   -- Instantiated argument types 
                        -- Used only to fill in the types of WildPats, which
                        -- are probably never looked at anyway
-          -> HsConDetails (LPat Id) (HsRecFields Id (LPat Id))
+          -> ConArgPats
           -> [Pat Id]
-conArgPats _data_con _arg_tys (PrefixCon ps)   = map unLoc ps
-conArgPats _data_con _arg_tys (InfixCon p1 p2) = [unLoc p1, unLoc p2]
-conArgPats  data_con  arg_tys (RecCon (HsRecFields rpats _))
-  | null rpats
-  =    -- Special case for C {}, which can be used for 
-       -- a constructor that isn't declared to have
-       -- fields at all
-    map WildPat arg_tys
-
-  | otherwise
-  = zipWith mk_pat (dataConFieldLabels data_con) arg_tys
-  where
-       -- mk_pat picks a WildPat of the appropriate type for absent fields,
-       -- and the specified pattern for present fields
-    mk_pat lbl arg_ty
-       = case [ pat | HsRecField sel_id pat _ <- rpats, idName (unLoc sel_id) == lbl ] of
-           (pat:pats) -> ASSERT( null pats ) unLoc pat
-           []         -> WildPat arg_ty
+conArgPats _arg_tys (PrefixCon ps)   = map unLoc ps
+conArgPats _arg_tys (InfixCon p1 p2) = [unLoc p1, unLoc p2]
+conArgPats  arg_tys (RecCon (HsRecFields { rec_flds = rpats }))
+  | null rpats = map WildPat arg_tys
+       -- Important special case for C {}, which can be used for a 
+       -- datacon that isn't declared to have fields at all
+  | otherwise  = map (unLoc . hsRecFieldArg) rpats
 \end{code}
 
+Note [Record patterns]
+~~~~~~~~~~~~~~~~~~~~~~
+Consider 
+        data T = T { x,y,z :: Bool }
+
+        f (T { y=True, x=False }) = ...
+
+We must match the patterns IN THE ORDER GIVEN, thus for the first
+one we match y=True before x=False.  See Trac #246; or imagine 
+matching against (T { y=False, x=undefined }): should fail without
+touching the undefined. 
+
+Now consider:
+
+        f (T { y=True, x=False }) = ...
+        f (T { x=True, y= False}) = ...
+
+In the first we must test y first; in the second we must test x 
+first.  So we must divide even the equations for a single constructor
+T into sub-goups, based on whether they match the same field in the
+same order.  That's what the (runs compatible_pats) grouping.
+
+All non-record patterns are "compatible" in this sense, because the
+positional patterns (T a b) and (a `T` b) all match the arguments
+in order.  Also T {} is special because it's equivalent to (T _ _).
+Hence the (null rpats) checks here and there.
+
+
 Note [Existentials in shift_con_pat]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider