Implement generalised list comprehensions
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index da1d0e0..e07e6da 100644 (file)
@@ -39,8 +39,12 @@ import TysWiredIn
 import PrelNames
 import Id
 import TyCon
 import PrelNames
 import Id
 import TyCon
+import TysPrim
 import Outputable
 import Outputable
+import Util
 import SrcLoc
 import SrcLoc
+
+import Control.Monad( liftM )
 \end{code}
 
 %************************************************************************
 \end{code}
 
 %************************************************************************
@@ -391,9 +395,72 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside
                      ; return (ids, pairs', thing) }
           ; return ( (stmts', ids) : pairs', thing ) }
 
                      ; return (ids, pairs', thing) }
           ; return ( (stmts', ids) : pairs', thing ) }
 
+tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty thing_inside = do
+    (stmts', (binders', usingExpr', maybeByExpr', thing)) <- 
+        tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+            let alphaListTy = mkTyConApp m_tc [alphaTy]
+                    
+            (usingExpr', maybeByExpr') <- 
+                case maybeByExpr of
+                    Nothing -> do
+                        -- We must validate that usingExpr :: forall a. [a] -> [a]
+                        usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListTy))
+                        return (usingExpr', Nothing)
+                    Just byExpr -> do
+                        -- We must infer a type such that e :: t and then check that usingExpr :: forall a. (a -> t) -> [a] -> [a]
+                        (byExpr', tTy) <- tcInferRho byExpr
+                        usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListTy)))
+                        return (usingExpr', Just byExpr')
+            
+            binders' <- tcLookupLocalIds binders
+            thing <- thing_inside elt_ty'
+            
+            return (binders', usingExpr', maybeByExpr', thing)
+
+    return (TransformStmt (stmts', binders') usingExpr' maybeByExpr', thing)
+
+tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_inside = do
+        (stmts', (bindersMap', groupByClause', thing)) <-
+            tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+                let alphaListTy = mkTyConApp m_tc [alphaTy]
+                    alphaListListTy = mkTyConApp m_tc [alphaListTy]
+            
+                groupByClause' <- 
+                    case groupByClause of
+                        GroupByNothing usingExpr ->
+                            -- We must validate that usingExpr :: forall a. [a] -> [[a]]
+                            tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing)
+                        GroupBySomething eitherUsingExpr byExpr -> do
+                            -- We must infer a type such that byExpr :: t
+                            (byExpr', tTy) <- tcInferRho byExpr
+                            
+                            -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]]
+                            let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy))
+                            eitherUsingExpr' <- 
+                                case eitherUsingExpr of
+                                    Left usingExpr  -> (tcPolyExpr usingExpr expectedUsingType) >>= (return . Left)
+                                    Right usingExpr -> (tcPolyExpr (noLoc usingExpr) expectedUsingType) >>= (return . Right . unLoc)
+                            return $ GroupBySomething eitherUsingExpr' byExpr'
+            
+                -- Find the IDs and types of all old binders
+                let (oldBinders, newBinders) = unzip bindersMap
+                oldBinders' <- tcLookupLocalIds oldBinders
+                
+                -- Ensure that every old binder of type b is linked up with its new binder which should have type [b]
+                let newBinders' = zipWith associateNewBinder oldBinders' newBinders
+            
+                -- Type check the thing in the environment with these new binders and return the result
+                thing <- tcExtendIdEnv newBinders' (thing_inside elt_ty')
+                return (zipEqual "tcLcStmt: Old and new binder lists were not of the same length" oldBinders' newBinders', groupByClause', thing)
+        
+        return (GroupStmt (stmts', bindersMap') groupByClause', thing)
+    where
+        associateNewBinder :: TcId -> Name -> TcId
+        associateNewBinder oldBinder newBinder = mkLocalId newBinder (mkTyConApp m_tc [idType oldBinder])
+    
 tcLcStmt m_tc ctxt stmt elt_ty thing_inside
   = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
 tcLcStmt m_tc ctxt stmt elt_ty thing_inside
   = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
-
+        
 --------------------------------
 --     Do-notation
 -- The main excitement here is dealing with rebindable syntax
 --------------------------------
 --     Do-notation
 -- The main excitement here is dealing with rebindable syntax