Refactor part of the renamer to fix Trac #3901
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index 6d917d1..cbe5940 100644 (file)
@@ -392,7 +392,7 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside
                      ; return (ids, pairs', thing) }
           ; return ( (stmts', ids) : pairs', thing ) }
 
-tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty thing_inside = do
+tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr) elt_ty thing_inside = do
     (stmts', (binders', usingExpr', maybeByExpr', thing)) <- 
         tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
             let alphaListTy = mkTyConApp m_tc [alphaTy]
@@ -414,46 +414,47 @@ tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty
             
             return (binders', usingExpr', maybeByExpr', thing)
 
-    return (TransformStmt (stmts', binders') usingExpr' maybeByExpr', thing)
+    return (TransformStmt stmts' binders' usingExpr' maybeByExpr', thing)
 
-tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_inside = do
-        (stmts', (bindersMap', groupByClause', thing)) <-
+tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using) elt_ty thing_inside
+  = do { let (bndr_names, list_bndr_names) = unzip bindersMap
+
+       ; (stmts', (bndr_ids, by', using_ty, elt_ty')) <-
             tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
-                let alphaListTy = mkTyConApp m_tc [alphaTy]
-                    alphaListListTy = mkTyConApp m_tc [alphaListTy]
-            
-                groupByClause' <- 
-                    case groupByClause of
-                        GroupByNothing usingExpr ->
-                            -- We must validate that usingExpr :: forall a. [a] -> [[a]]
-                            tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing)
-                        GroupBySomething eitherUsingExpr byExpr -> do
-                            -- We must infer a type such that byExpr :: t
-                            (byExpr', tTy) <- tcInferRhoNC byExpr
-                            
-                            -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]]
-                            let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy))
-                            eitherUsingExpr' <- 
-                                case eitherUsingExpr of
-                                    Left usingExpr  -> (tcPolyExpr usingExpr expectedUsingType) >>= (return . Left)
-                                    Right usingExpr -> (tcPolyExpr (noLoc usingExpr) expectedUsingType) >>= (return . Right . unLoc)
-                            return $ GroupBySomething eitherUsingExpr' byExpr'
-            
-                -- Find the IDs and types of all old binders
-                let (oldBinders, newBinders) = unzip bindersMap
-                oldBinders' <- tcLookupLocalIds oldBinders
+               (by', using_ty) <- case by of
+                                     Nothing   -> -- check that using :: forall a. [a] -> [[a]]
+                                                  return (Nothing, mkForAllTy alphaTyVar $
+                                                                   alphaListTy `mkFunTy` alphaListListTy)
+                                                       
+                                    Just by_e -> -- check that using :: forall a. (a -> t) -> [a] -> [[a]]
+                                                 -- where by :: t
+                                                  do { (by_e', t_ty) <- tcInferRhoNC by_e
+                                                     ; return (Just by_e', mkForAllTy alphaTyVar $
+                                                                           (alphaTy `mkFunTy` t_ty) 
+                                                                              `mkFunTy` alphaListTy 
+                                                                              `mkFunTy` alphaListListTy) }
+                -- Find the Ids (and hence types) of all old binders
+                bndr_ids <- tcLookupLocalIds bndr_names
                 
+                return (bndr_ids, by', using_ty, elt_ty')
+        
                 -- Ensure that every old binder of type b is linked up with its new binder which should have type [b]
-                let newBinders' = zipWith associateNewBinder oldBinders' newBinders
+       ; let list_bndr_ids = zipWith mk_list_bndr list_bndr_names bndr_ids
+             bindersMap' = bndr_ids `zip` list_bndr_ids
             
-                -- Type check the thing in the environment with these new binders and return the result
-                thing <- tcExtendIdEnv newBinders' (thing_inside elt_ty')
-                return (zipEqual "tcLcStmt: Old and new binder lists were not of the same length" oldBinders' newBinders', groupByClause', thing)
-        
-        return (GroupStmt (stmts', bindersMap') groupByClause', thing)
-    where
-        associateNewBinder :: TcId -> Name -> TcId
-        associateNewBinder oldBinder newBinder = mkLocalId newBinder (mkTyConApp m_tc [idType oldBinder])
+       ; using' <- case using of
+                     Left  e -> do { e' <- tcPolyExpr e         using_ty; return (Left  e') }
+                     Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) }
+
+             -- Type check the thing in the environment with these new binders and return the result
+       ; thing <- tcExtendIdEnv list_bndr_ids (thing_inside elt_ty')
+       ; return (GroupStmt stmts' bindersMap' by' using', thing) }
+  where
+    alphaListTy = mkTyConApp m_tc [alphaTy]
+    alphaListListTy = mkTyConApp m_tc [alphaListTy]
+            
+    mk_list_bndr :: Name -> TcId -> TcId
+    mk_list_bndr list_bndr_name bndr_id = mkLocalId list_bndr_name (mkTyConApp m_tc [idType bndr_id])
     
 tcLcStmt _ _ stmt _ _
   = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)