- = do { let bndrs = map fst prs
- rhs_env = extendBndrsWith RecFun env bndrs
-
- ; (rhs_usgs, prs_w_occs) <- mapAndUnzipUs (scRecRhs rhs_env) prs
- ; let rhs_usg = combineUsages rhs_usgs
- rhs_calls = calls rhs_usg
-
- ; prs_s <- mapUs (specialise env rhs_calls) prs_w_occs
- ; return (extendBndrs env bndrs,
- -- For the body of the letrec, just
- -- extend the env with Other to record
- -- that it's in scope; no funny RecFun business
- rhs_usg { calls = calls rhs_usg `delVarEnvList` bndrs },
- Rec (concat prs_s)) }
+ | not (all (couldBeSmallEnoughToInline (sc_size env)) rhss)
+ -- No specialisation
+ = do { let (rhs_env,bndrs') = extendRecBndrs env bndrs
+ ; (rhs_usgs, rhss') <- mapAndUnzipUs (scExpr rhs_env) rhss
+ ; return (rhs_env, combineUsages rhs_usgs, Rec (bndrs' `zip` rhss')) }
+ | otherwise -- Do specialisation
+ = do { let (rhs_env1,bndrs') = extendRecBndrs env bndrs
+ rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
+
+ ; (rhs_usgs, rhs_infos) <- mapAndUnzipUs (scRecRhs rhs_env2) (bndrs' `zip` rhss)
+ ; let rhs_usg = combineUsages rhs_usgs
+
+ ; (spec_usg, specs) <- spec_loop rhs_env2 (calls rhs_usg)
+ (repeat [] `zip` rhs_infos)
+
+ ; let all_usg = rhs_usg `combineUsage` spec_usg
+
+ ; return (rhs_env1, -- For the body of the letrec, delete the RecFun business
+ all_usg { calls = calls rhs_usg `delVarEnvList` bndrs' },
+ Rec (concat (zipWith addRules rhs_infos specs))) }
+ where
+ (bndrs,rhss) = unzip prs
+
+ spec_loop :: ScEnv
+ -> CallEnv
+ -> [([CallPat], RhsInfo)] -- One per binder
+ -> UniqSM (ScUsage, [[SpecInfo]]) -- One list per binder
+ spec_loop env all_calls rhs_stuff
+ = do { (spec_usg_s, new_pats_s, specs) <- mapAndUnzip3Us (specialise env all_calls) rhs_stuff
+ ; let spec_usg = combineUsages spec_usg_s
+ ; if all null new_pats_s then
+ return (spec_usg, specs) else do
+ { (spec_usg1, specs1) <- spec_loop env (calls spec_usg)
+ (zipWith add_pats new_pats_s rhs_stuff)
+ ; return (spec_usg `combineUsage` spec_usg1, zipWith (++) specs specs1) } }
+
+ add_pats :: [CallPat] -> ([CallPat], RhsInfo) -> ([CallPat], RhsInfo)
+ add_pats new_pats (done_pats, rhs_info) = (done_pats ++ new_pats, rhs_info)