Monadify specialise/Specialise: use do, return, standard monad functions and MonadUnique
[ghc-hetmet.git] / compiler / specialise / Specialise.lhs
index 7a0d8bc..37d5d81 100644 (file)
@@ -4,6 +4,13 @@
 \section[Specialise]{Stamping out overloading, and (optionally) polymorphism}
 
 \begin{code}
+{-# OPTIONS -w #-}
+-- The above warning supression flag is a temporary kludge.
+-- While working on this module you are encouraged to remove it and fix
+-- any warnings in the module. See
+--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
+-- for details
+
 module Specialise ( specProgram ) where
 
 #include "HsVersions.h"
@@ -29,8 +36,8 @@ import CoreLint               ( showPass, endPass )
 import Rules           ( addIdSpecialisations, mkLocalRule, lookupRule, emptyRuleBase, rulesOfBinds )
 import PprCore         ( pprRules )
 import UniqSupply      ( UniqSupply,
-                         UniqSM, initUs_, thenUs, returnUs, getUniqueUs, 
-                         getUs, mapUs
+                         UniqSM, initUs_,
+                         MonadUnique(..)
                        )
 import Name
 import MkId            ( voidArgId, realWorldPrimId )
@@ -45,7 +52,6 @@ import Util           ( zipEqual, zipWithEqual, cmpList, lengthIs,
 import Outputable
 import FastString
 
-infixr 9 `thenSM`
 \end{code}
 
 %************************************************************************
@@ -576,12 +582,12 @@ Hence, the invariant is this:
 
 \begin{code}
 specProgram :: DynFlags -> UniqSupply -> [CoreBind] -> IO [CoreBind]
-specProgram dflags us binds
-  = do
+specProgram dflags us binds = do
+   
        showPass dflags "Specialise"
 
-       let binds' = initSM us (go binds        `thenSM` \ (binds', uds') ->
-                               returnSM (dumpAllDictBinds uds' binds'))
+       let binds' = initSM us (do (binds', uds') <- go binds
+                                  return (dumpAllDictBinds uds' binds'))
 
        endPass dflags "Specialise" Opt_D_dump_spec binds'
 
@@ -595,12 +601,12 @@ specProgram dflags us binds
        -- accidentally re-use a unique that's already in use
        -- Easiest thing is to do it all at once, as if all the top-level
        -- decls were mutually recursive
-    top_subst      = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
+    top_subst       = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
 
-    go []          = returnSM ([], emptyUDs)
-    go (bind:binds) = go binds                                 `thenSM` \ (binds', uds) ->
-                     specBind top_subst bind uds       `thenSM` \ (bind', uds') ->
-                     returnSM (bind' ++ binds', uds')
+    go []           = return ([], emptyUDs)
+    go (bind:binds) = do (binds', uds) <- go binds
+                         (bind', uds') <- specBind top_subst bind uds
+                         return (bind' ++ binds', uds')
 \end{code}
 
 %************************************************************************
@@ -621,73 +627,69 @@ specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails)
 --        the RHS of specialised bindings (no type-let!)
 
 ---------------- First the easy cases --------------------
-specExpr subst (Type ty) = returnSM (Type (substTy subst ty), emptyUDs)
-specExpr subst (Var v)   = returnSM (specVar subst v,         emptyUDs)
-specExpr subst (Lit lit) = returnSM (Lit lit,                emptyUDs)
-specExpr subst (Cast e co) =
-  specExpr subst e              `thenSM` \ (e', uds) ->
-  returnSM ((Cast e' (substTy subst co)), uds)
-specExpr subst (Note note body)
-  = specExpr subst body        `thenSM` \ (body', uds) ->
-    returnSM (Note (specNote subst note) body', uds)
+specExpr subst (Type ty) = return (Type (substTy subst ty), emptyUDs)
+specExpr subst (Var v)   = return (specVar subst v,         emptyUDs)
+specExpr subst (Lit lit) = return (Lit lit,                 emptyUDs)
+specExpr subst (Cast e co) = do
+    (e', uds) <- specExpr subst e
+    return ((Cast e' (substTy subst co)), uds)
+specExpr subst (Note note body) = do
+    (body', uds) <- specExpr subst body
+    return (Note (specNote subst note) body', uds)
 
 
 ---------------- Applications might generate a call instance --------------------
 specExpr subst expr@(App fun arg)
   = go expr []
   where
-    go (App fun arg) args = specExpr subst arg `thenSM` \ (arg', uds_arg) ->
-                           go fun (arg':args)  `thenSM` \ (fun', uds_app) ->
-                           returnSM (App fun' arg', uds_arg `plusUDs` uds_app)
+    go (App fun arg) args = do (arg', uds_arg) <- specExpr subst arg
+                               (fun', uds_app) <- go fun (arg':args)
+                               return (App fun' arg', uds_arg `plusUDs` uds_app)
 
     go (Var f)       args = case specVar subst f of
-                               Var f' -> returnSM (Var f', mkCallUDs subst f' args)
-                               e'     -> returnSM (e', emptyUDs)       -- I don't expect this!
+                                Var f' -> return (Var f', mkCallUDs subst f' args)
+                                e'     -> return (e', emptyUDs)        -- I don't expect this!
     go other        args = specExpr subst other
 
 ---------------- Lambda/case require dumping of usage details --------------------
-specExpr subst e@(Lam _ _)
-  = specExpr subst' body       `thenSM` \ (body', uds) ->
-    let
-       (filtered_uds, body'') = dumpUDs bndrs' uds body'
-    in
-    returnSM (mkLams bndrs' body'', filtered_uds)
+specExpr subst e@(Lam _ _) = do
+    (body', uds) <- specExpr subst' body
+    let (filtered_uds, body'') = dumpUDs bndrs' uds body'
+    return (mkLams bndrs' body'', filtered_uds)
   where
     (bndrs, body) = collectBinders e
     (subst', bndrs') = substBndrs subst bndrs
        -- More efficient to collect a group of binders together all at once
        -- and we don't want to split a lambda group with dumped bindings
 
-specExpr subst (Case scrut case_bndr ty alts)
-  = specExpr subst scrut               `thenSM` \ (scrut', uds_scrut) ->
-    mapAndCombineSM spec_alt alts      `thenSM` \ (alts', uds_alts) ->
-    returnSM (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
+specExpr subst (Case scrut case_bndr ty alts) = do
+    (scrut', uds_scrut) <- specExpr subst scrut
+    (alts', uds_alts) <- mapAndCombineSM spec_alt alts
+    return (Case scrut' case_bndr' (substTy subst ty) alts', uds_scrut `plusUDs` uds_alts)
   where
     (subst_alt, case_bndr') = substBndr subst case_bndr
        -- No need to clone case binder; it can't float like a let(rec)
 
-    spec_alt (con, args, rhs)
-       = specExpr subst_rhs rhs                `thenSM` \ (rhs', uds) ->
-         let
-            (uds', rhs'') = dumpUDs args uds rhs'
-         in
-         returnSM ((con, args', rhs''), uds')
-       where
-         (subst_rhs, args') = substBndrs subst_alt args
+    spec_alt (con, args, rhs) = do
+          (rhs', uds) <- specExpr subst_rhs rhs
+          let (uds', rhs'') = do dumpUDs args uds rhs'
+          return ((con, args', rhs''), uds')
+        where
+          (subst_rhs, args') = substBndrs subst_alt args
 
 ---------------- Finally, let is the interesting case --------------------
-specExpr subst (Let bind body)
-  =    -- Clone binders
-    cloneBindSM subst bind                     `thenSM` \ (rhs_subst, body_subst, bind') ->
-       
-       -- Deal with the body
-    specExpr body_subst body                   `thenSM` \ (body', body_uds) ->
+specExpr subst (Let bind body) = do
+       -- Clone binders
+    (rhs_subst, body_subst, bind') <- cloneBindSM subst bind
+
+        -- Deal with the body
+    (body', body_uds) <- specExpr body_subst body
 
-       -- Deal with the bindings
-    specBind rhs_subst bind' body_uds          `thenSM` \ (binds', uds) ->
+        -- Deal with the bindings
+    (binds', uds) <- specBind rhs_subst bind' body_uds
 
-       -- All done
-    returnSM (foldr Let body' binds', uds)
+        -- All done
+    return (foldr Let body' binds', uds)
 
 -- Must apply the type substitution to coerceions
 specNote subst note          = note
@@ -706,8 +708,8 @@ specBind :: Subst                   -- Use this for RHSs
         -> SpecM ([CoreBind],          -- New bindings
                   UsageDetails)        -- And info to pass upstream
 
-specBind rhs_subst bind body_uds
-  = specBindItself rhs_subst bind (calls body_uds)     `thenSM` \ (bind', bind_uds) ->
+specBind rhs_subst bind body_uds = do
+    (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds)
     let
        bndrs   = bindersOf bind
        all_uds = zapCalls bndrs (body_uds `plusUDs` bind_uds)
@@ -716,7 +718,6 @@ specBind rhs_subst bind body_uds
                        -- used in the calls passed to specDefn.  So the
                        -- dictionary bindings in bind_uds may mention 
                        -- dictionaries bound in body_uds.
-    in
     case splitUDs bndrs all_uds of
 
        (_, ([],[]))    -- This binding doesn't bind anything needed
@@ -724,10 +725,10 @@ specBind rhs_subst bind body_uds
                        -- This is the case for most non-dict bindings, except
                        -- for the few that are mentioned in a dict binding
                        -- that is floating upwards in body_uds
-               -> returnSM ([bind'], all_uds)
+               -> return ([bind'], all_uds)
 
        (float_uds, (dict_binds, calls))        -- This binding is needed in the UDs, so float it out
-               -> returnSM ([], float_uds `plusUDs` mkBigUD bind' dict_binds calls)
+               -> return ([], float_uds `plusUDs` mkBigUD bind' dict_binds calls)
    
 
 -- A truly gruesome function
@@ -751,26 +752,24 @@ mkBigUD bind dbs calls
 
 -- specBindItself deals with the RHS, specialising it according
 -- to the calls found in the body (if any)
-specBindItself rhs_subst (NonRec bndr rhs) call_info
-  = specDefn rhs_subst call_info (bndr,rhs)    `thenSM` \ ((bndr',rhs'), spec_defns, spec_uds) ->
+specBindItself rhs_subst (NonRec bndr rhs) call_info = do
+    ((bndr',rhs'), spec_defns, spec_uds) <- specDefn rhs_subst call_info (bndr,rhs)
     let
         new_bind | null spec_defns = NonRec bndr' rhs'
                  | otherwise       = Rec ((bndr',rhs'):spec_defns)
                -- bndr' mentions the spec_defns in its SpecEnv
                -- Not sure why we couln't just put the spec_defns first
-    in
-    returnSM (new_bind, spec_uds)
+    return (new_bind, spec_uds)
 
-specBindItself rhs_subst (Rec pairs) call_info
-  = mapSM (specDefn rhs_subst call_info) pairs `thenSM` \ stuff ->
+specBindItself rhs_subst (Rec pairs) call_info = do
+    stuff <- mapM (specDefn rhs_subst call_info) pairs
     let
        (pairs', spec_defns_s, spec_uds_s) = unzip3 stuff
        spec_defns = concat spec_defns_s
        spec_uds   = plusUDList spec_uds_s
         new_bind   = Rec (spec_defns ++ pairs')
-    in
-    returnSM (new_bind, spec_uds)
-    
+    return (new_bind, spec_uds)
+
 
 specDefn :: Subst                      -- Subst to use for RHS
         -> CallDetails                 -- Info on how it is used in its scope
@@ -783,31 +782,33 @@ specDefn :: Subst                 -- Subst to use for RHS
 
 specDefn subst calls (fn, rhs)
        -- The first case is the interesting one
-  |  rhs_tyvars `lengthIs` n_tyvars    -- Rhs of fn's defn has right number of big lambdas
-  && rhs_bndrs  `lengthAtLeast` n_dicts        -- and enough dict args
+  |  rhs_tyvars `lengthIs`     n_tyvars -- Rhs of fn's defn has right number of big lambdas
+  && rhs_ids    `lengthAtLeast` n_dicts        -- and enough dict args
   && notNull calls_for_me              -- And there are some calls to specialise
 
 --   && not (certainlyWillInline (idUnfolding fn))     -- And it's not small
 --     See Note [Inline specialisation] for why we do not 
---     switch off specialisation for inline functions
-
-  =   -- Specialise the body of the function
-    specExpr subst rhs                                 `thenSM` \ (rhs', rhs_uds) ->
+--     switch off specialisation for inline functions = do
+  = do
+     -- Specialise the body of the function
+    (rhs', rhs_uds) <- specExpr subst rhs
 
       -- Make a specialised version for each call in calls_for_me
-    mapSM spec_call calls_for_me               `thenSM` \ stuff ->
+    stuff <- mapM spec_call calls_for_me
     let
-       (spec_defns, spec_uds, spec_rules) = unzip3 stuff
+        (spec_defns, spec_uds, spec_rules) = unzip3 stuff
 
-       fn' = addIdSpecialisations fn spec_rules
-    in
-    returnSM ((fn',rhs'), 
-             spec_defns, 
-             rhs_uds `plusUDs` plusUDList spec_uds)
+        fn' = addIdSpecialisations fn spec_rules
+
+    return ((fn',rhs'),
+              spec_defns,
+              rhs_uds `plusUDs` plusUDList spec_uds)
 
   | otherwise  -- No calls or RHS doesn't fit our preconceptions
-  = specExpr subst rhs                 `thenSM` \ (rhs', rhs_uds) ->
-    returnSM ((fn, rhs'), [], rhs_uds)
+  = WARN( notNull calls_for_me, ptext SLIT("Missed specialisation opportunity for") <+> ppr fn ) do
+         -- Note [Specialisation shape]
+    (rhs', rhs_uds) <- specExpr subst rhs
+    return ((fn, rhs'), [], rhs_uds)
   
   where
     fn_type           = idType fn
@@ -837,7 +838,7 @@ specDefn subst calls (fn, rhs)
                        UsageDetails,                   -- Usage details from specialised body
                        CoreRule)                       -- Info for the Id's SpecEnv
     spec_call (CallKey call_ts, (call_ds, call_fvs))
-      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts )
+      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts ) do
                -- Calls are only recorded for properly-saturated applications
        
        -- Suppose f's defn is  f = /\ a b c d -> \ d1 d2 -> rhs        
@@ -864,8 +865,8 @@ specDefn subst calls (fn, rhs)
                         mk_ty_arg rhs_tyvar Nothing   = Type (mkTyVarTy rhs_tyvar)
                         mk_ty_arg rhs_tyvar (Just ty) = Type ty
           rhs_subst  = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts])
-       in
-       cloneBinders rhs_subst rhs_dicts                `thenSM` \ (rhs_subst', rhs_dicts') ->
+
+       (rhs_subst', rhs_dicts') <- cloneBinders rhs_subst rhs_dicts
        let
           inst_args = ty_args ++ map Var rhs_dicts'
 
@@ -876,14 +877,15 @@ specDefn subst calls (fn, rhs)
                = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [realWorldPrimId])
                | otherwise = (poly_tyvars, poly_tyvars)
           spec_id_ty = mkPiTypes lam_args body_ty
-       in
-       newIdSM fn spec_id_ty                           `thenSM` \ spec_f ->
-       specExpr rhs_subst' (mkLams lam_args body)      `thenSM` \ (spec_rhs, rhs_uds) ->       
+
+        spec_f <- newIdSM fn spec_id_ty
+        (spec_rhs, rhs_uds) <- specExpr rhs_subst' (mkLams lam_args body)
        let
                -- The rule to put in the function's specialisation is:
                --      forall b,d, d1',d2'.  f t1 b t3 d d1' d2' = f1 b d  
            spec_env_rule = mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr fn)))
-                               AlwaysActive (idName fn)
+                               inline_prag     -- Note [Auto-specialisation and RULES]
+                               (idName fn)
                                (poly_tyvars ++ rhs_dicts')
                                inst_args 
                                (mkVarApps (Var spec_f) app_args)
@@ -893,15 +895,74 @@ specDefn subst calls (fn, rhs)
 
           spec_pr | inline_rhs = (spec_f `setInlinePragma` inline_prag, Note InlineMe spec_rhs)
                   | otherwise  = (spec_f,                               spec_rhs)
-       in
-        returnSM (spec_pr, final_uds, spec_env_rule)
+
+        return (spec_pr, final_uds, spec_env_rule)
 
       where
        my_zipEqual doc xs ys 
-        | not (equalLength xs ys) = pprPanic "my_zipEqual" (ppr xs $$ ppr ys $$ (ppr fn <+> ppr call_ts) $$ ppr rhs)
+#ifdef DEBUG
+        | not (equalLength xs ys) = pprPanic "my_zipEqual" (vcat 
+                                               [ ppr xs, ppr ys
+                                               , ppr fn <+> ppr call_ts
+                                               , ppr (idType fn), ppr theta
+                                               , ppr n_dicts, ppr rhs_dicts 
+                                               , ppr rhs])
+#endif
         | otherwise               = zipEqual doc xs ys
 \end{code}
 
+Note [Auto-specialisation and RULES]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider:
+   g :: Num a => a -> a
+   g = ...
+
+   f :: (Int -> Int) -> Int
+   f w = ...
+   {-# RULE f g = 0 #-}
+
+Suppose that auto-specialisation makes a specialised version of
+g::Int->Int That version won't appear in the LHS of the RULE for f.
+So if the specialisation rule fires too early, the rule for f may
+never fire. 
+
+It might be possible to add new rules, to "complete" the rewrite system.
+Thus when adding
+       RULE forall d. g Int d = g_spec
+also add
+       RULE f g_spec = 0
+
+But that's a bit complicated.  For now we ask the programmer's help,
+by *copying the INLINE activation pragma* to the auto-specialised rule.
+So if g says {-# NOINLINE[2] g #-}, then the auto-spec rule will also
+not be active until phase 2.  
+
+
+Note [Specialisation shape]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We only specialise a function if it has visible top-level lambdas
+corresponding to its overloading.  E.g. if
+       f :: forall a. Eq a => ....
+then its body must look like
+       f = /\a. \d. ...
+
+Reason: when specialising the body for a call (f ty dexp), we want to
+substitute dexp for d, and pick up specialised calls in the body of f.
+
+This doesn't always work.  One example I came across was htis:
+       newtype Gen a = MkGen{ unGen :: Int -> a }
+
+       choose :: Eq a => a -> Gen a
+       choose n = MkGen (\r -> n)
+
+       oneof = choose (1::Int)
+
+It's a silly exapmle, but we get
+       choose = /\a. g `cast` co
+where choose doesn't have any dict arguments.  Thus far I have not
+tried to fix this (wait till there's a real example).
+
+
 Note [Inline specialisations]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 We transfer to the specialised function any INLINE stuff from the
@@ -1147,46 +1208,37 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs,
 \begin{code}
 type SpecM a = UniqSM a
 
-thenSM    = thenUs
-returnSM  = returnUs
-getUniqSM = getUniqueUs
-mapSM     = mapUs
 initSM   = initUs_
 
-mapAndCombineSM f []     = returnSM ([], emptyUDs)
-mapAndCombineSM f (x:xs) = f x `thenSM` \ (y, uds1) ->
-                          mapAndCombineSM f xs `thenSM` \ (ys, uds2) ->
-                          returnSM (y:ys, uds1 `plusUDs` uds2)
+mapAndCombineSM f []     = return ([], emptyUDs)
+mapAndCombineSM f (x:xs) = do (y, uds1) <- f x
+                              (ys, uds2) <- mapAndCombineSM f xs
+                              return (y:ys, uds1 `plusUDs` uds2)
 
 cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind)
 -- Clone the binders of the bind; return new bind with the cloned binders
 -- Return the substitution to use for RHSs, and the one to use for the body
-cloneBindSM subst (NonRec bndr rhs)
-  = getUs      `thenUs` \ us ->
-    let
-       (subst', bndr') = cloneIdBndr subst us bndr
-    in
-    returnUs (subst, subst', NonRec bndr' rhs)
-
-cloneBindSM subst (Rec pairs)
-  = getUs      `thenUs` \ us ->
+cloneBindSM subst (NonRec bndr rhs) = do
+    us <- getUniqueSupplyM
+    let (subst', bndr') = do cloneIdBndr subst us bndr
+    return (subst, subst', NonRec bndr' rhs)
+
+cloneBindSM subst (Rec pairs) = do
+    us <- getUniqueSupplyM
+    let (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
+    return (subst', subst', Rec (bndrs' `zip` map snd pairs))
+
+cloneBinders subst bndrs = do
+    us <- getUniqueSupplyM
+    return (cloneIdBndrs subst us bndrs)
+
+newIdSM old_id new_ty = do
+    uniq <- getUniqueM
     let
-       (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs)
-    in
-    returnUs (subst', subst', Rec (bndrs' `zip` map snd pairs))
-
-cloneBinders subst bndrs
-  = getUs      `thenUs` \ us ->
-    returnUs (cloneIdBndrs subst us bndrs)
-
-newIdSM old_id new_ty
-  = getUniqSM          `thenSM` \ uniq ->
-    let 
-       -- Give the new Id a similar occurrence name to the old one
-       name   = idName old_id
-       new_id = mkUserLocal (mkSpecOcc (nameOccName name)) uniq new_ty (getSrcSpan name)
-    in
-    returnSM new_id
+        -- Give the new Id a similar occurrence name to the old one
+        name   = idName old_id
+        new_id = mkUserLocal (mkSpecOcc (nameOccName name)) uniq new_ty (getSrcSpan name)
+    return new_id
 \end{code}