Adjust Activations for specialise and work/wrap, and better simplify in InlineRules
[ghc-hetmet.git] / compiler / specialise / SpecConstr.lhs
index bd70e28..404b6cc 100644 (file)
@@ -11,7 +11,7 @@
 -- for details
 
 module SpecConstr(
-       specConstrProgram       
+       specConstrProgram, SpecConstrAnnotation(..)
     ) where
 
 #include "HsVersions.h"
@@ -21,8 +21,12 @@ import CoreSubst
 import CoreUtils
 import CoreUnfold      ( couldBeSmallEnoughToInline )
 import CoreFVs                 ( exprsFreeVars )
+import CoreMonad
+import HscTypes         ( ModGuts(..) )
 import WwLib           ( mkWorkerArgs )
-import DataCon         ( dataConRepArity, dataConUnivTyVars )
+import DataCon         ( dataConTyCon, dataConRepArity, dataConUnivTyVars )
+import TyCon            ( TyCon )
+import Literal          ( literalType )
 import Coercion        
 import Rules
 import Type            hiding( substTy )
@@ -35,16 +39,24 @@ import Name
 import DynFlags                ( DynFlags(..) )
 import StaticFlags     ( opt_PprStyle_Debug )
 import StaticFlags     ( opt_SpecInlineJoinPoints )
-import BasicTypes      ( Activation(..) )
 import Maybes          ( orElse, catMaybes, isJust, isNothing )
+import Demand
+import DmdAnal         ( both )
+import Serialized       ( deserializeWithData )
 import Util
 import UniqSupply
 import Outputable
 import FastString
 import UniqFM
+import qualified LazyUniqFM as L
 import MonadUtils
 import Control.Monad   ( zipWithM )
 import Data.List
+#if __GLASGOW_HASKELL__ > 609
+import Data.Data        ( Data, Typeable )
+#else
+import Data.Generics    ( Data, Typeable )
+#endif
 \end{code}
 
 -----------------------------------------------------
@@ -453,7 +465,19 @@ But perhaps the first one isn't good.  After all, we know that tpl_B2 is
 a T (I# x) really, because T is strict and Int has one constructor.  (We can't
 unbox the strict fields, becuase T is polymorphic!)
 
+%************************************************************************
+%*                                                                     *
+\subsection{Annotations}
+%*                                                                     *
+%************************************************************************
 
+Annotating a type with NoSpecConstr will make SpecConstr not specialise
+for arguments of that type.
+
+\begin{code}
+data SpecConstrAnnotation = NoSpecConstr | ForceSpecConstr
+                deriving( Data, Typeable, Eq )
+\end{code}
 
 %************************************************************************
 %*                                                                     *
@@ -462,8 +486,14 @@ unbox the strict fields, becuase T is polymorphic!)
 %************************************************************************
 
 \begin{code}
-specConstrProgram :: DynFlags -> UniqSupply -> [CoreBind] -> [CoreBind]
-specConstrProgram dflags us binds = fst $ initUs us (go (initScEnv dflags) binds)
+specConstrProgram :: ModGuts -> CoreM ModGuts
+specConstrProgram guts
+  = do
+      dflags <- getDynFlags
+      us     <- getUniqueSupplyM
+      annos  <- getFirstAnnotations deserializeWithData guts
+      let binds' = fst $ initUs us (go (initScEnv dflags annos) (mg_binds guts))
+      return (guts { mg_binds = binds' })
   where
     go _   []          = return []
     go env (bind:binds) = do (env', bind') <- scTopBind env bind
@@ -489,9 +519,11 @@ data ScEnv = SCE { sc_size  :: Maybe Int,  -- Size threshold
                        -- Binds interesting non-top-level variables
                        -- Domain is OutVars (*after* applying the substitution)
 
-                  sc_vals  :: ValueEnv
+                  sc_vals  :: ValueEnv,
                        -- Domain is OutIds (*after* applying the substitution)
                        -- Used even for top-level bindings (but not imported ones)
+
+                   sc_annotations :: L.UniqFM SpecConstrAnnotation
             }
 
 ---------------------
@@ -515,13 +547,14 @@ instance Outputable Value where
    ppr LambdaVal        = ptext (sLit "<Lambda>")
 
 ---------------------
-initScEnv :: DynFlags -> ScEnv
-initScEnv dflags
+initScEnv :: DynFlags -> L.UniqFM SpecConstrAnnotation -> ScEnv
+initScEnv dflags anns
   = SCE { sc_size = specConstrThreshold dflags,
          sc_count = specConstrCount dflags,
          sc_subst = emptySubst, 
          sc_how_bound = emptyVarEnv, 
-         sc_vals = emptyVarEnv }
+         sc_vals = emptyVarEnv,
+          sc_annotations = anns }
 
 data HowBound = RecFun -- These are the recursive functions for which 
                        -- we seek interesting call patterns
@@ -620,6 +653,39 @@ extendCaseBndrs env case_bndr con alt_bndrs
                      where
                        vanilla_args = map Type (tyConAppArgs (idType case_bndr)) ++
                                       varsToCoreExprs alt_bndrs
+
+ignoreTyCon :: ScEnv -> TyCon -> Bool
+ignoreTyCon env tycon
+  = L.lookupUFM (sc_annotations env) tycon == Just NoSpecConstr
+
+ignoreType :: ScEnv -> Type -> Bool
+ignoreType env ty
+  = case splitTyConApp_maybe ty of
+      Just (tycon, _) -> ignoreTyCon env tycon
+      _               -> False
+
+ignoreAltCon :: ScEnv -> AltCon -> Bool
+ignoreAltCon env (DataAlt dc) = ignoreTyCon env (dataConTyCon dc)
+ignoreAltCon env (LitAlt lit) = ignoreType env (literalType lit)
+ignoreAltCon _   DEFAULT      = True
+
+forceSpecBndr :: ScEnv -> Var -> Bool
+forceSpecBndr env var = forceSpecFunTy env . varType $ var
+
+forceSpecFunTy :: ScEnv -> Type -> Bool
+forceSpecFunTy env = any (forceSpecArgTy env) . fst . splitFunTys
+
+forceSpecArgTy :: ScEnv -> Type -> Bool
+forceSpecArgTy env ty
+  | Just ty' <- coreView ty = forceSpecArgTy env ty'
+
+forceSpecArgTy env ty
+  | Just (tycon, tys) <- splitTyConApp_maybe ty
+  , tycon /= funTyCon
+      = L.lookupUFM (sc_annotations env) tycon == Just ForceSpecConstr
+        || any (forceSpecArgTy env) tys
+
+forceSpecArgTy _ _ = False
 \end{code}
 
 
@@ -850,12 +916,14 @@ scExpr' env (Let (Rec prs) body)
   = do { let (bndrs,rhss) = unzip prs
              (rhs_env1,bndrs') = extendRecBndrs env bndrs
              rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
+              force_spec = any (forceSpecBndr env) bndrs'
 
        ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
        ; (body_usg, body')     <- scExpr rhs_env2 body
 
        -- NB: start specLoop from body_usg
-       ; (spec_usg, specs) <- specLoop rhs_env2 (scu_calls body_usg) rhs_infos nullUsage
+       ; (spec_usg, specs) <- specLoop rhs_env2 force_spec
+                                        (scu_calls body_usg) rhs_infos nullUsage
                                        [SI [] 0 (Just usg) | usg <- rhs_usgs]
 
        ; let all_usg = spec_usg `combineUsage` body_usg
@@ -909,6 +977,7 @@ scApp env (other_fn, args)
 scTopBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, CoreBind)
 scTopBind env (Rec prs)
   | Just threshold <- sc_size env
+  , not force_spec
   , not (all (couldBeSmallEnoughToInline threshold) rhss)
                -- No specialisation
   = do { let (rhs_env,bndrs') = extendRecBndrs env bndrs
@@ -921,13 +990,15 @@ scTopBind env (Rec prs)
        ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
        ; let rhs_usg = combineUsages rhs_usgs
 
-       ; (_, specs) <- specLoop rhs_env2 (scu_calls rhs_usg) rhs_infos nullUsage
+       ; (_, specs) <- specLoop rhs_env2 force_spec
+                                 (scu_calls rhs_usg) rhs_infos nullUsage
                                 [SI [] 0 Nothing | _ <- bndrs]
 
        ; return (rhs_env1,  -- For the body of the letrec, delete the RecFun business
                  Rec (concat (zipWith specInfoBinds rhs_infos specs))) }
   where
     (bndrs,rhss) = unzip prs
+    force_spec = any (forceSpecBndr env) bndrs
 
 scTopBind env (NonRec bndr rhs)
   = do { (_, rhs') <- scExpr env rhs
@@ -992,12 +1063,13 @@ data OneSpec  = OS CallPat               -- Call pattern that generated this specialisation
 
 
 specLoop :: ScEnv
+         -> Bool                                -- force specialisation?
         -> CallEnv
         -> [RhsInfo]
         -> ScUsage -> [SpecInfo]               -- One per binder; acccumulating parameter
         -> UniqSM (ScUsage, [SpecInfo])        -- ...ditto...
-specLoop env all_calls rhs_infos usg_so_far specs_so_far
-  = do { specs_w_usg <- zipWithM (specialise env all_calls) rhs_infos specs_so_far
+specLoop env force_spec all_calls rhs_infos usg_so_far specs_so_far
+  = do { specs_w_usg <- zipWithM (specialise env force_spec all_calls) rhs_infos specs_so_far
        ; let (new_usg_s, all_specs) = unzip specs_w_usg
              new_usg   = combineUsages new_usg_s
              new_calls = scu_calls new_usg
@@ -1005,10 +1077,11 @@ specLoop env all_calls rhs_infos usg_so_far specs_so_far
        ; if isEmptyVarEnv new_calls then
                return (all_usg, all_specs) 
          else 
-               specLoop env new_calls rhs_infos all_usg all_specs }
+               specLoop env force_spec new_calls rhs_infos all_usg all_specs }
 
 specialise 
    :: ScEnv
+   -> Bool                              -- force specialisation?
    -> CallEnv                          -- Info on calls
    -> RhsInfo
    -> SpecInfo                         -- Original RHS plus patterns dealt with
@@ -1018,7 +1091,7 @@ specialise
 -- So when we make a specialised copy of the RHS, we're starting
 -- from an RHS whose nested functions have been optimised already.
 
-specialise env bind_calls (fn, arg_bndrs, body, arg_occs) 
+specialise env force_spec bind_calls (fn, arg_bndrs, body, arg_occs) 
                          spec_info@(SI specs spec_count mb_unspec)
   | not (isBottomingId fn)      -- Note [Do not specialise diverging functions]
   , notNull arg_bndrs          -- Only specialise functions
@@ -1033,7 +1106,7 @@ specialise env bind_calls (fn, arg_bndrs, body, arg_occs)
                -- Rather a hacky way to do so, but it'll do for now
        ; let spec_count' = length pats + spec_count
        ; case sc_count env of
-           Just max | spec_count' > max
+           Just max | not force_spec && spec_count' > max
                -> WARN( True, msg ) return (nullUsage, spec_info)
                where
                   msg = vcat [ sep [ ptext (sLit "SpecConstr: specialisation of") <+> quotes (ppr fn)
@@ -1105,29 +1178,77 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
                -- Usual w/w hack to avoid generating 
                -- a spec_rhs of unlifted type and no args
        
-             fn_name   = idName fn
-             fn_loc    = nameSrcSpan fn_name
-             spec_occ  = mkSpecOcc (nameOccName fn_name)
-             rule_name = mkFastString ("SC:" ++ showSDoc (ppr fn <> int rule_number))
-             spec_rhs  = mkLams spec_lam_args spec_body
-             spec_id   = mkUserLocal spec_occ spec_uniq (mkPiTypes spec_lam_args body_ty) fn_loc
-             body_ty   = exprType spec_body
-             rule_rhs  = mkVarApps (Var spec_id) spec_call_args
-             rule      = mkLocalRule rule_name specConstrActivation fn_name qvars pats rule_rhs
+             fn_name    = idName fn
+             fn_loc     = nameSrcSpan fn_name
+             spec_occ   = mkSpecOcc (nameOccName fn_name)
+             rule_name  = mkFastString ("SC:" ++ showSDoc (ppr fn <> int rule_number))
+             spec_rhs   = mkLams spec_lam_args spec_body
+             spec_str   = calcSpecStrictness fn spec_lam_args pats
+             spec_id    = mkUserLocal spec_occ spec_uniq (mkPiTypes spec_lam_args body_ty) fn_loc
+                            `setIdStrictness` spec_str         -- See Note [Transfer strictness]
+                            `setIdArity` count isId spec_lam_args
+             body_ty    = exprType spec_body
+             rule_rhs   = mkVarApps (Var spec_id) spec_call_args
+              inline_act = idInlineActivation fn
+             rule       = mkLocalRule rule_name inline_act fn_name qvars pats rule_rhs
        ; return (spec_usg, OS call_pat rule spec_id spec_rhs) }
 
--- In which phase should the specialise-constructor rules be active?
--- Originally I made them always-active, but Manuel found that
--- this defeated some clever user-written rules.  So Plan B
--- is to make them active only in Phase 0; after all, currently,
--- the specConstr transformation is only run after the simplifier
--- has reached Phase 0.  In general one would want it to be 
--- flag-controllable, but for now I'm leaving it baked in
---                                     [SLPJ Oct 01]
-specConstrActivation :: Activation
-specConstrActivation = ActiveAfter 0   -- Baked in; see comments above
+calcSpecStrictness :: Id                    -- The original function
+                   -> [Var] -> [CoreExpr]    -- Call pattern
+                  -> StrictSig              -- Strictness of specialised thing
+-- See Note [Transfer strictness]
+calcSpecStrictness fn qvars pats
+  = StrictSig (mkTopDmdType spec_dmds TopRes)
+  where
+    spec_dmds = [ lookupVarEnv dmd_env qv `orElse` lazyDmd | qv <- qvars, isId qv ]
+    StrictSig (DmdType _ dmds _) = idStrictness fn
+
+    dmd_env = go emptyVarEnv dmds pats
+
+    go env ds (Type {} : pats) = go env ds pats
+    go env (d:ds) (pat : pats) = go (go_one env d pat) ds pats
+    go env _      _            = env
+
+    go_one env d   (Var v) = extendVarEnv_C both env v d
+    go_one env (Box d)   e = go_one env d e
+    go_one env (Eval (Prod ds)) e 
+          | (Var _, args) <- collectArgs e = go env ds args
+    go_one env _         _ = env
+
 \end{code}
 
+Note [Transfer activation]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+In which phase should the specialise-constructor rules be active?
+Originally I made them always-active, but Manuel found that this
+defeated some clever user-written rules.  Then I made them active only
+in Phase 0; after all, currently, the specConstr transformation is
+only run after the simplifier has reached Phase 0, but that meant
+that specialisations didn't fire inside wrappers; see test
+simplCore/should_compile/spec-inline.
+
+So now I just use the inline-activation of the parent Id, as the
+activation for the specialiation RULE, just like the main specialiser;
+see Note [Auto-specialisation and RULES] in Specialise.
+
+
+Note [Transfer strictness]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+We must transfer strictness information from the original function to
+the specialised one.  Suppose, for example
+
+  f has strictness     SS
+        and a RULE     f (a:as) b = f_spec a as b
+
+Now we want f_spec to have strictess  LLS, otherwise we'll use call-by-need
+when calling f_spec instead of call-by-value.  And that can result in 
+unbounded worsening in space (cf the classic foldl vs foldl')
+
+See Trac #3437 for a good example.
+
+The function calcSpecStrictness performs the calculation.
+
+
 %************************************************************************
 %*                                                                     *
 \subsection{Argument analysis}
@@ -1167,7 +1288,7 @@ callToPats env bndr_occs (con_env, args)
   = return Nothing
   | otherwise
   = do { let in_scope = substInScope (sc_subst env)
-       ; prs <- argsToPats in_scope con_env (args `zip` bndr_occs)
+       ; prs <- argsToPats env in_scope con_env (args `zip` bndr_occs)
        ; let (interesting_s, pats) = unzip prs
              pat_fvs = varSetElems (exprsFreeVars pats)
              qvars   = filterOut (`elemInScopeSet` in_scope) pat_fvs
@@ -1191,7 +1312,8 @@ callToPats env bndr_occs (con_env, args)
     -- placeholder variables.  For example:
     --    C a (D (f x) (g y))  ==>  C p1 (D p2 p3)
 
-argToPat :: InScopeSet                 -- What's in scope at the fn defn site
+argToPat :: ScEnv
+         -> InScopeSet                 -- What's in scope at the fn defn site
         -> ValueEnv                    -- ValueEnv at the call site
         -> CoreArg                     -- A call arg (or component thereof)
         -> ArgOcc
@@ -1206,11 +1328,11 @@ argToPat :: InScopeSet                  -- What's in scope at the fn defn site
 --             lvl7         --> (True, lvl7)      if lvl7 is bound 
 --                                                somewhere further out
 
-argToPat _in_scope _val_env arg@(Type {}) _arg_occ
+argToPat _env _in_scope _val_env arg@(Type {}) _arg_occ
   = return (False, arg)
 
-argToPat in_scope val_env (Note _ arg) arg_occ
-  = argToPat in_scope val_env arg arg_occ
+argToPat env in_scope val_env (Note _ arg) arg_occ
+  = argToPat env in_scope val_env arg arg_occ
        -- Note [Notes in call patterns]
        -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        -- Ignore Notes.  In particular, we want to ignore any InlineMe notes
@@ -1218,16 +1340,16 @@ argToPat in_scope val_env (Note _ arg) arg_occ
        -- ride roughshod over them all for now.
        --- See Note [Notes in RULE matching] in Rules
 
-argToPat in_scope val_env (Let _ arg) arg_occ
-  = argToPat in_scope val_env arg arg_occ
+argToPat env in_scope val_env (Let _ arg) arg_occ
+  = argToPat env in_scope val_env arg arg_occ
        -- Look through let expressions
        -- e.g.         f (let v = rhs in \y -> ...v...)
        -- Here we can specialise for f (\y -> ...)
        -- because the rule-matcher will look through the let.
 
-argToPat in_scope val_env (Cast arg co) arg_occ
-  = do { (interesting, arg') <- argToPat in_scope val_env arg arg_occ
-       ; let (ty1,ty2) = coercionKind co
+argToPat env in_scope val_env (Cast arg co) arg_occ
+  | not (ignoreType env ty2)
+  = do { (interesting, arg') <- argToPat env in_scope val_env arg arg_occ
        ; if not interesting then 
                wildCardPat ty2
          else do
@@ -1236,6 +1358,10 @@ argToPat in_scope val_env (Cast arg co) arg_occ
        ; let co_name = mkSysTvName uniq (fsLit "sg")
              co_var = mkCoVar co_name (mkCoKind ty1 ty2)
        ; return (interesting, Cast arg' (mkTyVarTy co_var)) } }
+  where
+    (ty1, ty2) = coercionKind co
+
+    
 
 {-     Disabling lambda specialisation for now
        It's fragile, and the spec_loop can be infinite
@@ -1251,15 +1377,16 @@ argToPat in_scope val_env arg arg_occ
 
   -- Check for a constructor application
   -- NB: this *precedes* the Var case, so that we catch nullary constrs
-argToPat in_scope val_env arg arg_occ
+argToPat env in_scope val_env arg arg_occ
   | Just (ConVal dc args) <- isValue val_env arg
+  , not (ignoreAltCon env dc)
   , case arg_occ of
        ScrutOcc _ -> True              -- Used only by case scrutinee
        BothOcc    -> case arg of       -- Used elsewhere
                        App {} -> True  --     see Note [Reboxing]
                        _other -> False
        _other     -> False     -- No point; the arg is not decomposed
-  = do { args' <- argsToPats in_scope val_env (args `zip` conArgOccs arg_occ dc)
+  = do { args' <- argsToPats env in_scope val_env (args `zip` conArgOccs arg_occ dc)
        ; return (True, mk_con_app dc (map snd args')) }
 
   -- Check if the argument is a variable that 
@@ -1267,9 +1394,10 @@ argToPat in_scope val_env arg arg_occ
   -- It's worth specialising on this if
   --   (a) it's used in an interesting way in the body
   --   (b) we know what its value is
-argToPat in_scope val_env (Var v) arg_occ
+argToPat env in_scope val_env (Var v) arg_occ
   | case arg_occ of { UnkOcc -> False; _other -> True },       -- (a)
-    is_value                                                   -- (b)
+    is_value,                                                  -- (b)
+    not (ignoreType env (varType v))
   = return (True, Var v)
   where
     is_value 
@@ -1298,7 +1426,7 @@ argToPat in_scope val_env (Var v) arg_occ
        -- We don't want to specialise for that *particular* x,y
 
   -- The default case: make a wild-card
-argToPat _in_scope _val_env arg _arg_occ
+argToPat _env _in_scope _val_env arg _arg_occ
   = wildCardPat (exprType arg)
 
 wildCardPat :: Type -> UniqSM (Bool, CoreArg)
@@ -1306,13 +1434,13 @@ wildCardPat ty = do { uniq <- getUniqueUs
                    ; let id = mkSysLocal (fsLit "sc") uniq ty
                    ; return (False, Var id) }
 
-argsToPats :: InScopeSet -> ValueEnv
+argsToPats :: ScEnv -> InScopeSet -> ValueEnv
           -> [(CoreArg, ArgOcc)]
           -> UniqSM [(Bool, CoreArg)]
-argsToPats in_scope val_env args
+argsToPats env in_scope val_env args
   = mapM do_one args
   where
-    do_one (arg,occ) = argToPat in_scope val_env arg occ
+    do_one (arg,occ) = argToPat env in_scope val_env arg occ
 \end{code}