Improve the handling of default methods
[ghc-hetmet.git] / compiler / deSugar / DsBinds.lhs
index 3fe8d54..9e29c96 100644 (file)
@@ -142,12 +142,10 @@ dsHsBind _ rest
 dsHsBind auto_scc rest (AbsBinds [] [] exports binds)
   = do { core_prs <- ds_lhs_binds NoSccs binds
        ; let env = mkABEnv exports
 dsHsBind auto_scc rest (AbsBinds [] [] exports binds)
   = do { core_prs <- ds_lhs_binds NoSccs binds
        ; let env = mkABEnv exports
-             ar_env = mkArityEnv binds
              do_one (lcl_id, rhs) 
                | Just (_, gbl_id, _, spec_prags) <- lookupVarEnv env lcl_id
              do_one (lcl_id, rhs) 
                | Just (_, gbl_id, _, spec_prags) <- lookupVarEnv env lcl_id
-               = WARN( not (null spec_prags), ppr gbl_id $$ ppr spec_prags )     -- Not overloaded
-                  makeCorePair gbl_id (lookupArity ar_env lcl_id)
-                              (addAutoScc auto_scc gbl_id rhs)
+               = WARN( hasSpecPrags spec_prags, pprTcSpecPrags gbl_id spec_prags )       -- Not overloaded
+                  makeCorePair gbl_id False 0 (addAutoScc auto_scc gbl_id rhs)
 
                | otherwise = (lcl_id, rhs)
 
 
                | otherwise = (lcl_id, rhs)
 
@@ -217,9 +215,7 @@ dsHsBind auto_scc rest (AbsBinds tyvars [] exports binds)
                where
                  fvs = exprSomeFreeVars (`elemVarSet` bndrs) rhs
 
                where
                  fvs = exprSomeFreeVars (`elemVarSet` bndrs) rhs
 
-             ar_env = mkArityEnv binds
              env = mkABEnv exports
              env = mkABEnv exports
-
              mk_lg_bind lcl_id gbl_id tyvars
                 = NonRec (setIdInfo lcl_id vanillaIdInfo)
                                -- Nuke the IdInfo so that no old unfoldings
              mk_lg_bind lcl_id gbl_id tyvars
                 = NonRec (setIdInfo lcl_id vanillaIdInfo)
                                -- Nuke the IdInfo so that no old unfoldings
@@ -229,14 +225,14 @@ dsHsBind auto_scc rest (AbsBinds tyvars [] exports binds)
 
              do_one lg_binds (lcl_id, rhs) 
                | Just (id_tvs, gbl_id, _, spec_prags) <- lookupVarEnv env lcl_id
 
              do_one lg_binds (lcl_id, rhs) 
                | Just (id_tvs, gbl_id, _, spec_prags) <- lookupVarEnv env lcl_id
-               = WARN( not (null spec_prags), ppr gbl_id $$ ppr spec_prags )     -- Not overloaded
+               = WARN( hasSpecPrags spec_prags, pprTcSpecPrags gbl_id spec_prags )       -- Not overloaded
                   (let rhs' = addAutoScc auto_scc gbl_id  $
                              mkLams id_tvs $
                              mkLets [ NonRec tv (Type (lookupVarEnv_NF arby_env tv))
                                     | tv <- tyvars, not (tv `elem` id_tvs)] $
                              add_lets lg_binds rhs
                  in return (mk_lg_bind lcl_id gbl_id id_tvs,
                   (let rhs' = addAutoScc auto_scc gbl_id  $
                              mkLams id_tvs $
                              mkLets [ NonRec tv (Type (lookupVarEnv_NF arby_env tv))
                                     | tv <- tyvars, not (tv `elem` id_tvs)] $
                              add_lets lg_binds rhs
                  in return (mk_lg_bind lcl_id gbl_id id_tvs,
-                            makeCorePair gbl_id (lookupArity ar_env lcl_id) rhs'))
+                            makeCorePair gbl_id False 0 rhs'))
                | otherwise
                = do { non_exp_gbl_id <- newUniqueId lcl_id (mkForAllTys tyvars (idType lcl_id))
                     ; return (mk_lg_bind lcl_id non_exp_gbl_id tyvars,
                | otherwise
                = do { non_exp_gbl_id <- newUniqueId lcl_id (mkForAllTys tyvars (idType lcl_id))
                     ; return (mk_lg_bind lcl_id non_exp_gbl_id tyvars,
@@ -254,25 +250,24 @@ dsHsBind auto_scc rest
   = ASSERT( all (`elem` tyvars) all_tyvars )
     do { core_prs <- ds_lhs_binds NoSccs binds
 
   = ASSERT( all (`elem` tyvars) all_tyvars )
     do { core_prs <- ds_lhs_binds NoSccs binds
 
-       ; let   -- Always treat the binds as recursive, because the typechecker
-               -- makes rather mixed-up dictionary bindings
+       ; let   -- Always treat the binds as recursive, because the 
+               -- typechecker makes rather mixed-up dictionary bindings
                core_bind = Rec core_prs
                core_bind = Rec core_prs
-               inl_arity = lookupArity (mkArityEnv binds) local
     
        ; (spec_binds, rules) <- dsSpecs all_tyvars dicts tyvars global 
     
        ; (spec_binds, rules) <- dsSpecs all_tyvars dicts tyvars global 
-                                        local inl_arity core_bind prags
+                                        local core_bind prags
 
        ; let   global'   = addIdSpecialisations global rules
                rhs       = addAutoScc auto_scc global $
                            mkLams tyvars $ mkLams dicts $ Let core_bind (Var local)
 
        ; let   global'   = addIdSpecialisations global rules
                rhs       = addAutoScc auto_scc global $
                            mkLams tyvars $ mkLams dicts $ Let core_bind (Var local)
-               main_bind = makeCorePair global' (inl_arity + dictArity dicts) rhs
+               main_bind = makeCorePair global' (isDefaultMethod prags)
+                                         (dictArity dicts) rhs 
     
        ; return (main_bind : spec_binds ++ rest) }
 
 dsHsBind auto_scc rest (AbsBinds all_tyvars dicts exports binds)
   = do { core_prs <- ds_lhs_binds NoSccs binds
        ; let env = mkABEnv exports
     
        ; return (main_bind : spec_binds ++ rest) }
 
 dsHsBind auto_scc rest (AbsBinds all_tyvars dicts exports binds)
   = do { core_prs <- ds_lhs_binds NoSccs binds
        ; let env = mkABEnv exports
-             ar_env = mkArityEnv binds
              do_one (lcl_id,rhs) | Just (_, gbl_id, _, _prags) <- lookupVarEnv env lcl_id
                                  = (lcl_id, addAutoScc auto_scc gbl_id rhs)
                                  | otherwise = (lcl_id,rhs)
              do_one (lcl_id,rhs) | Just (_, gbl_id, _, _prags) <- lookupVarEnv env lcl_id
                                  = (lcl_id, addAutoScc auto_scc gbl_id rhs)
                                  | otherwise = (lcl_id,rhs)
@@ -297,7 +292,7 @@ dsHsBind auto_scc rest (AbsBinds all_tyvars dicts exports binds)
                     ; locals' <- newSysLocalsDs (map substitute local_tys)
                     ; tup_id  <- newSysLocalDs  (substitute tup_ty)
                     ; (spec_binds, rules) <- dsSpecs all_tyvars dicts tyvars global local 
                     ; locals' <- newSysLocalsDs (map substitute local_tys)
                     ; tup_id  <- newSysLocalDs  (substitute tup_ty)
                     ; (spec_binds, rules) <- dsSpecs all_tyvars dicts tyvars global local 
-                                                     (lookupArity ar_env local) core_bind 
+                                                     core_bind 
                                                      spec_prags
                     ; let global' = addIdSpecialisations global rules
                           rhs = mkLams tyvars $ mkLams dicts $
                                                      spec_prags
                     ; let global' = addIdSpecialisations global rules
                           rhs = mkLams tyvars $ mkLams dicts $
@@ -317,50 +312,40 @@ dsHsBind auto_scc rest (AbsBinds all_tyvars dicts exports binds)
                    (concat export_binds_s ++ rest)) }
 
 ------------------------
                    (concat export_binds_s ++ rest)) }
 
 ------------------------
-makeCorePair :: Id-> Arity -> CoreExpr -> (Id, CoreExpr)
-makeCorePair gbl_id arity rhs
-  | isInlinePragma (idInlinePragma gbl_id)
+makeCorePair :: Id -> Bool -> Arity -> CoreExpr -> (Id, CoreExpr)
+makeCorePair gbl_id is_default_method dict_arity rhs
+  | is_default_method                -- Default methods are *always* inlined
+  = (gbl_id `setIdUnfolding` mkCompulsoryUnfolding rhs, rhs)
+
+  | not (isInlinePragma inline_prag)
+  = (gbl_id, rhs)
+
+  | Just arity <- inlinePragmaSat inline_prag
        -- Add an Unfolding for an INLINE (but not for NOINLINE)
        -- And eta-expand the RHS; see Note [Eta-expanding INLINE things]
        -- Add an Unfolding for an INLINE (but not for NOINLINE)
        -- And eta-expand the RHS; see Note [Eta-expanding INLINE things]
-  = (gbl_id `setIdUnfolding` mkInlineRule InlSat rhs arity,
+  = (gbl_id `setIdUnfolding` mkInlineRule rhs (Just (dict_arity + arity)),
+           -- NB: The arity in the InlineRule takes account of the dictionaries
      etaExpand arity rhs)
      etaExpand arity rhs)
+
   | otherwise
   | otherwise
-  = (gbl_id, rhs)
+  = (gbl_id `setIdUnfolding` mkInlineRule rhs Nothing, rhs)
+  where
+    inline_prag = idInlinePragma gbl_id
+
+dictArity :: [Var] -> Arity
+-- Don't count coercion variables in arity
+dictArity dicts = count isId dicts
+
 
 ------------------------
 
 ------------------------
-type AbsBindEnv = VarEnv ([TyVar], Id, Id, [LSpecPrag])
+type AbsBindEnv = VarEnv ([TyVar], Id, Id, TcSpecPrags)
        -- Maps the "lcl_id" for an AbsBind to
        -- its "gbl_id" and associated pragmas, if any
 
        -- Maps the "lcl_id" for an AbsBind to
        -- its "gbl_id" and associated pragmas, if any
 
-mkABEnv :: [([TyVar], Id, Id, [LSpecPrag])] -> AbsBindEnv
+mkABEnv :: [([TyVar], Id, Id, TcSpecPrags)] -> AbsBindEnv
 -- Takes the exports of a AbsBinds, and returns a mapping
 --     lcl_id -> (tyvars, gbl_id, lcl_id, prags)
 mkABEnv exports = mkVarEnv [ (lcl_id, export) | export@(_, _, lcl_id, _) <- exports]
 -- Takes the exports of a AbsBinds, and returns a mapping
 --     lcl_id -> (tyvars, gbl_id, lcl_id, prags)
 mkABEnv exports = mkVarEnv [ (lcl_id, export) | export@(_, _, lcl_id, _) <- exports]
-
-mkArityEnv :: LHsBinds Id -> IdEnv Arity
-       -- Maps a local to the arity of its definition
-mkArityEnv binds = foldrBag (plusVarEnv . lhsBindArity) emptyVarEnv binds
-
-lhsBindArity :: LHsBind Id -> IdEnv Arity
-lhsBindArity (L _ (FunBind { fun_id = id, fun_matches = ms })) 
-  = unitVarEnv (unLoc id) (matchGroupArity ms)
-lhsBindArity (L _ (AbsBinds { abs_exports = exports
-                            , abs_dicts = dicts
-                            , abs_binds = binds })) 
-  = mkVarEnv [ (gbl, lookupArity ar_env lcl + n_val_dicts) 
-             | (_, gbl, lcl, _) <- exports]
-  where             -- See Note [Nested arities] 
-    ar_env = mkArityEnv binds
-    n_val_dicts = dictArity dicts      
-
-lhsBindArity _ = emptyVarEnv   -- PatBind/VarBind
-
-dictArity :: [Var] -> Arity
--- Don't count coercion variables in arity
-dictArity dicts = count isId dicts
-
-lookupArity :: IdEnv Arity -> Id -> Arity
-lookupArity ar_env id = lookupVarEnv ar_env id `orElse` 0
 \end{code}
 
 Note [Eta-expanding INLINE things]
 \end{code}
 
 Note [Eta-expanding INLINE things]
@@ -397,44 +382,57 @@ gotten from the binding for fromT_1.
 It might be better to have just one level of AbsBinds, but that requires more
 thought!
 
 It might be better to have just one level of AbsBinds, but that requires more
 thought!
 
+Note [Implementing SPECIALISE pragmas]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Example:
+       f :: (Eq a, Ix b) => a -> b -> Bool
+       {-# SPECIALISE f :: (Ix p, Ix q) => Int -> (p,q) -> Bool #-}
+
+From this the typechecker generates
+
+    AbsBinds [ab] [d1,d2] [([ab], f, f_mono, prags)] binds
+
+    SpecPrag (wrap_fn :: forall a b. (Eq a, Ix b) => XXX
+                      -> forall p q. (Ix p, Ix q) => XXX[ Int/a, (p,q)/b ])
+
+Note that wrap_fn can transform *any* function with the right type prefix 
+    forall ab. (Eq a, Ix b) => <blah>
+regardless of <blah>.  It's sort of polymorphic in <blah>.  This is
+useful: we use the same wrapper to transform each of the class ops, as
+well as the dict.
+
+From these we generate:
+
+    Rule:      forall p, q, (dp:Ix p), (dq:Ix q). 
+                    f Int (p,q) dInt ($dfInPair dp dq) = f_spec p q dp dq
+
+    Spec bind: f_spec = wrap_fn (/\ab \d1 d2. Let binds in f_mono)
+
+Note that 
+
+  * The LHS of the rule may mention dictionary *expressions* (eg
+    $dfIxPair dp dq), and that is essential because the dp, dq are
+    needed on the RHS.
+
+  * The RHS of f_spec has a *copy* of 'binds', so that it can fully
+    specialise it.
 
 \begin{code}
 ------------------------
 dsSpecs :: [TyVar] -> [DictId] -> [TyVar]
 
 \begin{code}
 ------------------------
 dsSpecs :: [TyVar] -> [DictId] -> [TyVar]
-        -> Id -> Id -> Arity           -- Global, local, arity of local
-        -> CoreBind -> [LSpecPrag]
+        -> Id -> Id    -- Global, local
+        -> CoreBind -> TcSpecPrags
         -> DsM ( [(Id,CoreExpr)]       -- Binding for specialised Ids
               , [CoreRule] )           -- Rules for the Global Ids
         -> DsM ( [(Id,CoreExpr)]       -- Binding for specialised Ids
               , [CoreRule] )           -- Rules for the Global Ids
--- Example:
---     f :: (Eq a, Ix b) => a -> b -> b
---     {-# SPECIALISE f :: Ix b => Int -> b -> b #-}
---
---     AbsBinds [ab] [d1,d2] [([ab], f, f_mono, prags)] binds
--- 
---     SpecPrag (/\b.\(d:Ix b). f Int b dInt d) 
---              (forall b. Ix b => Int -> b -> b)
---
--- Rule:       forall b,(d:Ix b). f Int b dInt d = f_spec b d
---
--- Spec bind:  f_spec = Let f = /\ab \(d1:Eq a)(d2:Ix b). let binds in f_mono 
---                      /\b.\(d:Ix b). in f Int b dInt d
---             The idea is that f occurs just once, so it'll be 
---             inlined and specialised
---
--- Given SpecPrag (/\as.\ds. f es) t, we have
--- the defn            f_spec as ds = let-nonrec f = /\fas\fds. let f_mono = <f-rhs> in f_mono
---                                    in f es 
--- and the RULE                forall as, ds. f es = f_spec as ds
---
--- It is *possible* that 'es' does not mention all of the dictionaries 'ds'
--- (a bit silly, because then the 
-
-dsSpecs all_tvs dicts tvs poly_id mono_id inl_arity mono_bind prags
-  = do { pairs <- mapMaybeM spec_one prags
-       ; let (spec_binds_s, rules) = unzip pairs
-       ; return (concat spec_binds_s, rules) }
+-- See Note [Implementing SPECIALISE pragmas]
+dsSpecs all_tvs dicts tvs poly_id mono_id mono_bind prags
+  = case prags of
+      IsDefaultMethod      -> return ([], [])
+      SpecPrags sps -> do { pairs <- mapMaybeM spec_one sps
+                          ; let (spec_binds_s, rules) = unzip pairs
+                          ; return (concat spec_binds_s, rules) }
  where 
  where 
-    spec_one :: LSpecPrag -> DsM (Maybe ([(Id,CoreExpr)], CoreRule))
+    spec_one :: Located TcSpecPrag -> DsM (Maybe ([(Id,CoreExpr)], CoreRule))
     spec_one (L loc (SpecPrag spec_co spec_inl))
       = putSrcSpanDs loc $ 
         do { let poly_name = idName poly_id
     spec_one (L loc (SpecPrag spec_co spec_inl))
       = putSrcSpanDs loc $ 
         do { let poly_name = idName poly_id
@@ -452,7 +450,7 @@ dsSpecs all_tvs dicts tvs poly_id mono_id inl_arity mono_bind prags
                bs | not (null bs) -> do { warnDs (dead_msg bs); return Nothing } 
                   | otherwise -> do
 
                bs | not (null bs) -> do { warnDs (dead_msg bs); return Nothing } 
                   | otherwise -> do
 
-          { (spec_unf, unf_pairs) <- specUnfolding wrap_fn (idUnfolding poly_id)
+          { (spec_unf, unf_pairs) <- specUnfolding wrap_fn (realIdUnfolding poly_id)
 
           ; let f_body = fix_up (Let mono_bind (Var mono_id))
                  spec_ty = exprType ds_spec_expr
 
           ; let f_body = fix_up (Let mono_bind (Var mono_id))
                  spec_ty = exprType ds_spec_expr
@@ -464,11 +462,9 @@ dsSpecs all_tvs dicts tvs poly_id mono_id inl_arity mono_bind prags
                      -- Get the INLINE pragma from SPECIALISE declaration, or,
                       -- failing that, from the original Id
 
                      -- Get the INLINE pragma from SPECIALISE declaration, or,
                       -- failing that, from the original Id
 
-                spec_id_arity = inl_arity + count isDictId bndrs
-
                 extra_dict_bndrs = [ localiseId d  -- See Note [Constant rule dicts]
                 extra_dict_bndrs = [ localiseId d  -- See Note [Constant rule dicts]
-                                        | d <- varSetElems (exprFreeVars ds_spec_expr)
-                                        , isDictId d]
+                                   | d <- varSetElems (exprFreeVars ds_spec_expr)
+                                   , isDictId d]
                                -- Note [Const rule dicts]
 
                 rule =  mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr poly_name)))
                                -- Note [Const rule dicts]
 
                 rule =  mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr poly_name)))
@@ -477,7 +473,7 @@ dsSpecs all_tvs dicts tvs poly_id mono_id inl_arity mono_bind prags
                                (mkVarApps (Var spec_id) bndrs)
 
                  spec_rhs = wrap_fn (mkLams (tvs ++ dicts) f_body)
                                (mkVarApps (Var spec_id) bndrs)
 
                  spec_rhs = wrap_fn (mkLams (tvs ++ dicts) f_body)
-                 spec_pair = makeCorePair spec_id spec_id_arity spec_rhs
+                 spec_pair = makeCorePair spec_id False (dictArity bndrs) spec_rhs
 
            ; return (Just (spec_pair : unf_pairs, rule))
            } } } }
 
            ; return (Just (spec_pair : unf_pairs, rule))
            } } } }