Add new ForceSpecConstr annotation
[ghc-hetmet.git] / compiler / specialise / SpecConstr.lhs
index 5606830..8067617 100644 (file)
@@ -476,7 +476,8 @@ Annotating a type with NoSpecConstr will make SpecConstr not specialise
 for arguments of that type.
 
 \begin{code}
-data SpecConstrAnnotation = NoSpecConstr deriving( Data, Typeable )
+data SpecConstrAnnotation = NoSpecConstr | ForceSpecConstr
+                deriving( Data, Typeable, Eq )
 \end{code}
 
 %************************************************************************
@@ -491,7 +492,7 @@ specConstrProgram guts
   = do
       dflags <- getDynFlags
       us     <- getUniqueSupplyM
-      annos  <- deserializeAnnotations deserializeWithData
+      annos  <- deserializeAnnotations guts deserializeWithData
       let binds' = fst $ initUs us (go (initScEnv dflags annos) (mg_binds guts))
       return (guts { mg_binds = binds' })
   where
@@ -656,9 +657,7 @@ extendCaseBndrs env case_bndr con alt_bndrs
 
 ignoreTyCon :: ScEnv -> TyCon -> Bool
 ignoreTyCon env tycon
-  = case L.lookupUFM (sc_annotations env) tycon of
-      Just NoSpecConstr -> True
-      _                 -> False
+  = L.lookupUFM (sc_annotations env) tycon == Just NoSpecConstr
 
 ignoreType :: ScEnv -> Type -> Bool
 ignoreType env ty
@@ -670,6 +669,24 @@ ignoreAltCon :: ScEnv -> AltCon -> Bool
 ignoreAltCon env (DataAlt dc) = ignoreTyCon env (dataConTyCon dc)
 ignoreAltCon env (LitAlt lit) = ignoreType env (literalType lit)
 ignoreAltCon _   DEFAULT      = True
+
+forceSpecBndr :: ScEnv -> Var -> Bool
+forceSpecBndr env var = forceSpecFunTy env . varType $ var
+
+forceSpecFunTy :: ScEnv -> Type -> Bool
+forceSpecFunTy env = any (forceSpecArgTy env) . fst . splitFunTys
+
+forceSpecArgTy :: ScEnv -> Type -> Bool
+forceSpecArgTy env ty
+  | Just ty' <- coreView ty = forceSpecArgTy env ty'
+
+forceSpecArgTy env ty
+  | Just (tycon, tys) <- splitTyConApp_maybe ty
+  , tycon /= funTyCon
+      = L.lookupUFM (sc_annotations env) tycon == Just ForceSpecConstr
+        || any (forceSpecArgTy env) tys
+
+forceSpecArgTy _ _ = False
 \end{code}
 
 
@@ -900,12 +917,14 @@ scExpr' env (Let (Rec prs) body)
   = do { let (bndrs,rhss) = unzip prs
              (rhs_env1,bndrs') = extendRecBndrs env bndrs
              rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
+              force_spec = any (forceSpecBndr env) bndrs'
 
        ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
        ; (body_usg, body')     <- scExpr rhs_env2 body
 
        -- NB: start specLoop from body_usg
-       ; (spec_usg, specs) <- specLoop rhs_env2 (scu_calls body_usg) rhs_infos nullUsage
+       ; (spec_usg, specs) <- specLoop rhs_env2 force_spec
+                                        (scu_calls body_usg) rhs_infos nullUsage
                                        [SI [] 0 (Just usg) | usg <- rhs_usgs]
 
        ; let all_usg = spec_usg `combineUsage` body_usg
@@ -959,6 +978,7 @@ scApp env (other_fn, args)
 scTopBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, CoreBind)
 scTopBind env (Rec prs)
   | Just threshold <- sc_size env
+  , not force_spec
   , not (all (couldBeSmallEnoughToInline threshold) rhss)
                -- No specialisation
   = do { let (rhs_env,bndrs') = extendRecBndrs env bndrs
@@ -971,13 +991,15 @@ scTopBind env (Rec prs)
        ; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
        ; let rhs_usg = combineUsages rhs_usgs
 
-       ; (_, specs) <- specLoop rhs_env2 (scu_calls rhs_usg) rhs_infos nullUsage
+       ; (_, specs) <- specLoop rhs_env2 force_spec
+                                 (scu_calls rhs_usg) rhs_infos nullUsage
                                 [SI [] 0 Nothing | _ <- bndrs]
 
        ; return (rhs_env1,  -- For the body of the letrec, delete the RecFun business
                  Rec (concat (zipWith specInfoBinds rhs_infos specs))) }
   where
     (bndrs,rhss) = unzip prs
+    force_spec = any (forceSpecBndr env) bndrs
 
 scTopBind env (NonRec bndr rhs)
   = do { (_, rhs') <- scExpr env rhs
@@ -1042,12 +1064,13 @@ data OneSpec  = OS CallPat              -- Call pattern that generated this specialisation
 
 
 specLoop :: ScEnv
+         -> Bool                                -- force specialisation?
         -> CallEnv
         -> [RhsInfo]
         -> ScUsage -> [SpecInfo]               -- One per binder; acccumulating parameter
         -> UniqSM (ScUsage, [SpecInfo])        -- ...ditto...
-specLoop env all_calls rhs_infos usg_so_far specs_so_far
-  = do { specs_w_usg <- zipWithM (specialise env all_calls) rhs_infos specs_so_far
+specLoop env force_spec all_calls rhs_infos usg_so_far specs_so_far
+  = do { specs_w_usg <- zipWithM (specialise env force_spec all_calls) rhs_infos specs_so_far
        ; let (new_usg_s, all_specs) = unzip specs_w_usg
              new_usg   = combineUsages new_usg_s
              new_calls = scu_calls new_usg
@@ -1055,10 +1078,11 @@ specLoop env all_calls rhs_infos usg_so_far specs_so_far
        ; if isEmptyVarEnv new_calls then
                return (all_usg, all_specs) 
          else 
-               specLoop env new_calls rhs_infos all_usg all_specs }
+               specLoop env force_spec new_calls rhs_infos all_usg all_specs }
 
 specialise 
    :: ScEnv
+   -> Bool                              -- force specialisation?
    -> CallEnv                          -- Info on calls
    -> RhsInfo
    -> SpecInfo                         -- Original RHS plus patterns dealt with
@@ -1068,7 +1092,7 @@ specialise
 -- So when we make a specialised copy of the RHS, we're starting
 -- from an RHS whose nested functions have been optimised already.
 
-specialise env bind_calls (fn, arg_bndrs, body, arg_occs) 
+specialise env force_spec bind_calls (fn, arg_bndrs, body, arg_occs) 
                          spec_info@(SI specs spec_count mb_unspec)
   | not (isBottomingId fn)      -- Note [Do not specialise diverging functions]
   , notNull arg_bndrs          -- Only specialise functions
@@ -1083,7 +1107,7 @@ specialise env bind_calls (fn, arg_bndrs, body, arg_occs)
                -- Rather a hacky way to do so, but it'll do for now
        ; let spec_count' = length pats + spec_count
        ; case sc_count env of
-           Just max | spec_count' > max
+           Just max | not force_spec && spec_count' > max
                -> WARN( True, msg ) return (nullUsage, spec_info)
                where
                   msg = vcat [ sep [ ptext (sLit "SpecConstr: specialisation of") <+> quotes (ppr fn)