Teach SpecConstr how to handle mutually-recursive functions
authorsimonpj@microsoft.com <unknown>
Wed, 29 Nov 2006 21:39:31 +0000 (21:39 +0000)
committersimonpj@microsoft.com <unknown>
Wed, 29 Nov 2006 21:39:31 +0000 (21:39 +0000)
Roman found cases where it was important to do SpecConstr for
mutually-recursive definitions.  Here is one:
foo :: Maybe Int -> Int
foo Nothing  = 0
foo (Just 0) = foo Nothing
foo (Just n) = foo (Just (n-1))
By the time SpecConstr gets to it, it looks like this:
lvl = foo Nothing
foo Nothing  = 0
foo (Just 0) = lvl
foo (Just n) = foo (Just (n-1))

Happily, it turns out to be rather straightforward to generalise the
transformation to mutually-recursive functions.  Look, ma, only 4
extra lines of ocde!

compiler/specialise/SpecConstr.lhs

index 5cccf80..3876a44 100644 (file)
@@ -485,7 +485,13 @@ instance Outputable HowBound where
 
 lookupScopeEnv env v = lookupVarEnv (scope env) v
 
-extendBndrs env bndrs = env { scope = extendVarEnvList (scope env) [(b,Other) | b <- bndrs] }
+
+extendBndrsWith :: HowBound -> ScEnv -> [Var] -> ScEnv
+extendBndrsWith how_bound env bndrs 
+  =  env { scope = scope env `extendVarEnvList` 
+                       [(bndr,how_bound) | bndr <- bndrs] }
+
+extendBndrs env bndrs = extendBndrsWith Other env bndrs
 extendBndr  env bndr  = env { scope = extendVarEnv (scope env) bndr Other }
 
     -- When we encounter
@@ -497,11 +503,12 @@ extendCaseBndrs env case_bndr scrut con alt_bndrs
   = case con of
        DEFAULT    -> env1
        LitAlt lit -> extendCons env1 scrut case_bndr (CV con [])
-       DataAlt dc -> extend_data_con dc
+       DataAlt dc -> extendCons env1 scrut case_bndr (CV con vanilla_args)
+             where
+               vanilla_args = map Type (tyConAppArgs (idType case_bndr)) ++
+                              varsToCoreExprs alt_bndrs
   where
-    cur_scope = scope env
-    env1 = env { scope = extendVarEnvList cur_scope 
-                               [(b,how_bound) | b <- case_bndr:alt_bndrs] }
+    env1 = extendBndrsWith (get_how scrut) env (case_bndr:alt_bndrs)
 
        -- Record RecArg for the components iff the scrutinee is RecArg
        -- I think the only reason for this is to keep the usage envt small
@@ -511,18 +518,10 @@ extendCaseBndrs env case_bndr scrut con alt_bndrs
        --           now in the branch of a case, and we don't want to
        --           record a non-scrutinee use of v if we have
        --              case v of { (a,b) -> ...(f v)... }" ]
-    how_bound = get_how scrut
-       where
-           get_how (Var v)    = lookupVarEnv cur_scope v `orElse` Other
-           get_how (Cast e _) = get_how e
-           get_how (Note _ e) = get_how e
-           get_how other      = Other
-
-    extend_data_con data_con = 
-      extendCons env1 scrut case_bndr (CV con vanilla_args)
-       where
-           vanilla_args = map Type (tyConAppArgs (idType case_bndr)) ++
-                          varsToCoreExprs alt_bndrs
+    get_how (Var v)    = lookupVarEnv (scope env) v `orElse` Other
+    get_how (Cast e _) = get_how e
+    get_how (Note _ e) = get_how e
+    get_how other      = Other
 
 extendCons :: ScEnv -> CoreExpr -> Id -> ConValue -> ScEnv
 extendCons env scrut case_bndr val
@@ -531,14 +530,6 @@ extendCons env scrut case_bndr val
        other -> env { cons = cons1 }
   where
     cons1 = extendVarEnv (cons env) case_bndr val
-
-    -- When we encounter a recursive function binding
-    -- f = \x y -> ...
-    -- we want to extend the scope env with bindings 
-    -- that record that f is a RecFn and x,y are RecArgs
-extendRecBndr env fn bndrs
-  =  env { scope = scope env `extendVarEnvList` 
-                  ((fn,RecFun): [(bndr,RecArg) | bndr <- bndrs]) }
 \end{code}
 
 
@@ -551,7 +542,7 @@ extendRecBndr env fn bndrs
 \begin{code}
 data ScUsage
    = SCU {
-       calls :: !(IdEnv ([Call])),     -- Calls
+       calls :: !(IdEnv [Call]),       -- Calls
                                        -- The functions are a subset of the 
                                        --      RecFuns in the ScEnv
 
@@ -702,6 +693,7 @@ scExpr env e@(App _ _)
        ; (arg_usgs, args') <- mapAndUnzipUs (scExpr env) args
        ; let call_usg = case fn of
                           Var f | Just RecFun <- lookupScopeEnv env f
+                                , not (null args)      -- Not a proper call!
                                 -> SCU { calls = unitVarEnv f [(cons env, args)], 
                                          occs  = emptyVarEnv }
                           other -> nullUsage
@@ -723,36 +715,38 @@ scScrut env e             occ = scExpr env e
 
 ----------------------
 scBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, ScUsage, CoreBind)
-scBind env (Rec [(fn,rhs)])
-  | notNull val_bndrs
-  = scExpr env_fn_body body            `thenUs` \ (usg, body') ->
-    specialise env fn bndrs body' usg  `thenUs` \ (rules, spec_prs) ->
-       -- Note body': the specialised copies should be based on the 
-       --             optimised version of the body, in case there were
-       --             nested functions inside.
-    let
-       SCU { calls = calls, occs = occs } = usg
-    in
-    returnUs (extendBndr env fn,       -- For the body of the letrec, just
-                                       -- extend the env with Other to record 
-                                       -- that it's in scope; no funny RecFun business
-             SCU { calls = calls `delVarEnv` fn, occs = occs `delVarEnvList` val_bndrs},
-             Rec ((fn `addIdSpecialisations` rules, mkLams bndrs body') : spec_prs))
-  where
-    (bndrs,body) = collectBinders rhs
-    val_bndrs    = filter isId bndrs
-    env_fn_body         = extendRecBndr env fn bndrs
-
 scBind env (Rec prs)
-  = mapAndUnzipUs do_one prs   `thenUs` \ (usgs, prs') ->
-    returnUs (extendBndrs env (map fst prs), combineUsages usgs, Rec prs')
-  where
-    do_one (bndr,rhs) = scExpr env rhs `thenUs` \ (usg, rhs') ->
-                       returnUs (usg, (bndr,rhs'))
+  = do { let bndrs = map fst prs
+             rhs_env = extendBndrsWith RecFun env bndrs
+
+       ; (rhs_usgs, prs_w_occs) <- mapAndUnzipUs (scRecRhs rhs_env) prs
+       ; let rhs_usg   = combineUsages rhs_usgs
+             rhs_calls = calls rhs_usg
+
+       ; prs_s <- mapUs (specialise env rhs_calls) prs_w_occs
+       ; return (extendBndrs env bndrs, 
+                               -- For the body of the letrec, just
+                               -- extend the env with Other to record 
+                               -- that it's in scope; no funny RecFun business
+                   rhs_usg { calls = calls rhs_usg `delVarEnvList` bndrs },
+                   Rec (concat prs_s)) }
 
 scBind env (NonRec bndr rhs)
-  = scExpr env rhs     `thenUs` \ (usg, rhs') ->
-    returnUs (extendBndr env bndr, usg, NonRec bndr rhs')
+  = do { (usg, rhs') <- scExpr env rhs
+       ; return (extendBndr env bndr, usg, NonRec bndr rhs') }
+
+----------------------
+scRecRhs :: ScEnv -> (Id,CoreExpr)
+        -> UniqSM (ScUsage, (Id, CoreExpr, [ArgOcc]))
+-- The returned [ArgOcc] says how the visible,
+-- lambda-bound binders of the RHS are used
+-- (including the TyVar binders)
+scRecRhs env (bndr,rhs)
+  = do { let (arg_bndrs,body) = collectBinders rhs
+             body_env = extendBndrsWith RecArg env arg_bndrs
+       ; (body_usg, body') <- scExpr body_env body
+       ; let (rhs_usg, arg_occs) = lookupOccs body_usg arg_bndrs
+       ; return (rhs_usg, (bndr, mkLams arg_bndrs body', arg_occs)) }
 
 ----------------------
 varUsage env v use 
@@ -769,18 +763,22 @@ varUsage env v use
 %************************************************************************
 
 \begin{code}
-specialise :: ScEnv
-          -> Id                        -- Functionn
-          -> [CoreBndr] -> CoreExpr    -- Its RHS
-          -> ScUsage                   -- Info on usage
-          -> UniqSM ([CoreRule],       -- Rules
-                     [(Id,CoreExpr)])  -- Bindings
-
-specialise env fn bndrs body body_usg
-  = do { let (_, bndr_occs) = lookupOccs body_usg bndrs
-             all_calls = lookupVarEnv (calls body_usg) fn `orElse` []
-
-       ; mb_pats <- mapM (callToPats (scope env) bndr_occs) all_calls
+specialise 
+   :: ScEnv
+   -> IdEnv [Call]             -- Info on usage
+   -> (Id, CoreExpr, [ArgOcc]) -- Original binding, plus info on how the rhs's
+                               -- lambda-binders are used (includes TyVar bndrs)
+   -> UniqSM [(Id,CoreExpr)]   -- Original binding (decorated with rules)
+                               -- plus specialised bindings
+
+-- Note: the rhs here is the optimised version of the original rhs
+-- So when we make a specialised copy of the RHS, we're starting
+-- from an RHS whose nested functions have been optimised already.
+
+specialise env calls (fn, rhs, arg_occs)
+  | notNull arg_occs,  -- Only specialise functions
+    Just all_calls <- lookupVarEnv calls fn
+  = do { mb_pats <- mapM (callToPats (scope env) arg_occs) all_calls
 
        ; let good_pats :: [([Var], [CoreArg])]
              good_pats = catMaybes mb_pats
@@ -788,12 +786,19 @@ specialise env fn bndrs body body_usg
                         [ exprsFreeVars pats `delVarSetList` vs 
                         | (vs,pats) <- good_pats ]
              uniq_pats = nubBy (same_pat in_scope) good_pats
-       ; -- pprTrace "specialise" (vcat [ppr fn <+> ppr bndrs <+> ppr bndr_occs,
-         --                            text "calls" <+> ppr all_calls,
-         --                            text "good pats" <+> ppr good_pats,
-         --                            text "uniq pats" <+> ppr uniq_pats])  $
-         mapAndUnzipUs (spec_one env fn (mkLams bndrs body)) 
-                       (uniq_pats `zip` [1..]) }
+       ; pprTrace "specialise" (vcat [ppr fn <+> ppr arg_occs,
+                                       text "calls" <+> ppr all_calls,
+                                       text "good pats" <+> ppr good_pats,
+                               text "uniq pats" <+> ppr uniq_pats])  $
+         return ()
+
+       ; (rules, spec_prs) <- mapAndUnzipUs (spec_one fn rhs) 
+                                            (uniq_pats `zip` [1..])
+
+       ; return ((fn `addIdSpecialisations` rules, rhs) : spec_prs) }
+
+  | otherwise
+  = return [(fn,rhs)]  -- The boring case
   where
        -- Two pats are the same if they match both ways
     same_pat in_scope (vs1,as1)(vs2,as2)
@@ -821,8 +826,7 @@ callToPats in_scope bndr_occs (con_env, args)
          else return Nothing }
 
 ---------------------
-spec_one :: ScEnv
-        -> Id                                  -- Function
+spec_one :: Id                                 -- Function
         -> CoreExpr                            -- Rhs of the original function
         -> (([Var], [CoreArg]), Int)
         -> UniqSM (CoreRule, (Id,CoreExpr))    -- Rule and binding
@@ -848,7 +852,7 @@ spec_one :: ScEnv
            f (b,c) ((:) (a,(b,c)) (x,v) hw) = f_spec b c v hw
 -}
 
-spec_one env fn rhs ((vars_to_bind, pats), rule_number)
+spec_one fn rhs ((vars_to_bind, pats), rule_number)
   = getUniqueUs                `thenUs` \ spec_uniq ->
     let 
        fn_name      = idName fn