Fix dependency analysis (notably bindInstsOfLocalFuns) in TcBinds
authorsimonpj@microsoft.com <unknown>
Tue, 5 Sep 2006 10:51:43 +0000 (10:51 +0000)
committersimonpj@microsoft.com <unknown>
Tue, 5 Sep 2006 10:51:43 +0000 (10:51 +0000)
GHC 6.5 does enhanced dependency analysis for recursive bindings, to
maximise polymorphism based on type signatures.  (See Mark Jones's
THIH paper.)

I didn't do the bindInstsOfLocalFuns part correctly though, and jhc
showed up the bug.  (It only matters when you have a recursive group
of two or more functions with a type signature, not at top level, which
is why it hasn't shown up till now.)

Test is tc207.hs

compiler/typecheck/TcBinds.lhs

index 076de00..33c8ddb 100644 (file)
@@ -182,31 +182,32 @@ tcValBinds top_lvl (ValBindsOut binds sigs) thing_inside
 
                -- Extend the envt right away with all 
                -- the Ids declared with type signatures
+       ; gla_exts     <- doptM Opt_GlasgowExts
        ; (binds', thing) <- tcExtendIdEnv poly_ids $
-                            tc_val_binds top_lvl sig_fn prag_fn 
+                            tc_val_binds gla_exts top_lvl sig_fn prag_fn 
                                          binds thing_inside
 
        ; return (ValBindsOut binds' sigs, thing) }
 
 ------------------------
-tc_val_binds :: TopLevelFlag -> TcSigFun -> TcPragFun
+tc_val_binds :: Bool -> TopLevelFlag -> TcSigFun -> TcPragFun
             -> [(RecFlag, LHsBinds Name)] -> TcM thing
             -> TcM ([(RecFlag, LHsBinds TcId)], thing)
 -- Typecheck a whole lot of value bindings,
 -- one strongly-connected component at a time
 
-tc_val_binds top_lvl sig_fn prag_fn [] thing_inside
+tc_val_binds gla_exts top_lvl sig_fn prag_fn [] thing_inside
   = do { thing <- thing_inside
        ; return ([], thing) }
 
-tc_val_binds top_lvl sig_fn prag_fn (group : groups) thing_inside
+tc_val_binds gla_exts top_lvl sig_fn prag_fn (group : groups) thing_inside
   = do { (group', (groups', thing))
-               <- tc_group top_lvl sig_fn prag_fn group $ 
-                  tc_val_binds top_lvl sig_fn prag_fn groups thing_inside
+               <- tc_group gla_exts top_lvl sig_fn prag_fn group $ 
+                  tc_val_binds gla_exts top_lvl sig_fn prag_fn groups thing_inside
        ; return (group' ++ groups', thing) }
 
 ------------------------
-tc_group :: TopLevelFlag -> TcSigFun -> TcPragFun
+tc_group :: Bool -> TopLevelFlag -> TcSigFun -> TcPragFun
         -> (RecFlag, LHsBinds Name) -> TcM thing
         -> TcM ([(RecFlag, LHsBinds TcId)], thing)
 
@@ -214,41 +215,60 @@ tc_group :: TopLevelFlag -> TcSigFun -> TcPragFun
 -- We get a list of groups back, because there may 
 -- be specialisations etc as well
 
-tc_group top_lvl sig_fn prag_fn (NonRecursive, binds) thing_inside
-  =    -- A single non-recursive binding
+tc_group gla_exts top_lvl sig_fn prag_fn (NonRecursive, binds) thing_inside
+       -- A single non-recursive binding
        -- We want to keep non-recursive things non-recursive
         -- so that we desugar unlifted bindings correctly
-    do { (binds, thing) <- tcPolyBinds top_lvl NonRecursive NonRecursive
-                                       sig_fn prag_fn binds thing_inside
+ =  do { (binds, thing) <- tc_haskell98 top_lvl sig_fn prag_fn NonRecursive binds thing_inside
        ; return ([(NonRecursive, b) | b <- binds], thing) }
 
-tc_group top_lvl sig_fn prag_fn (Recursive, binds) thing_inside
-  =    -- A recursive strongly-connected component
-       -- To maximise polymorphism (with -fglasgow-exts), we do a new 
+tc_group gla_exts top_lvl sig_fn prag_fn (Recursive, binds) thing_inside
+  | not gla_exts       -- Recursive group, normal Haskell 98 route
+  = do { (binds1, thing) <- tc_haskell98 top_lvl sig_fn prag_fn Recursive binds thing_inside
+       ; return ([(Recursive, unionManyBags binds1)], thing) }
+
+  | otherwise          -- Recursive group, with gla-exts
+  =    -- To maximise polymorphism (with -fglasgow-exts), we do a new 
        -- strongly-connected-component analysis, this time omitting 
        -- any references to variables with type signatures.
        --
-       -- Then we bring into scope all the variables with type signatures
+       -- Notice that the bindInsts thing covers *all* the bindings in the original
+       -- group at once; an earlier one may use a later one!
     do { traceTc (text "tc_group rec" <+> pprLHsBinds binds)
-       ; gla_exts     <- doptM Opt_GlasgowExts
-       ; (binds,thing) <- if gla_exts 
-                          then go new_sccs
-                          else tc_binds Recursive binds thing_inside
-       ; return ([(Recursive, unionManyBags binds)], thing) }
+       ; (binds1,thing) <- bindLocalInsts top_lvl $
+                           go (stronglyConnComp (mkEdges sig_fn binds))
+       ; return ([(Recursive, unionManyBags binds1)], thing) }
                -- Rec them all together
   where
-    new_sccs :: [SCC (LHsBind Name)]
-    new_sccs = stronglyConnComp (mkEdges sig_fn binds)
+--  go :: SCC (LHsBind Name) -> TcM ([LHsBind TcId], [TcId], thing)
+    go (scc:sccs) = do { (binds1, ids1) <- tc_scc scc
+                       ; (binds2, ids2, thing) <- tcExtendIdEnv ids1 $ go sccs
+                       ; return (binds1 ++ binds2, ids1 ++ ids2, thing) }
+    go []        = do  { thing <- thing_inside; return ([], [], thing) }
 
---  go :: SCC (LHsBind Name) -> TcM ([LHsBind TcId], thing)
-    go (scc:sccs) = do { (binds1, (binds2, thing)) <- go1 scc (go sccs)
-                       ; return (binds1 ++ binds2, thing) }
-    go []        = do  { thing <- thing_inside; return ([], thing) }
+    tc_scc (AcyclicSCC bind) = tc_sub_group NonRecursive (unitBag bind)
+    tc_scc (CyclicSCC binds) = tc_sub_group Recursive    (listToBag binds)
 
-    go1 (AcyclicSCC bind) = tc_binds NonRecursive (unitBag bind)
-    go1 (CyclicSCC binds) = tc_binds Recursive    (listToBag binds)
+    tc_sub_group = tcPolyBinds top_lvl sig_fn prag_fn Recursive
 
-    tc_binds rec_tc binds = tcPolyBinds top_lvl Recursive rec_tc sig_fn prag_fn binds
+tc_haskell98 top_lvl sig_fn prag_fn rec_flag binds thing_inside
+  = bindLocalInsts top_lvl $ do
+    { (binds1, ids) <- tcPolyBinds top_lvl sig_fn prag_fn rec_flag rec_flag binds
+    ; thing <- tcExtendIdEnv ids thing_inside
+    ; return (binds1, ids, thing) }
+
+------------------------
+bindLocalInsts :: TopLevelFlag -> TcM ([LHsBinds TcId], [TcId], a) -> TcM ([LHsBinds TcId], a)
+bindLocalInsts top_lvl thing_inside
+  | isTopLevel top_lvl = do { (binds, ids, thing) <- thing_inside; return (binds, thing) }
+       -- For the top level don't bother will all this bindInstsOfLocalFuns stuff. 
+       -- All the top level things are rec'd together anyway, so it's fine to
+       -- leave them to the tcSimplifyTop, and quite a bit faster too
+
+  | otherwise  -- Nested case
+  = do { ((binds, ids, thing), lie) <- getLIE thing_inside
+       ; lie_binds <- bindInstsOfLocalFuns lie ids
+       ; return (binds ++ [lie_binds], thing) }
 
 ------------------------
 mkEdges :: TcSigFun -> LHsBinds Name
@@ -276,63 +296,28 @@ bindersOfHsBind (PatBind { pat_lhs = pat })  = collectPatBinders pat
 bindersOfHsBind (FunBind { fun_id = L _ f }) = [f]
 
 ------------------------
-tcPolyBinds :: TopLevelFlag 
+tcPolyBinds :: TopLevelFlag -> TcSigFun -> TcPragFun
            -> RecFlag                  -- Whether the group is really recursive
-           -> RecFlag                  -- Whether it's recursive for typechecking purposes
-           -> TcSigFun -> TcPragFun
+           -> RecFlag                  -- Whether it's recursive after breaking
+                                       -- dependencies based on type signatures
            -> LHsBinds Name
-           -> TcM thing
-           -> TcM ([LHsBinds TcId], thing)
+           -> TcM ([LHsBinds TcId], [TcId])
 
 -- Typechecks a single bunch of bindings all together, 
 -- and generalises them.  The bunch may be only part of a recursive
 -- group, because we use type signatures to maximise polymorphism
 --
--- Deals with the bindInstsOfLocalFuns thing too
---
 -- Returns a list because the input may be a single non-recursive binding,
 -- in which case the dependency order of the resulting bindings is
 -- important.  
-
-tcPolyBinds top_lvl rec_group rec_tc sig_fn prag_fn scc thing_inside
-  =    -- NB: polymorphic recursion means that a function
-       -- may use an instance of itself, we must look at the LIE arising
-       -- from the function's own right hand side.  Hence the getLIE
-       -- encloses the tc_poly_binds. 
-    do { traceTc (text "tcPolyBinds" <+> ppr scc)
-       ; ((binds1, poly_ids, thing), lie) <- getLIE $ 
-               do { (binds1, poly_ids) <- tc_poly_binds top_lvl rec_group rec_tc
-                                                        sig_fn prag_fn scc
-                  ; thing <- tcExtendIdEnv poly_ids thing_inside
-                  ; return (binds1, poly_ids, thing) }
-
-       ; if isTopLevel top_lvl 
-         then          -- For the top level don't bother will all this
-                       -- bindInstsOfLocalFuns stuff. All the top level 
-                       -- things are rec'd together anyway, so it's fine to
-                       -- leave them to the tcSimplifyTop, 
-                       -- and quite a bit faster too
-               do { extendLIEs lie; return (binds1, thing) }
-
-         else do       -- Nested case
-               { lie_binds <- bindInstsOfLocalFuns lie poly_ids
-               ; return (binds1 ++ [lie_binds], thing) }}
-
-------------------------
-tc_poly_binds :: TopLevelFlag          -- See comments on tcPolyBinds
-             -> RecFlag -> RecFlag
-             -> TcSigFun -> TcPragFun
-             -> LHsBinds Name
-             -> TcM ([LHsBinds TcId], [TcId])
--- Typechecks the bindings themselves
+-- 
 -- Knows nothing about the scope of the bindings
 
-tc_poly_binds top_lvl rec_group rec_tc sig_fn prag_fn binds
+tcPolyBinds top_lvl sig_fn prag_fn rec_group rec_tc binds
   = let 
-        binder_names = collectHsBindBinders binds
        bind_list    = bagToList binds
-
-       loc = getLoc (head bind_list)
+        binder_names = collectHsBindBinders binds
+       loc          = getLoc (head bind_list)
                -- TODO: location a bit awkward, but the mbinds have been
                --       dependency analysed and may no longer be adjacent
     in