[project @ 1999-09-17 09:15:22 by simonpj]
authorsimonpj <unknown>
Fri, 17 Sep 1999 09:15:44 +0000 (09:15 +0000)
committersimonpj <unknown>
Fri, 17 Sep 1999 09:15:44 +0000 (09:15 +0000)
This bunch of commits represents work in progress on inlining and
worker/wrapper stuff.

Currently, I think it makes the compiler slightly worse than 4.04, for
reasons I don't yet understand.  But it means that Simon and I can
both peer at what is going on.

* Substantially improve handling of coerces in worker/wrapper

* exprIsDupable for an application (f e1 .. en) wasn't calling exprIsDupable
  on the arguments!!  So applications with few, but large, args were being dupliated.

* sizeExpr on an application wasn't doing a nukeScrutDiscount on the arg of
  an application!!  So bogus discounts could accumulate from arguments!

* Improve handling of INLINE pragmas in calcUnfoldingGuidance.  It was really
  wrong before

17 files changed:
ghc/compiler/coreSyn/CoreUnfold.lhs
ghc/compiler/coreSyn/CoreUtils.lhs
ghc/compiler/main/ErrUtils.lhs
ghc/compiler/main/MkIface.lhs
ghc/compiler/parser/RdrHsSyn.lhs
ghc/compiler/rename/Rename.lhs
ghc/compiler/rename/RnSource.lhs
ghc/compiler/simplCore/OccurAnal.lhs
ghc/compiler/simplCore/SetLevels.lhs
ghc/compiler/simplCore/SimplCore.lhs
ghc/compiler/simplCore/SimplMonad.lhs
ghc/compiler/simplCore/SimplUtils.lhs
ghc/compiler/simplCore/Simplify.lhs
ghc/compiler/stgSyn/CoreToStg.lhs
ghc/compiler/stranal/StrictAnal.lhs
ghc/compiler/stranal/WorkWrap.lhs
ghc/compiler/stranal/WwLib.lhs

index dfad210..96c93a6 100644 (file)
@@ -48,14 +48,14 @@ import PprCore              ( pprCoreExpr )
 import OccurAnal       ( occurAnalyseGlobalExpr )
 import BinderInfo      ( )
 import CoreUtils       ( coreExprType, exprIsTrivial, exprIsValue, exprIsCheap )
-import Id              ( Id, idType, idUnique, isId, 
+import Id              ( Id, idType, idUnique, isId, getIdWorkerInfo,
                          getIdSpecialisation, getInlinePragma, getIdUnfolding
                        )
 import VarSet
 import Name            ( isLocallyDefined )
 import Const           ( Con(..), isLitLitLit, isWHNFCon )
 import PrimOp          ( PrimOp(..), primOpIsDupable )
-import IdInfo          ( ArityInfo(..), InlinePragInfo(..), OccInfo(..) )
+import IdInfo          ( ArityInfo(..), InlinePragInfo(..), OccInfo(..), workerExists )
 import TyCon           ( tyConFamilySize )
 import Type            ( splitAlgTyConApp_maybe, splitFunTy_maybe, isUnLiftedType )
 import Const           ( isNoRepLit )
@@ -170,10 +170,8 @@ instance Outputable UnfoldingGuidance where
     ppr UnfoldAlways    = ptext SLIT("ALWAYS")
     ppr UnfoldNever    = ptext SLIT("NEVER")
     ppr (UnfoldIfGoodArgs v cs size discount)
-      = hsep [ptext SLIT("IF_ARGS"), int v,
-              if null cs       -- always print *something*
-               then char 'X'
-               else hcat (map (text . show) cs),
+      = hsep [ ptext SLIT("IF_ARGS"), int v,
+              brackets (hsep (map int cs)),
               int size,
               int discount ]
 \end{code}
@@ -199,21 +197,33 @@ calcUnfoldingGuidance bOMB_OUT_SIZE expr
   = UnfoldAlways
  
   | otherwise
-  = case collectBinders expr of { (binders, body) ->
-    let
-       val_binders = filter isId binders
-    in
+  = case collect_val_bndrs expr of { (inline, val_binders, body) ->
     case (sizeExpr bOMB_OUT_SIZE val_binders body) of
 
       TooBig -> UnfoldNever
 
       SizeIs size cased_args scrut_discount
        -> UnfoldIfGoodArgs
-                       (length val_binders)
+                       n_val_binders
                        (map discount_for val_binders)
-                       (I# size)
+                       final_size
                        (I# scrut_discount)
        where        
+           boxed_size    = I# size
+
+           n_val_binders = length val_binders
+
+           final_size | inline     = boxed_size `min` (n_val_binders + 2)
+                      | otherwise  = boxed_size
+               -- The idea is that if there is an INLINE pragma (inline is True)
+               -- and there's a big body, we give a size of n_val_binders+2.  This
+               -- This is enough to defeat the no-size-increase test in callSiteInline;
+               --   we don't want to inline an INLINE thing into a totally boring context
+               --
+               -- Sometimes, though, an INLINE thing is smaller than n_val_binders+2.
+               -- A particular case in point is a constructor, which has size 1.
+               -- We want to inline this regardless, hence the `min`
+
            discount_for b 
                | num_cases == 0 = 0
                | is_fun_ty      = num_cases * opt_UF_FunAppDiscount
@@ -228,6 +238,19 @@ calcUnfoldingGuidance bOMB_OUT_SIZE expr
                                          Nothing       -> (False, panic "discount")
                                          Just (tc,_,_) -> (True,  tc)
        }
+  where
+
+    collect_val_bndrs e = go False [] e
+       -- We need to be a bit careful about how we collect the
+       -- value binders.  In ptic, if we see 
+       --      __inline_me (\x y -> e)
+       -- We want to say "2 value binders".  Why?  So that 
+       -- we take account of information given for the arguments
+
+    go inline rev_vbs (Note InlineMe e)     = go True   rev_vbs     e
+    go inline rev_vbs (Lam b e) | isId b    = go inline (b:rev_vbs) e
+                               | otherwise = go inline rev_vbs     e
+    go inline rev_vbs e                            = (inline, reverse rev_vbs, e)
 \end{code}
 
 \begin{code}
@@ -243,12 +266,6 @@ sizeExpr (I# bOMB_OUT_SIZE) args expr
     size_up (Type t)         = sizeZero        -- Types cost nothing
     size_up (Var v)           = sizeOne
 
-    size_up (Note InlineMe _) = sizeTwo                -- The idea is that this is one more
-                                               -- than the size of the "call" (i.e. 1)
-                                               -- We want to reply "no" to noSizeIncrease
-                                               -- for a bare reference (i.e. applied to no args) 
-                                               -- to an INLINE thing
-
     size_up (Note _ body)     = size_up body   -- Notes cost nothing
 
     size_up (App fun (Type t))  = size_up fun
@@ -289,7 +306,7 @@ sizeExpr (I# bOMB_OUT_SIZE) args expr
 
     ------------ 
     size_up_app (App fun arg) args   = size_up_app fun (arg:args)
-    size_up_app fun          args   = foldr (addSize . size_up) (fun_discount fun) args
+    size_up_app fun          args   = foldr (addSize . nukeScrutDiscount . size_up) (fun_discount fun) args
 
        -- A function application with at least one value argument
        -- so if the function is an argument give it an arg-discount
@@ -597,7 +614,9 @@ computeDiscount n_vals_wanted arg_discounts res_discount arg_infos result_used
        -- we also discount 1 for each argument passed, because these will
        -- reduce with the lambdas in the function (we count 1 for a lambda
        -- in size_up).
-  = length (take n_vals_wanted arg_infos) +
+  = 1 +                        -- Discount of 1 because the result replaces the call
+                       -- so we count 1 for the function itself
+    length (take n_vals_wanted arg_infos) +
                        -- Discount of 1 for each arg supplied, because the 
                        -- result replaces the call
     round (opt_UF_KeenessFactor * 
@@ -636,10 +655,21 @@ blackListed :: IdSet              -- Used in transformation rules
 -- inlined because of the inline phase we are in.  This is the sole
 -- place that the inline phase number is looked at.
 
+--     ToDo: improve horrible coding style (too much duplication)
+
 -- Phase 0: used for 'no imported inlinings please'
 -- This prevents wrappers getting inlined which in turn is bad for full laziness
+-- NEW: try using 'not a wrapper' rather than 'not imported' in this phase.
+-- This allows a little more inlining, which seems to be important, sometimes.
+-- For example PrelArr.newIntArr gets better.
 blackListed rule_vars (Just 0)
-  = \v -> not (isLocallyDefined v)
+  = \v -> let v_uniq = idUnique v
+         in 
+               -- not (isLocallyDefined v)
+            workerExists (getIdWorkerInfo v)
+         || v `elemVarSet` rule_vars
+         || not (isEmptyCoreRules (getIdSpecialisation v))
+         || v_uniq == runSTRepIdKey
 
 -- Phase 1: don't inline any rule-y things or things with specialisations
 blackListed rule_vars (Just 1)
index fb0b0eb..e2a3b13 100644 (file)
@@ -10,7 +10,7 @@ module CoreUtils (
        exprIsBottom, exprIsDupable, exprIsTrivial, exprIsCheap, 
        exprIsValue,
        exprOkForSpeculation, exprIsBig, hashExpr,
-       exprArity, exprGenerousArity,
+       exprArity, exprEtaExpandArity,
        cheapEqExpr, eqExpr, applyTypeToArgs
     ) where
 
@@ -149,7 +149,7 @@ exprIsDupable (Con con args) = conIsDupable con &&
 
 exprIsDupable (Note _ e)     = exprIsDupable e
 exprIsDupable expr          = case collectArgs expr of  
-                                 (Var f, args) ->  valArgCount args <= dupAppSize
+                                 (Var f, args) ->  all exprIsDupable args && valArgCount args <= dupAppSize
                                  other         ->  False
 
 dupAppSize :: Int
@@ -230,7 +230,8 @@ It returns True iff
 
        the expression guarantees to terminate, 
        soon, 
-       without raising an exceptoin
+       without raising an exception,
+       without causing a side effect (e.g. writing a mutable variable)
 
 E.G.
        let x = case y# +# 1# of { r# -> I# r# }
@@ -303,13 +304,24 @@ exprIsValue e@(App _ _)   = case collectArgs e of
 exprArity :: CoreExpr -> Int   -- How many value lambdas are at the top
 exprArity (Lam b e)     | isTyVar b    = exprArity e
                        | otherwise     = 1 + exprArity e
+
 exprArity (Note note e) | ok_note note = exprArity e
-exprArity other                                = 0
+                       where
+                         ok_note (Coerce _ _) = True
+                               -- We *do* look through coerces when getting arities.
+                               -- Reason: arities are to do with *representation* and
+                               -- work duplication. 
+                         ok_note InlineMe     = True
+                         ok_note InlineCall   = True
+                         ok_note other        = False
+                               -- SCC and TermUsg might be over-conservative?
+
+exprArity other        = 0
 \end{code}
 
 
 \begin{code}
-exprGenerousArity :: CoreExpr -> Int   -- The number of args the thing can be applied to
+exprEtaExpandArity :: CoreExpr -> Int  -- The number of args the thing can be applied to
                                        -- without doing much work
 -- This is used when eta expanding
 --     e  ==>  \xy -> e x y
@@ -320,17 +332,36 @@ exprGenerousArity :: CoreExpr -> Int      -- The number of args the thing can be app
 -- We are prepared to evaluate x each time round the loop in order to get that
 -- Hence "generous" arity
 
-exprGenerousArity (Var v)              = arityLowerBound (getIdArity v)
-exprGenerousArity (Note note e)        
-  | ok_note note                       = exprGenerousArity e
-exprGenerousArity (Lam x e) 
-  | isId x                             = 1 + exprGenerousArity e
-  | otherwise                          = exprGenerousArity e
-exprGenerousArity (Let bind body)      
-  | all exprIsCheap (rhssOfBind bind)  = exprGenerousArity body
-exprGenerousArity (Case scrut _ alts)
-  | exprIsCheap scrut                  = min_zero [exprGenerousArity rhs | (_,_,rhs) <- alts]
-exprGenerousArity other                = 0     -- Could do better for applications
+exprEtaExpandArity (Var v)             = arityLowerBound (getIdArity v)
+exprEtaExpandArity (Lam x e) 
+  | isId x                             = 1 + exprEtaExpandArity e
+  | otherwise                          = exprEtaExpandArity e
+exprEtaExpandArity (Let bind body)     
+  | all exprIsCheap (rhssOfBind bind)  = exprEtaExpandArity body
+exprEtaExpandArity (Case scrut _ alts)
+  | exprIsCheap scrut                  = min_zero [exprEtaExpandArity rhs | (_,_,rhs) <- alts]
+
+exprEtaExpandArity (Note note e)       
+  | ok_note note                       = exprEtaExpandArity e
+  where
+    ok_note InlineCall = True
+    ok_note other      = False
+       -- Notice that we do not look through __inline_me__
+       -- This one is a bit more surprising, but consider
+       --      f = _inline_me (\x -> e)
+       -- We DO NOT want to eta expand this to
+       --      f = \x -> (_inline_me (\x -> e)) x
+       -- because the _inline_me gets dropped now it is applied, 
+       -- giving just
+       --      f = \x -> e
+       -- A Bad Idea
+       --
+       -- Notice also that we don't look through Coerce
+       -- This is simply because the etaExpand code in SimplUtils
+       -- isn't capable of making the alternating lambdas and coerces
+       -- that would be necessary to exploit it
+
+exprEtaExpandArity other               = 0     -- Could do better for applications
 
 min_zero :: [Int] -> Int       -- Find the minimum, but zero is the smallest
 min_zero (x:xs) = go x xs
@@ -340,24 +371,6 @@ min_zero (x:xs) = go x xs
                  go min (x:xs) | x < min   = go x xs
                                | otherwise = go min xs 
 
-ok_note (SCC _)             = False    -- (Over?) conservative
-ok_note (TermUsg _)  = False   -- Doesn't matter much
-
-ok_note (Coerce _ _) = True
-       -- We *do* look through coerces when getting arities.
-       -- Reason: arities are to do with *representation* and
-       -- work duplication. 
-
-ok_note InlineCall   = True
-ok_note InlineMe     = False
-       -- This one is a bit more surprising, but consider
-       --      f = _inline_me (\x -> e)
-       -- We DO NOT want to eta expand this to
-       --      f = \x -> (_inline_me (\x -> e)) x
-       -- because the _inline_me gets dropped now it is applied, 
-       -- giving just
-       --      f = \x -> e
-       -- A Bad Idea
 \end{code}
 
 
index c5abb68..bf94690 100644 (file)
@@ -9,14 +9,14 @@ module ErrUtils (
        addShortErrLocLine, addShortWarnLocLine,
        addErrLocHdrLine,
        dontAddErrLoc,
-       pprBagOfErrors, pprBagOfWarnings,
+       printErrorsAndWarnings, pprBagOfErrors, pprBagOfWarnings,
        ghcExit,
        doIfSet, dumpIfSet
     ) where
 
 #include "HsVersions.h"
 
-import Bag             ( Bag, bagToList )
+import Bag             ( Bag, bagToList, isEmptyBag )
 import SrcLoc          ( SrcLoc, noSrcLoc )
 import Util            ( sortLt )
 import Outputable
@@ -57,6 +57,16 @@ dontAddErrLoc title rest_of_err_msg
  | otherwise  =
     ( noSrcLoc, hang (text title <> colon) 4 rest_of_err_msg )
 
+printErrorsAndWarnings :: Bag ErrMsg -> Bag WarnMsg -> IO ()
+       -- Don't print any warnings if there are errors
+printErrorsAndWarnings errs warns
+  | no_errs && no_warns  = return ()
+  | no_errs             = printErrs (pprBagOfWarnings warns)
+  | otherwise           = printErrs (pprBagOfErrors   errs)
+  where
+    no_warns = isEmptyBag warns
+    no_errs  = isEmptyBag errs
+
 pprBagOfErrors :: Bag ErrMsg -> SDoc
 pprBagOfErrors bag_of_errors
   = vcat [text "" $$ p | (_,p) <- sorted_errs ]
index e823e47..a407ab7 100644 (file)
@@ -27,8 +27,8 @@ import Id             ( Id, idType, idInfo, omitIfaceSigForId, isUserExportedId,
 import Var             ( isId )
 import VarSet
 import DataCon         ( StrictnessMark(..), dataConSig, dataConFieldLabels, dataConStrictMarks )
-import IdInfo          ( IdInfo, StrictnessInfo, ArityInfo, InlinePragInfo(..), inlinePragInfo,
-                         arityInfo, ppArityInfo, 
+import IdInfo          ( IdInfo, StrictnessInfo(..), ArityInfo, InlinePragInfo(..), inlinePragInfo,
+                         arityInfo, ppArityInfo, arityLowerBound,
                          strictnessInfo, ppStrictnessInfo, isBottomingStrictness,
                          cafInfo, ppCafInfo, specInfo,
                          cprInfo, ppCprInfo,
@@ -290,7 +290,8 @@ ifaceId get_idinfo needed_ids is_rec id rhs
   = Nothing            -- Well, that was easy!
 
 ifaceId get_idinfo needed_ids is_rec id rhs
-  = Just (hsep [sig_pretty, prag_pretty, char ';'], new_needed_ids)
+  = ASSERT2( arity_matches_strictness, ppr id )
+    Just (hsep [sig_pretty, prag_pretty, char ';'], new_needed_ids)
   where
     core_idinfo = idInfo id
     stg_idinfo  = get_idinfo id
@@ -310,7 +311,8 @@ ifaceId get_idinfo needed_ids is_rec id rhs
                                        ptext SLIT("##-}")]
 
     ------------  Arity  --------------
-    arity_pretty  = ppArityInfo (arityInfo stg_idinfo)
+    arity_info    = arityInfo stg_idinfo
+    arity_pretty  = ppArityInfo arity_info
 
     ------------ Caf Info --------------
     caf_pretty = ppCafInfo (cafInfo stg_idinfo)
@@ -369,6 +371,15 @@ ifaceId get_idinfo needed_ids is_rec id rhs
 
     find_fvs expr = exprSomeFreeVars interestingId expr
 
+    ------------ Sanity checking --------------
+       -- The arity of a wrapper function should match its strictness,
+       -- or else an importing module will get very confused indeed.
+    arity_matches_strictness
+       = not has_worker ||
+         case strict_info of
+           StrictnessInfo ds _ -> length ds == arityLowerBound arity_info
+           other               -> True
+    
 interestingId id = isId id && isLocallyDefined id &&
                   not (omitIfaceSigForId id)
 \end{code}
index 40250ee..74b4da4 100644 (file)
@@ -48,7 +48,7 @@ module RdrHsSyn (
        RdrNameGenPragmas,
        RdrNameInstancePragmas,
        extractHsTyRdrNames, 
-       extractHsTyRdrTyVars,
+       extractHsTyRdrTyVars, extractHsTysRdrTyVars,
        extractPatsTyVars, 
        extractRuleBndrsTyVars,
  
@@ -138,6 +138,9 @@ extractHsTyRdrNames ty = nub (extract_ty ty [])
 extractHsTyRdrTyVars    :: RdrNameHsType -> [RdrName]
 extractHsTyRdrTyVars ty =  filter isRdrTyVar (extractHsTyRdrNames ty)
 
+extractHsTysRdrTyVars    :: [RdrNameHsType] -> [RdrName]
+extractHsTysRdrTyVars tys =  filter isRdrTyVar (nub (extract_tys tys []))
+
 extractRuleBndrsTyVars :: [RuleBndr RdrName] -> [RdrName]
 extractRuleBndrsTyVars bndrs = filter isRdrTyVar (nub (foldr go [] bndrs))
                            where
@@ -151,6 +154,8 @@ extract_ctxt ctxt acc = foldr extract_ass acc ctxt
                     where
                       extract_ass (cls, tys) acc = foldr extract_ty (cls : acc) tys
 
+extract_tys tys acc = foldr extract_ty acc tys
+
 extract_ty (MonoTyApp ty1 ty2)          acc = extract_ty ty1 (extract_ty ty2 acc)
 extract_ty (MonoListTy ty)              acc = extract_ty ty acc
 extract_ty (MonoTupleTy tys _)          acc = foldr extract_ty acc tys
index baf7b30..a15d700 100644 (file)
@@ -44,9 +44,7 @@ import PrelMods               ( mAIN_Name, pREL_MAIN_Name )
 import TysWiredIn      ( unitTyCon, intTyCon, doubleTyCon, boolTyCon )
 import PrelInfo                ( ioTyCon_NAME, numClass_RDR, thinAirIdNames, derivingOccurrences )
 import Type            ( namesOfType, funTyCon )
-import ErrUtils                ( pprBagOfErrors, pprBagOfWarnings,
-                         doIfSet, dumpIfSet, ghcExit
-                       )
+import ErrUtils                ( printErrorsAndWarnings, dumpIfSet, ghcExit )
 import BasicTypes      ( NewOrData(..) )
 import Bag             ( isEmptyBag, bagToList )
 import FiniteMap       ( fmToList, delListFromFM, addToFM, sizeFM, eltsFM )
@@ -77,14 +75,7 @@ renameModule us this_mod@(HsModule mod_name vers exports imports local_decls loc
        \ (maybe_rn_stuff, rn_errs_bag, rn_warns_bag) ->
 
        -- Check for warnings
-    doIfSet (not (isEmptyBag rn_warns_bag))
-           (printErrs (pprBagOfWarnings rn_warns_bag)) >>
-
-       -- Check for errors; exit if so
-    doIfSet (not (isEmptyBag rn_errs_bag))
-           (printErrs (pprBagOfErrors rn_errs_bag)      >>
-            ghcExit 1
-           )                                            >>
+    printErrorsAndWarnings rn_errs_bag rn_warns_bag    >>
 
        -- Dump output, if any
     (case maybe_rn_stuff of
@@ -95,7 +86,10 @@ renameModule us this_mod@(HsModule mod_name vers exports imports local_decls loc
     )                                                  >>
 
        -- Return results
-    return maybe_rn_stuff
+    if not (isEmptyBag rn_errs_bag) then
+           ghcExit 1 >> return Nothing
+    else
+           return maybe_rn_stuff
 \end{code}
 
 
index 780c91f..ecc7015 100644 (file)
@@ -14,7 +14,7 @@ import HsPragmas
 import HsTypes         ( getTyVarName, pprClassAssertion, cmpHsTypes )
 import RdrName         ( RdrName, isRdrDataCon, rdrNameOcc, isRdrTyVar )
 import RdrHsSyn                ( RdrNameContext, RdrNameHsType, RdrNameConDecl,
-                         extractRuleBndrsTyVars, extractHsTyRdrTyVars
+                         extractRuleBndrsTyVars, extractHsTyRdrTyVars, extractHsTysRdrTyVars
                        )
 import RnHsSyn
 import HsCore
@@ -551,7 +551,7 @@ rnHsPolyType doc (HsForAllTy Nothing ctxt ty)
        mentioned_in_tau = extractHsTyRdrTyVars ty
        forall_tyvars    = filter (not . (`elemFM` name_env)) mentioned_in_tau
     in
-    checkConstraints False doc forall_tyvars ctxt ty   `thenRn` \ ctxt' ->
+    checkConstraints doc forall_tyvars mentioned_in_tau ctxt ty        `thenRn` \ ctxt' ->
     rnForAll doc (map UserTyVar forall_tyvars) ctxt' ty
 
 rnHsPolyType doc (HsForAllTy (Just forall_tyvars) ctxt tau)
@@ -575,9 +575,9 @@ rnHsPolyType doc (HsForAllTy (Just forall_tyvars) ctxt tau)
  
        forall_tyvar_names    = map getTyVarName forall_tyvars
     in
-    mapRn_ (forAllErr doc tau) bad_guys                        `thenRn_`
-    mapRn_ (forAllWarn doc tau) warn_guys                      `thenRn_`
-    checkConstraints True doc forall_tyvar_names ctxt tau      `thenRn` \ ctxt' ->
+    mapRn_ (forAllErr doc tau) bad_guys                                        `thenRn_`
+    mapRn_ (forAllWarn doc tau) warn_guys                                      `thenRn_`
+    checkConstraints doc forall_tyvar_names mentioned_in_tau ctxt tau  `thenRn` \ ctxt' ->
     rnForAll doc forall_tyvars ctxt' tau
 
 rnHsPolyType doc other_ty = rnHsType doc other_ty
@@ -587,19 +587,26 @@ rnHsPolyType doc other_ty = rnHsType doc other_ty
 -- Since the forall'd type variables are a subset of the free tyvars
 -- of the tau-type part, this guarantees that every constraint mentions
 -- at least one of the free tyvars in ty
-checkConstraints explicit_forall doc forall_tyvars ctxt ty
+checkConstraints doc forall_tyvars tau_vars ctxt ty
    = mapRn check ctxt                  `thenRn` \ maybe_ctxt' ->
      returnRn (catMaybes maybe_ctxt')
            -- Remove problem ones, to avoid duplicate error message.
    where
      check ct@(_,tys)
-       | forall_mentioned = returnRn (Just ct)
-       | otherwise        = addErrRn (ctxtErr explicit_forall doc forall_tyvars ct ty)
-                            `thenRn_` returnRn Nothing
+       | ambiguous = failWithRn Nothing (ambigErr doc ct ty)
+       | not_univ  = failWithRn Nothing (univErr  doc ct ty)
+       | otherwise = returnRn (Just ct)
         where
-         forall_mentioned = foldr ((||) . any (`elem` forall_tyvars) . extractHsTyRdrTyVars)
-                            False
-                            tys
+         ct_vars    = extractHsTysRdrTyVars tys
+
+         ambiguous  =  -- All the universally-quantified tyvars in the constraint must appear in the tau ty
+                       -- (will change when we get functional dependencies)
+                       not (all (\ct_var -> not (ct_var `elem` forall_tyvars) || ct_var `elem` tau_vars) ct_vars)
+                       
+         not_univ   =  -- At least one of the tyvars in each constraint must
+                       -- be universally quantified. This restriction isn't in Hugs
+                       not (any (`elem` forall_tyvars) ct_vars)
+       
 
 rnForAll doc forall_tyvars ctxt ty
   = bindTyVarsFVRn doc forall_tyvars   $ \ new_tyvars ->
@@ -918,17 +925,22 @@ forAllErr doc ty tyvar
       $$
       (ptext SLIT("In") <+> doc))
 
-ctxtErr explicit_forall doc tyvars constraint ty
-  = sep [ptext SLIT("None of the type variable(s) in the constraint")
-          <+> quotes (pprClassAssertion constraint),
-        if explicit_forall then
-          nest 4 (ptext SLIT("is universally quantified (i.e. bound by the forall)"))
-        else
-          nest 4 (ptext SLIT("appears in the type") <+> quotes (ppr ty))
+univErr doc constraint ty
+  = sep [ptext SLIT("All of the type variable(s) in the constraint")
+          <+> quotes (pprClassAssertion constraint) 
+         <+> ptext SLIT("are already in scope"),
+        nest 4 (ptext SLIT("At least one must be universally quantified here"))
     ]
     $$
     (ptext SLIT("In") <+> doc)
 
+ambigErr doc constraint ty
+  = sep [ptext SLIT("Ambiguous constraint") <+> quotes (pprClassAssertion constraint),
+        nest 4 (ptext SLIT("in the type:") <+> ppr ty),
+        nest 4 (ptext SLIT("Each forall-d type variable mentioned by the constraint must appear after the =>."))]
+    $$
+    (ptext SLIT("In") <+> doc)
+
 unexpectedForAllTy ty
   = ptext SLIT("Unexpected forall type:") <+> ppr ty
 
index e137536..01e5652 100644 (file)
@@ -25,7 +25,7 @@ import CoreSyn
 import CoreFVs         ( idRuleVars )
 import CoreUtils       ( exprIsTrivial )
 import Const           ( Con(..), Literal(..) )
-import Id              ( isSpecPragmaId, isOneShotLambda,
+import Id              ( isSpecPragmaId, isOneShotLambda, setOneShotLambda, 
                          getInlinePragma, setInlinePragma,
                          isExportedId, modifyIdInfo, idInfo,
                          getIdSpecialisation, 
@@ -626,6 +626,10 @@ occAnal env expr@(Lam _ _)
   = case occAnal (env_body `addNewCands` binders) body of { (body_usage, body') ->
     let
         (final_usage, tagged_binders) = tagBinders body_usage binders
+       --      URGH!  Sept 99: we don't seem to be able to use binders' here, because
+       --      we get linear-typed things in the resulting program that we can't handle yet.
+       --      (e.g. PrelShow)  TODO 
+
        really_final_usage = if linear then
                                final_usage
                             else
@@ -635,7 +639,7 @@ occAnal env expr@(Lam _ _)
      mkLams tagged_binders body') }
   where
     (binders, body)    = collectBinders expr
-    (linear, env_body) = oneShotGroup env (filter isId binders)
+    (linear, env_body, binders') = oneShotGroup env binders
 
 occAnal env (Case scrut bndr alts)
   = case mapAndUnzip (occAnalAlt alt_env) alts of { (alts_usage_s, alts')   -> 
@@ -764,15 +768,31 @@ addNewCand (OccEnv ifun cands ctxt) id
 setCtxt :: OccEnv -> CtxtTy -> OccEnv
 setCtxt (OccEnv ifun cands _) ctxt = OccEnv ifun cands ctxt
 
-oneShotGroup :: OccEnv -> [Id] -> (Bool, OccEnv)       -- True <=> this is a one-shot linear lambda group
-                                                       -- The [Id] are the binders
+oneShotGroup :: OccEnv -> [CoreBndr] -> (Bool, OccEnv, [CoreBndr])
+       -- True <=> this is a one-shot linear lambda group
+       -- The [CoreBndr] are the binders.
+
+       -- The result binders have one-shot-ness set that they might not have had originally.
+       -- This happens in (build (\cn -> e)).  Here the occurrence analyser
+       -- linearity context knows that c,n are one-shot, and it records that fact in
+       -- the binder. This is useful to guide subsequent float-in/float-out tranformations
+
 oneShotGroup (OccEnv ifun cands ctxt) bndrs 
-  = (go bndrs ctxt, OccEnv ifun cands (drop (length bndrs) ctxt))
+  = case go ctxt bndrs [] of
+       (new_ctxt, new_bndrs) -> (all is_one_shot new_bndrs, OccEnv ifun cands new_ctxt, new_bndrs)
   where
-       -- Only return True if *all* the lambdas are linear
-    go (bndr:bndrs) (lin:ctxt)         = (lin || isOneShotLambda bndr) && go bndrs ctxt
-    go []          ctxt        = True
-    go bndrs       []          = all isOneShotLambda bndrs
+    is_one_shot b = isId b && isOneShotLambda b
+
+    go ctxt [] rev_bndrs = (ctxt, reverse rev_bndrs)
+
+    go (lin_ctxt:ctxt) (bndr:bndrs) rev_bndrs
+       | isId bndr = go ctxt bndrs (bndr':rev_bndrs)
+       where
+         bndr' | lin_ctxt  = setOneShotLambda bndr
+               | otherwise = bndr
+
+    go ctxt (bndr:bndrs) rev_bndrs = go ctxt bndrs (bndr:rev_bndrs)
+
 
 zapCtxt env@(OccEnv ifun cands []) = env
 zapCtxt     (OccEnv ifun cands _ ) = OccEnv ifun cands []
index 13970ff..fb552e4 100644 (file)
@@ -266,9 +266,9 @@ lvlExpr ctxt_lvl env (_, AnnNote note expr)
 -- Why not?  Because partial applications are fairly rare, and splitting
 -- lambdas makes them more expensive.
 
-lvlExpr ctxt_lvl env (_, AnnLam bndr rhs)
+lvlExpr ctxt_lvl env expr@(_, AnnLam bndr rhs)
   = lvlMFE incd_lvl new_env body       `thenLvl` \ body' ->
-    returnLvl (mkLams lvld_bndrs body')
+    returnLvl (mk_lams lvld_bndrs expr body')
   where
     bndr_is_id         = isId bndr
     bndr_is_tyvar      = isTyVar bndr
@@ -283,11 +283,21 @@ lvlExpr ctxt_lvl env (_, AnnLam bndr rhs)
     lvld_bndrs = [(b,incd_lvl) | b <- bndrs]
     new_env    = extendLvlEnv env lvld_bndrs
 
+       -- Ignore notes, because we don't want to split
+       -- a lambda like this (\x -> coerce t (\s -> ...))
+       -- This happens quite a bit in state-transformer programs
     go (_, AnnLam bndr rhs) |  bndr_is_id && isId bndr 
                            || bndr_is_tyvar && isTyVar bndr
                            =  case go rhs of { (bndrs, body) -> (bndr:bndrs, body) }
+    go (_, AnnNote _ rhs)   = go rhs
     go body                = ([], body)
 
+       -- Have to reconstruct the right Notes, since we ignored
+       -- them when gathering the lambdas
+    mk_lams (lb : lbs) (_, AnnLam _ body)     body' = Lam  lb   (mk_lams lbs body body')
+    mk_lams lbs               (_, AnnNote note body) body' = Note note (mk_lams lbs body body')
+    mk_lams []        body                   body' = body'
+
 lvlExpr ctxt_lvl env (_, AnnLet bind body)
   = lvlBind NotTopLevel ctxt_lvl env bind      `thenLvl` \ (binds', new_env) ->
     lvlExpr ctxt_lvl new_env body              `thenLvl` \ body' ->
index e3ab3d4..2356d85 100644 (file)
@@ -259,6 +259,11 @@ simplifyPgm (imported_rule_ids, rule_lhs_fvs)
           let { (binds', counts') = initSmpl sw_chkr us1 imported_rule_ids 
                                              black_list_fn 
                                              (simplTopBinds tagged_binds);
+                       -- The imported_rule_ids are used by initSmpl to initialise
+                       -- the in-scope set.  That way, the simplifier will change any
+                       -- occurrences of the imported id to the one in the imported_rule_ids
+                       -- set, which are decorated with their rules.
+
                 all_counts        = counts `plusSimplCount` counts'
               } ;
 
@@ -447,7 +452,14 @@ postSimplExpr (Let bind body)
     returnPM (Let bind' body')
 
 postSimplExpr (Note note body)
-  = postSimplExprEta body      `thenPM` \ body' ->
+  = postSimplExpr body         `thenPM` \ body' ->
+       -- Do *not* call postSimplExprEta here
+       -- We don't want to turn f = \x -> coerce t (\y -> f x y)
+       -- into                  f = \x -> coerce t (f x)
+       -- because then f has a lower arity.
+       -- This is not only bad in general, it causes the arity to 
+       -- not match the [Demand] on an Id, 
+       -- which confuses the importer of this module.
     returnPM (Note note body')
 
 postSimplExpr (Case scrut case_bndr alts)
index a946da4..80b8553 100644 (file)
@@ -39,6 +39,7 @@ module SimplMonad (
        getEnclosingCC, setEnclosingCC,
 
        -- Environments
+       getEnv, setAllExceptInScope,
        getSubst, setSubst,
        getSubstEnv, extendSubst, extendSubstList,
        getInScope, setInScope, extendInScope, extendInScopes, modifyInScope,
@@ -146,7 +147,7 @@ data SimplCont              -- Strict contexts
                                --      f (error "foo") ==> coerce t (error "foo")
                                -- when f is strict
                                -- We need to know the type t, to which to coerce.
-           (OutExpr -> SimplM OutExprStuff)    -- What to do with the result
+            (OutExpr -> SimplM OutExprStuff)   -- What to do with the result
 
 instance Outputable SimplCont where
   ppr (Stop _)                      = ptext SLIT("Stop")
@@ -777,6 +778,14 @@ emptySimplEnv sw_chkr in_scope black_list
               seSubst = mkSubst in_scope emptySubstEnv }
        -- The top level "enclosing CC" is "SUBSUMED".
 
+getEnv :: SimplM SimplEnv
+getEnv env us sc = (env, us, sc)
+
+setAllExceptInScope :: SimplEnv -> SimplM a -> SimplM a
+setAllExceptInScope new_env@(SimplEnv {seSubst = new_subst}) m 
+                           (SimplEnv {seSubst = old_subst}) us sc 
+  = m (new_env {seSubst = Subst.setInScope new_subst (substInScope old_subst)}) us sc
+
 getSubst :: SimplM Subst
 getSubst env us sc = (seSubst env, us, sc)
 
index a5877bd..b685876 100644 (file)
@@ -18,7 +18,7 @@ import BinderInfo
 import CmdLineOpts     ( opt_SimplDoLambdaEtaExpansion, opt_SimplCaseMerge )
 import CoreSyn
 import CoreFVs         ( exprFreeVars )
-import CoreUtils       ( exprIsTrivial, cheapEqExpr, coreExprType, exprIsCheap, exprGenerousArity )
+import CoreUtils       ( exprIsTrivial, cheapEqExpr, coreExprType, exprIsCheap, exprEtaExpandArity )
 import Subst           ( substBndrs, substBndr, substIds )
 import Id              ( Id, idType, getIdArity, isId, idName,
                          getInlinePragma, setInlinePragma,
@@ -322,7 +322,7 @@ tryEtaExpansion rhs
     (x_bndrs, body) = collectValBinders rhs
     (fun, args)            = collectArgs body
     trivial_args    = map exprIsTrivial args
-    fun_arity      = exprGenerousArity fun
+    fun_arity      = exprEtaExpandArity fun
 
     bind_z_arg (arg, trivial_arg) 
        | trivial_arg = returnSmpl (Nothing, arg)
@@ -357,7 +357,7 @@ tryEtaExpansion rhs
        -- See if the body could obviously do with more args
        (fun_arity - valArgCount args)
 
--- This case is now deal with by exprGenerousArity
+-- This case is now deal with by exprEtaExpandArity
        -- Finally, see if it's a state transformer, and xs is non-null
        -- (so it's also a function not a thunk) in which
        -- case we eta-expand on principle! This can waste work,
index 473b03b..0828a79 100644 (file)
@@ -14,7 +14,7 @@ import CmdLineOpts    ( intSwitchSet,
                          SimplifierSwitch(..)
                        )
 import SimplMonad
-import SimplUtils      ( mkCase, transformRhs, findAlt,
+import SimplUtils      ( mkCase, transformRhs, findAlt, etaCoreExpr,
                          simplBinder, simplBinders, simplIds, findDefault, mkCoerce
                        )
 import Var             ( TyVar, mkSysTyVar, tyVarKind, maybeModifyIdInfo )
@@ -24,7 +24,7 @@ import Id             ( Id, idType, idInfo, idUnique,
                          getIdUnfolding, setIdUnfolding, isExportedId, 
                          getIdSpecialisation, setIdSpecialisation,
                          getIdDemandInfo, setIdDemandInfo,
-                         getIdArity, setIdArity, setIdInfo,
+                         setIdInfo,
                          getIdStrictness, 
                          setInlinePragma, getInlinePragma, idMustBeINLINEd,
                          setOneShotLambda, maybeModifyIdInfo
@@ -217,8 +217,8 @@ simplExprF expr@(Con (PrimOp op) args) cont
          Nothing -> rebuild (Con (PrimOp op) args2) cont2
 
 simplExprF (Con con@(DataCon _) args) cont
-  = simplConArgs args          ( \ args' ->
-    rebuild (Con con args') cont)
+  = simplConArgs args          $ \ args' ->
+    rebuild (Con con args') cont
 
 simplExprF expr@(Con con@(Literal _) args) cont
   = ASSERT( null args )
@@ -334,11 +334,14 @@ simplLam fun cont
        -- Exactly enough args
     go expr cont = simplExprF expr cont
 
-
 -- completeLam deals with the case where a lambda doesn't have an ApplyTo
--- continuation.  Try for eta reduction, but *only* if we get all
--- the way to an exprIsTrivial expression.  
--- 'acc' holds the simplified binders, in reverse order
+-- continuation.  
+-- We used to try for eta reduction here, but I found that this was
+-- eta reducing things like 
+--     f = \x -> (coerce (\x -> e))
+-- This made f's arity reduce, which is a bad thing, so I removed the
+-- eta reduction at this point, and now do it only when binding 
+-- (at the call to postInlineUnconditionally
 
 completeLam acc (Lam bndr body) cont
   = simplBinder bndr                   $ \ bndr' ->
@@ -346,29 +349,8 @@ completeLam acc (Lam bndr body) cont
 
 completeLam acc body cont
   = simplExpr body                     `thenSmpl` \ body' ->
-
-    case (opt_SimplDoEtaReduction, check_eta acc body') of
-       (True, Just body'')     -- Eta reduce!
-               -> tick (EtaReduction (head acc))       `thenSmpl_`
-                  rebuild body'' cont
-
-       other   ->      -- No eta reduction
-                  rebuild (foldl (flip Lam) body' acc) cont
-                       -- Remember, acc is the reversed binders
-  where
-       -- NB: the binders are reversed
-    check_eta (b : bs) (App fun arg)
-       |  (varToCoreExpr b `cheapEqExpr` arg)
-       = check_eta bs fun
-
-    check_eta [] body
-       | exprIsTrivial body &&                 -- ONLY if the body is trivial
-         not (any (`elemVarSet` body_fvs) acc)
-       = Just body             -- Success!
-       where
-         body_fvs = exprFreeVars body
-
-    check_eta _ _ = Nothing    -- Bale out
+    rebuild (foldl (flip Lam) body' acc) cont
+               -- Remember, acc is the *reversed* binders
 
 mkLamBndrZapper :: CoreExpr    -- Function
                -> Int          -- Number of args
@@ -396,6 +378,10 @@ simplConArgs (arg:args) thing_inside
   = switchOffInlining (simplExpr arg)  `thenSmpl` \ arg' ->
        -- Simplify the RHS with inlining switched off, so that
        -- only absolutely essential things will happen.
+       -- If we don't do this, consider:
+       --      let x = e in C {x}
+       -- We end up inlining x back into C's argument,
+       -- and then let-binding it again!
 
     simplConArgs args                          $ \ args' ->
 
@@ -404,8 +390,8 @@ simplConArgs (arg:args) thing_inside
        thing_inside (arg' : args')
     else
        newId (coreExprType arg')               $ \ arg_id ->
-       thing_inside (Var arg_id : args')       `thenSmpl` \ res ->
-       returnSmpl (addBind (NonRec arg_id arg') res)
+       completeBeta arg_id arg_id arg'         $
+       thing_inside (Var arg_id : args')
 \end{code}
 
 
@@ -486,14 +472,30 @@ simplArg arg_ty demand arg arg_se cont_ty thing_inside
        -- Return true only for dictionary types where the dictionary
        -- has more than one component (else we risk poking on the component
        -- of a newtype dictionary)
-  = getSubstEnv                                        `thenSmpl` \ body_se ->
-    transformRhs arg                           `thenSmpl` \ t_arg ->
-    setSubstEnv arg_se (simplExprF t_arg (ArgOf NoDup cont_ty $ \ arg' ->
-    setSubstEnv body_se (thing_inside arg')
-    )) -- NB: we must restore body_se before carrying on with thing_inside!!
+  = transformRhs arg                   `thenSmpl` \ t_arg ->
+    getEnv                             `thenSmpl` \ env ->
+    setSubstEnv arg_se                                 $
+    simplExprF t_arg (ArgOf NoDup cont_ty      $ \ rhs' ->
+    setAllExceptInScope env                    $
+    etaFirst thing_inside rhs')
 
   | otherwise
-  = simplRhs NotTopLevel True arg_ty arg arg_se thing_inside
+  = simplRhs NotTopLevel True {- OK to float unboxed -}
+            arg_ty arg arg_se 
+            thing_inside
+   
+-- Do eta-reduction on the simplified RHS, if eta reduction is on
+-- NB: etaCoreExpr only eta-reduces if that results in something trivial
+etaFirst | opt_SimplDoEtaReduction = \ thing_inside rhs -> thing_inside (etaCoreExprToTrivial rhs)
+        | otherwise               = \ thing_inside rhs -> thing_inside rhs
+
+-- Try for eta reduction, but *only* if we get all
+-- the way to an exprIsTrivial expression.    We don't want to remove
+-- extra lambdas unless we are going to avoid allocating this thing altogether
+etaCoreExprToTrivial rhs | exprIsTrivial rhs' = rhs'
+                        | otherwise          = rhs
+                        where
+                          rhs' = etaCoreExpr rhs
 \end{code}
 
 
@@ -647,13 +649,13 @@ simplRhs top_lvl float_ubx rhs_ty rhs rhs_se thing_inside
                -- and so there can't be any 'will be demanded' bindings in the floats.
                -- Hence the assert
        WARN( any demanded_float floats_out, ppr floats_out )
-       setInScope in_scope' (thing_inside rhs'')       `thenSmpl` \ stuff ->
+       setInScope in_scope' (etaFirst thing_inside rhs'')      `thenSmpl` \ stuff ->
                -- in_scope' may be excessive, but that's OK;
                -- it's a superset of what's in scope
        returnSmpl (addBinds floats_out stuff)
     else       
                -- Don't do the float
-       thing_inside (mkLets floats rhs')
+       etaFirst thing_inside (mkLets floats rhs')
 
 -- In a let-from-let float, we just tick once, arbitrarily
 -- choosing the first floated binder to identify it
@@ -724,16 +726,6 @@ completeCall black_list_fn in_scope orig_var var cont
 -- Doing so results in a significant space leak.
 -- Instead we pass orig_var, which has no inlinings etc.
 
-       -- Look for rules or specialisations that match
-       -- Do this *before* trying inlining because some functions
-       -- have specialisations *and* are strict; we don't want to
-       -- inline the wrapper of the non-specialised thing... better
-       -- to call the specialised thing instead.
-  | maybeToBool maybe_rule_match
-  = tick (RuleFired rule_name)                 `thenSmpl_`
-    zapSubstEnv (simplExprF rule_rhs (pushArgs emptySubstEnv rule_args result_cont))
-       -- See note below about zapping the substitution here
-
        -- Look for an unfolding. There's a binding for the
        -- thing, but perhaps we want to inline it anyway
   | maybeToBool maybe_inline
@@ -751,33 +743,47 @@ completeCall black_list_fn in_scope orig_var var cont
   | otherwise          -- Neither rule nor inlining
                        -- Use prepareArgs to use function strictness
   = prepareArgs (ppr var) (idType var) (get_str var) cont      $ \ args' cont' ->
-    rebuild (mkApps (Var orig_var) args') cont'
+
+       -- Look for rules or specialisations that match
+       --
+       -- It's important to simplify the args first, because the rule-matcher
+       -- doesn't do substitution as it goes.  We don't want to use subst_args
+       -- (defined in the 'where') because that throws away useful occurrence info,
+       -- and perhaps-very-important specialisations.
+       --
+       -- Some functions have specialisations *and* are strict; in this case,
+       -- we don't want to inline the wrapper of the non-specialised thing; better
+       -- to call the specialised thing instead.
+       -- But the black-listing mechanism means that inlining of the wrapper
+       -- won't occur for things that have specialisations till a later phase, so
+       -- it's ok to try for inlining first.
+    case lookupRule in_scope var args' of
+       Just (rule_name, rule_rhs, rule_args) -> 
+               tick (RuleFired rule_name)                      `thenSmpl_`
+               zapSubstEnv (simplExprF rule_rhs (pushArgs emptySubstEnv rule_args cont'))
+                       -- See note above about zapping the substitution here
+       
+       Nothing -> rebuild (mkApps (Var orig_var) args') cont'
 
   where
     get_str var = case getIdStrictness var of
                        NoStrictnessInfo                  -> (repeat wwLazy, False)
                        StrictnessInfo demands result_bot -> (demands, result_bot)
 
-  
-    (args', result_cont) = contArgs in_scope cont
-    val_args            = filter isValArg args'
-    arg_infos                   = map (interestingArg in_scope) val_args
-    inline_call                 = contIsInline result_cont
-    interesting_cont     = contIsInteresting result_cont
-    discard_inline_cont  | inline_call = discardInline cont
-                        | otherwise   = cont
-
        ---------- Unfolding stuff
+    (subst_args, result_cont) = contArgs in_scope cont
+    val_args                 = filter isValArg subst_args
+    arg_infos                        = map (interestingArg in_scope) val_args
+    inline_call                      = contIsInline result_cont
+    interesting_cont          = contIsInteresting result_cont
+    discard_inline_cont       | inline_call = discardInline cont
+                             | otherwise   = cont
+
     maybe_inline  = callSiteInline black_listed inline_call 
                                   var arg_infos interesting_cont
     Just unf_template = maybe_inline
     black_listed      = black_list_fn var
 
-       ---------- Specialisation stuff
-    maybe_rule_match           = lookupRule in_scope var args'
-    Just (rule_name, rule_rhs, rule_args) = maybe_rule_match
-
-
 
 -- An argument is interesting if it has *some* structure
 -- We are here trying to avoid unfolding a function that
@@ -1403,8 +1409,11 @@ mkDupableCont join_arg_ty (ArgOf _ cont_ty cont_fn) thing_inside
        new_cont = ArgOf OkToDup cont_ty
                         (\arg' -> rebuild_done (App (Var join_id) arg'))
     in
-       
-       -- Do the thing inside
+
+    tick (CaseOfCase join_id)                                          `thenSmpl_`
+       -- Want to tick here so that we go round again,
+       -- and maybe copy or inline the code;
+       -- not strictly CaseOf Case
     thing_inside new_cont              `thenSmpl` \ res ->
     returnSmpl (addBind (NonRec join_id join_rhs) res)
 
@@ -1415,6 +1424,11 @@ mkDupableCont ty (ApplyTo _ arg se cont) thing_inside
        thing_inside (ApplyTo OkToDup arg' emptySubstEnv cont')
     else
     newId (coreExprType arg')                                          $ \ bndr ->
+
+    tick (CaseOfCase bndr)                                             `thenSmpl_`
+       -- Want to tick here so that we go round again,
+       -- and maybe copy or inline the code;
+       -- not strictly CaseOf Case
     thing_inside (ApplyTo OkToDup (Var bndr) emptySubstEnv cont')      `thenSmpl` \ res ->
     returnSmpl (addBind (NonRec bndr arg') res)
 
index 5e8bfa7..10cb1ce 100644 (file)
@@ -654,9 +654,15 @@ coreExprToStgFloat env
        (Case scrut@(Con (PrimOp SeqOp) [Type ty, e]) bndr alts) dem
   = coreExprToStgFloat env (Case e new_bndr [(DEFAULT,[],default_rhs)]) dem
   where 
-    new_bndr                   = setIdType bndr ty
     (other_alts, maybe_default) = findDefault alts
     Just default_rhs           = maybe_default
+    new_bndr                   = setIdType bndr ty
+       -- NB:  SeqOp :: forall a. a -> Int#
+       -- So bndr has type Int# 
+       -- But now we are going to scrutinise the SeqOp's argument directly,
+       -- so we must change the type of the case binder to match that
+       -- of the argument expression e.  We can get this type from the argument
+       -- type of the SeqOp.
 
 coreExprToStgFloat env 
        (Case scrut@(Con (PrimOp ParOp) args) bndr alts) dem
index 904ea3e..a2e8188 100644 (file)
@@ -331,6 +331,7 @@ addStrictnessInfoToId str_val abs_val binder body
        --      foldr k z (case e of p -> build g) 
        -- gets transformed to
        --      case e of p -> foldr k z (build g)
+       -- [foldr is only inlined late in compilation, after strictness analysis]
        (binders, rhs) -> binder `setIdStrictness` 
                          mkStrictnessInfo strictness
                where
index 7a95e55..d919b73 100644 (file)
@@ -4,7 +4,7 @@
 \section[WorkWrap]{Worker/wrapper-generating back-end of strictness analyser}
 
 \begin{code}
-module WorkWrap ( wwTopBinds ) where
+module WorkWrap ( wwTopBinds, mkWrapper ) where
 
 #include "HsVersions.h"
 
@@ -14,19 +14,19 @@ import CmdLineOpts  ( opt_UF_CreationThreshold , opt_D_verbose_core2core,
                           opt_D_dump_worker_wrapper
                        )
 import CoreLint                ( beginPass, endPass )
-import CoreUtils       ( coreExprType )
+import CoreUtils       ( coreExprType, exprArity )
 import Const           ( Con(..) )
 import DataCon         ( DataCon )
 import MkId            ( mkWorkerId )
-import Id              ( Id, getIdStrictness, setIdArity, 
-                         setIdStrictness, 
+import Id              ( Id, idType, getIdStrictness, setIdArity, 
+                         setIdStrictness, getIdDemandInfo,
                          setIdWorkerInfo, getIdCprInfo )
 import VarSet
-import Type            ( isNewType )
+import Type            ( Type, isNewType, splitForAllTys, splitFunTys )
 import IdInfo          ( mkStrictnessInfo, noStrictnessInfo, StrictnessInfo(..),
                          CprInfo(..), exactArity
                        )
-import Demand           ( wwLazy )
+import Demand           ( Demand, wwLazy )
 import SaLib
 import UniqSupply      ( UniqSupply, initUs_, returnUs, thenUs, mapUs, getUniqueUs, UniqSM )
 import UniqSet
@@ -85,15 +85,8 @@ workersAndWrappers :: UniqSupply -> [CoreBind] -> [CoreBind]
 
 workersAndWrappers us top_binds
   = initUs_ us $
-    mapUs (wwBind True{-top-level-}) top_binds `thenUs` \ top_binds2 ->
-    let
-       top_binds3 = map make_top_binding top_binds2
-    in
-    returnUs (concat top_binds3)
-  where
-    make_top_binding :: WwBinding -> [CoreBind]
-
-    make_top_binding (WwLet binds) = binds
+    mapUs wwBind top_binds `thenUs` \ top_binds' ->
+    returnUs (concat top_binds')
 \end{code}
 
 %************************************************************************
@@ -106,24 +99,23 @@ workersAndWrappers us top_binds
 turn.  Non-recursive case first, then recursive...
 
 \begin{code}
-wwBind :: Bool                 -- True <=> top-level binding
-       -> CoreBind
-       -> UniqSM WwBinding     -- returns a WwBinding intermediate form;
+wwBind :: CoreBind
+       -> UniqSM [CoreBind]    -- returns a WwBinding intermediate form;
                                -- the caller will convert to Expr/Binding,
                                -- as appropriate.
 
-wwBind top_level (NonRec binder rhs)
+wwBind (NonRec binder rhs)
   = wwExpr rhs                                         `thenUs` \ new_rhs ->
     tryWW True {- non-recursive -} binder new_rhs      `thenUs` \ new_pairs ->
-    returnUs (WwLet [NonRec b e | (b,e) <- new_pairs])
+    returnUs [NonRec b e | (b,e) <- new_pairs]
       -- Generated bindings must be non-recursive
       -- because the original binding was.
 
 ------------------------------
 
-wwBind top_level (Rec pairs)
+wwBind (Rec pairs)
   = mapUs do_one pairs         `thenUs` \ new_pairs ->
-    returnUs (WwLet [Rec (concat new_pairs)])
+    returnUs [Rec (concat new_pairs)]
   where
     do_one (binder, rhs) = wwExpr rhs  `thenUs` \ new_rhs ->
                           tryWW False {- recursive -} binder new_rhs
@@ -159,12 +151,9 @@ wwExpr (Note note expr)
     returnUs (Note note new_expr)
 
 wwExpr (Let bind expr)
-  = wwBind False{-not top-level-} bind `thenUs` \ intermediate_bind ->
-    wwExpr expr                                `thenUs` \ new_expr ->
-    returnUs (mash_ww_bind intermediate_bind new_expr)
-  where
-    mash_ww_bind (WwLet  binds)   body = mkLets binds body
-    mash_ww_bind (WwCase case_fn) body = case_fn body
+  = wwBind bind                        `thenUs` \ intermediate_bind ->
+    wwExpr expr                        `thenUs` \ new_expr ->
+    returnUs (mkLets intermediate_bind new_expr)
 
 wwExpr (Case expr binder alts)
   = wwExpr expr                                `thenUs` \ new_expr ->
@@ -206,83 +195,62 @@ tryWW     :: Bool                         -- True <=> a non-recursive binding
                                        -- wrapper.
 tryWW non_rec fn_id rhs
   | (non_rec &&                -- Don't split if its non-recursive and small
-     certainlySmallEnoughToInline (calcUnfoldingGuidance opt_UF_CreationThreshold rhs) &&
+     certainlySmallEnoughToInline (calcUnfoldingGuidance opt_UF_CreationThreshold rhs)
        -- No point in worker/wrappering something that is going to be
        -- INLINEd wholesale anyway.  If the strictness analyser is run
        -- twice, this test also prevents wrappers (which are INLINEd)
        -- from being re-done.
+    )
 
-     not (null wrap_args && do_coerce_ww)
-       -- However, if we have  f = coerce T E
-       -- then we want to w/w anyway, to get
-       --                      fw = E
-       --                      f  = coerce T fw
-       -- We want to do this even if the binding is small and non-rec.
-       -- Reason: I've seen this situation:
-       --      let f = coerce T (\s -> E)
-       --      in \x -> case x of
-       --                  p -> coerce T' f
-       --                  q -> \s -> E2
-       -- If only we w/w'd f, we'd inline the coerce (because it's trivial)
-       -- to get
-       --      let fw = \s -> E
-       --      in \x -> case x of
-       --                  p -> fw
-       --                  q -> \s -> E2
-       -- Now we'll see that fw has arity 1, and will arity expand
-       -- the \x to get what we want.
-     )
-
-  || not (do_strict_ww || do_cpr_ww || do_coerce_ww) 
+  || arity == 0                -- Don't split if it's not a function
+
+  || not (do_strict_ww || do_cpr_ww || do_coerce_ww)
   = returnUs [ (fn_id, rhs) ]
 
   | otherwise          -- Do w/w split
-  = mkWwBodies tyvars wrap_args 
-              body_ty 
-              wrap_demands
-              cpr_info
-                                                `thenUs` \ (wrap_fn, work_fn, work_demands) ->
+  = mkWwBodies fun_ty arity wrap_dmds cpr_info `thenUs` \ (work_args, wrap_fn, work_fn) ->
     getUniqueUs                                        `thenUs` \ work_uniq ->
     let
-       work_rhs  = work_fn body
-       work_id   = mkWorkerId work_uniq fn_id (coreExprType work_rhs) `setIdStrictness`
-                   (if has_strictness_info then mkStrictnessInfo (work_demands ++ remaining_arg_demands, result_bot)
-                                           else noStrictnessInfo) 
+       work_rhs     = work_fn rhs
+       work_demands = [getIdDemandInfo v | v <- work_args, isId v]
+       proto_work_id            = mkWorkerId work_uniq fn_id (coreExprType work_rhs) 
+       work_id | has_strictness = proto_work_id `setIdStrictness` mkStrictnessInfo (work_demands, result_bot)
+               | otherwise      = proto_work_id
 
        wrap_rhs = wrap_fn work_id
-       wrap_id  = fn_id `setIdStrictness` 
-                         (if has_strictness_info then mkStrictnessInfo (wrap_demands ++ remaining_arg_demands, result_bot)
-                                                else noStrictnessInfo) 
+       wrap_id  = fn_id `setIdStrictness`      wrapper_strictness
                          `setIdWorkerInfo`     Just work_id
-                        `setIdArity`           exactArity (length wrap_args)
+                        `setIdArity`           exactArity arity
                -- Add info to the wrapper:
-               --      (a) we want to inline it everywhere
+               --      (a) we want to set its arity
                --      (b) we want to pin on its revised strictness info
                --      (c) we pin on its worker id 
     in
     returnUs ([(work_id, work_rhs), (wrap_id, wrap_rhs)])
        -- Worker first, because wrapper mentions it
   where
-    (tyvars, wrap_args, body) = collectTyAndValBinders rhs
-    n_wrap_args                      = length wrap_args
-    body_ty                  = coreExprType body
-    strictness_info     = getIdStrictness fn_id
-    has_strictness_info = case strictness_info of
-                               StrictnessInfo _ _ -> True
-                               other              -> False
+    fun_ty = idType fn_id
+    arity  = exprArity rhs
 
+    strictness_info                      = getIdStrictness fn_id
     StrictnessInfo arg_demands result_bot = strictness_info
+    has_strictness                       = case strictness_info of
+                                               StrictnessInfo _ _ -> True
+                                               other              -> False
                        
-       -- NB: There maybe be more items in arg_demands than wrap_args, because
-       -- the strictness info is semantic and looks through InlineMe and Scc
-       -- Notes, whereas wrap_args does not
-    demands_for_visible_args = take n_wrap_args arg_demands
-    remaining_arg_demands    = drop n_wrap_args arg_demands
+    do_strict_ww = has_strictness && worthSplitting wrap_dmds result_bot
+
+       -- NB: There maybe be more items in arg_demands than arity, because
+       -- the strictness info is semantic and looks through InlineMe and Scc Notes, 
+       -- whereas arity does not
+    demands_for_visible_args = take arity arg_demands
+    remaining_arg_demands    = drop arity arg_demands
 
-    wrap_demands | has_strictness_info = setUnpackStrategy demands_for_visible_args
-                | otherwise           = repeat wwLazy
+    wrap_dmds | has_strictness = setUnpackStrategy demands_for_visible_args
+             | otherwise      = take arity (repeat wwLazy)
 
-    do_strict_ww = has_strictness_info && worthSplitting wrap_demands result_bot
+    wrapper_strictness | has_strictness = mkStrictnessInfo (wrap_dmds ++ remaining_arg_demands, result_bot)
+                      | otherwise      = noStrictnessInfo
 
        -------------------------------------------------------------
     cpr_info     = getIdCprInfo fn_id
@@ -293,52 +261,46 @@ tryWW non_rec fn_id rhs
     do_cpr_ww = has_cpr_info
 
        -------------------------------------------------------------
-       -- Do the coercion thing if the body is of a newtype
-    do_coerce_ww = isNewType body_ty
-
-
-{-     July 99: removed again by Simon
-
--- This rather (nay! extremely!) crude function looks at a wrapper function, and
--- snaffles out the worker Id from the wrapper.
--- This is needed when we write an interface file.
--- [May 1999: we used to get the constructors too, but that's no longer
---           necessary, because the renamer hauls in all type decls in 
---           their fullness.]
-
--- <Mar 1999 (keving)> - Well,  since the addition of the CPR transformation this function
--- got too crude!  
--- Now the worker id is stored directly in the id's Info field.  We still use this function to
--- snaffle the wrapper's constructors but I don't trust the code to find the worker id.
-getWorkerId :: Id -> CoreExpr -> Id
-getWorkerId wrap_id wrapper_fn
-  = work_id wrapper_fn
+    do_coerce_ww = check_for_coerce arity fun_ty
+
+-- See if there's a Coerce before we run out of arity;
+-- if so, it's worth trying a w/w split.  Reason: we find
+-- functions like      f = coerce (\s -> e)
+--          and        g = \x -> coerce (\s -> e)
+-- and they may have no useful strictness or cpr info, but if we
+-- do the w/w thing we get rid of the coerces.  
+
+check_for_coerce arity ty
+  = length arg_tys <= arity && isNewType res_ty
+       -- Don't look further than arity args, 
+       -- but if there are arity or fewer, see if there's
+       -- a newtype in the corner
   where
+    (_, tau)         = splitForAllTys ty
+    (arg_tys, res_ty) = splitFunTys tau
+\end{code}
+
+
 
-    work_id wrapper_fn
-            = case get_work_id wrapper_fn of
-                []   -> case work_id_try2 wrapper_fn of
-                        [] -> pprPanic "getWorkerId: can't find worker id" (ppr wrap_id)
-                        [id] -> id
-                       _    -> pprPanic "getWorkerId: found too many worker ids" (ppr wrap_id)
-                [id] -> id
-                _    -> pprPanic "getWorkerId: found too many worker ids" (ppr wrap_id)
-
-    get_work_id (Lam _ body)                    = get_work_id body
-    get_work_id (Case _ _ [(_,_,rhs@(Case _ _ _))])    = get_work_id rhs
-    get_work_id (Case scrut _ [(_,_,rhs)])             = (get_work_id scrut) ++ (get_work_id rhs)
-    get_work_id (Note _ body)                   = get_work_id body
-    get_work_id (Let _ body)                    = get_work_id body
-    get_work_id (App (Var work_id) _)           = [work_id]
-    get_work_id (App fn _)                      = get_work_id fn
-    get_work_id (Var work_id)                   = []
-    get_work_id other                           = [] 
-
-    work_id_try2 (Lam _ body)                   = work_id_try2 body
-    work_id_try2 (Note _ body)                  = work_id_try2 body
-    work_id_try2 (Let _ body)                   = work_id_try2 body
-    work_id_try2 (App fn _)                     = work_id_try2 fn
-    work_id_try2 (Var work_id)                  = [work_id]
-    work_id_try2 other                          = [] 
--}
+%************************************************************************
+%*                                                                     *
+\subsection{The worker wrapper core}
+%*                                                                     *
+%************************************************************************
+
+@mkWrapper@ is called when importing a function.  We have the type of 
+the function and the name of its worker, and we want to make its body (the wrapper).
+
+\begin{code}
+mkWrapper :: Type              -- Wrapper type
+         -> Int                -- Arity
+         -> [Demand]           -- Wrapper strictness info
+         -> CprInfo            -- Wrapper cpr info
+         -> UniqSM (Id -> CoreExpr)    -- Wrapper body, missing worker Id
+
+mkWrapper fun_ty arity demands cpr_info
+  = mkWwBodies fun_ty arity demands cpr_info   `thenUs` \ (_, wrap_fn, _) ->
+    returnUs wrap_fn
 \end{code}
+
+
index 235bd91..1a6c4de 100644 (file)
@@ -5,15 +5,14 @@
 
 \begin{code}
 module WwLib (
-       WwBinding(..),
-
-       worthSplitting, setUnpackStrategy,
-       mkWwBodies, mkWrapper
+       mkWwBodies,
+       worthSplitting, setUnpackStrategy
     ) where
 
 #include "HsVersions.h"
 
 import CoreSyn
+import CoreUtils       ( coreExprType )
 import Id              ( Id, idType, mkSysLocal, getIdDemandInfo, setIdDemandInfo,
                           mkWildId, setIdInfo
                        )
@@ -25,44 +24,20 @@ import PrelInfo             ( realWorldPrimId, aBSENT_ERROR_ID )
 import TysPrim         ( realWorldStatePrimTy )
 import TysWiredIn      ( unboxedTupleCon, unboxedTupleTyCon )
 import Type            ( isUnLiftedType, 
-                         splitForAllTys, splitFunTys, splitFunTysN,
+                         splitForAllTys, splitFunTys, 
                          splitAlgTyConApp_maybe, splitNewType_maybe,
-                         mkTyConApp, 
+                         mkTyConApp, mkFunTys,
                          Type
                        )
 import TyCon            ( isNewTyCon, isProductTyCon, TyCon )
-import BasicTypes      ( NewOrData(..) )
-import Var              ( TyVar )
+import BasicTypes      ( NewOrData(..), Arity )
+import Var              ( TyVar, IdOrTyVar )
 import UniqSupply      ( returnUs, thenUs, getUniqueUs, getUniquesUs, 
                           mapUs, UniqSM )
 import Util            ( zipWithEqual, zipEqual )
 import Outputable
 \end{code}
 
-%************************************************************************
-%*                                                                     *
-\subsection[datatype-WwLib]{@WwBinding@: a datatype for worker/wrapper-ing}
-%*                                                                     *
-%************************************************************************
-
-In the worker/wrapper stuff, we want to carry around @CoreBindings@ in
-an ``intermediate form'' that can later be turned into a \tr{let} or
-\tr{case} (depending on strictness info).
-
-\begin{code}
-data WwBinding
-  = WwLet  [CoreBind]
-  | WwCase (CoreExpr -> CoreExpr)
-               -- the "case" will be a "strict let" of the form:
-               --
-               --  case rhs of
-               --    <blah> -> body
-               --
-               -- (instead of "let <blah> = rhs in body")
-               --
-               -- The expr you pass to the function is "body" (the
-               -- expression that goes "in the corner").
-\end{code}
 
 %************************************************************************
 %*                                                                     *
@@ -208,10 +183,19 @@ nonAbsentArgs (d     : ds) = 1 + nonAbsentArgs ds
 worthSplitting :: [Demand]
               -> Bool  -- Result is bottom
               -> Bool  -- True <=> the wrapper would not be an identity function
-worthSplitting ds result_bot = not result_bot && any worth_it ds
-       -- Don't split if the result is bottom; there's no efficiency to
-       -- be gained, and (worse) the wrapper body may not look like a wrapper
-       -- body to getWorkerIdAndCons
+worthSplitting ds result_bot = any worth_it ds
+       -- We used not to split if the result is bottom.
+       -- [Justification:  there's no efficiency to be gained, 
+       --  and (worse) the wrapper body may not look like a wrapper
+       --  body to getWorkerIdAndCons]
+       -- But now (a) we don't have getWorkerIdAndCons, and
+       -- (b) it's sometimes bad not to make a wrapper.  Consider
+       --      fw = \x# -> let x = I# x# in case e of
+       --                                      p1 -> error_fn x
+       --                                      p2 -> error_fn x
+       --                                      p3 -> the real stuff
+       -- The re-boxing code won't go away unless error_fn gets a wrapper too.
+
   where
     worth_it (WwLazy True)      = True         -- Absent arg
     worth_it (WwUnpack _ True _) = True                -- Arg to unpack
@@ -233,66 +217,27 @@ allAbsent ds = all absent ds
 %*                                                                     *
 %************************************************************************
 
-@mkWrapper@ is called when importing a function.  We have the type of 
-the function and the name of its worker, and we want to make its body (the wrapper).
-
-\begin{code}
-mkWrapper :: Type              -- Wrapper type
-         -> Int                -- Arity
-         -> [Demand]           -- Wrapper strictness info
-         -> CprInfo            -- Wrapper cpr info
-         -> UniqSM (Id -> CoreExpr)    -- Wrapper body, missing worker Id
-
-mkWrapper fun_ty arity demands cpr_info
-  = getUniquesUs arity         `thenUs` \ wrap_uniqs ->
-    let
-       (tyvars, tau_ty)   = splitForAllTys fun_ty
-       (arg_tys, body_ty) = splitFunTysN "mkWrapper" arity tau_ty
-               -- The "expanding dicts" part here is important, even for the splitForAll
-               -- The imported thing might be a dictionary, such as Functor Foo
-               -- But Functor Foo = forall a b. (a->b) -> Foo a -> Foo b
-               -- and as such might have some strictness info attached.
-               -- Then we need to have enough args to zip to the strictness info
-       
-       wrap_args          = zipWith mk_ww_local wrap_uniqs arg_tys
-    in
-    mkWwBodies tyvars wrap_args body_ty demands cpr_info       `thenUs` \ (wrap_fn, _, _) ->
-    returnUs wrap_fn
-\end{code}
-
 @mkWwBodies@ is called when doing the worker/wrapper split inside a module.
 
 \begin{code}
-mkWwBodies :: [TyVar] -> [Id]                  -- Original fn args 
-          -> Type                              -- Type of result of original function
-          -> [Demand]                          -- Strictness info for original fn; corresp 1-1 with args
+mkWwBodies :: Type                             -- Type of original function
+          -> Arity                             -- Arity of original function
+          -> [Demand]                          -- Strictness of original function
           -> CprInfo                           -- Result of CPR analysis 
-          -> UniqSM (Id -> CoreExpr,           -- Wrapper body, lacking only the worker Id
-                     CoreExpr -> CoreExpr,     -- Worker body, lacking the original function body
-                     [Demand])                 -- Strictness info for worker
-
-mkWwBodies tyvars wrap_args res_ty demands cpr_info
-  = let
-        -- demands may be longer than number of args.  If we aren't doing w/w
-        -- for strictness then demands is an infinite list of 'lazy' args.
-       wrap_args_w_demands = zipWith setIdDemandInfo wrap_args demands
-       
-    in
-    mkWWstr wrap_args_w_demands                        `thenUs` \ (wrap_fn_str,    work_fn_str,    work_arg_dmds) ->
-    mkWWcoerce res_ty                          `thenUs` \ (wrap_fn_coerce, work_fn_coerce, coerce_res_ty) ->
-    mkWWcpr coerce_res_ty cpr_info             `thenUs` \ (wrap_fn_cpr,    work_fn_cpr,    cpr_res_ty) ->
-    mkWWfixup cpr_res_ty (null work_arg_dmds)  `thenUs` \ (wrap_fn_fixup,  work_fn_fixup) ->
-
-    returnUs (\ work_id -> Note InlineMe $
-                          mkLams tyvars $ mkLams wrap_args_w_demands $
-                          (wrap_fn_coerce . wrap_fn_cpr . wrap_fn_str . wrap_fn_fixup) $
-                          mkVarApps (Var work_id) tyvars,
-
-             \ work_body  -> mkLams tyvars $ 
-                             (work_fn_fixup . work_fn_str . work_fn_cpr . work_fn_coerce) 
-                             work_body,
-
-             work_arg_dmds)
+          -> UniqSM ([IdOrTyVar],              -- Worker args
+                     Id -> CoreExpr,           -- Wrapper body, lacking only the worker Id
+                     CoreExpr -> CoreExpr)     -- Worker body, lacking the original function rhs
+
+mkWwBodies fun_ty arity demands cpr_info
+  = WARN( arity /= length demands, text "mkWrapper" <+> ppr fun_ty <+> ppr arity <+> ppr demands )
+    mkWWargs fun_ty arity demands      `thenUs` \ (wrap_args, wrap_fn_args,   work_fn_args, res_ty) ->
+    mkWWstr wrap_args                  `thenUs` \ (work_args, wrap_fn_str,    work_fn_str) ->
+    mkWWcpr res_ty cpr_info            `thenUs` \ (wrap_fn_cpr,    work_fn_cpr,  cpr_res_ty) ->
+    mkWWfixup cpr_res_ty work_args     `thenUs` \ (wrap_fn_fixup,  work_fn_fixup) ->
+
+    returnUs (work_args,
+             Note InlineMe . wrap_fn_args . wrap_fn_cpr . wrap_fn_str . wrap_fn_fixup . Var,
+             work_fn_fixup . work_fn_str . work_fn_cpr . work_fn_args)
 \end{code}
 
 
@@ -302,26 +247,80 @@ mkWwBodies tyvars wrap_args res_ty demands cpr_info
 %*                                                                     *
 %************************************************************************
 
-The "coerce" transformation is
-       f :: T1 -> T2 -> R
-       f = \xy -> e
-===>
-       f = \xy -> coerce R R' (fw x y)
-       fw = \xy -> coerce R' R e
 
-where R' is the representation type for R.
+We really want to "look through" coerces.
+Reason: I've seen this situation:
+
+       let f = coerce T (\s -> E)
+       in \x -> case x of
+                   p -> coerce T' f
+                   q -> \s -> E2
+                   r -> coerce T' f
+
+If only we w/w'd f, we'd get
+       let f = coerce T (\s -> fw s)
+           fw = \s -> E
+       in ...
+
+Now we'll inline f to get
+
+       let fw = \s -> E
+       in \x -> case x of
+                   p -> fw
+                   q -> \s -> E2
+                   r -> fw
+
+Now we'll see that fw has arity 1, and will arity expand
+the \x to get what we want.
 
 \begin{code}
-mkWWcoerce body_ty 
-  = case splitNewType_maybe body_ty of
+-- mkWWargs is driven off the function type.  
+-- It chomps bites off foralls, arrows, newtypes
+-- and keeps repeating that until it's satisfied the supplied arity
 
-       Nothing     -> returnUs (id, id, body_ty)
+mkWWargs :: Type -> Int -> [Demand]
+        -> UniqSM  ([IdOrTyVar],                       -- Wrapper args
+                    CoreExpr -> CoreExpr,              -- Wrapper fn
+                    CoreExpr -> CoreExpr,              -- Worker fn
+                    Type)                              -- Type of wrapper body
 
-       Just rep_ty -> returnUs (mkNote (Coerce body_ty rep_ty),
-                                mkNote (Coerce rep_ty body_ty),
-                                rep_ty)
-\end{code}    
+mkWWargs fun_ty arity demands
+  | arity == 0
+  = returnUs ([], id, id, fun_ty)
 
+  | otherwise
+  = getUniquesUs n_args                `thenUs` \ wrap_uniqs ->
+    let
+      val_args = zipWith3 mk_wrap_arg wrap_uniqs arg_tys demands
+      wrap_args = tyvars ++ val_args
+    in
+    mkWWargs body_rep_ty 
+            (arity - n_args) 
+            (drop n_args demands)      `thenUs` \ (more_wrap_args, wrap_fn_args, work_fn_args, res_ty) ->
+
+    returnUs (wrap_args ++ more_wrap_args,
+             mkLams wrap_args . wrap_coerce_fn . wrap_fn_args,
+             work_fn_args . work_coerce_fn . applyToVars wrap_args,
+             res_ty)
+  where
+    (tyvars, tau)              = splitForAllTys fun_ty
+    (arg_tys, body_ty)         = splitFunTys tau
+    n_arg_tys          = length arg_tys
+    n_args             = arity `min` n_arg_tys
+    (wrap_coerce_fn, work_coerce_fn, body_rep_ty) 
+       | n_arg_tys == n_args           -- All arg_tys used up
+       = case splitNewType_maybe body_ty of
+               Just rep_ty -> (Note (Coerce body_ty rep_ty), Note (Coerce rep_ty body_ty), rep_ty)
+               Nothing     -> ASSERT2( n_args /= 0, text "mkWWargs" <+> ppr arity <+> ppr fun_ty )
+                              (id, id, body_ty)
+       | otherwise                     -- Leftover arg-tys
+       = (id, id, mkFunTys (drop n_args arg_tys) body_ty)
+
+applyToVars :: [IdOrTyVar] -> CoreExpr -> CoreExpr
+applyToVars vars fn = mkVarApps fn vars
+
+mk_wrap_arg uniq ty dmd = setIdDemandInfo (mkSysLocal SLIT("w") uniq ty) dmd
+\end{code}
 
 
 %************************************************************************
@@ -331,8 +330,8 @@ mkWWcoerce body_ty
 %************************************************************************
 
 \begin{code}
-mkWWfixup res_ty no_worker_args
-  | no_worker_args && isUnLiftedType res_ty 
+mkWWfixup res_ty work_args
+  | null work_args && isUnLiftedType res_ty 
        -- Horrid special case.  If the worker would have no arguments, and the
        -- function returns a primitive type value, that would make the worker into
        -- an unboxed value.  We box it by passing a dummy void argument, thus:
@@ -360,21 +359,22 @@ mkWWfixup res_ty no_worker_args
 %************************************************************************
 
 \begin{code}
-mkWWstr :: [Id]                                        -- Wrapper args; have their demand info on them
-        -> UniqSM (CoreExpr -> CoreExpr,       -- Wrapper body, lacking the worker call
+mkWWstr :: [IdOrTyVar]                         -- Wrapper args; have their demand info on them
+                                               -- *Includes type variables*
+        -> UniqSM ([IdOrTyVar],                        -- Worker args
+                  CoreExpr -> CoreExpr,        -- Wrapper body, lacking the worker call
                                                -- and without its lambdas 
                                                -- This fn adds the unboxing, and makes the
                                                -- call passing the unboxed things
                                
-                  CoreExpr -> CoreExpr,        -- Worker body, lacking the original body of the function,
+                  CoreExpr -> CoreExpr)        -- Worker body, lacking the original body of the function,
                                                -- but *with* lambdas
-                  [Demand])                    -- Worker arg demands
 
 mkWWstr wrap_args
-  = mk_ww_str wrap_args                `thenUs` \ (work_args_w_demands, wrap_fn, work_fn) ->
-    returnUs ( \ wrapper_body -> wrap_fn (mkVarApps wrapper_body work_args_w_demands),
-              \ worker_body  -> mkLams work_args_w_demands (work_fn worker_body),
-              map getIdDemandInfo work_args_w_demands)
+  = mk_ww_str wrap_args                `thenUs` \ (work_args, wrap_fn, work_fn) ->
+    returnUs ( work_args,
+              \ wrapper_body -> wrap_fn (mkVarApps wrapper_body work_args),
+              \ worker_body  -> mkLams work_args (work_fn worker_body))
 
        -- Empty case
 mk_ww_str []
@@ -384,6 +384,11 @@ mk_ww_str []
 
 
 mk_ww_str (arg : ds)
+  | isTyVar arg
+  = mk_ww_str ds               `thenUs` \ (worker_args, wrap_fn, work_fn) ->
+    returnUs (arg : worker_args, wrap_fn, work_fn)
+
+  | otherwise
   = case getIdDemandInfo arg of
 
        -- Absent case
@@ -437,143 +442,36 @@ mkWWcpr :: Type                              -- function body type
 
 mkWWcpr body_ty NoCPRInfo 
     = returnUs (id, id, body_ty)      -- Must be just the strictness transf.
+
 mkWWcpr body_ty (CPRInfo cpr_args)
-    = getUniqueUs              `thenUs` \ body_arg_uniq ->
+    | n_con_args == 1 && isUnLiftedType con_arg_ty1
+       -- Special case when there is a single result of unlifted type
+    = getUniquesUs 2                   `thenUs` \ [work_uniq, arg_uniq] ->
       let
-        body_var = mk_ww_local body_arg_uniq body_ty
+       work_wild = mk_ww_local work_uniq body_ty
+       arg       = mk_ww_local arg_uniq  con_arg_ty1
       in
-      cpr_reconstruct body_ty cpr_info'                   `thenUs` \reconst_fn ->
-      cpr_flatten body_ty cpr_info'                       `thenUs` \(flatten_fn, res_ty) ->
-      returnUs (reconst_fn, flatten_fn, res_ty)
-    where
-           -- We only make use of the outer level of CprInfo,  otherwise we
-           -- may lose laziness.  :-(  Hopefully,  we will find a use for the
-           -- extra info some day (e.g. creating versions specialized to 
-           -- the use made of the components of the result by the callee)
-      cpr_info' = CPRInfo (map (const NoCPRInfo) cpr_args) 
-\end{code}
-
+      returnUs (\ wkr_call -> mkConApp data_con (map Type tycon_arg_tys ++ [wkr_call]),
+               \ body     -> Case body work_wild [(DataCon data_con, [arg], Var arg)],
+               con_arg_ty1)
 
-@cpr_flatten@ takes the result type produced by the body and the info
-from the CPR analysis and flattens the constructed product components.
-These are returned in an unboxed tuple.
-
-\begin{code}
-cpr_flatten :: Type -> CprInfo -> UniqSM (CoreExpr -> CoreExpr, Type)
-cpr_flatten ty cpr_info
-    = mk_cpr_case (ty, cpr_info)       `thenUs` \(res_id, tup_ids, flatten_exp) ->
+    | otherwise                -- The general case
+    = getUniquesUs (n_con_args + 2)    `thenUs` \ uniqs ->
       let
-       (unbx_tuple, unbx_tuple_ty) = mk_unboxed_tuple tup_ids
+        (wrap_wild : work_wild : args) = zipWith mk_ww_local uniqs (ubx_tup_ty : body_ty : con_arg_tys)
+       arg_vars                       = map Var args
+       ubx_tup_con                    = unboxedTupleCon n_con_args
+       ubx_tup_ty                     = coreExprType ubx_tup_app
+       ubx_tup_app                    = mkConApp ubx_tup_con (map Type con_arg_tys   ++ arg_vars)
+        con_app                               = mkConApp data_con    (map Type tycon_arg_tys ++ arg_vars)
       in
-      returnUs (\body -> Case body res_id [(DEFAULT, [], flatten_exp unbx_tuple)],
-               unbx_tuple_ty)
-
-
-
-mk_cpr_case :: (Type, CprInfo) -> 
-               UniqSM (CoreBndr,                     -- Name of binder for this part of result 
-                      [(CoreExpr, Type)],            -- expressions for flattened result
-                      CoreExpr -> CoreExpr)          -- add in code to flatten result
-
-mk_cpr_case (ty, NoCPRInfo) 
-      -- this component must be returned as a component of the unboxed tuple result
-    = getUniqueUs            `thenUs`     \id_uniq   ->
-      let id_id = mk_ww_local id_uniq ty in
-        returnUs (id_id, [(Var id_id, ty)], id)
-mk_cpr_case (ty, cpr_info@(CPRInfo ci_args))
-    | isNewTyCon tycon  -- a new type: under the coercions must be a 
-                        -- constructed product
-    = ASSERT ( null $ tail inst_con_arg_tys )
-      mk_cpr_case (target_of_from_type, cpr_info) 
-                                 `thenUs`  \(arg, tup, exp) ->
-      getUniqueUs                `thenUs`  \id_uniq   ->
-      let id_id = mk_ww_local id_uniq ty 
-          new_exp_case = \var -> Case (Note (Coerce (idType arg) ty) (Var id_id))
-                                     arg
-                                     [(DEFAULT,[], exp var)]
-      in
-        returnUs (id_id, tup, new_exp_case)
-
-    | otherwise            -- a data type
-                           -- flatten components
-    = mapUs mk_cpr_case (zip inst_con_arg_tys ci_args) 
-                                 `thenUs`  \sub_builds ->
-      getUniqueUs                `thenUs`  \id_uniq   ->
-      let id_id = mk_ww_local id_uniq ty 
-          (args, tup, exp) = unzip3 sub_builds
-          -- not used: con_app = mkConApp data_con (map Var args) 
-          new_tup = concat tup
-          new_exp_case = \var -> Case (Var id_id) (mkWildId ty)
-                                [(DataCon data_con, args, 
-                                  foldl (\e f -> f e) var exp)]
-      in
-        returnUs (id_id, new_tup, new_exp_case)
+      returnUs (\ wkr_call -> Case wkr_call wrap_wild [(DataCon ubx_tup_con, args, con_app)],
+               \ body     -> Case body     work_wild [(DataCon data_con,    args, ubx_tup_app)],
+               ubx_tup_ty)
     where
-      (tycon, tycon_arg_tys, data_con, inst_con_arg_tys) = splitProductType "mk_cpr_case" ty
-      from_type = head inst_con_arg_tys
-      -- if coerced from a function 'look through' to find result type
-      target_of_from_type = (snd.splitFunTys.snd.splitForAllTys) from_type
-
-\end{code}
-
-@cpr_reconstruct@ does the opposite of @cpr_flatten@.  It takes the unboxed
-tuple produced by the worker and reconstructs the structured result.
-
-\begin{code}
-cpr_reconstruct :: Type -> CprInfo -> UniqSM (CoreExpr -> CoreExpr)
-cpr_reconstruct ty cpr_info
-    = mk_cpr_let (ty,cpr_info)     `thenUs`  \(res_id, tup_ids, reconstruct_exp) ->
-      returnUs (\worker -> Case worker (mkWildId $ worker_type tup_ids)
-                           [(DataCon $ unboxedTupleCon $ length tup_ids,
-                           tup_ids, reconstruct_exp $ Var res_id)])
-                            
-    where
-       worker_type ids = mkTyConApp (unboxedTupleTyCon (length ids)) (map idType ids) 
-
-
-mk_cpr_let :: (Type, CprInfo) -> 
-              UniqSM (CoreBndr,                -- Binder for this component of result 
-                      [CoreBndr],              -- Binders which will appear in worker's result
-                      CoreExpr -> CoreExpr)    -- Code to produce structured result.
-mk_cpr_let (ty, NoCPRInfo)
-      -- this component will appear explicitly in the unboxed tuple.
-    = getUniqueUs            `thenUs`     \id_uniq   ->
-      let
-       id_id = mk_ww_local id_uniq ty
-      in
-      returnUs (id_id, [id_id], id)
-
-mk_cpr_let (ty, cpr_info@(CPRInfo ci_args))
-
-{- Should not be needed now:  mkWWfixup does this job
-    | isNewTyCon tycon   -- a new type: must coerce the argument to this type
-    = ASSERT ( null $ tail inst_con_arg_tys )
-      mk_cpr_let (target_of_from_type, cpr_info) 
-                                 `thenUs`  \(arg, tup, exp) ->
-      getUniqueUs                `thenUs`  \id_uniq   ->
-      let id_id = mk_ww_local id_uniq ty 
-          new_exp = \var -> exp (Let (NonRec id_id (Note (Coerce ty (idType arg)) (Var arg))) var) 
-      in
-        returnUs (id_id, tup, new_exp)
-
-    | otherwise     -- a data type
-                    -- reconstruct components then apply data con
--}
-    = mapUs mk_cpr_let (zip inst_con_arg_tys ci_args) 
-                                 `thenUs`  \sub_builds ->
-      getUniqueUs                `thenUs`  \id_uniq   ->
-      let id_id = mk_ww_local id_uniq ty 
-          (args, tup, exp) = unzip3 sub_builds
-          con_app = mkConApp data_con $ (map Type tycon_arg_tys) ++ (map Var args) 
-          new_tup = concat tup
-          new_exp = \var -> foldl (\e f -> f e) (Let (NonRec id_id con_app) var) exp 
-      in
-        returnUs (id_id, new_tup, new_exp)
-    where
-      (tycon, tycon_arg_tys, data_con, inst_con_arg_tys) = splitProductType "mk_cpr_let" ty
-      from_type = head inst_con_arg_tys
-      -- if coerced from a function 'look through' to find result type
-      target_of_from_type = (snd.splitFunTys.snd.splitForAllTys) from_type
+      (tycon, tycon_arg_tys, data_con, con_arg_tys) = splitProductType "mkWWcpr" body_ty
+      n_con_args  = length con_arg_tys
+      con_arg_ty1 = head con_arg_tys
 
 
 splitProductType :: String -> Type -> (TyCon, [Type], DataCon, [Type])
@@ -664,12 +562,4 @@ mk_pk_let DataType arg boxing_con con_tys unpk_args body
 
 mk_ww_local uniq ty = mkSysLocal SLIT("ww") uniq ty
 
-
-mk_unboxed_tuple :: [(CoreExpr, Type)] -> (CoreExpr, Type)
-mk_unboxed_tuple contents
-    = (mkConApp (unboxedTupleCon (length contents)) 
-                (map (Type . snd) contents ++
-                 map fst contents),
-       mkTyConApp (unboxedTupleTyCon (length contents)) 
-                  (map snd contents))
 \end{code}