Refactor part of the renamer to fix Trac #3901
[ghc-hetmet.git] / compiler / rename / RnExpr.lhs
index 6dc6801..78088d5 100644 (file)
@@ -42,7 +42,6 @@ import UniqSet
 import Data.List
 import Util            ( isSingleton )
 import ListSetOps      ( removeDups )
-import Maybes          ( expectJust )
 import Outputable
 import SrcLoc
 import FastString
@@ -538,8 +537,8 @@ methodNamesStmt (BindStmt _ cmd _ _)             = methodNamesLCmd cmd
 methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
 methodNamesStmt (LetStmt _)                      = emptyFVs
 methodNamesStmt (ParStmt _)                      = emptyFVs
-methodNamesStmt (TransformStmt _ _ _)            = emptyFVs
-methodNamesStmt (GroupStmt _ _)                  = emptyFVs
+methodNamesStmt (TransformStmt {})               = emptyFVs
+methodNamesStmt (GroupStmt {})                   = emptyFVs
    -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error 
    -- here so we just do what's convenient
 \end{code}
@@ -635,33 +634,43 @@ rnBracket (DecBrG _) = panic "rnBracket: unexpected DecBrG"
 rnStmts :: HsStmtContext Name -> [LStmt RdrName] 
        -> RnM (thing, FreeVars)
        -> RnM (([LStmt Name], thing), FreeVars)
+-- Variables bound by the Stmts, and mentioned in thing_inside,
+-- do not appear in the result FreeVars
 
-rnStmts (MDoExpr _) = rnMDoStmts
-rnStmts ctxt        = rnNormalStmts ctxt
+rnStmts (MDoExpr _) stmts thing_inside = rnMDoStmts    stmts thing_inside
+rnStmts ctxt        stmts thing_inside = rnNormalStmts ctxt stmts (\ _ -> thing_inside)
 
 rnNormalStmts :: HsStmtContext Name -> [LStmt RdrName]
-             -> RnM (thing, FreeVars)
+             -> ([Name] -> RnM (thing, FreeVars))
              -> RnM (([LStmt Name], thing), FreeVars)  
+-- Variables bound by the Stmts, and mentioned in thing_inside,
+-- do not appear in the result FreeVars
+--
+-- Renaming a single RecStmt can give a sequence of smaller Stmts
+
 rnNormalStmts _ [] thing_inside 
-  = do { (thing, fvs) <- thing_inside
-       ; return (([],thing), fvs) } 
+  = do { (res, fvs) <- thing_inside []
+       ; return (([], res), fvs) }
 
 rnNormalStmts ctxt (stmt@(L loc _) : stmts) thing_inside
   = do { ((stmts1, (stmts2, thing)), fvs) 
-            <- setSrcSpan loc $
-               rnStmt ctxt stmt $
-               rnNormalStmts ctxt stmts thing_inside
+            <- setSrcSpan loc           $
+               rnStmt ctxt stmt         $ \ bndrs1 ->
+               rnNormalStmts ctxt stmts $ \ bndrs2 ->
+               thing_inside (bndrs1 ++ bndrs2)
        ; return (((stmts1 ++ stmts2), thing), fvs) }
 
 
 rnStmt :: HsStmtContext Name -> LStmt RdrName
-       -> RnM (thing, FreeVars)
+       -> ([Name] -> RnM (thing, FreeVars))
        -> RnM (([LStmt Name], thing), FreeVars)
+-- Variables bound by the Stmt, and mentioned in thing_inside,
+-- do not appear in the result FreeVars
 
 rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
        ; (then_op, fvs1)  <- lookupSyntaxName thenMName
-       ; (thing, fvs2)    <- thing_inside
+       ; (thing, fvs2)    <- thing_inside []
        ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing),
                  fv_expr `plusFV` fvs1 `plusFV` fvs2) }
 
@@ -671,7 +680,7 @@ rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
        ; (bind_op, fvs1) <- lookupSyntaxName bindMName
        ; (fail_op, fvs2) <- lookupSyntaxName failMName
        ; rnPat (StmtCtxt ctxt) pat $ \ pat' -> do
-       { (thing, fvs3) <- thing_inside
+       { (thing, fvs3) <- thing_inside (collectPatBinders pat')
        ; return (([L loc (BindStmt pat' expr' bind_op fail_op)], thing),
                  fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }}
        -- fv_expr shouldn't really be filtered by the rnPatsAndThen
@@ -680,7 +689,7 @@ rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
 rnStmt ctxt (L loc (LetStmt binds)) thing_inside 
   = do { checkLetStmt ctxt binds
        ; rnLocalBindsAndThen binds $ \binds' -> do
-       { (thing, fvs) <- thing_inside
+       { (thing, fvs) <- thing_inside (collectLocalBinders binds')
         ; return (([L loc (LetStmt binds')], thing), fvs) }  }
 
 rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
@@ -697,7 +706,9 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
        -- context.)
         ; rn_rec_stmts_and_then rec_stmts      $ \ segs -> do
 
-       { (thing, fvs_later) <- thing_inside
+       { let bndrs = nameSetToList $ foldr (unionNameSets . (\(ds,_,_,_) -> ds)) 
+                                            emptyNameSet segs
+        ; (thing, fvs_later) <- thing_inside bndrs
        ; (return_op, fvs1)  <- lookupSyntaxName returnMName
        ; (mfix_op,   fvs2)  <- lookupSyntaxName mfixName
        ; (bind_op,   fvs3)  <- lookupSyntaxName bindMName
@@ -730,146 +741,103 @@ rnStmt ctxt (L loc (ParStmt segs)) thing_inside
        ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
        ; return (([L loc (ParStmt segs')], thing), fvs) }
 
-rnStmt ctxt (L loc (TransformStmt (stmts, _) usingExpr maybeByExpr)) thing_inside = do
-    checkTransformStmt ctxt
-    
-    (usingExpr', fv_usingExpr) <- rnLExpr usingExpr
-    ((stmts', binders, (maybeByExpr', thing)), fvs) <- 
-        rnNormalStmtsAndFindUsedBinders (TransformStmtCtxt ctxt) stmts $ \_unshadowed_bndrs -> do
-            (maybeByExpr', fv_maybeByExpr)  <- rnMaybeLExpr maybeByExpr
-            (thing, fv_thing)               <- thing_inside
-            
-            return ((maybeByExpr', thing), fv_maybeByExpr `plusFV` fv_thing)
+rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside
+  = do { checkTransformStmt ctxt
     
-    return (([L loc (TransformStmt (stmts', binders) usingExpr' maybeByExpr')], thing), 
-             fv_usingExpr `plusFV` fvs)
-  where
-    rnMaybeLExpr Nothing = return (Nothing, emptyFVs)
-    rnMaybeLExpr (Just expr) = do
-        (expr', fv_expr) <- rnLExpr expr
-        return (Just expr', fv_expr)
+       ; (using', fvs1) <- rnLExpr using
+
+       ; ((stmts', (by', used_bndrs, thing)), fvs2)
+             <- rnNormalStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
+                do { (by', fvs_by) <- case by of
+                                        Nothing -> return (Nothing, emptyFVs)
+                                        Just e  -> do { (e', fvs) <- rnLExpr e; return (Just e', fvs) }
+                   ; (thing, fvs_thing) <- thing_inside bndrs
+                   ; let fvs        = fvs_by `plusFV` fvs_thing
+                         used_bndrs = filter (`elemNameSet` fvs_thing) bndrs
+                   ; return ((by', used_bndrs, thing), fvs) }
+
+       ; return (([L loc (TransformStmt stmts' used_bndrs using' by')], thing), 
+                 fvs1 `plusFV` fvs2) }
         
-rnStmt ctxt (L loc (GroupStmt (stmts, _) groupByClause)) thing_inside = do
-    checkTransformStmt ctxt
-    
-    -- We must rename the using expression in the context before the transform is begun
-    groupByClauseAction <- 
-        case groupByClause of
-            GroupByNothing usingExpr -> do
-                (usingExpr', fv_usingExpr) <- rnLExpr usingExpr
-                (return . return) (GroupByNothing usingExpr', fv_usingExpr)
-            GroupBySomething eitherUsingExpr byExpr -> do
-                (eitherUsingExpr', fv_eitherUsingExpr) <- 
-                    case eitherUsingExpr of
-                        Right _ -> return (Right $ HsVar groupWithName, unitNameSet groupWithName)
-                        Left usingExpr -> do
-                            (usingExpr', fv_usingExpr) <- rnLExpr usingExpr
-                            return (Left usingExpr', fv_usingExpr)
-                            
-                return $ do
-                    (byExpr', fv_byExpr) <- rnLExpr byExpr
-                    return (GroupBySomething eitherUsingExpr' byExpr', fv_eitherUsingExpr `plusFV` fv_byExpr)
+rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside
+  = do { checkTransformStmt ctxt
     
-    -- We only use rnNormalStmtsAndFindUsedBinders to get unshadowed_bndrs, so
-    -- perhaps we could refactor this to use rnNormalStmts directly?
-    ((stmts', _, (groupByClause', usedBinderMap, thing)), fvs) <- 
-        rnNormalStmtsAndFindUsedBinders (TransformStmtCtxt ctxt) stmts $ \unshadowed_bndrs -> do
-            (groupByClause', fv_groupByClause) <- groupByClauseAction
-            
-            unshadowed_bndrs' <- mapM newLocalName unshadowed_bndrs
-            let binderMap = zip unshadowed_bndrs unshadowed_bndrs'
-            
-            -- Bind the "thing" inside a context where we have REBOUND everything
-            -- bound by the statements before the group. This is necessary since after
-            -- the grouping the same identifiers actually have different meanings
-            -- i.e. they refer to lists not singletons!
-            (thing, fv_thing) <- bindLocalNames unshadowed_bndrs' thing_inside
-            
-            -- We remove entries from the binder map that are not used in the thing_inside.
-            -- We can then use that usage information to ensure that the free variables do 
-            -- not contain the things we just bound, but do contain the things we need to
-            -- make those bindings (i.e. the corresponding non-listy variables)
-            
-            -- Note that we also retain those entries which have an old binder in our
-            -- own free variables (the using or by expression). This is because this map
-            -- is reused in the desugarer to create the type to bind from the statements
-            -- that occur before this one. If the binders we need are not in the map, they
-            -- will never get bound into our desugared expression and hence the simplifier
-            -- crashes as we refer to variables that don't exist!
-            let usedBinderMap = filter 
-                    (\(old_binder, new_binder) -> 
-                        (new_binder `elemNameSet` fv_thing) || 
-                        (old_binder `elemNameSet` fv_groupByClause)) binderMap
-                (usedOldBinders, usedNewBinders) = unzip usedBinderMap
-                real_fv_thing = (delListFromNameSet fv_thing usedNewBinders) `plusFV` (mkNameSet usedOldBinders)
-            
-            return ((groupByClause', usedBinderMap, thing), fv_groupByClause `plusFV` real_fv_thing)
-    
-    traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr usedBinderMap)
-    return (([L loc (GroupStmt (stmts', usedBinderMap) groupByClause')], thing), fvs)
-  
-rnNormalStmtsAndFindUsedBinders :: HsStmtContext Name 
-          -> [LStmt RdrName]
-          -> ([Name] -> RnM (thing, FreeVars))
-          -> RnM (([LStmt Name], [Name], thing), FreeVars)     
-rnNormalStmtsAndFindUsedBinders ctxt stmts thing_inside = do
-    ((stmts', (used_bndrs, inner_thing)), fvs) <- rnNormalStmts ctxt stmts $ do
-        -- Find the Names that are bound by stmts that
-        -- by assumption we have just renamed
-        local_env <- getLocalRdrEnv
-        let 
-            stmts_binders = collectLStmtsBinders stmts
-            bndrs = map (expectJust "rnStmt"
-                        . lookupLocalRdrEnv local_env
-                        . unLoc) stmts_binders
-                        
-            -- If shadow, we'll look up (Unqual x) twice, getting
-            -- the second binding both times, which is the
-            -- one we want
-            unshadowed_bndrs = nub bndrs
-                        
-        -- Typecheck the thing inside, passing on all 
-        -- the Names bound before it for its information
-        (thing, fvs) <- thing_inside unshadowed_bndrs
-
-        -- Figure out which of the bound names are used
-        -- after the statements we renamed
-        let used_bndrs = filter (`elemNameSet` fvs) bndrs
-        return ((used_bndrs, thing), fvs)
-
-    -- Flatten the tuple returned by the above call a bit!
-    return ((stmts', used_bndrs, inner_thing), fvs)
-
-rnParallelStmts :: HsStmtContext Name -> [([LStmt RdrName], [RdrName])]
-                -> RnM (thing, FreeVars)
-                -> RnM (([([LStmt Name], [Name])], thing), FreeVars)
-rnParallelStmts ctxt segs thing_inside = do
-        orig_lcl_env <- getLocalRdrEnv
-        go orig_lcl_env [] segs
-    where
-        go orig_lcl_env bndrs [] = do 
-            let (bndrs', dups) = removeDups cmpByOcc bndrs
-                inner_env = extendLocalRdrEnvList orig_lcl_env bndrs'
-            
-            mapM_ dupErr dups
-            (thing, fvs) <- setLocalRdrEnv inner_env thing_inside
-            return (([], thing), fvs)
-
-        go orig_lcl_env bndrs_so_far ((stmts, _) : segs) = do 
-            ((stmts', bndrs, (segs', thing)), fvs) <- rnNormalStmtsAndFindUsedBinders ctxt stmts $ \new_bndrs -> do
-                -- Typecheck the thing inside, passing on all
-                -- the Names bound, but separately; revert the envt
-                setLocalRdrEnv orig_lcl_env $ do
-                    go orig_lcl_env (new_bndrs ++ bndrs_so_far) segs
-
-            let seg' = (stmts', bndrs)
-            return (((seg':segs'), thing), delListFromNameSet fvs bndrs)
-
-        cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
-        dupErr vs = addErr (ptext (sLit "Duplicate binding in parallel list comprehension for:")
+         -- Rename the 'using' expression in the context before the transform is begun
+       ; (using', fvs1) <- case using of
+                             Left e  -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) }
+                            Right _ -> do { (e', fvs) <- lookupSyntaxName groupWithName
+                                           ; return (Right e', fvs) }
+
+         -- Rename the stmts and the 'by' expression
+        -- Keep track of the variables mentioned in the 'by' expression
+       ; ((stmts', (by', used_bndrs, thing)), fvs2) 
+             <- rnNormalStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
+                do { (by',   fvs_by) <- mapMaybeFvRn rnLExpr by
+                   ; (thing, fvs_thing) <- thing_inside bndrs
+                   ; let fvs = fvs_by `plusFV` fvs_thing
+                         used_bndrs = filter (`elemNameSet` fvs) bndrs
+                   ; return ((by', used_bndrs, thing), fvs) }
+
+       ; let all_fvs  = fvs1 `plusFV` fvs2 
+             bndr_map = used_bndrs `zip` used_bndrs
+
+       ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
+       ; return (([L loc (GroupStmt stmts' bndr_map by' using')], thing), all_fvs) }
+
+
+type ParSeg id = ([LStmt id], [id])       -- The Names are bound by the Stmts
+
+rnParallelStmts :: forall thing. HsStmtContext Name 
+                -> [ParSeg RdrName]
+                -> ([Name] -> RnM (thing, FreeVars))
+                -> RnM (([ParSeg Name], thing), FreeVars)
+-- Note [Renaming parallel Stmts]
+rnParallelStmts ctxt segs thing_inside
+  = do { orig_lcl_env <- getLocalRdrEnv
+       ; rn_segs orig_lcl_env [] segs }
+  where
+    rn_segs :: LocalRdrEnv
+            -> [Name] -> [ParSeg RdrName]
+            -> RnM (([ParSeg Name], thing), FreeVars)
+    rn_segs _ bndrs_so_far [] 
+      = do { let (bndrs', dups) = removeDups cmpByOcc bndrs_so_far
+           ; mapM_ dupErr dups
+           ; (thing, fvs) <- bindLocalNames bndrs' (thing_inside bndrs')
+           ; return (([], thing), fvs) }
+
+    rn_segs env bndrs_so_far ((stmts,_) : segs) 
+      = do { ((stmts', (used_bndrs, segs', thing)), fvs)
+                    <- rnNormalStmts ctxt stmts $ \ bndrs ->
+                       setLocalRdrEnv env       $ do
+                       { ((segs', thing), fvs) <- rn_segs env (bndrs ++ bndrs_so_far) segs
+                      ; let used_bndrs = filter (`elemNameSet` fvs) bndrs
+                       ; return ((used_bndrs, segs', thing), fvs) }
+                      
+           ; let seg' = (stmts', used_bndrs)
+           ; return ((seg':segs', thing), fvs) }
+
+    cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
+    dupErr vs = addErr (ptext (sLit "Duplicate binding in parallel list comprehension for:")
                     <+> quotes (ppr (head vs)))
 \end{code}
 
+Note [Renaming parallel Stmts]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Renaming parallel statements is painful.  Given, say  
+     [ a+c | a <- as, bs <- bss
+           | c <- bs, a <- ds ]
+Note that
+  (a) In order to report "Defined by not used" about 'bs', we must rename
+      each group of Stmts with a thing_inside whose FreeVars include at least {a,c}
+   
+  (b) We want to report that 'a' is illegally bound in both branches
+
+  (c) The 'bs' in the second group must obviously not be captured by 
+      the binding in the first group
+
+To satisfy (a) we nest the segements. 
+To satisfy (b) we check for duplicates just before thing_inside.
+To satisfy (c) we reset the LocalRdrEnv each time.
 
 %************************************************************************
 %*                                                                     *
@@ -916,7 +884,7 @@ rn_rec_stmts_and_then s cont
        ; new_lhs_and_fv <- rn_rec_stmts_lhs fix_env s
 
          --    ...bring them and their fixities into scope
-       ; let bound_names = map unLoc $ collectLStmtsBinders (map fst new_lhs_and_fv)
+       ; let bound_names = collectLStmtsBinders (map fst new_lhs_and_fv)
        ; bindLocalNamesFV_WithFixities bound_names fix_env $ do
 
          -- (C) do the right-hand-sides and thing-inside
@@ -972,10 +940,10 @@ rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts }))    -- Flatten Rec in
 rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _))       -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
   
-rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt _ _ _))     -- Syntactically illegal in mdo
+rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt {}))        -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
   
-rn_rec_stmt_lhs _ stmt@(L _ (GroupStmt _ _))   -- Syntactically illegal in mdo
+rn_rec_stmt_lhs _ stmt@(L _ (GroupStmt {}))    -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
 
 rn_rec_stmt_lhs _ (L _ (LetStmt EmptyLocalBinds))
@@ -985,13 +953,13 @@ rn_rec_stmts_lhs :: MiniFixityEnv
                  -> [LStmt RdrName] 
                  -> RnM [(LStmtLR Name RdrName, FreeVars)]
 rn_rec_stmts_lhs fix_env stmts
-  = do { let boundNames = collectLStmtsBinders stmts
+  = do { ls <- concatMapM (rn_rec_stmt_lhs fix_env) stmts
+       ; let boundNames = collectLStmtsBinders (map fst ls)
             -- First do error checking: we need to check for dups here because we
             -- don't bind all of the variables from the Stmt at once
             -- with bindLocatedLocals.
-       ; checkDupRdrNames boundNames
-       ; ls <- mapM (rn_rec_stmt_lhs fix_env) stmts
-       ; return (concat ls) }
+       ; checkDupNames boundNames
+       ; return ls }
 
 
 -- right-hand-sides