Adjust Activations for specialise and work/wrap, and better simplify in InlineRules
[ghc-hetmet.git] / compiler / specialise / SpecConstr.lhs
index f366cd7..404b6cc 100644 (file)
@@ -39,9 +39,8 @@ import Name
 import DynFlags                ( DynFlags(..) )
 import StaticFlags     ( opt_PprStyle_Debug )
 import StaticFlags     ( opt_SpecInlineJoinPoints )
-import BasicTypes      ( Activation(..) )
 import Maybes          ( orElse, catMaybes, isJust, isNothing )
-import NewDemand
+import Demand
 import DmdAnal         ( both )
 import Serialized       ( deserializeWithData )
 import Util
@@ -53,7 +52,11 @@ import qualified LazyUniqFM as L
 import MonadUtils
 import Control.Monad   ( zipWithM )
 import Data.List
+#if __GLASGOW_HASKELL__ > 609
 import Data.Data        ( Data, Typeable )
+#else
+import Data.Generics    ( Data, Typeable )
+#endif
 \end{code}
 
 -----------------------------------------------------
@@ -472,7 +475,8 @@ Annotating a type with NoSpecConstr will make SpecConstr not specialise
 for arguments of that type.
 
 \begin{code}
-data SpecConstrAnnotation = NoSpecConstr deriving( Data, Typeable )
+data SpecConstrAnnotation = NoSpecConstr | ForceSpecConstr
+                deriving( Data, Typeable, Eq )
 \end{code}
 
 %************************************************************************
@@ -487,7 +491,7 @@ specConstrProgram guts
   = do
       dflags <- getDynFlags
       us     <- getUniqueSupplyM
-      annos  <- deserializeAnnotations deserializeWithData
+      annos  <- getFirstAnnotations deserializeWithData guts
       let binds' = fst $ initUs us (go (initScEnv dflags annos) (mg_binds guts))
       return (guts { mg_binds = binds' })
   where
@@ -543,14 +547,14 @@ instance Outputable Value where
    ppr LambdaVal        = ptext (sLit "<Lambda>")
 
 ---------------------
-initScEnv :: DynFlags -> L.UniqFM [SpecConstrAnnotation] -> ScEnv
-initScEnv dflags annos
+initScEnv :: DynFlags -> L.UniqFM SpecConstrAnnotation -> ScEnv
+initScEnv dflags anns
   = SCE { sc_size = specConstrThreshold dflags,
          sc_count = specConstrCount dflags,
          sc_subst = emptySubst, 
          sc_how_bound = emptyVarEnv, 
          sc_vals = emptyVarEnv,
-          sc_annotations = L.mapUFM head $ L.filterUFM (not . null) annos }
+          sc_annotations = anns }
 
 data HowBound = RecFun -- These are the recursive functions for which 
                        -- we seek interesting call patterns
@@ -652,9 +656,7 @@ extendCaseBndrs env case_bndr con alt_bndrs
 
 ignoreTyCon :: ScEnv -> TyCon -> Bool
 ignoreTyCon env tycon
-  = case L.lookupUFM (sc_annotations env) tycon of
-      Just NoSpecConstr -> True
-      _                 -> False
+  = L.lookupUFM (sc_annotations env) tycon == Just NoSpecConstr
 
 ignoreType :: ScEnv -> Type -> Bool
 ignoreType env ty
@@ -666,6 +668,24 @@ ignoreAltCon :: ScEnv -> AltCon -> Bool
 ignoreAltCon env (DataAlt dc) = ignoreTyCon env (dataConTyCon dc)
 ignoreAltCon env (LitAlt lit) = ignoreType env (literalType lit)
 ignoreAltCon _   DEFAULT      = True
+
+forceSpecBndr :: ScEnv -> Var -> Bool
+forceSpecBndr env var = forceSpecFunTy env . varType $ var
+
+forceSpecFunTy :: ScEnv -> Type -> Bool
+forceSpecFunTy env = any (forceSpecArgTy env) . fst . splitFunTys
+
+forceSpecArgTy :: ScEnv -> Type -> Bool
+forceSpecArgTy env ty
+  | Just ty' <- coreView ty = forceSpecArgTy env ty'
+
+forceSpecArgTy env ty
+  | Just (tycon, tys) <- splitTyConApp_maybe ty
+  , tycon /= funTyCon
+      = L.lookupUFM (sc_annotations env) tycon == Just ForceSpecConstr
+        || any (forceSpecArgTy env) tys
+
+forceSpecArgTy _ _ = False
 \end{code}
 
 
@@ -896,12 +916,14 @@ scExpr' env (Let (Rec prs) body)
   = do { let (bndrs,rhss) = unzip prs
              (rhs_env1,bndrs') = extendRecBndrs env bndrs
              rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
+              force_spec = any (forceSpecBndr env) bndrs'
 
        ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
        ; (body_usg, body')     <- scExpr rhs_env2 body
 
        -- NB: start specLoop from body_usg
-       ; (spec_usg, specs) <- specLoop rhs_env2 (scu_calls body_usg) rhs_infos nullUsage
+       ; (spec_usg, specs) <- specLoop rhs_env2 force_spec
+                                        (scu_calls body_usg) rhs_infos nullUsage
                                        [SI [] 0 (Just usg) | usg <- rhs_usgs]
 
        ; let all_usg = spec_usg `combineUsage` body_usg
@@ -955,6 +977,7 @@ scApp env (other_fn, args)
 scTopBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, CoreBind)
 scTopBind env (Rec prs)
   | Just threshold <- sc_size env
+  , not force_spec
   , not (all (couldBeSmallEnoughToInline threshold) rhss)
                -- No specialisation
   = do { let (rhs_env,bndrs') = extendRecBndrs env bndrs
@@ -967,13 +990,15 @@ scTopBind env (Rec prs)
        ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
        ; let rhs_usg = combineUsages rhs_usgs
 
-       ; (_, specs) <- specLoop rhs_env2 (scu_calls rhs_usg) rhs_infos nullUsage
+       ; (_, specs) <- specLoop rhs_env2 force_spec
+                                 (scu_calls rhs_usg) rhs_infos nullUsage
                                 [SI [] 0 Nothing | _ <- bndrs]
 
        ; return (rhs_env1,  -- For the body of the letrec, delete the RecFun business
                  Rec (concat (zipWith specInfoBinds rhs_infos specs))) }
   where
     (bndrs,rhss) = unzip prs
+    force_spec = any (forceSpecBndr env) bndrs
 
 scTopBind env (NonRec bndr rhs)
   = do { (_, rhs') <- scExpr env rhs
@@ -1038,12 +1063,13 @@ data OneSpec  = OS CallPat              -- Call pattern that generated this specialisation
 
 
 specLoop :: ScEnv
+         -> Bool                                -- force specialisation?
         -> CallEnv
         -> [RhsInfo]
         -> ScUsage -> [SpecInfo]               -- One per binder; acccumulating parameter
         -> UniqSM (ScUsage, [SpecInfo])        -- ...ditto...
-specLoop env all_calls rhs_infos usg_so_far specs_so_far
-  = do { specs_w_usg <- zipWithM (specialise env all_calls) rhs_infos specs_so_far
+specLoop env force_spec all_calls rhs_infos usg_so_far specs_so_far
+  = do { specs_w_usg <- zipWithM (specialise env force_spec all_calls) rhs_infos specs_so_far
        ; let (new_usg_s, all_specs) = unzip specs_w_usg
              new_usg   = combineUsages new_usg_s
              new_calls = scu_calls new_usg
@@ -1051,10 +1077,11 @@ specLoop env all_calls rhs_infos usg_so_far specs_so_far
        ; if isEmptyVarEnv new_calls then
                return (all_usg, all_specs) 
          else 
-               specLoop env new_calls rhs_infos all_usg all_specs }
+               specLoop env force_spec new_calls rhs_infos all_usg all_specs }
 
 specialise 
    :: ScEnv
+   -> Bool                              -- force specialisation?
    -> CallEnv                          -- Info on calls
    -> RhsInfo
    -> SpecInfo                         -- Original RHS plus patterns dealt with
@@ -1064,7 +1091,7 @@ specialise
 -- So when we make a specialised copy of the RHS, we're starting
 -- from an RHS whose nested functions have been optimised already.
 
-specialise env bind_calls (fn, arg_bndrs, body, arg_occs) 
+specialise env force_spec bind_calls (fn, arg_bndrs, body, arg_occs) 
                          spec_info@(SI specs spec_count mb_unspec)
   | not (isBottomingId fn)      -- Note [Do not specialise diverging functions]
   , notNull arg_bndrs          -- Only specialise functions
@@ -1079,7 +1106,7 @@ specialise env bind_calls (fn, arg_bndrs, body, arg_occs)
                -- Rather a hacky way to do so, but it'll do for now
        ; let spec_count' = length pats + spec_count
        ; case sc_count env of
-           Just max | spec_count' > max
+           Just max | not force_spec && spec_count' > max
                -> WARN( True, msg ) return (nullUsage, spec_info)
                where
                   msg = vcat [ sep [ ptext (sLit "SpecConstr: specialisation of") <+> quotes (ppr fn)
@@ -1151,18 +1178,19 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
                -- Usual w/w hack to avoid generating 
                -- a spec_rhs of unlifted type and no args
        
-             fn_name   = idName fn
-             fn_loc    = nameSrcSpan fn_name
-             spec_occ  = mkSpecOcc (nameOccName fn_name)
-             rule_name = mkFastString ("SC:" ++ showSDoc (ppr fn <> int rule_number))
-             spec_rhs  = mkLams spec_lam_args spec_body
-             spec_str  = calcSpecStrictness fn spec_lam_args pats
-             spec_id   = mkUserLocal spec_occ spec_uniq (mkPiTypes spec_lam_args body_ty) fn_loc
-                           `setIdNewStrictness` spec_str       -- See Note [Transfer strictness]
-                           `setIdArity` count isId spec_lam_args
-             body_ty   = exprType spec_body
-             rule_rhs  = mkVarApps (Var spec_id) spec_call_args
-             rule      = mkLocalRule rule_name specConstrActivation fn_name qvars pats rule_rhs
+             fn_name    = idName fn
+             fn_loc     = nameSrcSpan fn_name
+             spec_occ   = mkSpecOcc (nameOccName fn_name)
+             rule_name  = mkFastString ("SC:" ++ showSDoc (ppr fn <> int rule_number))
+             spec_rhs   = mkLams spec_lam_args spec_body
+             spec_str   = calcSpecStrictness fn spec_lam_args pats
+             spec_id    = mkUserLocal spec_occ spec_uniq (mkPiTypes spec_lam_args body_ty) fn_loc
+                            `setIdStrictness` spec_str         -- See Note [Transfer strictness]
+                            `setIdArity` count isId spec_lam_args
+             body_ty    = exprType spec_body
+             rule_rhs   = mkVarApps (Var spec_id) spec_call_args
+              inline_act = idInlineActivation fn
+             rule       = mkLocalRule rule_name inline_act fn_name qvars pats rule_rhs
        ; return (spec_usg, OS call_pat rule spec_id spec_rhs) }
 
 calcSpecStrictness :: Id                    -- The original function
@@ -1173,7 +1201,7 @@ calcSpecStrictness fn qvars pats
   = StrictSig (mkTopDmdType spec_dmds TopRes)
   where
     spec_dmds = [ lookupVarEnv dmd_env qv `orElse` lazyDmd | qv <- qvars, isId qv ]
-    StrictSig (DmdType _ dmds _) = idNewStrictness fn
+    StrictSig (DmdType _ dmds _) = idStrictness fn
 
     dmd_env = go emptyVarEnv dmds pats
 
@@ -1187,18 +1215,23 @@ calcSpecStrictness fn qvars pats
           | (Var _, args) <- collectArgs e = go env ds args
     go_one env _         _ = env
 
--- In which phase should the specialise-constructor rules be active?
--- Originally I made them always-active, but Manuel found that
--- this defeated some clever user-written rules.  So Plan B
--- is to make them active only in Phase 0; after all, currently,
--- the specConstr transformation is only run after the simplifier
--- has reached Phase 0.  In general one would want it to be 
--- flag-controllable, but for now I'm leaving it baked in
---                                     [SLPJ Oct 01]
-specConstrActivation :: Activation
-specConstrActivation = ActiveAfter 0   -- Baked in; see comments above
 \end{code}
 
+Note [Transfer activation]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+In which phase should the specialise-constructor rules be active?
+Originally I made them always-active, but Manuel found that this
+defeated some clever user-written rules.  Then I made them active only
+in Phase 0; after all, currently, the specConstr transformation is
+only run after the simplifier has reached Phase 0, but that meant
+that specialisations didn't fire inside wrappers; see test
+simplCore/should_compile/spec-inline.
+
+So now I just use the inline-activation of the parent Id, as the
+activation for the specialiation RULE, just like the main specialiser;
+see Note [Auto-specialisation and RULES] in Specialise.
+
+
 Note [Transfer strictness]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
 We must transfer strictness information from the original function to