X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=blobdiff_plain;f=compiler%2Fspecialise%2FSpecConstr.lhs;h=219e758c4aa0f36210d2cc733c41b65ae6c8d9e5;hp=b811f404ebdbe5a9efaab945bafa1c93352c426e;hb=59b01a2fb6cd6a9af37f5fd6775f574bc53af02a;hpb=99f41975ae61fc919638aa389199b32742332eff diff --git a/compiler/specialise/SpecConstr.lhs b/compiler/specialise/SpecConstr.lhs index b811f40..219e758 100644 --- a/compiler/specialise/SpecConstr.lhs +++ b/compiler/specialise/SpecConstr.lhs @@ -389,6 +389,38 @@ But fspec doesn't have decent strictnes info. As it happened, and hence f. But now f's strictness is less than its arity, which breaks an invariant. +Note [Forcing specialisation] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +With stream fusion and in other similar cases, we want to fully specialise +some (but not necessarily all!) loops regardless of their size and the +number of specialisations. We allow a library to specify this by annotating +a type with ForceSpecConstr and then adding a parameter of that type to the +loop. Here is a (simplified) example from the vector library: + + data SPEC = SPEC | SPEC2 + {-# ANN type SPEC ForceSpecConstr #-} + + foldl :: (a -> b -> a) -> a -> Stream b -> a + {-# INLINE foldl #-} + foldl f z (Stream step s _) = foldl_loop SPEC z s + where + foldl_loop SPEC z s = case step s of + Yield x s' -> foldl_loop SPEC (f z x) s' + Skip -> foldl_loop SPEC z s' + Done -> z + +SpecConstr will spot the SPEC parameter and always fully specialise +foldl_loop. Note that we can't just annotate foldl_loop since it isn't a +top-level function but even if we could, inlining etc. could easily drop the +annotation. We also have to prevent the SPEC argument from being removed by +w/w which is why SPEC is a sum type. This is all quite ugly; we ought to come +up with a better design. + +ForceSpecConstr arguments are spotted in scExpr' and scTopBinds which then set +force_spec to True when calling specLoop. This flag makes specLoop and +specialise ignore specConstrCount and specConstrThreshold when deciding +whether to specialise a function. + ----------------------------------------------------- Stuff not yet handled ----------------------------------------------------- @@ -510,6 +542,7 @@ specConstrProgram guts \begin{code} data ScEnv = SCE { sc_size :: Maybe Int, -- Size threshold sc_count :: Maybe Int, -- Max # of specialisations for any one fn + -- See Note [Avoiding exponential blowup] sc_subst :: Subst, -- Current substitution -- Maps InIds to OutExprs @@ -528,6 +561,7 @@ data ScEnv = SCE { sc_size :: Maybe Int, -- Size threshold --------------------- -- As we go, we apply a substitution (sc_subst) to the current term type InExpr = CoreExpr -- _Before_ applying the subst +type InVar = Var type OutExpr = CoreExpr -- _After_ applying the subst type OutId = Id @@ -669,7 +703,7 @@ ignoreAltCon env (LitAlt lit) = ignoreType env (literalType lit) ignoreAltCon _ DEFAULT = True forceSpecBndr :: ScEnv -> Var -> Bool -forceSpecBndr env var = forceSpecFunTy env . varType $ var +forceSpecBndr env var = forceSpecFunTy env . snd . splitForAllTys . varType $ var forceSpecFunTy :: ScEnv -> Type -> Bool forceSpecFunTy env = any (forceSpecArgTy env) . fst . splitFunTys @@ -685,8 +719,39 @@ forceSpecArgTy env ty || any (forceSpecArgTy env) tys forceSpecArgTy _ _ = False + +decreaseSpecCount :: ScEnv -> Int -> ScEnv +-- See Note [Avoiding exponential blowup] +decreaseSpecCount env n_specs + = env { sc_count = case sc_count env of + Nothing -> Nothing + Just n -> Just (n `div` (n_specs + 1)) } + -- The "+1" takes account of the original function; + -- See Note [Avoiding exponential blowup] \end{code} +Note [Avoiding exponential blowup] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The sc_count field of the ScEnv says how many times we are prepared to +duplicate a single function. But we must take care with recursive +specialiations. Consider + + let $j1 = let $j2 = let $j3 = ... + in + ...$j3... + in + ...$j2... + in + ...$j1... + +If we specialise $j1 then in each specialisation (as well as the original) +we can specialise $j2, and similarly $j3. Even if we make just *one* +specialisation of each, becuase we also have the original we'll get 2^n +copies of $j3, which is not good. + +So when recursively specialising we divide the sc_count by the number of +copies we are making at this level, including the original. + %************************************************************************ %* * @@ -880,18 +945,23 @@ scExpr' env (Let (NonRec bndr rhs) body) | otherwise -- Note [Local let bindings] = do { let (body_env, bndr') = extendBndr env bndr + body_env2 = extendHowBound body_env [bndr'] RecFun + ; (body_usg, body') <- scExpr body_env2 body + ; (rhs_usg, rhs_info) <- scRecRhs env (bndr',rhs) + + -- NB: We don't use the ForceSpecConstr mechanism (see + -- Note [Forcing specialisation]) for non-recursive bindings + -- at the moment. I'm not sure if this is the right thing to do. ; let force_spec = False - ; let body_env2 = extendHowBound body_env [bndr'] RecFun - ; (body_usg, body') <- scExpr body_env2 body ; (spec_usg, specs) <- specialise env force_spec (scu_calls body_usg) rhs_info - (SI [] 0 Nothing) + (SI [] 0 (Just rhs_usg)) ; return (body_usg { scu_calls = scu_calls body_usg `delVarEnv` bndr' } - `combineUsage` rhs_usg `combineUsage` spec_usg, - mkLets [NonRec b r | (b,r) <- specInfoBinds rhs_info specs] body') + `combineUsage` spec_usg, + mkLets [NonRec b r | (b,r) <- specInfoBinds rhs_info specs] body') } @@ -901,6 +971,7 @@ scExpr' env (Let (Rec prs) body) (rhs_env1,bndrs') = extendRecBndrs env bndrs rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun force_spec = any (forceSpecBndr env) bndrs' + -- Note [Forcing specialisation] ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss) ; (body_usg, body') <- scExpr rhs_env2 body @@ -909,6 +980,9 @@ scExpr' env (Let (Rec prs) body) ; (spec_usg, specs) <- specLoop rhs_env2 force_spec (scu_calls body_usg) rhs_infos nullUsage [SI [] 0 (Just usg) | usg <- rhs_usgs] + -- Do not unconditionally use rhs_usgs. + -- Instead use them only if we find an unspecialised call + -- See Note [Local recursive groups] ; let all_usg = spec_usg `combineUsage` body_usg bind' = Rec (concat (zipWith specInfoBinds rhs_infos specs)) @@ -1001,6 +1075,7 @@ scTopBind env (Rec prs) where (bndrs,rhss) = unzip prs force_spec = any (forceSpecBndr env) bndrs + -- Note [Forcing specialisation] scTopBind env (NonRec bndr rhs) = do { (_, rhs') <- scExpr env rhs @@ -1015,8 +1090,8 @@ scRecRhs env (bndr,rhs) (body_env, arg_bndrs') = extendBndrsWith RecArg env arg_bndrs ; (body_usg, body') <- scExpr body_env body ; let (rhs_usg, arg_occs) = lookupOccs body_usg arg_bndrs' - ; return (rhs_usg, (bndr, arg_bndrs', body', arg_occs)) } - + ; return (rhs_usg, RI bndr (mkLams arg_bndrs' body') + arg_bndrs body arg_occs) } -- The arg_occs says how the visible, -- lambda-bound binders of the RHS are used -- (including the TyVar binders) @@ -1024,9 +1099,9 @@ scRecRhs env (bndr,rhs) ---------------------- specInfoBinds :: RhsInfo -> SpecInfo -> [(Id,CoreExpr)] -specInfoBinds (fn, args, body, _) (SI specs _ _) +specInfoBinds (RI fn new_rhs _ _ _) (SI specs _ _) = [(id,rhs) | OS _ _ id rhs <- specs] ++ - [(fn `addIdSpecialisations` rules, mkLams args body)] + [(fn `addIdSpecialisations` rules, new_rhs)] where rules = [r | OS _ r _ _ <- specs] @@ -1046,17 +1121,21 @@ varUsage env v use %************************************************************************ \begin{code} -type RhsInfo = (OutId, [OutVar], OutExpr, [ArgOcc]) - -- Info about the *original* RHS of a binding we are specialising - -- Original binding f = \xs.body - -- Plus info about usage of arguments +data RhsInfo = RI OutId -- The binder + OutExpr -- The new RHS + [InVar] InExpr -- The *original* RHS (\xs.body) + -- Note [Specialise original body] + [ArgOcc] -- Info on how the xs occur in body data SpecInfo = SI [OneSpec] -- The specialisations we have generated + Int -- Length of specs; used for numbering them + (Maybe ScUsage) -- Nothing => we have generated specialisations -- from calls in the *original* RHS -- Just cs => we haven't, and this is the usage -- of the original RHS + -- See Note [Local recursive groups] -- One specialisation: Rule plus definition data OneSpec = OS CallPat -- Call pattern that generated this specialisation @@ -1066,6 +1145,7 @@ data OneSpec = OS CallPat -- Call pattern that generated this specialisation specLoop :: ScEnv -> Bool -- force specialisation? + -- Note [Forcing specialisation] -> CallEnv -> [RhsInfo] -> ScUsage -> [SpecInfo] -- One per binder; acccumulating parameter @@ -1084,6 +1164,7 @@ specLoop env force_spec all_calls rhs_infos usg_so_far specs_so_far specialise :: ScEnv -> Bool -- force specialisation? + -- Note [Forcing specialisation] -> CallEnv -- Info on calls -> RhsInfo -> SpecInfo -- Original RHS plus patterns dealt with @@ -1093,30 +1174,30 @@ specialise -- So when we make a specialised copy of the RHS, we're starting -- from an RHS whose nested functions have been optimised already. -specialise env force_spec bind_calls (fn, arg_bndrs, body, arg_occs) +specialise env force_spec bind_calls (RI fn _ arg_bndrs body arg_occs) spec_info@(SI specs spec_count mb_unspec) | not (isBottomingId fn) -- Note [Do not specialise diverging functions] , notNull arg_bndrs -- Only specialise functions , Just all_calls <- lookupVarEnv bind_calls fn = do { (boring_call, pats) <- callsToPats env specs arg_occs all_calls -- ; pprTrace "specialise" (vcat [ ppr fn <+> text "with" <+> int (length pats) <+> text "good patterns" --- , text "arg_occs" <+> ppr arg_occs, --- , text "calls" <+> ppr all_calls, +-- , text "arg_occs" <+> ppr arg_occs +-- , text "calls" <+> ppr all_calls -- , text "good pats" <+> ppr pats]) $ -- return () -- Bale out if too many specialisations - -- Rather a hacky way to do so, but it'll do for now - ; let n_pats = length pats - spec_count' = length pats + spec_count + ; let n_pats = length pats + spec_count' = n_pats + spec_count ; case sc_count env of Just max | not force_spec && spec_count' > max -> pprTrace "SpecConstr" msg $ - return (nullUsage, spec_info) + return (nullUsage, spec_info) where msg = vcat [ sep [ ptext (sLit "Function") <+> quotes (ppr fn) - , nest 2 (ptext (sLit "has") <+> int n_pats <+> - ptext (sLit "call pattterns, but the limit is") <+> int max) ] + , nest 2 (ptext (sLit "has") <+> + speakNOf spec_count' (ptext (sLit "call pattern")) <> comma <+> + ptext (sLit "but the limit is") <+> int max) ] , ptext (sLit "Use -fspec-constr-count=n to set the bound") , extra ] extra | not opt_PprStyle_Debug = ptext (sLit "Use -dppr-debug to see specialisations") @@ -1124,8 +1205,10 @@ specialise env force_spec bind_calls (fn, arg_bndrs, body, arg_occs) _normal_case -> do { - (spec_usgs, new_specs) <- mapAndUnzipM (spec_one env fn arg_bndrs body) + let spec_env = decreaseSpecCount env n_pats + ; (spec_usgs, new_specs) <- mapAndUnzipM (spec_one spec_env fn arg_bndrs body) (pats `zip` [spec_count..]) + -- See Note [Specialise original body] ; let spec_usg = combineUsages spec_usgs (new_usg, mb_unspec') @@ -1141,8 +1224,8 @@ specialise env force_spec bind_calls (fn, arg_bndrs, body, arg_occs) --------------------- spec_one :: ScEnv -> OutId -- Function - -> [Var] -- Lambda-binders of RHS; should match patterns - -> CoreExpr -- Body of the original function + -> [InVar] -- Lambda-binders of RHS; should match patterns + -> InExpr -- Body of the original function -> (CallPat, Int) -> UniqSM (ScUsage, OneSpec) -- Rule and binding @@ -1169,30 +1252,33 @@ spec_one :: ScEnv -} spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number) - = do { -- Specialise the body - let spec_env = extendScSubstList (extendScInScope env qvars) + = do { spec_uniq <- getUniqueUs + ; let spec_env = extendScSubstList (extendScInScope env qvars) (arg_bndrs `zip` pats) - ; (spec_usg, spec_body) <- scExpr spec_env body - --- ; pprTrace "spec_one" (ppr fn <+> vcat [text "pats" <+> ppr pats, --- text "calls" <+> (ppr (scu_calls spec_usg))]) --- (return ()) - - -- And build the results - ; spec_uniq <- getUniqueUs - ; let (spec_lam_args, spec_call_args) = mkWorkerArgs qvars body_ty - -- Usual w/w hack to avoid generating - -- a spec_rhs of unlifted type and no args - fn_name = idName fn fn_loc = nameSrcSpan fn_name spec_occ = mkSpecOcc (nameOccName fn_name) rule_name = mkFastString ("SC:" ++ showSDoc (ppr fn <> int rule_number)) - spec_rhs = mkLams spec_lam_args spec_body - spec_str = calcSpecStrictness fn spec_lam_args pats - spec_id = mkUserLocal spec_occ spec_uniq (mkPiTypes spec_lam_args body_ty) fn_loc + spec_name = mkInternalName spec_uniq spec_occ fn_loc +-- ; pprTrace "{spec_one" (ppr (sc_count env) <+> ppr fn <+> ppr pats <+> text "-->" <+> ppr spec_name) $ +-- return () + + -- Specialise the body + ; (spec_usg, spec_body) <- scExpr spec_env body + +-- ; pprTrace "done spec_one}" (ppr fn) $ +-- return () + + -- And build the results + ; let spec_id = mkLocalId spec_name (mkPiTypes spec_lam_args body_ty) `setIdStrictness` spec_str -- See Note [Transfer strictness] `setIdArity` count isId spec_lam_args + spec_str = calcSpecStrictness fn spec_lam_args pats + (spec_lam_args, spec_call_args) = mkWorkerArgs qvars body_ty + -- Usual w/w hack to avoid generating + -- a spec_rhs of unlifted type and no args + + spec_rhs = mkLams spec_lam_args spec_body body_ty = exprType spec_body rule_rhs = mkVarApps (Var spec_id) spec_call_args inline_act = idInlineActivation fn @@ -1223,6 +1309,13 @@ calcSpecStrictness fn qvars pats \end{code} +Note [Specialise original body] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The RhsInfo for a binding keeps the *original* body of the binding. We +must specialise that, *not* the result of applying specExpr to the RHS +(which is also kept in RhsInfo). Otherwise we end up specialising a +specialised RHS, and that can lead directly to exponential behaviour. + Note [Transfer activation] ~~~~~~~~~~~~~~~~~~~~~~~~~~ In which phase should the specialise-constructor rules be active?