Fix Trac #2206: ensure the return type is rigid in a GADT match
authorsimonpj@microsoft.com <unknown>
Thu, 10 Apr 2008 11:15:14 +0000 (11:15 +0000)
committersimonpj@microsoft.com <unknown>
Thu, 10 Apr 2008 11:15:14 +0000 (11:15 +0000)
compiler/typecheck/TcPat.lhs

index 61ee938..f07ce91 100644 (file)
@@ -119,10 +119,11 @@ tc_lam_pats :: PatCtxt
 tc_lam_pats ctxt pat_ty_prs res_ty thing_inside 
   =  do        { let init_state = PS { pat_ctxt = ctxt, pat_eqs = False }
 
-       ; (pats', ex_tvs, res) <- tcMultiple tc_lpat_pr pat_ty_prs init_state $ \ pstate' ->
-                                 if (pat_eqs pstate' && (not $ isRigidTy res_ty))
-                                    then failWithTc (nonRigidResult res_ty)
-                                    else thing_inside res_ty
+       ; (pats', ex_tvs, res) <- do { traceTc (text "tc_lam_pats" <+> (ppr pat_ty_prs $$ ppr res_ty)) 
+                                 ; tcMultiple tc_lpat_pr pat_ty_prs init_state $ \ pstate' ->
+                                   if (pat_eqs pstate' && (not $ isRigidTy res_ty))
+                                    then nonRigidResult res_ty
+                                    else thing_inside res_ty }
 
        ; let tys = map snd pat_ty_prs
        ; tcCheckExistentialPat pats' ex_tvs tys res_ty
@@ -152,8 +153,9 @@ tcCheckExistentialPat pats ex_tvs pat_tys body_ty
 
 data PatState = PS {
        pat_ctxt :: PatCtxt,
-       pat_eqs  :: Bool        -- <=> there are GADT equational constraints 
-                               --     for refinement 
+       pat_eqs  :: Bool        -- <=> there are any equational constraints 
+                               -- Used at the end to say whether the result
+                               -- type must be rigid
   }
 
 data PatCtxt 
@@ -645,12 +647,7 @@ tcConPat pstate con_span data_con tycon pat_ty arg_pats thing_inside
 
          else do   -- The general case, with existential, and local equality 
                     -- constraints
-       { let eq_preds = [mkEqPred (mkTyVarTy tv, ty) | (tv, ty) <- eq_spec]
-             theta'   = substTheta tenv (eq_preds ++ full_theta)
-                           -- order is *important* as we generate the list of
-                           -- dictionary binders from theta'
-             ctxt     = pat_ctxt pstate
-       ; checkTc (case ctxt of { ProcPat -> False; other -> True })
+       { checkTc (case pat_ctxt pstate of { ProcPat -> False; other -> True })
                  (existentialProcPat data_con)
 
           -- Need to test for rigidity if *any* constraints in theta as class
@@ -661,11 +658,19 @@ tcConPat pstate con_span data_con tycon pat_ty arg_pats thing_inside
           -- FIXME: AT THE MOMENT WE CHEAT!  We only perform the rigidity test
           --   if we explicit or implicit (by a GADT def) have equality 
           --   constraints.
-        ; unless (all (not . isEqPred) theta') $
-            checkTc (isRigidTy pat_ty) (nonRigidMatch data_con)
+        ; let eq_preds = [mkEqPred (mkTyVarTy tv, ty) | (tv, ty) <- eq_spec]
+             theta'   = substTheta tenv (eq_preds ++ full_theta)
+                           -- order is *important* as we generate the list of
+                           -- dictionary binders from theta'
+             no_equalities = not (any isEqPred theta')
+             pstate' | no_equalities = pstate
+                     | otherwise     = pstate { pat_eqs = True }
+
+       ; unless no_equalities (checkTc (isRigidTy pat_ty)
+                                        (nonRigidMatch data_con))
 
        ; ((arg_pats', inner_tvs, res), lie_req) <- getLIE $
-               tcConArgs data_con arg_tys' arg_pats pstate thing_inside
+               tcConArgs data_con arg_tys' arg_pats pstate' thing_inside
 
        ; loc <- getInstLoc origin
        ; dicts <- newDictBndrs loc theta'
@@ -1034,8 +1039,12 @@ nonRigidMatch con
        2 (ptext SLIT("Solution: add a type signature"))
 
 nonRigidResult res_ty
-  =  hang (ptext SLIT("GADT pattern match with non-rigid result type") <+> quotes (ppr res_ty))
-       2 (ptext SLIT("Solution: add a type signature"))
+  = do { env0 <- tcInitTidyEnv
+       ; let (env1, res_ty') = tidyOpenType env0 res_ty
+             msg = hang (ptext SLIT("GADT pattern match with non-rigid result type")
+                               <+> quotes (ppr res_ty'))
+                        2 (ptext SLIT("Solution: add a type signature"))
+       ; failWithTcM (env1, msg) }
 
 inaccessibleAlt msg
   = hang (ptext SLIT("Inaccessible case alternative:")) 2 msg