Implement auto-specialisation of imported Ids
[ghc-hetmet.git] / compiler / deSugar / DsBinds.lhs
index c886c8e..7e922fd 100644 (file)
@@ -10,7 +10,7 @@ in that the @Rec@/@NonRec@/etc structure is thrown away (whereas at
 lower levels it is preserved with @let@/@letrec@s).
 
 \begin{code}
-module DsBinds ( dsTopLHsBinds, dsLHsBinds, decomposeRuleLhs, 
+module DsBinds ( dsTopLHsBinds, dsLHsBinds, decomposeRuleLhs, dsSpec,
                 dsHsWrapper, dsTcEvBinds, dsEvBinds, wrapDsEvBinds, 
                 DsEvBind(..), AutoScc(..)
   ) where
@@ -69,9 +69,8 @@ import MonadUtils
 %************************************************************************
 
 \begin{code}
-dsTopLHsBinds :: AutoScc -> LHsBinds Id -> DsM [(Id,CoreExpr)]
-dsTopLHsBinds auto_scc binds = do { binds' <- ds_lhs_binds auto_scc binds
-                                  ; return (fromOL binds') }
+dsTopLHsBinds :: AutoScc -> LHsBinds Id -> DsM (OrdList (Id,CoreExpr))
+dsTopLHsBinds auto_scc binds = ds_lhs_binds auto_scc binds
 
 dsLHsBinds :: LHsBinds Id -> DsM [(Id,CoreExpr)]
 dsLHsBinds binds = do { binds' <- ds_lhs_binds NoSccs binds
@@ -107,91 +106,16 @@ dsHsBind _ (FunBind { fun_id = L _ fun, fun_matches = matches
  = do  { (args, body) <- matchWrapper (FunRhs (idName fun) inf) matches
        ; body'    <- mkOptTickBox tick body
        ; wrap_fn' <- dsHsWrapper co_fn 
-       ; return (unitOL (fun, wrap_fn' (mkLams args body'))) }
+       ; let rhs = wrap_fn' (mkLams args body')
+       ; return (unitOL (makeCorePair fun False 0 rhs)) }
 
 dsHsBind _ (PatBind { pat_lhs = pat, pat_rhs = grhss, pat_rhs_ty = ty })
   = do { body_expr <- dsGuarded grhss ty
        ; sel_binds <- mkSelectorBinds pat body_expr
+         -- We silently ignore inline pragmas; no makeCorePair
+         -- Not so cool, but really doesn't matter
        ; return (toOL sel_binds) }
 
-{-
-dsHsBind auto_scc (AbsBinds { abs_tvs = [], abs_ev_vars = []
-                                   , abs_exports = exports, abs_ev_binds = ev_binds
-                                   , abs_binds = binds })
-  = do { bind_prs    <- ds_lhs_binds NoSccs binds
-        ; ds_ev_binds <- dsTcEvBinds ev_binds
-
-       ; let core_prs = addEvPairs ds_ev_binds bind_prs
-              env = mkABEnv exports
-             do_one (lcl_id, rhs) 
-               | Just (_, gbl_id, _, spec_prags) <- lookupVarEnv env lcl_id
-               = do { let rhs' = addAutoScc auto_scc gbl_id rhs
-                    ; (spec_binds, rules) <- dsSpecs gbl_id (Let (Rec core_prs) rhs') spec_prags
-                                   -- See Note [Specialising in no-dict case]
-                     ; let   gbl_id'   = addIdSpecialisations gbl_id rules
-                             main_bind = makeCorePair gbl_id' False 0 rhs'
-                    ; return (main_bind : spec_binds) }
-
-               | otherwise = return [(lcl_id, rhs)]
-
-             locals'  = [(lcl_id, Var gbl_id) | (_, gbl_id, lcl_id, _) <- exports]
-                       -- Note [Rules and inlining]
-        ; export_binds <- mapM do_one core_prs
-       ; return (concat export_binds ++ locals' ++ rest) }
-               -- No Rec needed here (contrast the other AbsBinds cases)
-               -- because we can rely on the enclosing dsBind to wrap in Rec
-
-
-dsHsBind auto_scc rest (AbsBinds { abs_tvs = tyvars, abs_ev_vars = []
-                                        , abs_exports = exports, abs_ev_binds = ev_binds
-                                        , abs_binds = binds })
-  | opt_DsMultiTyVar   -- This (static) debug flag just lets us
-                       -- switch on and off this optimisation to
-                       -- see if it has any impact; it is on by default
-  , allOL isLazyEvBind ev_binds
-  =    -- Note [Abstracting over tyvars only]
-    do { bind_prs    <- ds_lhs_binds NoSccs binds
-        ; ds_ev_binds <- dsTcEvBinds ev_binds
-
-       ; let core_prs = addEvPairs ds_ev_binds bind_prs
-              arby_env = mkArbitraryTypeEnv tyvars exports
-             bndrs = mkVarSet (map fst core_prs)
-
-             add_lets | core_prs `lengthExceeds` 10 = add_some
-                      | otherwise                   = mkLets
-             add_some lg_binds rhs = mkLets [ NonRec b r | NonRec b r <- lg_binds
-                                                         , b `elemVarSet` fvs] rhs
-               where
-                 fvs = exprSomeFreeVars (`elemVarSet` bndrs) rhs
-
-             env = mkABEnv exports
-             mk_lg_bind lcl_id gbl_id tyvars
-                = NonRec (setIdInfo lcl_id vanillaIdInfo)
-                               -- Nuke the IdInfo so that no old unfoldings
-                               -- confuse use (it might mention something not
-                               -- even in scope at the new site
-                         (mkTyApps (Var gbl_id) (mkTyVarTys tyvars))
-
-             do_one lg_binds (lcl_id, rhs) 
-               | Just (id_tvs, gbl_id, _, spec_prags) <- lookupVarEnv env lcl_id
-               = do { let rhs' = addAutoScc auto_scc gbl_id  $
-                                 mkLams id_tvs $
-                                 mkLets [ NonRec tv (Type (lookupVarEnv_NF arby_env tv))
-                                        | tv <- tyvars, not (tv `elem` id_tvs)] $
-                                 add_lets lg_binds rhs
-                    ; (spec_binds, rules) <- dsSpecs gbl_id rhs' spec_prags
-                     ; let   gbl_id'   = addIdSpecialisations gbl_id rules
-                             main_bind = makeCorePair gbl_id' False 0 rhs'
-                    ; return (mk_lg_bind lcl_id gbl_id' id_tvs, main_bind : spec_binds) }
-               | otherwise
-               = do { non_exp_gbl_id <- newUniqueId lcl_id (mkForAllTys tyvars (idType lcl_id))
-                    ; return (mk_lg_bind lcl_id non_exp_gbl_id tyvars,
-                              [(non_exp_gbl_id, mkLams tyvars (add_lets lg_binds rhs))]) }
-                                                 
-       ; (_, core_prs') <- fixDs (\ ~(lg_binds, _) -> mapAndUnzipM (do_one lg_binds) core_prs)
-       ; return (concat core_prs' ++ rest) }
--}
-
        -- A common case: one exported variable
        -- Non-recursive bindings come through this way
        -- So do self-recursive bindings, and recursive bindings
@@ -210,7 +134,7 @@ dsHsBind auto_scc (AbsBinds { abs_tvs = all_tyvars, abs_ev_vars = dicts
                             Let core_bind $
                             Var local
     
-       ; (spec_binds, rules) <- dsSpecs global rhs prags
+       ; (spec_binds, rules) <- dsSpecs rhs prags
 
        ; let   global'   = addIdSpecialisations global rules
                main_bind = makeCorePair global' (isDefaultMethod prags)
@@ -253,9 +177,9 @@ dsHsBind auto_scc (AbsBinds { abs_tvs = all_tyvars, abs_ev_vars = dicts
                                 mkTupleSelector locals' (locals' !! n) tup_id $
                                 mkVarApps (mkTyApps (Var poly_tup_id) ty_args)
                                           dicts
-                    ; (spec_binds, rules) <- dsSpecs global
-                                                     (Let (NonRec poly_tup_id poly_tup_rhs) rhs)
-                                                     spec_prags
+                           full_rhs = Let (NonRec poly_tup_id poly_tup_rhs) rhs
+                    ; (spec_binds, rules) <- dsSpecs full_rhs spec_prags
+                                                     
                     ; let global' = addIdSpecialisations global rules
                     ; return ((global', rhs) `consOL` spec_binds) }
                where
@@ -355,21 +279,29 @@ makeCorePair gbl_id is_default_method dict_arity rhs
   | is_default_method                -- Default methods are *always* inlined
   = (gbl_id `setIdUnfolding` mkCompulsoryUnfolding rhs, rhs)
 
-  | not (isInlinePragma inline_prag)
-  = (gbl_id, rhs)
+  | otherwise
+  = case inlinePragmaSpec inline_prag of
+         EmptyInlineSpec -> (gbl_id, rhs)
+         NoInline        -> (gbl_id, rhs)
+         Inlinable       -> (gbl_id `setIdUnfolding` inlinable_unf, rhs)
+          Inline          -> inline_pair
 
-  | Just arity <- inlinePragmaSat inline_prag
+  where
+    inline_prag   = idInlinePragma gbl_id
+    inlinable_unf = mkInlinableUnfolding rhs
+    inline_pair
+       | Just arity <- inlinePragmaSat inline_prag
        -- Add an Unfolding for an INLINE (but not for NOINLINE)
        -- And eta-expand the RHS; see Note [Eta-expanding INLINE things]
-  , let real_arity = dict_arity + arity
+       , let real_arity = dict_arity + arity
         -- NB: The arity in the InlineRule takes account of the dictionaries
-  = (gbl_id `setIdUnfolding` mkInlineRule rhs (Just real_arity),
-     etaExpand real_arity rhs)
+       = ( gbl_id `setIdUnfolding` mkInlineUnfolding (Just real_arity) rhs
+         , etaExpand real_arity rhs)
+
+       | otherwise
+       = pprTrace "makeCorePair: arity missing" (ppr gbl_id) $
+         (gbl_id `setIdUnfolding` mkInlineUnfolding Nothing rhs, rhs)
 
-  | otherwise
-  = (gbl_id `setIdUnfolding` mkInlineRule rhs Nothing, rhs)
-  where
-    inline_prag = idInlinePragma gbl_id
 
 dictArity :: [Var] -> Arity
 -- Don't count coercion variables in arity
@@ -409,7 +341,7 @@ This does not happen in the same way to polymorphic binds,
 because they desugar to
        M.f = /\a. let f_lcl = ...f_lcl... in f_lcl
 Although I'm a bit worried about whether full laziness might
-float the f_lcl binding out and then inline M.f at its call site -}
+float the f_lcl binding out and then inline M.f at its call site
 
 Note [Specialising in no-dict case]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -542,66 +474,69 @@ Note that
 
 \begin{code}
 ------------------------
-dsSpecs :: Id          -- The polymorphic Id
-        -> CoreExpr     -- Its rhs
+dsSpecs :: CoreExpr     -- Its rhs
         -> TcSpecPrags
         -> DsM ( OrdList (Id,CoreExpr)         -- Binding for specialised Ids
               , [CoreRule] )           -- Rules for the Global Ids
 -- See Note [Implementing SPECIALISE pragmas]
-dsSpecs poly_id poly_rhs prags
-  = case prags of
-      IsDefaultMethod      -> return (nilOL, [])
-      SpecPrags sps -> do { pairs <- mapMaybeM spec_one sps
-                          ; let (spec_binds_s, rules) = unzip pairs
-                          ; return (concatOL spec_binds_s, rules) }
- where 
-    spec_one :: Located TcSpecPrag -> DsM (Maybe (OrdList (Id,CoreExpr), CoreRule))
-    spec_one (L loc (SpecPrag spec_co spec_inl))
-      = putSrcSpanDs loc $ 
-        do { let poly_name = idName poly_id
-          ; spec_name <- newLocalName poly_name
-          ; wrap_fn   <- dsHsWrapper spec_co
-           ; let (bndrs, ds_lhs) = collectBinders (wrap_fn (Var poly_id))
-                 spec_ty = mkPiTypes bndrs (exprType ds_lhs)
-          ; case decomposeRuleLhs ds_lhs of {
-              Nothing -> do { warnDs (decomp_msg spec_co)
-                             ; return Nothing } ;
-
-              Just (_fn, args) ->
-
-          -- Check for dead binders: Note [Unused spec binders]
-             let arg_fvs = exprsFreeVars args
-                 bad_bndrs = filterOut (`elemVarSet` arg_fvs) bndrs
-            in if not (null bad_bndrs)
-                then do { warnDs (dead_msg bad_bndrs); return Nothing } 
-               else do
-
-          { (spec_unf, unf_pairs) <- specUnfolding wrap_fn spec_ty (realIdUnfolding poly_id)
-
-          ; let spec_id  = mkLocalId spec_name spec_ty 
-                           `setInlinePragma` inl_prag
-                           `setIdUnfolding`  spec_unf
-                inl_prag | isDefaultInlinePragma spec_inl = idInlinePragma poly_id
-                         | otherwise                      = spec_inl
-                     -- Get the INLINE pragma from SPECIALISE declaration, or,
-                      -- failing that, from the original Id
-
-                extra_dict_bndrs = [ mkLocalId (localiseName (idName d)) (idType d)
-                                            -- See Note [Constant rule dicts]
-                                   | d <- varSetElems (arg_fvs `delVarSetList` bndrs)
-                                   , isDictId d]
-
-                rule =  mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr poly_name)))
-                               AlwaysActive poly_name
-                               (extra_dict_bndrs ++ bndrs) args
-                               (mkVarApps (Var spec_id) bndrs)
-
-                 spec_rhs  = wrap_fn poly_rhs
-                 spec_pair = makeCorePair spec_id False (dictArity bndrs) spec_rhs
-
-           ; return (Just (spec_pair `consOL` unf_pairs, rule))
-           } } }
-
+dsSpecs _ IsDefaultMethod = return (nilOL, [])
+dsSpecs poly_rhs (SpecPrags sps)
+  = do { pairs <- mapMaybeM (dsSpec (Just poly_rhs)) sps
+       ; let (spec_binds_s, rules) = unzip pairs
+       ; return (concatOL spec_binds_s, rules) }
+
+dsSpec :: Maybe CoreExpr       -- Just rhs => RULE is for a local binding
+                                       -- Nothing => RULE is for an imported Id
+                               --            rhs is in the Id's unfolding
+       -> Located TcSpecPrag
+       -> DsM (Maybe (OrdList (Id,CoreExpr), CoreRule))
+dsSpec mb_poly_rhs (L loc (SpecPrag poly_id spec_co spec_inl))
+  = putSrcSpanDs loc $ 
+    do { let poly_name = idName poly_id
+       ; spec_name <- newLocalName poly_name
+       ; wrap_fn   <- dsHsWrapper spec_co
+       ; let (bndrs, ds_lhs) = collectBinders (wrap_fn (Var poly_id))
+             spec_ty = mkPiTypes bndrs (exprType ds_lhs)
+       ; case decomposeRuleLhs ds_lhs of {
+          Nothing -> do { warnDs (decomp_msg spec_co)
+                        ; return Nothing } ;
+
+          Just (_fn, args) ->
+
+         -- Check for dead binders: Note [Unused spec binders]
+         let arg_fvs = exprsFreeVars args
+             bad_bndrs = filterOut (`elemVarSet` arg_fvs) bndrs
+         in if not (null bad_bndrs)
+            then do { warnDs (dead_msg bad_bndrs); return Nothing } 
+                   else do
+
+       { (spec_unf, unf_pairs) <- specUnfolding wrap_fn spec_ty (realIdUnfolding poly_id)
+
+       ; let spec_id  = mkLocalId spec_name spec_ty 
+                           `setInlinePragma` inl_prag
+                           `setIdUnfolding`  spec_unf
+             inl_prag | isDefaultInlinePragma spec_inl = idInlinePragma poly_id
+                      | otherwise                      = spec_inl
+                     -- Get the INLINE pragma from SPECIALISE declaration, or,
+              -- failing that, from the original Id
+
+             extra_dict_bndrs = [ mkLocalId (localiseName (idName d)) (idType d)
+                                       -- See Note [Constant rule dicts]
+                               | d <- varSetElems (arg_fvs `delVarSetList` bndrs)
+                               , isDictId d]
+
+             rule =  mkRule False {- Not auto -} is_local_id
+                        (mkFastString ("SPEC " ++ showSDoc (ppr poly_name)))
+                               AlwaysActive poly_name
+                               (extra_dict_bndrs ++ bndrs) args
+                               (mkVarApps (Var spec_id) bndrs)
+
+             spec_rhs  = wrap_fn poly_rhs
+             spec_pair = makeCorePair spec_id False (dictArity bndrs) spec_rhs
+
+       ; return (Just (spec_pair `consOL` unf_pairs, rule))
+       } } }
+  where
     dead_msg bs = vcat [ sep [ptext (sLit "Useless constraint") <> plural bs
                                 <+> ptext (sLit "in specialied type:"),
                             nest 2 (pprTheta (map get_pred bs))]
@@ -612,6 +547,15 @@ dsSpecs poly_id poly_rhs prags
         = hang (ptext (sLit "Specialisation too complicated to desugar; ignored"))
             2 (pprHsWrapper (ppr poly_id) spec_co)
             
+    is_local_id = isJust mb_poly_rhs
+    poly_rhs | Just rhs <-  mb_poly_rhs
+             = rhs
+             | Just unfolding <- maybeUnfoldingTemplate (idUnfolding poly_id)
+             = unfolding
+             | otherwise = pprPanic "dsImpSpecs" (ppr poly_id)
+       -- In the Nothing case the specialisation is for an imported Id
+       -- whose unfolding gives the RHS to be specialised
+        -- The type checker has checked that it has an unfolding
 
 specUnfolding :: (CoreExpr -> CoreExpr) -> Type 
               -> Unfolding -> DsM (Unfolding, OrdList (Id,CoreExpr))