From: simonpj@microsoft.com Date: Thu, 4 Mar 2010 12:53:37 +0000 (+0000) Subject: Refactor part of the renamer to fix Trac #3901 X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=f1cc3eb980a634e62f2739a7a25387c902fa9d8a Refactor part of the renamer to fix Trac #3901 This one was bigger than I anticipated! The problem was that were were gathering the binders from a pattern before renaming -- but with record wild-cards we don't know what variables are bound by C {..} until after the renamer has filled in the "..". So this patch does the following * Change all the collect-X-Binders functions in HsUtils so that they expect to only be called *after* renaming. That means they don't need to return [Located id] but just [id]. Which turned out to be a very worthwhile simplification all by itself. * Refactor the renamer, and in ptic RnExpr.rnStmt, so that it doesn't need to use collectLStmtsBinders on pre-renamed Stmts. * This in turn required me to understand how GroupStmt and TransformStmts were renamed. Quite fiddly. I rewrote most of it; result is much shorter. * In doing so I flattened HsExpr.GroupByClause into its parent GroupStmt, with trivial knock-on effects in other files. Blargh. --- diff --git a/compiler/deSugar/Coverage.lhs b/compiler/deSugar/Coverage.lhs index 52c0f04..6bdc8a1 100644 --- a/compiler/deSugar/Coverage.lhs +++ b/compiler/deSugar/Coverage.lhs @@ -24,6 +24,7 @@ import HscTypes import StaticFlags import TyCon import FiniteMap +import MonadUtils import Maybes import Data.Array @@ -290,7 +291,7 @@ addTickHsExpr (HsIf e1 e2 e3) = (addTickLHsExprOptAlt True e2) (addTickLHsExprOptAlt True e3) addTickHsExpr (HsLet binds e) = - bindLocals (map unLoc $ collectLocalBinders binds) $ + bindLocals (collectLocalBinders binds) $ liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsExprNeverOrAlways e) @@ -398,7 +399,7 @@ addTickGRHSs isOneOfMany (GRHSs guarded local_binds) = do guarded' <- mapM (liftL (addTickGRHS isOneOfMany)) guarded return $ GRHSs guarded' local_binds' where - binders = map unLoc (collectLocalBinders local_binds) + binders = collectLocalBinders local_binds addTickGRHS :: Bool -> GRHS Id -> TM (GRHS Id) addTickGRHS isOneOfMany (GRHS stmts expr) = do @@ -420,7 +421,7 @@ addTickLStmts' isGuard lstmts res a <- res return (lstmts', a) where - binders = map unLoc (collectLStmtsBinders lstmts) + binders = collectLStmtsBinders lstmts addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id -> TM (Stmt Id) addTickStmt _isGuard (BindStmt pat e bind fail) = do @@ -440,25 +441,21 @@ addTickStmt _isGuard (LetStmt binds) = do addTickStmt isGuard (ParStmt pairs) = do liftM ParStmt (mapM (addTickStmtAndBinders isGuard) pairs) -addTickStmt isGuard (TransformStmt (stmts, ids) usingExpr maybeByExpr) = do - liftM3 TransformStmt - (addTickStmtAndBinders isGuard (stmts, ids)) + +addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr) = do + liftM4 TransformStmt + (addTickLStmts isGuard stmts) + (return ids) (addTickLHsExprAlways usingExpr) (addTickMaybeByLHsExpr maybeByExpr) -addTickStmt isGuard (GroupStmt (stmts, binderMap) groupByClause) = do - liftM2 GroupStmt - (addTickStmtAndBinders isGuard (stmts, binderMap)) - (case groupByClause of - GroupByNothing usingExpr -> addTickLHsExprAlways usingExpr >>= (return . GroupByNothing) - GroupBySomething eitherUsingExpr byExpr -> do - eitherUsingExpr' <- mapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) eitherUsingExpr - byExpr' <- addTickLHsExprAlways byExpr - return $ GroupBySomething eitherUsingExpr' byExpr') - where - mapEitherM f g x = do - case x of - Left a -> f a >>= (return . Left) - Right b -> g b >>= (return . Right) + +addTickStmt isGuard (GroupStmt stmts binderMap by using) = do + liftM4 GroupStmt + (addTickLStmts isGuard stmts) + (return binderMap) + (fmapMaybeM addTickLHsExprAlways by) + (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using) + addTickStmt isGuard stmt@(RecStmt {}) = do { stmts' <- addTickLStmts isGuard (recS_stmts stmt) ; ret' <- addTickSyntaxExpr hpcSrcSpan (recS_ret_fn stmt) diff --git a/compiler/deSugar/DsArrows.lhs b/compiler/deSugar/DsArrows.lhs index 48700f6..b1a4c59 100644 --- a/compiler/deSugar/DsArrows.lhs +++ b/compiler/deSugar/DsArrows.lhs @@ -14,8 +14,7 @@ import Match import DsUtils import DsMonad -import HsSyn hiding (collectPatBinders, collectLocatedPatBinders, collectl, - collectPatsBinders, collectLocatedPatsBinders) +import HsSyn hiding (collectPatBinders, collectPatsBinders ) import TcHsSyn -- NB: The desugarer, which straddles the source and Core worlds, sometimes @@ -526,7 +525,7 @@ dsCmd ids local_vars env_ids stack res_ty (HsCase exp (MatchGroup matches match_ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do let - defined_vars = mkVarSet (map unLoc (collectLocalBinders binds)) + defined_vars = mkVarSet (collectLocalBinders binds) local_vars' = local_vars `unionVarSet` defined_vars (core_body, _free_vars, env_ids') <- dsfixCmd ids local_vars' stack res_ty body @@ -633,7 +632,7 @@ dsCmdDo ids local_vars env_ids res_ty [] body dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) body = do let - bound_vars = mkVarSet (map unLoc (collectLStmtBinders stmt)) + bound_vars = mkVarSet (collectLStmtBinders stmt) local_vars' = local_vars `unionVarSet` bound_vars (core_stmts, _, env_ids') <- fixDs (\ ~(_,_,env_ids') -> do (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts body @@ -923,7 +922,7 @@ dsCmdStmts ids local_vars env_ids out_ids [stmt] dsCmdStmts ids local_vars env_ids out_ids (stmt:stmts) = do let - bound_vars = mkVarSet (map unLoc (collectLStmtBinders stmt)) + bound_vars = mkVarSet (collectLStmtBinders stmt) local_vars' = local_vars `unionVarSet` bound_vars (core_stmts, _fv_stmts, env_ids') <- dsfixCmdStmts ids local_vars' out_ids stmts (core_stmt, fv_stmt) <- dsCmdLStmt ids local_vars env_ids env_ids' stmt @@ -963,10 +962,10 @@ leavesMatch (L _ (Match pats _ (GRHSs grhss binds))) = let defined_vars = mkVarSet (collectPatsBinders pats) `unionVarSet` - mkVarSet (map unLoc (collectLocalBinders binds)) + mkVarSet (collectLocalBinders binds) in [(expr, - mkVarSet (map unLoc (collectLStmtsBinders stmts)) + mkVarSet (collectLStmtsBinders stmts) `unionVarSet` defined_vars) | L _ (GRHS stmts expr) <- grhss] \end{code} @@ -1009,6 +1008,8 @@ foldb f xs = foldb f (fold_pairs xs) fold_pairs (x1:x2:xs) = f x1 x2:fold_pairs xs \end{code} +Note [Dictionary binders in ConPatOut] See also same Note in HsUtils +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The following functions to collect value variables from patterns are copied from HsUtils, with one change: we also collect the dictionary bindings (pat_binds) from ConPatOut. We need them for cases like @@ -1029,29 +1030,24 @@ these bindings. \begin{code} collectPatBinders :: OutputableBndr a => LPat a -> [a] -collectPatBinders pat = map unLoc (collectLocatedPatBinders pat) - -collectLocatedPatBinders :: OutputableBndr a => LPat a -> [Located a] -collectLocatedPatBinders pat = collectl pat [] +collectPatBinders pat = collectl pat [] collectPatsBinders :: OutputableBndr a => [LPat a] -> [a] -collectPatsBinders pats = map unLoc (collectLocatedPatsBinders pats) - -collectLocatedPatsBinders :: OutputableBndr a => [LPat a] -> [Located a] -collectLocatedPatsBinders pats = foldr collectl [] pats +collectPatsBinders pats = foldr collectl [] pats --------------------- -collectl :: OutputableBndr a => LPat a -> [Located a] -> [Located a] -collectl (L l pat) bndrs +collectl :: OutputableBndr a => LPat a -> [a] -> [a] +-- See Note [Dictionary binders in ConPatOut] +collectl (L _ pat) bndrs = go pat where - go (VarPat var) = L l var : bndrs - go (VarPatOut var bs) = L l var : collectHsBindLocatedBinders bs + go (VarPat var) = var : bndrs + go (VarPatOut var bs) = var : collectHsBindsBinders bs ++ bndrs go (WildPat _) = bndrs go (LazyPat pat) = collectl pat bndrs go (BangPat pat) = collectl pat bndrs - go (AsPat a pat) = a : collectl pat bndrs + go (AsPat (L _ a) pat) = a : collectl pat bndrs go (ParPat pat) = collectl pat bndrs go (ListPat pats _) = foldr collectl bndrs pats @@ -1060,11 +1056,11 @@ collectl (L l pat) bndrs go (ConPatIn _ ps) = foldr collectl bndrs (hsConPatArgs ps) go (ConPatOut {pat_args=ps, pat_binds=ds}) = - collectHsBindLocatedBinders ds + collectHsBindsBinders ds ++ foldr collectl bndrs (hsConPatArgs ps) go (LitPat _) = bndrs go (NPat _ _ _) = bndrs - go (NPlusKPat n _ _ _) = n : bndrs + go (NPlusKPat (L _ n) _ _ _) = n : bndrs go (SigPatIn pat _) = collectl pat bndrs go (SigPatOut pat _) = collectl pat bndrs diff --git a/compiler/deSugar/DsListComp.lhs b/compiler/deSugar/DsListComp.lhs index e7c1f20..46ae129 100644 --- a/compiler/deSugar/DsListComp.lhs +++ b/compiler/deSugar/DsListComp.lhs @@ -38,8 +38,6 @@ import PrelInfo import SrcLoc import Outputable import FastString - -import Control.Monad ( liftM2 ) \end{code} List comprehensions may be desugared in one of two ways: ``ordinary'' @@ -95,7 +93,7 @@ dsInnerListComp (stmts, bndrs) = do -- Given such a statement it gives you back an expression representing how to compute the transformed -- list and the tuple that you need to bind from that list in order to proceed with your desugaring dsTransformStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) -dsTransformStmt (TransformStmt (stmts, binders) usingExpr maybeByExpr) = do +dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr) = do (expr, binders_tuple_type) <- dsInnerListComp (stmts, binders) usingExpr' <- dsLExpr usingExpr @@ -120,7 +118,7 @@ dsTransformStmt (TransformStmt (stmts, binders) usingExpr maybeByExpr) = do -- Given such a statement it gives you back an expression representing how to compute the transformed -- list and the tuple that you need to bind from that list in order to proceed with your desugaring dsGroupStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) -dsGroupStmt (GroupStmt (stmts, binderMap) groupByClause) = do +dsGroupStmt (GroupStmt stmts binderMap by using) = do let (fromBinders, toBinders) = unzip binderMap fromBindersTypes = map idType fromBinders @@ -129,23 +127,19 @@ dsGroupStmt (GroupStmt (stmts, binderMap) groupByClause) = do toBindersTupleType = mkBigCoreTupTy toBindersTypes -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders - (expr, fromBindersTupleType) <- dsInnerListComp (stmts, fromBinders) + (expr, from_tup_ty) <- dsInnerListComp (stmts, fromBinders) -- Work out what arguments should be supplied to that expression: i.e. is an extraction -- function required? If so, create that desugared function and add to arguments - (usingExpr', usingArgs) <- - case groupByClause of - GroupByNothing usingExpr -> liftM2 (,) (dsLExpr usingExpr) (return [expr]) - GroupBySomething usingExpr byExpr -> do - usingExpr' <- dsLExpr (either id noLoc usingExpr) - - byExpr' <- dsLExpr byExpr - - us <- newUniqueSupply - [fromBindersTuple] <- newSysLocalsDs [fromBindersTupleType] - let byExprWrapper = mkTupleCase us fromBinders byExpr' fromBindersTuple (Var fromBindersTuple) - - return (usingExpr', [Lam fromBindersTuple byExprWrapper, expr]) + usingExpr' <- dsLExpr (either id noLoc using) + usingArgs <- case by of + Nothing -> return [expr] + Just by_e -> do { by_e' <- dsLExpr by_e + ; us <- newUniqueSupply + ; [from_tup_id] <- newSysLocalsDs [from_tup_ty] + ; let by_wrap = mkTupleCase us fromBinders by_e' + from_tup_id (Var from_tup_id) + ; return [Lam from_tup_id by_wrap, expr] } -- Create an unzip function for the appropriate arity and element types and find "map" (unzip_fn, unzip_rhs) <- mkUnzipBind fromBindersTypes @@ -153,12 +147,12 @@ dsGroupStmt (GroupStmt (stmts, binderMap) groupByClause) = do -- Generate the expressions to build the grouped list let -- First we apply the grouping function to the inner list - inner_list_expr = mkApps usingExpr' ((Type fromBindersTupleType) : usingArgs) + inner_list_expr = mkApps usingExpr' ((Type from_tup_ty) : usingArgs) -- Then we map our "unzip" across it to turn the lists of tuples into tuples of lists -- We make sure we instantiate the type variable "a" to be a list of "from" tuples and -- the "b" to be a tuple of "to" lists! unzipped_inner_list_expr = mkApps (Var map_id) - [Type (mkListTy fromBindersTupleType), Type toBindersTupleType, Var unzip_fn, inner_list_expr] + [Type (mkListTy from_tup_ty), Type toBindersTupleType, Var unzip_fn, inner_list_expr] -- Then finally we bind the unzip function around that expression bound_unzipped_inner_list_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_list_expr @@ -270,11 +264,11 @@ deListComp (LetStmt binds : quals) body list = do core_rest <- deListComp quals body list dsLocalBinds binds core_rest -deListComp (stmt@(TransformStmt _ _ _) : quals) body list = do +deListComp (stmt@(TransformStmt {}) : quals) body list = do (inner_list_expr, pat) <- dsTransformStmt stmt deBindComp pat inner_list_expr quals body list -deListComp (stmt@(GroupStmt _ _) : quals) body list = do +deListComp (stmt@(GroupStmt {}) : quals) body list = do (inner_list_expr, pat) <- dsGroupStmt stmt deBindComp pat inner_list_expr quals body list @@ -362,12 +356,12 @@ dfListComp c_id n_id (LetStmt binds : quals) body = do core_rest <- dfListComp c_id n_id quals body dsLocalBinds binds core_rest -dfListComp c_id n_id (stmt@(TransformStmt _ _ _) : quals) body = do +dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) body = do (inner_list_expr, pat) <- dsTransformStmt stmt -- Anyway, we bind the newly transformed list via the generic binding function dfBindComp c_id n_id (pat, inner_list_expr) quals body -dfListComp c_id n_id (stmt@(GroupStmt _ _) : quals) body = do +dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) body = do (inner_list_expr, pat) <- dsGroupStmt stmt -- Anyway, we bind the newly grouped list via the generic binding function dfBindComp c_id n_id (pat, inner_list_expr) quals body @@ -604,7 +598,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do -- dePArrComp (LetStmt ds : qs) body pa cea = do mapP <- dsLookupGlobalId mapPName - let xs = map unLoc (collectLocalBinders ds) + let xs = collectLocalBinders ds ty'cea = parrElemType cea v <- newSysLocalDs ty'cea clet <- dsLocalBinds ds (mkCoreTup (map Var xs)) diff --git a/compiler/deSugar/DsMeta.hs b/compiler/deSugar/DsMeta.hs index 43c4622..f071a17 100644 --- a/compiler/deSugar/DsMeta.hs +++ b/compiler/deSugar/DsMeta.hs @@ -106,7 +106,7 @@ repTopP pat = do { ss <- mkGenSyms (collectPatBinders pat) repTopDs :: HsGroup Name -> DsM (Core (TH.Q [TH.Dec])) repTopDs group - = do { let { bndrs = map unLoc (groupBinders group) } ; + = do { let { bndrs = groupBinders group } ; ss <- mkGenSyms bndrs ; -- Bind all the names mainly to avoid repeated use of explicit strings. @@ -135,13 +135,13 @@ repTopDs group -- Do *not* gensym top-level binders } -groupBinders :: HsGroup Name -> [Located Name] +groupBinders :: HsGroup Name -> [Name] groupBinders (HsGroup { hs_valds = val_decls, hs_tyclds = tycl_decls, hs_instds = inst_decls, hs_fords = foreign_decls }) -- Collect the binders of a Group = collectHsValBinders val_decls ++ - [n | d <- tycl_decls ++ assoc_tycl_decls, n <- tyClDeclNames (unLoc d)] ++ - [n | L _ (ForeignImport n _ _) <- foreign_decls] + [n | d <- tycl_decls ++ assoc_tycl_decls, L _ n <- tyClDeclNames (unLoc d)] ++ + [n | L _ (ForeignImport (L _ n) _ _) <- foreign_decls] where assoc_tycl_decls = concat [ats | L _ (InstDecl _ _ _ ats) <- inst_decls] @@ -317,7 +317,7 @@ repInstD' (L loc (InstDecl ty binds _ ats)) -- Ignore user pragmas for now -- appear in the resulting data structure do { cxt1 <- repContext cxt ; inst_ty1 <- repPredTy (HsClassP cls tys) - ; ss <- mkGenSyms (collectHsBindBinders binds) + ; ss <- mkGenSyms (collectHsBindsBinders binds) ; binds1 <- addBinds ss (rep_binds binds) ; ats1 <- repLAssocFamInst ats ; decls1 <- coreList decQTyConName (ats1 ++ binds1) @@ -900,7 +900,7 @@ repBinds EmptyLocalBinds repBinds b@(HsIPBinds _) = notHandled "Implicit parameters" (ppr b) repBinds (HsValBinds decs) - = do { let { bndrs = map unLoc (collectHsValBinders decs) } + = do { let { bndrs = collectHsValBinders decs } -- No need to worrry about detailed scopes within -- the binding group, because we are talking Names -- here, so we can safely treat it as a mutually diff --git a/compiler/hsSyn/HsExpr.lhs b/compiler/hsSyn/HsExpr.lhs index fd4f6db..a328cee 100644 --- a/compiler/hsSyn/HsExpr.lhs +++ b/compiler/hsSyn/HsExpr.lhs @@ -808,15 +808,6 @@ type LStmtLR idL idR = Located (StmtLR idL idR) type Stmt id = StmtLR id id -data GroupByClause id - = GroupByNothing (LHsExpr id) -- Using expression, i.e. - -- "then group using f" ==> GroupByNothing f - | GroupBySomething (Either (LHsExpr id) (SyntaxExpr id)) (LHsExpr id) - -- "then group using f by e" ==> GroupBySomething (Left f) e - -- "then group by e" ==> GroupBySomething (Right _) e: in - -- this case the expression is filled - -- in by the renamer - -- The SyntaxExprs in here are used *only* for do-notation, which -- has rebindable syntax. Otherwise they are unused. data StmtLR idL idR @@ -838,16 +829,33 @@ data StmtLR idL idR -- After renaming, the ids are the binders bound by the stmts and used -- after them - | TransformStmt ([LStmt idL], [idR]) (LHsExpr idR) (Maybe (LHsExpr idR)) - -- After renaming, the IDs are the binders occurring within this - -- transform statement that are used after it - -- "qs, then f by e" ==> TransformStmt (qs, binders) f (Just e) - -- "qs, then f" ==> TransformStmt (qs, binders) f Nothing + -- "qs, then f by e" ==> TransformStmt qs binders f (Just e) + -- "qs, then f" ==> TransformStmt qs binders f Nothing + | TransformStmt + [LStmt idL] -- Stmts are the ones to the left of the 'then' + + [idR] -- After renaming, the IDs are the binders occurring + -- within this transform statement that are used after it + + (LHsExpr idR) -- "then f" + + (Maybe (LHsExpr idR)) -- "by e" (optional) - | GroupStmt ([LStmt idL], [(idR, idR)]) (GroupByClause idR) - -- After renaming, the IDs are the binders occurring within this - -- transform statement that are used after it which are paired with - -- the names which they group over in statements + | GroupStmt + [LStmt idL] -- Stmts to the *left* of the 'group' + -- which generates the tuples to be grouped + + [(idR, idR)] -- After renaming, the IDs are the binders + -- occurring within this transform statement that + -- are used after it which are paired with the + -- names which they group over in statements + + (Maybe (LHsExpr idR)) -- "by e" (optional) + + (Either -- "using f" + (LHsExpr idR) -- Left f => explicit "using f" + (SyntaxExpr idR)) -- Right f => implicit; filled in with 'groupWith' + -- Recursive statement (see Note [RecStmt] below) | RecStmt @@ -959,43 +967,57 @@ pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds] pprStmt (ExprStmt expr _ _) = ppr expr pprStmt (ParStmt stmtss) = hsep (map doStmts stmtss) where doStmts stmts = ptext (sLit "| ") <> ppr stmts -pprStmt (TransformStmt (stmts, _) usingExpr maybeByExpr) - = (hsep [stmtsDoc, ptext (sLit "then"), ppr usingExpr, byExprDoc]) - where stmtsDoc = interpp'SP stmts - byExprDoc = maybe empty (\byExpr -> hsep [ptext (sLit "by"), ppr byExpr]) maybeByExpr -pprStmt (GroupStmt (stmts, _) groupByClause) = (hsep [stmtsDoc, ptext (sLit "then group"), pprGroupByClause groupByClause]) - where stmtsDoc = interpp'SP stmts -pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids, recS_later_ids = later_ids }) + +pprStmt (TransformStmt stmts _ using by) + = sep (ppr_lc_stmts stmts ++ [pprTransformStmt using by]) + +pprStmt (GroupStmt stmts _ by using) + = sep (ppr_lc_stmts stmts ++ [pprGroupStmt by using]) + +pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids + , recS_later_ids = later_ids }) = ptext (sLit "rec") <+> vcat [ braces (vcat (map ppr segment)) , ifPprDebug (vcat [ ptext (sLit "rec_ids=") <> ppr rec_ids , ptext (sLit "later_ids=") <> ppr later_ids])] -pprGroupByClause :: (OutputableBndr id) => GroupByClause id -> SDoc -pprGroupByClause (GroupByNothing usingExpr) = hsep [ptext (sLit "using"), ppr usingExpr] -pprGroupByClause (GroupBySomething eitherUsingExpr byExpr) = hsep [ptext (sLit "by"), ppr byExpr, usingExprDoc] - where usingExprDoc = either (\usingExpr -> hsep [ptext (sLit "using"), ppr usingExpr]) (const empty) eitherUsingExpr +pprTransformStmt :: OutputableBndr id => LHsExpr id -> Maybe (LHsExpr id) -> SDoc +pprTransformStmt using by = sep [ ptext (sLit "then"), nest 2 (ppr using), nest 2 (pprBy by)] + +pprGroupStmt :: OutputableBndr id => Maybe (LHsExpr id) + -> Either (LHsExpr id) (SyntaxExpr is) + -> SDoc +pprGroupStmt by using + = sep [ ptext (sLit "then group"), nest 2 (pprBy by), nest 2 (ppr_using using)] + where + ppr_using (Right _) = empty + ppr_using (Left e) = ptext (sLit "using") <+> ppr e + +pprBy :: OutputableBndr id => Maybe (LHsExpr id) -> SDoc +pprBy Nothing = empty +pprBy (Just e) = ptext (sLit "by") <+> ppr e pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> LHsExpr id -> SDoc pprDo DoExpr stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body pprDo GhciStmt stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body pprDo (MDoExpr _) stmts body = ptext (sLit "mdo") <+> ppr_do_stmts stmts body -pprDo ListComp stmts body = pprComp brackets stmts body -pprDo PArrComp stmts body = pprComp pa_brackets stmts body +pprDo ListComp stmts body = brackets $ pprComp stmts body +pprDo PArrComp stmts body = pa_brackets $ pprComp stmts body pprDo _ _ _ = panic "pprDo" -- PatGuard, ParStmtCxt ppr_do_stmts :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc -- Print a bunch of do stmts, with explicit braces and semicolons, -- so that we are not vulnerable to layout bugs ppr_do_stmts stmts body - = lbrace <+> pprDeeperList vcat ([ ppr s <> semi | s <- stmts] ++ [ppr body]) + = lbrace <+> pprDeeperList vcat ([ppr s <> semi | s <- stmts] ++ [ppr body]) <+> rbrace -pprComp :: OutputableBndr id => (SDoc -> SDoc) -> [LStmt id] -> LHsExpr id -> SDoc -pprComp brack quals body - = brack $ - hang (ppr body <+> char '|') - 4 (interpp'SP quals) +ppr_lc_stmts :: OutputableBndr id => [LStmt id] -> [SDoc] +ppr_lc_stmts stmts = [ppr s <> comma | s <- stmts] + +pprComp :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc +pprComp quals body -- Prints: body | qual1, ..., qualn + = hang (ppr body <+> char '|') 2 (interpp'SP quals) \end{code} %************************************************************************ @@ -1202,5 +1224,10 @@ pprMatchInCtxt ctxt match = hang (ptext (sLit "In") <+> pprMatchContext ctxt <> pprStmtInCtxt :: (OutputableBndr idL, OutputableBndr idR) => HsStmtContext idL -> StmtLR idL idR -> SDoc pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprStmtContext ctxt <> colon) - 4 (ppr stmt) + 4 (ppr_stmt stmt) + where + -- For Group and Transform Stmts, don't print the nested stmts! + ppr_stmt (GroupStmt _ _ by using) = pprGroupStmt by using + ppr_stmt (TransformStmt _ _ using by) = pprTransformStmt using by + ppr_stmt stmt = pprStmt stmt \end{code} diff --git a/compiler/hsSyn/HsPat.lhs b/compiler/hsSyn/HsPat.lhs index 5065375..8ab583a 100644 --- a/compiler/hsSyn/HsPat.lhs +++ b/compiler/hsSyn/HsPat.lhs @@ -195,7 +195,7 @@ data HsRecFields id arg -- A bunch of record fields data HsRecField id arg = HsRecField { hsRecFieldId :: Located id, - hsRecFieldArg :: arg, + hsRecFieldArg :: arg, -- Filled in by renamer hsRecPun :: Bool -- Note [Punning] } diff --git a/compiler/hsSyn/HsUtils.lhs b/compiler/hsSyn/HsUtils.lhs index 14193e0..d5ff6f5 100644 --- a/compiler/hsSyn/HsUtils.lhs +++ b/compiler/hsSyn/HsUtils.lhs @@ -14,7 +14,51 @@ which deal with the intantiated versions are located elsewhere: Id typecheck/TcHsSyn \begin{code} -module HsUtils where +module HsUtils( + -- Terms + mkHsPar, mkHsApp, mkHsConApp, mkSimpleHsAlt, + mkSimpleMatch, unguardedGRHSs, unguardedRHS, + mkMatchGroup, mkMatch, mkHsLam, + mkHsWrap, mkLHsWrap, mkHsWrapCoI, coiToHsWrapper, mkHsDictLet, + mkHsOpApp, mkHsDo, + + nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps, + nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList, + mkLHsTupleExpr, mkLHsVarTuple, missingTupArg, + + -- Bindigns + mkFunBind, mkVarBind, mkHsVarBind, mk_easy_FunBind, mk_FunBind, + + -- Literals + mkHsIntegral, mkHsFractional, mkHsIsString, mkHsString, + + -- Patterns + mkNPat, mkNPlusKPat, nlVarPat, nlLitPat, nlConVarPat, nlConPat, nlInfixConPat, + nlNullaryConPat, nlWildConPat, nlWildPat, nlTuplePat, + + -- Types + mkHsAppTy, userHsTyVarBndrs, + nlHsAppTy, nlHsTyVar, nlHsFunTy, nlHsTyConApp, + + -- Stmts + mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, + mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt, + emptyRecStmt, mkRecStmt, + + -- Template Haskell + unqualSplice, mkHsSpliceTy, mkHsSplice, mkHsQuasiQuote, unqualQuasiQuote, + + -- Flags + noRebindableInfo, + + -- Collecting binders + collectLocalBinders, collectHsValBinders, + collectHsBindsBinders, collectHsBindBinders, collectMethodBinders, + collectPatBinders, collectPatsBinders, + collectLStmtsBinders, collectStmtsBinders, + collectLStmtBinders, collectStmtBinders, + collectSigTysFromPats, collectSigTysFromPat + ) where import HsBinds import HsExpr @@ -135,10 +179,6 @@ mkNPlusKPat :: Located id -> HsOverLit id -> Pat id mkTransformStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR -mkGroupUsingStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR -mkGroupByStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR -mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR - mkExprStmt :: LHsExpr idR -> StmtLR idL idR mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR @@ -158,12 +198,16 @@ mkHsDo ctxt stmts body = HsDo ctxt stmts body placeHolderType mkNPat lit neg = NPat lit neg noSyntaxExpr mkNPlusKPat id lit = NPlusKPat id lit noSyntaxExpr noSyntaxExpr -mkTransformStmt stmts usingExpr = TransformStmt (stmts, []) usingExpr Nothing -mkTransformByStmt stmts usingExpr byExpr = TransformStmt (stmts, []) usingExpr (Just byExpr) +mkTransformStmt stmts usingExpr = TransformStmt stmts [] usingExpr Nothing +mkTransformByStmt stmts usingExpr byExpr = TransformStmt stmts [] usingExpr (Just byExpr) -mkGroupUsingStmt stmts usingExpr = GroupStmt (stmts, []) (GroupByNothing usingExpr) -mkGroupByStmt stmts byExpr = GroupStmt (stmts, []) (GroupBySomething (Right noSyntaxExpr) byExpr) -mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt (stmts, []) (GroupBySomething (Left usingExpr) byExpr) +mkGroupUsingStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR +mkGroupByStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR +mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR + +mkGroupUsingStmt stmts usingExpr = GroupStmt stmts [] Nothing (Left usingExpr) +mkGroupByStmt stmts byExpr = GroupStmt stmts [] (Just byExpr) (Right noSyntaxExpr) +mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt stmts [] (Just byExpr) (Left usingExpr) mkExprStmt expr = ExprStmt expr noSyntaxExpr placeHolderType mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr @@ -362,7 +406,7 @@ mkMatch pats expr binds %************************************************************************ %* * - Collecting binders from HsBindGroups and HsBinds + Collecting binders %* * %************************************************************************ @@ -376,126 +420,116 @@ where it should return [x, y, f, a, b] (remember, order important). +Note [Collect binders only after renaming] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +These functions should only be used on HsSyn *after* the renamer, +to reuturn a [Name] or [Id]. Before renaming the record punning +and wild-card mechanism makes it hard to know what is bound. +So these functions should not be applied to (HsSyn RdrName) + \begin{code} -collectLocalBinders :: HsLocalBindsLR idL idR -> [Located idL] +----------------- Bindings -------------------------- +collectLocalBinders :: HsLocalBindsLR idL idR -> [idL] collectLocalBinders (HsValBinds val_binds) = collectHsValBinders val_binds collectLocalBinders (HsIPBinds _) = [] collectLocalBinders EmptyLocalBinds = [] -collectHsValBinders :: HsValBindsLR idL idR -> [Located idL] -collectHsValBinders (ValBindsIn binds _) = collectHsBindLocatedBinders binds +collectHsValBinders :: HsValBindsLR idL idR -> [idL] +collectHsValBinders (ValBindsIn binds _) = collectHsBindsBinders binds collectHsValBinders (ValBindsOut binds _) = foldr collect_one [] binds where - collect_one (_,binds) acc = foldrBag (collectAcc . unLoc) acc binds - -collectAcc :: HsBindLR idL idR -> [Located idL] -> [Located idL] -collectAcc (PatBind { pat_lhs = p }) acc = collectLocatedPatBinders p ++ acc -collectAcc (FunBind { fun_id = f }) acc = f : acc -collectAcc (VarBind { var_id = f }) acc = noLoc f : acc -collectAcc (AbsBinds { abs_exports = dbinds, abs_binds = _binds }) acc - = [noLoc dp | (_,dp,_,_) <- dbinds] ++ acc - -- ++ foldr collectAcc acc binds + collect_one (_,binds) acc = collect_binds binds acc + +collectHsBindBinders :: HsBindLR idL idR -> [idL] +collectHsBindBinders b = collect_bind b [] + +collect_bind :: HsBindLR idL idR -> [idL] -> [idL] +collect_bind (PatBind { pat_lhs = p }) acc = collect_lpat p acc +collect_bind (FunBind { fun_id = L _ f }) acc = f : acc +collect_bind (VarBind { var_id = f }) acc = f : acc +collect_bind (AbsBinds { abs_exports = dbinds, abs_binds = _binds }) acc + = [dp | (_,dp,_,_) <- dbinds] ++ acc + -- ++ foldr collect_bind acc binds -- I don't think we want the binders from the nested binds -- The only time we collect binders from a typechecked -- binding (hence see AbsBinds) is in zonking in TcHsSyn -collectHsBindBinders :: LHsBindsLR idL idR -> [idL] -collectHsBindBinders binds = map unLoc (collectHsBindLocatedBinders binds) - -collectHsBindLocatedBinders :: LHsBindsLR idL idR -> [Located idL] -collectHsBindLocatedBinders binds = foldrBag (collectAcc . unLoc) [] binds -\end{code} +collectHsBindsBinders :: LHsBindsLR idL idR -> [idL] +collectHsBindsBinders binds = collect_binds binds [] +collect_binds :: LHsBindsLR idL idR -> [idL] -> [idL] +collect_binds binds acc = foldrBag (collect_bind . unLoc) acc binds -%************************************************************************ -%* * - Getting binders from statements -%* * -%************************************************************************ +collectMethodBinders :: LHsBindsLR RdrName idR -> [Located RdrName] +-- Used exclusively for the bindings of an instance decl which are all FunBinds +collectMethodBinders binds = foldrBag get [] binds + where + get (L _ (FunBind { fun_id = f })) fs = f : fs + get _ fs = fs + -- Someone else complains about non-FunBinds -\begin{code} -collectLStmtsBinders :: [LStmtLR idL idR] -> [Located idL] +----------------- Statements -------------------------- +collectLStmtsBinders :: [LStmtLR idL idR] -> [idL] collectLStmtsBinders = concatMap collectLStmtBinders -collectStmtsBinders :: [StmtLR idL idR] -> [Located idL] +collectStmtsBinders :: [StmtLR idL idR] -> [idL] collectStmtsBinders = concatMap collectStmtBinders -collectLStmtBinders :: LStmtLR idL idR -> [Located idL] +collectLStmtBinders :: LStmtLR idL idR -> [idL] collectLStmtBinders = collectStmtBinders . unLoc -collectStmtBinders :: StmtLR idL idR -> [Located idL] +collectStmtBinders :: StmtLR idL idR -> [idL] -- Id Binders for a Stmt... [but what about pattern-sig type vars]? -collectStmtBinders (BindStmt pat _ _ _) = collectLocatedPatBinders pat +collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat collectStmtBinders (LetStmt binds) = collectLocalBinders binds collectStmtBinders (ExprStmt _ _ _) = [] collectStmtBinders (ParStmt xs) = collectLStmtsBinders $ concatMap fst xs -collectStmtBinders (TransformStmt (stmts, _) _ _) = collectLStmtsBinders stmts -collectStmtBinders (GroupStmt (stmts, _) _) = collectLStmtsBinders stmts -collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss -\end{code} +collectStmtBinders (TransformStmt stmts _ _ _) = collectLStmtsBinders stmts +collectStmtBinders (GroupStmt stmts _ _ _) = collectLStmtsBinders stmts +collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss -%************************************************************************ -%* * -%* Gathering stuff out of patterns -%* * -%************************************************************************ - -This function @collectPatBinders@ works with the ``collectBinders'' -functions for @HsBinds@, etc. The order in which the binders are -collected is important; see @HsBinds.lhs@. - -It collects the bounds *value* variables in renamed patterns; type variables -are *not* collected. - -\begin{code} +----------------- Patterns -------------------------- collectPatBinders :: LPat a -> [a] -collectPatBinders pat = map unLoc (collectLocatedPatBinders pat) - -collectLocatedPatBinders :: LPat a -> [Located a] -collectLocatedPatBinders pat = collectl pat [] +collectPatBinders pat = collect_lpat pat [] collectPatsBinders :: [LPat a] -> [a] -collectPatsBinders pats = map unLoc (collectLocatedPatsBinders pats) - -collectLocatedPatsBinders :: [LPat a] -> [Located a] -collectLocatedPatsBinders pats = foldr collectl [] pats +collectPatsBinders pats = foldr collect_lpat [] pats ---------------------- -collectl :: LPat name -> [Located name] -> [Located name] -collectl (L l pat) bndrs +------------- +collect_lpat :: LPat name -> [name] -> [name] +collect_lpat (L _ pat) bndrs = go pat where - go (VarPat var) = L l var : bndrs - go (VarPatOut var bs) = L l var : collectHsBindLocatedBinders bs - ++ bndrs + go (VarPat var) = var : bndrs + go (VarPatOut var bs) = var : collect_binds bs bndrs go (WildPat _) = bndrs - go (LazyPat pat) = collectl pat bndrs - go (BangPat pat) = collectl pat bndrs - go (AsPat a pat) = a : collectl pat bndrs - go (ViewPat _ pat _) = collectl pat bndrs - go (ParPat pat) = collectl pat bndrs + go (LazyPat pat) = collect_lpat pat bndrs + go (BangPat pat) = collect_lpat pat bndrs + go (AsPat (L _ a) pat) = a : collect_lpat pat bndrs + go (ViewPat _ pat _) = collect_lpat pat bndrs + go (ParPat pat) = collect_lpat pat bndrs - go (ListPat pats _) = foldr collectl bndrs pats - go (PArrPat pats _) = foldr collectl bndrs pats - go (TuplePat pats _ _) = foldr collectl bndrs pats + go (ListPat pats _) = foldr collect_lpat bndrs pats + go (PArrPat pats _) = foldr collect_lpat bndrs pats + go (TuplePat pats _ _) = foldr collect_lpat bndrs pats - go (ConPatIn _ ps) = foldr collectl bndrs (hsConPatArgs ps) - go (ConPatOut {pat_args=ps}) = foldr collectl bndrs (hsConPatArgs ps) + go (ConPatIn _ ps) = foldr collect_lpat bndrs (hsConPatArgs ps) + go (ConPatOut {pat_args=ps}) = foldr collect_lpat bndrs (hsConPatArgs ps) -- See Note [Dictionary binders in ConPatOut] go (LitPat _) = bndrs go (NPat _ _ _) = bndrs - go (NPlusKPat n _ _ _) = n : bndrs + go (NPlusKPat (L _ n) _ _ _) = n : bndrs - go (SigPatIn pat _) = collectl pat bndrs - go (SigPatOut pat _) = collectl pat bndrs + go (SigPatIn pat _) = collect_lpat pat bndrs + go (SigPatOut pat _) = collect_lpat pat bndrs go (QuasiQuotePat _) = bndrs go (TypePat _) = bndrs - go (CoPat _ pat _) = collectl (noLoc pat) bndrs + go (CoPat _ pat _) = go pat \end{code} -Note [Dictionary binders in ConPatOut] +Note [Dictionary binders in ConPatOut] See also same Note in DsArrows ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Do *not* gather (a) dictionary and (b) dictionary bindings as binders of a ConPatOut pattern. For most calls it doesn't matter, because @@ -521,27 +555,33 @@ and *also* uses that dictionary to match the (n+1) pattern. Yet, the variables bound by the lazy pattern are n,m, *not* the dictionary d. So in mkSelectorBinds in DsUtils, we want just m,n as the variables bound. +%************************************************************************ +%* * + Collecting type signatures from patterns +%* * +%************************************************************************ + \begin{code} collectSigTysFromPats :: [InPat name] -> [LHsType name] -collectSigTysFromPats pats = foldr collect_lpat [] pats +collectSigTysFromPats pats = foldr collect_sig_lpat [] pats collectSigTysFromPat :: InPat name -> [LHsType name] -collectSigTysFromPat pat = collect_lpat pat [] - -collect_lpat :: InPat name -> [LHsType name] -> [LHsType name] -collect_lpat pat acc = collect_pat (unLoc pat) acc - -collect_pat :: Pat name -> [LHsType name] -> [LHsType name] -collect_pat (SigPatIn pat ty) acc = collect_lpat pat (ty:acc) -collect_pat (TypePat ty) acc = ty:acc - -collect_pat (LazyPat pat) acc = collect_lpat pat acc -collect_pat (BangPat pat) acc = collect_lpat pat acc -collect_pat (AsPat _ pat) acc = collect_lpat pat acc -collect_pat (ParPat pat) acc = collect_lpat pat acc -collect_pat (ListPat pats _) acc = foldr collect_lpat acc pats -collect_pat (PArrPat pats _) acc = foldr collect_lpat acc pats -collect_pat (TuplePat pats _ _) acc = foldr collect_lpat acc pats -collect_pat (ConPatIn _ ps) acc = foldr collect_lpat acc (hsConPatArgs ps) -collect_pat _ acc = acc -- Literals, vars, wildcard +collectSigTysFromPat pat = collect_sig_lpat pat [] + +collect_sig_lpat :: InPat name -> [LHsType name] -> [LHsType name] +collect_sig_lpat pat acc = collect_sig_pat (unLoc pat) acc + +collect_sig_pat :: Pat name -> [LHsType name] -> [LHsType name] +collect_sig_pat (SigPatIn pat ty) acc = collect_sig_lpat pat (ty:acc) +collect_sig_pat (TypePat ty) acc = ty:acc + +collect_sig_pat (LazyPat pat) acc = collect_sig_lpat pat acc +collect_sig_pat (BangPat pat) acc = collect_sig_lpat pat acc +collect_sig_pat (AsPat _ pat) acc = collect_sig_lpat pat acc +collect_sig_pat (ParPat pat) acc = collect_sig_lpat pat acc +collect_sig_pat (ListPat pats _) acc = foldr collect_sig_lpat acc pats +collect_sig_pat (PArrPat pats _) acc = foldr collect_sig_lpat acc pats +collect_sig_pat (TuplePat pats _ _) acc = foldr collect_sig_lpat acc pats +collect_sig_pat (ConPatIn _ ps) acc = foldr collect_sig_lpat acc (hsConPatArgs ps) +collect_sig_pat _ acc = acc -- Literals, vars, wildcard \end{code} diff --git a/compiler/rename/RnBinds.lhs b/compiler/rename/RnBinds.lhs index 5773108..2cf2bdc 100644 --- a/compiler/rename/RnBinds.lhs +++ b/compiler/rename/RnBinds.lhs @@ -179,7 +179,7 @@ rnTopBinds :: HsValBinds RdrName -> RnM (HsValBinds Name, DefUses) rnTopBinds b = do nl <- rnTopBindsLHS emptyFsEnv b - let bound_names = map unLoc (collectHsValBinders nl) + let bound_names = collectHsValBinders nl bindLocalNames bound_names $ rnTopBindsRHS (mkNameSet bound_names) nl @@ -261,7 +261,7 @@ rnValBindsLHS fix_env binds -- g = let f = ... in f -- should. ; binds' <- rnValBindsLHSFromDoc (localRecNameMaker fix_env) binds - ; let bound_names = map unLoc $ collectHsValBinders binds' + ; let bound_names = collectHsValBinders binds' ; envs <- getRdrEnvs ; checkDupAndShadowedNames envs bound_names ; return (bound_names, binds') } @@ -276,7 +276,7 @@ rnValBindsLHSFromDoc topP (ValBindsIn mbinds sigs) = do { mbinds' <- mapBagM (rnBindLHS topP doc) mbinds ; return $ ValBindsIn mbinds' sigs } where - bndrs = collectHsBindBinders mbinds + bndrs = collectHsBindsBinders mbinds doc = text "In the binding group for:" <+> pprWithCommas ppr bndrs rnValBindsLHSFromDoc _ b = pprPanic "rnValBindsLHSFromDoc" (ppr b) diff --git a/compiler/rename/RnEnv.lhs b/compiler/rename/RnEnv.lhs index c6d5052..c3b5592 100644 --- a/compiler/rename/RnEnv.lhs +++ b/compiler/rename/RnEnv.lhs @@ -26,8 +26,8 @@ module RnEnv ( bindTyVarsRn, extendTyVarEnvFVRn, checkDupRdrNames, checkDupAndShadowedRdrNames, - checkDupAndShadowedNames, - mapFvRn, mapFvRnCPS, + checkDupNames, checkDupAndShadowedNames, + addFvRn, mapFvRn, mapMaybeFvRn, mapFvRnCPS, warnUnusedMatches, warnUnusedModules, warnUnusedImports, warnUnusedTopBinds, warnUnusedLocalBinds, dataTcOccs, unknownNameErr, kindSigErr, perhapsForallMsg @@ -989,11 +989,19 @@ checkShadowedOccs (global_env,local_env) loc_occs \begin{code} -- A useful utility +addFvRn :: FreeVars -> RnM (thing, FreeVars) -> RnM (thing, FreeVars) +addFvRn fvs1 thing_inside = do { (res, fvs2) <- thing_inside + ; return (res, fvs1 `plusFV` fvs2) } + mapFvRn :: (a -> RnM (b, FreeVars)) -> [a] -> RnM ([b], FreeVars) mapFvRn f xs = do stuff <- mapM f xs case unzip stuff of (ys, fvs_s) -> return (ys, plusFVs fvs_s) +mapMaybeFvRn :: (a -> RnM (b, FreeVars)) -> Maybe a -> RnM (Maybe b, FreeVars) +mapMaybeFvRn _ Nothing = return (Nothing, emptyFVs) +mapMaybeFvRn f (Just x) = do { (y, fvs) <- f x; return (Just y, fvs) } + -- because some of the rename functions are CPSed: -- maps the function across the list from left to right; -- collects all the free vars into one set diff --git a/compiler/rename/RnExpr.lhs b/compiler/rename/RnExpr.lhs index 6dc6801..78088d5 100644 --- a/compiler/rename/RnExpr.lhs +++ b/compiler/rename/RnExpr.lhs @@ -42,7 +42,6 @@ import UniqSet import Data.List import Util ( isSingleton ) import ListSetOps ( removeDups ) -import Maybes ( expectJust ) import Outputable import SrcLoc import FastString @@ -538,8 +537,8 @@ methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName methodNamesStmt (LetStmt _) = emptyFVs methodNamesStmt (ParStmt _) = emptyFVs -methodNamesStmt (TransformStmt _ _ _) = emptyFVs -methodNamesStmt (GroupStmt _ _) = emptyFVs +methodNamesStmt (TransformStmt {}) = emptyFVs +methodNamesStmt (GroupStmt {}) = emptyFVs -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error -- here so we just do what's convenient \end{code} @@ -635,33 +634,43 @@ rnBracket (DecBrG _) = panic "rnBracket: unexpected DecBrG" rnStmts :: HsStmtContext Name -> [LStmt RdrName] -> RnM (thing, FreeVars) -> RnM (([LStmt Name], thing), FreeVars) +-- Variables bound by the Stmts, and mentioned in thing_inside, +-- do not appear in the result FreeVars -rnStmts (MDoExpr _) = rnMDoStmts -rnStmts ctxt = rnNormalStmts ctxt +rnStmts (MDoExpr _) stmts thing_inside = rnMDoStmts stmts thing_inside +rnStmts ctxt stmts thing_inside = rnNormalStmts ctxt stmts (\ _ -> thing_inside) rnNormalStmts :: HsStmtContext Name -> [LStmt RdrName] - -> RnM (thing, FreeVars) + -> ([Name] -> RnM (thing, FreeVars)) -> RnM (([LStmt Name], thing), FreeVars) +-- Variables bound by the Stmts, and mentioned in thing_inside, +-- do not appear in the result FreeVars +-- +-- Renaming a single RecStmt can give a sequence of smaller Stmts + rnNormalStmts _ [] thing_inside - = do { (thing, fvs) <- thing_inside - ; return (([],thing), fvs) } + = do { (res, fvs) <- thing_inside [] + ; return (([], res), fvs) } rnNormalStmts ctxt (stmt@(L loc _) : stmts) thing_inside = do { ((stmts1, (stmts2, thing)), fvs) - <- setSrcSpan loc $ - rnStmt ctxt stmt $ - rnNormalStmts ctxt stmts thing_inside + <- setSrcSpan loc $ + rnStmt ctxt stmt $ \ bndrs1 -> + rnNormalStmts ctxt stmts $ \ bndrs2 -> + thing_inside (bndrs1 ++ bndrs2) ; return (((stmts1 ++ stmts2), thing), fvs) } rnStmt :: HsStmtContext Name -> LStmt RdrName - -> RnM (thing, FreeVars) + -> ([Name] -> RnM (thing, FreeVars)) -> RnM (([LStmt Name], thing), FreeVars) +-- Variables bound by the Stmt, and mentioned in thing_inside, +-- do not appear in the result FreeVars rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside = do { (expr', fv_expr) <- rnLExpr expr ; (then_op, fvs1) <- lookupSyntaxName thenMName - ; (thing, fvs2) <- thing_inside + ; (thing, fvs2) <- thing_inside [] ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing), fv_expr `plusFV` fvs1 `plusFV` fvs2) } @@ -671,7 +680,7 @@ rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside ; (bind_op, fvs1) <- lookupSyntaxName bindMName ; (fail_op, fvs2) <- lookupSyntaxName failMName ; rnPat (StmtCtxt ctxt) pat $ \ pat' -> do - { (thing, fvs3) <- thing_inside + { (thing, fvs3) <- thing_inside (collectPatBinders pat') ; return (([L loc (BindStmt pat' expr' bind_op fail_op)], thing), fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }} -- fv_expr shouldn't really be filtered by the rnPatsAndThen @@ -680,7 +689,7 @@ rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside rnStmt ctxt (L loc (LetStmt binds)) thing_inside = do { checkLetStmt ctxt binds ; rnLocalBindsAndThen binds $ \binds' -> do - { (thing, fvs) <- thing_inside + { (thing, fvs) <- thing_inside (collectLocalBinders binds') ; return (([L loc (LetStmt binds')], thing), fvs) } } rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside @@ -697,7 +706,9 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside -- context.) ; rn_rec_stmts_and_then rec_stmts $ \ segs -> do - { (thing, fvs_later) <- thing_inside + { let bndrs = nameSetToList $ foldr (unionNameSets . (\(ds,_,_,_) -> ds)) + emptyNameSet segs + ; (thing, fvs_later) <- thing_inside bndrs ; (return_op, fvs1) <- lookupSyntaxName returnMName ; (mfix_op, fvs2) <- lookupSyntaxName mfixName ; (bind_op, fvs3) <- lookupSyntaxName bindMName @@ -730,146 +741,103 @@ rnStmt ctxt (L loc (ParStmt segs)) thing_inside ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside ; return (([L loc (ParStmt segs')], thing), fvs) } -rnStmt ctxt (L loc (TransformStmt (stmts, _) usingExpr maybeByExpr)) thing_inside = do - checkTransformStmt ctxt - - (usingExpr', fv_usingExpr) <- rnLExpr usingExpr - ((stmts', binders, (maybeByExpr', thing)), fvs) <- - rnNormalStmtsAndFindUsedBinders (TransformStmtCtxt ctxt) stmts $ \_unshadowed_bndrs -> do - (maybeByExpr', fv_maybeByExpr) <- rnMaybeLExpr maybeByExpr - (thing, fv_thing) <- thing_inside - - return ((maybeByExpr', thing), fv_maybeByExpr `plusFV` fv_thing) +rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside + = do { checkTransformStmt ctxt - return (([L loc (TransformStmt (stmts', binders) usingExpr' maybeByExpr')], thing), - fv_usingExpr `plusFV` fvs) - where - rnMaybeLExpr Nothing = return (Nothing, emptyFVs) - rnMaybeLExpr (Just expr) = do - (expr', fv_expr) <- rnLExpr expr - return (Just expr', fv_expr) + ; (using', fvs1) <- rnLExpr using + + ; ((stmts', (by', used_bndrs, thing)), fvs2) + <- rnNormalStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs -> + do { (by', fvs_by) <- case by of + Nothing -> return (Nothing, emptyFVs) + Just e -> do { (e', fvs) <- rnLExpr e; return (Just e', fvs) } + ; (thing, fvs_thing) <- thing_inside bndrs + ; let fvs = fvs_by `plusFV` fvs_thing + used_bndrs = filter (`elemNameSet` fvs_thing) bndrs + ; return ((by', used_bndrs, thing), fvs) } + + ; return (([L loc (TransformStmt stmts' used_bndrs using' by')], thing), + fvs1 `plusFV` fvs2) } -rnStmt ctxt (L loc (GroupStmt (stmts, _) groupByClause)) thing_inside = do - checkTransformStmt ctxt - - -- We must rename the using expression in the context before the transform is begun - groupByClauseAction <- - case groupByClause of - GroupByNothing usingExpr -> do - (usingExpr', fv_usingExpr) <- rnLExpr usingExpr - (return . return) (GroupByNothing usingExpr', fv_usingExpr) - GroupBySomething eitherUsingExpr byExpr -> do - (eitherUsingExpr', fv_eitherUsingExpr) <- - case eitherUsingExpr of - Right _ -> return (Right $ HsVar groupWithName, unitNameSet groupWithName) - Left usingExpr -> do - (usingExpr', fv_usingExpr) <- rnLExpr usingExpr - return (Left usingExpr', fv_usingExpr) - - return $ do - (byExpr', fv_byExpr) <- rnLExpr byExpr - return (GroupBySomething eitherUsingExpr' byExpr', fv_eitherUsingExpr `plusFV` fv_byExpr) +rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside + = do { checkTransformStmt ctxt - -- We only use rnNormalStmtsAndFindUsedBinders to get unshadowed_bndrs, so - -- perhaps we could refactor this to use rnNormalStmts directly? - ((stmts', _, (groupByClause', usedBinderMap, thing)), fvs) <- - rnNormalStmtsAndFindUsedBinders (TransformStmtCtxt ctxt) stmts $ \unshadowed_bndrs -> do - (groupByClause', fv_groupByClause) <- groupByClauseAction - - unshadowed_bndrs' <- mapM newLocalName unshadowed_bndrs - let binderMap = zip unshadowed_bndrs unshadowed_bndrs' - - -- Bind the "thing" inside a context where we have REBOUND everything - -- bound by the statements before the group. This is necessary since after - -- the grouping the same identifiers actually have different meanings - -- i.e. they refer to lists not singletons! - (thing, fv_thing) <- bindLocalNames unshadowed_bndrs' thing_inside - - -- We remove entries from the binder map that are not used in the thing_inside. - -- We can then use that usage information to ensure that the free variables do - -- not contain the things we just bound, but do contain the things we need to - -- make those bindings (i.e. the corresponding non-listy variables) - - -- Note that we also retain those entries which have an old binder in our - -- own free variables (the using or by expression). This is because this map - -- is reused in the desugarer to create the type to bind from the statements - -- that occur before this one. If the binders we need are not in the map, they - -- will never get bound into our desugared expression and hence the simplifier - -- crashes as we refer to variables that don't exist! - let usedBinderMap = filter - (\(old_binder, new_binder) -> - (new_binder `elemNameSet` fv_thing) || - (old_binder `elemNameSet` fv_groupByClause)) binderMap - (usedOldBinders, usedNewBinders) = unzip usedBinderMap - real_fv_thing = (delListFromNameSet fv_thing usedNewBinders) `plusFV` (mkNameSet usedOldBinders) - - return ((groupByClause', usedBinderMap, thing), fv_groupByClause `plusFV` real_fv_thing) - - traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr usedBinderMap) - return (([L loc (GroupStmt (stmts', usedBinderMap) groupByClause')], thing), fvs) - -rnNormalStmtsAndFindUsedBinders :: HsStmtContext Name - -> [LStmt RdrName] - -> ([Name] -> RnM (thing, FreeVars)) - -> RnM (([LStmt Name], [Name], thing), FreeVars) -rnNormalStmtsAndFindUsedBinders ctxt stmts thing_inside = do - ((stmts', (used_bndrs, inner_thing)), fvs) <- rnNormalStmts ctxt stmts $ do - -- Find the Names that are bound by stmts that - -- by assumption we have just renamed - local_env <- getLocalRdrEnv - let - stmts_binders = collectLStmtsBinders stmts - bndrs = map (expectJust "rnStmt" - . lookupLocalRdrEnv local_env - . unLoc) stmts_binders - - -- If shadow, we'll look up (Unqual x) twice, getting - -- the second binding both times, which is the - -- one we want - unshadowed_bndrs = nub bndrs - - -- Typecheck the thing inside, passing on all - -- the Names bound before it for its information - (thing, fvs) <- thing_inside unshadowed_bndrs - - -- Figure out which of the bound names are used - -- after the statements we renamed - let used_bndrs = filter (`elemNameSet` fvs) bndrs - return ((used_bndrs, thing), fvs) - - -- Flatten the tuple returned by the above call a bit! - return ((stmts', used_bndrs, inner_thing), fvs) - -rnParallelStmts :: HsStmtContext Name -> [([LStmt RdrName], [RdrName])] - -> RnM (thing, FreeVars) - -> RnM (([([LStmt Name], [Name])], thing), FreeVars) -rnParallelStmts ctxt segs thing_inside = do - orig_lcl_env <- getLocalRdrEnv - go orig_lcl_env [] segs - where - go orig_lcl_env bndrs [] = do - let (bndrs', dups) = removeDups cmpByOcc bndrs - inner_env = extendLocalRdrEnvList orig_lcl_env bndrs' - - mapM_ dupErr dups - (thing, fvs) <- setLocalRdrEnv inner_env thing_inside - return (([], thing), fvs) - - go orig_lcl_env bndrs_so_far ((stmts, _) : segs) = do - ((stmts', bndrs, (segs', thing)), fvs) <- rnNormalStmtsAndFindUsedBinders ctxt stmts $ \new_bndrs -> do - -- Typecheck the thing inside, passing on all - -- the Names bound, but separately; revert the envt - setLocalRdrEnv orig_lcl_env $ do - go orig_lcl_env (new_bndrs ++ bndrs_so_far) segs - - let seg' = (stmts', bndrs) - return (((seg':segs'), thing), delListFromNameSet fvs bndrs) - - cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2 - dupErr vs = addErr (ptext (sLit "Duplicate binding in parallel list comprehension for:") + -- Rename the 'using' expression in the context before the transform is begun + ; (using', fvs1) <- case using of + Left e -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) } + Right _ -> do { (e', fvs) <- lookupSyntaxName groupWithName + ; return (Right e', fvs) } + + -- Rename the stmts and the 'by' expression + -- Keep track of the variables mentioned in the 'by' expression + ; ((stmts', (by', used_bndrs, thing)), fvs2) + <- rnNormalStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs -> + do { (by', fvs_by) <- mapMaybeFvRn rnLExpr by + ; (thing, fvs_thing) <- thing_inside bndrs + ; let fvs = fvs_by `plusFV` fvs_thing + used_bndrs = filter (`elemNameSet` fvs) bndrs + ; return ((by', used_bndrs, thing), fvs) } + + ; let all_fvs = fvs1 `plusFV` fvs2 + bndr_map = used_bndrs `zip` used_bndrs + + ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map) + ; return (([L loc (GroupStmt stmts' bndr_map by' using')], thing), all_fvs) } + + +type ParSeg id = ([LStmt id], [id]) -- The Names are bound by the Stmts + +rnParallelStmts :: forall thing. HsStmtContext Name + -> [ParSeg RdrName] + -> ([Name] -> RnM (thing, FreeVars)) + -> RnM (([ParSeg Name], thing), FreeVars) +-- Note [Renaming parallel Stmts] +rnParallelStmts ctxt segs thing_inside + = do { orig_lcl_env <- getLocalRdrEnv + ; rn_segs orig_lcl_env [] segs } + where + rn_segs :: LocalRdrEnv + -> [Name] -> [ParSeg RdrName] + -> RnM (([ParSeg Name], thing), FreeVars) + rn_segs _ bndrs_so_far [] + = do { let (bndrs', dups) = removeDups cmpByOcc bndrs_so_far + ; mapM_ dupErr dups + ; (thing, fvs) <- bindLocalNames bndrs' (thing_inside bndrs') + ; return (([], thing), fvs) } + + rn_segs env bndrs_so_far ((stmts,_) : segs) + = do { ((stmts', (used_bndrs, segs', thing)), fvs) + <- rnNormalStmts ctxt stmts $ \ bndrs -> + setLocalRdrEnv env $ do + { ((segs', thing), fvs) <- rn_segs env (bndrs ++ bndrs_so_far) segs + ; let used_bndrs = filter (`elemNameSet` fvs) bndrs + ; return ((used_bndrs, segs', thing), fvs) } + + ; let seg' = (stmts', used_bndrs) + ; return ((seg':segs', thing), fvs) } + + cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2 + dupErr vs = addErr (ptext (sLit "Duplicate binding in parallel list comprehension for:") <+> quotes (ppr (head vs))) \end{code} +Note [Renaming parallel Stmts] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Renaming parallel statements is painful. Given, say + [ a+c | a <- as, bs <- bss + | c <- bs, a <- ds ] +Note that + (a) In order to report "Defined by not used" about 'bs', we must rename + each group of Stmts with a thing_inside whose FreeVars include at least {a,c} + + (b) We want to report that 'a' is illegally bound in both branches + + (c) The 'bs' in the second group must obviously not be captured by + the binding in the first group + +To satisfy (a) we nest the segements. +To satisfy (b) we check for duplicates just before thing_inside. +To satisfy (c) we reset the LocalRdrEnv each time. %************************************************************************ %* * @@ -916,7 +884,7 @@ rn_rec_stmts_and_then s cont ; new_lhs_and_fv <- rn_rec_stmts_lhs fix_env s -- ...bring them and their fixities into scope - ; let bound_names = map unLoc $ collectLStmtsBinders (map fst new_lhs_and_fv) + ; let bound_names = collectLStmtsBinders (map fst new_lhs_and_fv) ; bindLocalNamesFV_WithFixities bound_names fix_env $ do -- (C) do the right-hand-sides and thing-inside @@ -972,10 +940,10 @@ rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec in rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _)) -- Syntactically illegal in mdo = pprPanic "rn_rec_stmt" (ppr stmt) -rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt _ _ _)) -- Syntactically illegal in mdo +rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt {})) -- Syntactically illegal in mdo = pprPanic "rn_rec_stmt" (ppr stmt) -rn_rec_stmt_lhs _ stmt@(L _ (GroupStmt _ _)) -- Syntactically illegal in mdo +rn_rec_stmt_lhs _ stmt@(L _ (GroupStmt {})) -- Syntactically illegal in mdo = pprPanic "rn_rec_stmt" (ppr stmt) rn_rec_stmt_lhs _ (L _ (LetStmt EmptyLocalBinds)) @@ -985,13 +953,13 @@ rn_rec_stmts_lhs :: MiniFixityEnv -> [LStmt RdrName] -> RnM [(LStmtLR Name RdrName, FreeVars)] rn_rec_stmts_lhs fix_env stmts - = do { let boundNames = collectLStmtsBinders stmts + = do { ls <- concatMapM (rn_rec_stmt_lhs fix_env) stmts + ; let boundNames = collectLStmtsBinders (map fst ls) -- First do error checking: we need to check for dups here because we -- don't bind all of the variables from the Stmt at once -- with bindLocatedLocals. - ; checkDupRdrNames boundNames - ; ls <- mapM (rn_rec_stmt_lhs fix_env) stmts - ; return (concat ls) } + ; checkDupNames boundNames + ; return ls } -- right-hand-sides diff --git a/compiler/rename/RnPat.lhs b/compiler/rename/RnPat.lhs index bc17495..813f39b 100644 --- a/compiler/rename/RnPat.lhs +++ b/compiler/rename/RnPat.lhs @@ -233,7 +233,8 @@ rnPats ctxt pats thing_inside rnPat :: HsMatchContext Name -- for error messages -> LPat RdrName -> (LPat Name -> RnM (a, FreeVars)) - -> RnM (a, FreeVars) + -> RnM (a, FreeVars) -- Variables bound by pattern do not + -- appear in the result FreeVars rnPat ctxt pat thing_inside = rnPats ctxt [pat] (\[pat'] -> thing_inside pat') diff --git a/compiler/rename/RnSource.lhs b/compiler/rename/RnSource.lhs index c01afec..f2683e8 100644 --- a/compiler/rename/RnSource.lhs +++ b/compiler/rename/RnSource.lhs @@ -125,7 +125,7 @@ rnSrcDecls group@(HsGroup {hs_valds = val_decls, -- It uses the fixity env from (A) to bind fixities for view patterns. new_lhs <- rnTopBindsLHS local_fix_env val_decls ; -- bind the LHSes (and their fixities) in the global rdr environment - let { val_binders = map unLoc $ collectHsValBinders new_lhs ; + let { val_binders = collectHsValBinders new_lhs ; val_bndr_set = mkNameSet val_binders ; all_bndr_set = val_bndr_set `unionNameSets` availsToNameSet tc_avails ; val_avails = map Avail val_binders @@ -440,7 +440,7 @@ rnSrcInstDecl (InstDecl inst_ty mbinds uprags ats) -- The typechecker (not the renamer) checks that all -- the bindings are for the right class let - meth_names = collectHsBindLocatedBinders mbinds + meth_names = collectMethodBinders mbinds (inst_tyvars, _, cls,_) = splitHsInstDeclTy (unLoc inst_ty') in checkDupRdrNames meth_names `thenM_` @@ -478,7 +478,7 @@ rnSrcInstDecl (InstDecl inst_ty mbinds uprags ats) -- -- But the (unqualified) method names are in scope let - binders = collectHsBindBinders mbinds' + binders = collectHsBindsBinders mbinds' bndr_set = mkNameSet binders in bindLocalNames binders diff --git a/compiler/typecheck/TcBinds.lhs b/compiler/typecheck/TcBinds.lhs index 2871f3b..2e675ac 100644 --- a/compiler/typecheck/TcBinds.lhs +++ b/compiler/typecheck/TcBinds.lhs @@ -310,7 +310,7 @@ tcPolyBinds :: TopLevelFlag -> TcSigFun -> TcPragFun tcPolyBinds top_lvl sig_fn prag_fn rec_group rec_tc binds = let bind_list = bagToList binds - binder_names = collectHsBindBinders binds + binder_names = collectHsBindsBinders binds loc = getLoc (head bind_list) -- TODO: location a bit awkward, but the mbinds have been -- dependency analysed and may no longer be adjacent diff --git a/compiler/typecheck/TcDeriv.lhs b/compiler/typecheck/TcDeriv.lhs index b60a9be..3bf030d 100644 --- a/compiler/typecheck/TcDeriv.lhs +++ b/compiler/typecheck/TcDeriv.lhs @@ -337,7 +337,7 @@ renameDeriv is_boot gen_binds insts ; let aux_binds = listToBag $ map (genAuxBind loc) $ rm_dups [] $ concat deriv_aux_binds ; rn_aux_lhs <- rnTopBindsLHS emptyFsEnv (ValBindsIn aux_binds []) - ; let aux_names = map unLoc (collectHsValBinders rn_aux_lhs) + ; let aux_names = collectHsValBinders rn_aux_lhs ; bindLocalNames aux_names $ do { (rn_aux, dus_aux) <- rnTopBindsRHS (mkNameSet aux_names) rn_aux_lhs diff --git a/compiler/typecheck/TcHsSyn.lhs b/compiler/typecheck/TcHsSyn.lhs index e46ab45..1708349 100644 --- a/compiler/typecheck/TcHsSyn.lhs +++ b/compiler/typecheck/TcHsSyn.lhs @@ -318,7 +318,7 @@ zonkValBinds env (ValBindsOut binds sigs) zonkRecMonoBinds :: ZonkEnv -> LHsBinds TcId -> TcM (ZonkEnv, LHsBinds Id) zonkRecMonoBinds env binds = fixM (\ ~(_, new_binds) -> do - { let env1 = extendZonkEnv env (collectHsBindBinders new_binds) + { let env1 = extendZonkEnv env (collectHsBindsBinders new_binds) ; binds' <- zonkMonoBinds env1 binds ; return (env1, binds') }) @@ -351,7 +351,7 @@ zonk_bind env (AbsBinds { abs_tvs = tyvars, abs_dicts = dicts, fixM (\ ~(new_val_binds, _) -> let env1 = extendZonkEnv env new_dicts - env2 = extendZonkEnv env1 (collectHsBindBinders new_val_binds) + env2 = extendZonkEnv env1 (collectHsBindsBinders new_val_binds) in zonkMonoBinds env2 val_binds `thenM` \ new_val_binds -> mappM (zonkExport env2) exports `thenM` \ new_exports -> @@ -710,32 +710,21 @@ zonkStmt env (ExprStmt expr then_op ty) zonkTcTypeToType env ty `thenM` \ new_ty -> returnM (env, ExprStmt new_expr new_then new_ty) -zonkStmt env (TransformStmt (stmts, binders) usingExpr maybeByExpr) +zonkStmt env (TransformStmt stmts binders usingExpr maybeByExpr) = do { (env', stmts') <- zonkStmts env stmts ; let binders' = zonkIdOccs env' binders ; usingExpr' <- zonkLExpr env' usingExpr ; maybeByExpr' <- zonkMaybeLExpr env' maybeByExpr - ; return (env', TransformStmt (stmts', binders') usingExpr' maybeByExpr') } + ; return (env', TransformStmt stmts' binders' usingExpr' maybeByExpr') } -zonkStmt env (GroupStmt (stmts, binderMap) groupByClause) +zonkStmt env (GroupStmt stmts binderMap by using) = do { (env', stmts') <- zonkStmts env stmts ; binderMap' <- mappM (zonkBinderMapEntry env') binderMap - ; groupByClause' <- - case groupByClause of - GroupByNothing usingExpr -> (zonkLExpr env' usingExpr) >>= (return . GroupByNothing) - GroupBySomething eitherUsingExpr byExpr -> do - eitherUsingExpr' <- mapEitherM (zonkLExpr env') (zonkExpr env') eitherUsingExpr - byExpr' <- zonkLExpr env' byExpr - return $ GroupBySomething eitherUsingExpr' byExpr' - + ; by' <- fmapMaybeM (zonkLExpr env') by + ; using' <- fmapEitherM (zonkLExpr env) (zonkExpr env) using ; let env'' = extendZonkEnv env' (map snd binderMap') - ; return (env'', GroupStmt (stmts', binderMap') groupByClause') } + ; return (env'', GroupStmt stmts' binderMap' by' using') } where - mapEitherM f g x = do - case x of - Left a -> f a >>= (return . Left) - Right b -> g b >>= (return . Right) - zonkBinderMapEntry env (oldBinder, newBinder) = do let oldBinder' = zonkIdOcc env oldBinder newBinder' <- zonkIdBndr env newBinder diff --git a/compiler/typecheck/TcMatches.lhs b/compiler/typecheck/TcMatches.lhs index 6d917d1..cbe5940 100644 --- a/compiler/typecheck/TcMatches.lhs +++ b/compiler/typecheck/TcMatches.lhs @@ -392,7 +392,7 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside ; return (ids, pairs', thing) } ; return ( (stmts', ids) : pairs', thing ) } -tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty thing_inside = do +tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr) elt_ty thing_inside = do (stmts', (binders', usingExpr', maybeByExpr', thing)) <- tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do let alphaListTy = mkTyConApp m_tc [alphaTy] @@ -414,46 +414,47 @@ tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty return (binders', usingExpr', maybeByExpr', thing) - return (TransformStmt (stmts', binders') usingExpr' maybeByExpr', thing) + return (TransformStmt stmts' binders' usingExpr' maybeByExpr', thing) -tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_inside = do - (stmts', (bindersMap', groupByClause', thing)) <- +tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using) elt_ty thing_inside + = do { let (bndr_names, list_bndr_names) = unzip bindersMap + + ; (stmts', (bndr_ids, by', using_ty, elt_ty')) <- tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do - let alphaListTy = mkTyConApp m_tc [alphaTy] - alphaListListTy = mkTyConApp m_tc [alphaListTy] - - groupByClause' <- - case groupByClause of - GroupByNothing usingExpr -> - -- We must validate that usingExpr :: forall a. [a] -> [[a]] - tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing) - GroupBySomething eitherUsingExpr byExpr -> do - -- We must infer a type such that byExpr :: t - (byExpr', tTy) <- tcInferRhoNC byExpr - - -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]] - let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy)) - eitherUsingExpr' <- - case eitherUsingExpr of - Left usingExpr -> (tcPolyExpr usingExpr expectedUsingType) >>= (return . Left) - Right usingExpr -> (tcPolyExpr (noLoc usingExpr) expectedUsingType) >>= (return . Right . unLoc) - return $ GroupBySomething eitherUsingExpr' byExpr' - - -- Find the IDs and types of all old binders - let (oldBinders, newBinders) = unzip bindersMap - oldBinders' <- tcLookupLocalIds oldBinders + (by', using_ty) <- case by of + Nothing -> -- check that using :: forall a. [a] -> [[a]] + return (Nothing, mkForAllTy alphaTyVar $ + alphaListTy `mkFunTy` alphaListListTy) + + Just by_e -> -- check that using :: forall a. (a -> t) -> [a] -> [[a]] + -- where by :: t + do { (by_e', t_ty) <- tcInferRhoNC by_e + ; return (Just by_e', mkForAllTy alphaTyVar $ + (alphaTy `mkFunTy` t_ty) + `mkFunTy` alphaListTy + `mkFunTy` alphaListListTy) } + -- Find the Ids (and hence types) of all old binders + bndr_ids <- tcLookupLocalIds bndr_names + return (bndr_ids, by', using_ty, elt_ty') + -- Ensure that every old binder of type b is linked up with its new binder which should have type [b] - let newBinders' = zipWith associateNewBinder oldBinders' newBinders + ; let list_bndr_ids = zipWith mk_list_bndr list_bndr_names bndr_ids + bindersMap' = bndr_ids `zip` list_bndr_ids - -- Type check the thing in the environment with these new binders and return the result - thing <- tcExtendIdEnv newBinders' (thing_inside elt_ty') - return (zipEqual "tcLcStmt: Old and new binder lists were not of the same length" oldBinders' newBinders', groupByClause', thing) - - return (GroupStmt (stmts', bindersMap') groupByClause', thing) - where - associateNewBinder :: TcId -> Name -> TcId - associateNewBinder oldBinder newBinder = mkLocalId newBinder (mkTyConApp m_tc [idType oldBinder]) + ; using' <- case using of + Left e -> do { e' <- tcPolyExpr e using_ty; return (Left e') } + Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) } + + -- Type check the thing in the environment with these new binders and return the result + ; thing <- tcExtendIdEnv list_bndr_ids (thing_inside elt_ty') + ; return (GroupStmt stmts' bindersMap' by' using', thing) } + where + alphaListTy = mkTyConApp m_tc [alphaTy] + alphaListListTy = mkTyConApp m_tc [alphaListTy] + + mk_list_bndr :: Name -> TcId -> TcId + mk_list_bndr list_bndr_name bndr_id = mkLocalId list_bndr_name (mkTyConApp m_tc [idType bndr_id]) tcLcStmt _ _ stmt _ _ = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt) diff --git a/compiler/typecheck/TcRnDriver.lhs b/compiler/typecheck/TcRnDriver.lhs index 42e98b2..acaf05c 100644 --- a/compiler/typecheck/TcRnDriver.lhs +++ b/compiler/typecheck/TcRnDriver.lhs @@ -1198,7 +1198,7 @@ mkPlan (L loc (ExprStmt expr _ _)) -- An expression typed at the prompt ]} mkPlan stmt@(L loc (BindStmt {})) - | [L _ v] <- collectLStmtBinders stmt -- One binder, for a bind stmt + | [v] <- collectLStmtBinders stmt -- One binder, for a bind stmt = do { let print_v = L loc $ ExprStmt (nlHsApp (nlHsVar printName) (nlHsVar v)) (HsVar thenIOName) placeHolderType @@ -1229,7 +1229,7 @@ tcGhciStmts stmts io_ret_ty = mkTyConApp ioTyCon [ret_ty] ; tc_io_stmts stmts = tcStmts GhciStmt tcDoStmt stmts io_ret_ty ; - names = map unLoc (collectLStmtsBinders stmts) ; + names = collectLStmtsBinders stmts ; -- mk_return builds the expression -- returnIO @ [()] [coerce () x, .., coerce () z]