Implement auto-specialisation of imported Ids
[ghc-hetmet.git] / compiler / deSugar / DsBinds.lhs
index b5b58fe..7e922fd 100644 (file)
@@ -10,7 +10,7 @@ in that the @Rec@/@NonRec@/etc structure is thrown away (whereas at
 lower levels it is preserved with @let@/@letrec@s).
 
 \begin{code}
-module DsBinds ( dsTopLHsBinds, dsLHsBinds, decomposeRuleLhs, 
+module DsBinds ( dsTopLHsBinds, dsLHsBinds, decomposeRuleLhs, dsSpec,
                 dsHsWrapper, dsTcEvBinds, dsEvBinds, wrapDsEvBinds, 
                 DsEvBind(..), AutoScc(..)
   ) where
@@ -69,9 +69,8 @@ import MonadUtils
 %************************************************************************
 
 \begin{code}
-dsTopLHsBinds :: AutoScc -> LHsBinds Id -> DsM [(Id,CoreExpr)]
-dsTopLHsBinds auto_scc binds = do { binds' <- ds_lhs_binds auto_scc binds
-                                  ; return (fromOL binds') }
+dsTopLHsBinds :: AutoScc -> LHsBinds Id -> DsM (OrdList (Id,CoreExpr))
+dsTopLHsBinds auto_scc binds = ds_lhs_binds auto_scc binds
 
 dsLHsBinds :: LHsBinds Id -> DsM [(Id,CoreExpr)]
 dsLHsBinds binds = do { binds' <- ds_lhs_binds NoSccs binds
@@ -135,7 +134,7 @@ dsHsBind auto_scc (AbsBinds { abs_tvs = all_tyvars, abs_ev_vars = dicts
                             Let core_bind $
                             Var local
     
-       ; (spec_binds, rules) <- dsSpecs global rhs prags
+       ; (spec_binds, rules) <- dsSpecs rhs prags
 
        ; let   global'   = addIdSpecialisations global rules
                main_bind = makeCorePair global' (isDefaultMethod prags)
@@ -178,9 +177,9 @@ dsHsBind auto_scc (AbsBinds { abs_tvs = all_tyvars, abs_ev_vars = dicts
                                 mkTupleSelector locals' (locals' !! n) tup_id $
                                 mkVarApps (mkTyApps (Var poly_tup_id) ty_args)
                                           dicts
-                    ; (spec_binds, rules) <- dsSpecs global
-                                                     (Let (NonRec poly_tup_id poly_tup_rhs) rhs)
-                                                     spec_prags
+                           full_rhs = Let (NonRec poly_tup_id poly_tup_rhs) rhs
+                    ; (spec_binds, rules) <- dsSpecs full_rhs spec_prags
+                                                     
                     ; let global' = addIdSpecialisations global rules
                     ; return ((global', rhs) `consOL` spec_binds) }
                where
@@ -475,66 +474,69 @@ Note that
 
 \begin{code}
 ------------------------
-dsSpecs :: Id          -- The polymorphic Id
-        -> CoreExpr     -- Its rhs
+dsSpecs :: CoreExpr     -- Its rhs
         -> TcSpecPrags
         -> DsM ( OrdList (Id,CoreExpr)         -- Binding for specialised Ids
               , [CoreRule] )           -- Rules for the Global Ids
 -- See Note [Implementing SPECIALISE pragmas]
-dsSpecs poly_id poly_rhs prags
-  = case prags of
-      IsDefaultMethod      -> return (nilOL, [])
-      SpecPrags sps -> do { pairs <- mapMaybeM spec_one sps
-                          ; let (spec_binds_s, rules) = unzip pairs
-                          ; return (concatOL spec_binds_s, rules) }
- where 
-    spec_one :: Located TcSpecPrag -> DsM (Maybe (OrdList (Id,CoreExpr), CoreRule))
-    spec_one (L loc (SpecPrag spec_co spec_inl))
-      = putSrcSpanDs loc $ 
-        do { let poly_name = idName poly_id
-          ; spec_name <- newLocalName poly_name
-          ; wrap_fn   <- dsHsWrapper spec_co
-           ; let (bndrs, ds_lhs) = collectBinders (wrap_fn (Var poly_id))
-                 spec_ty = mkPiTypes bndrs (exprType ds_lhs)
-          ; case decomposeRuleLhs ds_lhs of {
-              Nothing -> do { warnDs (decomp_msg spec_co)
-                             ; return Nothing } ;
-
-              Just (_fn, args) ->
-
-          -- Check for dead binders: Note [Unused spec binders]
-             let arg_fvs = exprsFreeVars args
-                 bad_bndrs = filterOut (`elemVarSet` arg_fvs) bndrs
-            in if not (null bad_bndrs)
-                then do { warnDs (dead_msg bad_bndrs); return Nothing } 
-               else do
-
-          { (spec_unf, unf_pairs) <- specUnfolding wrap_fn spec_ty (realIdUnfolding poly_id)
-
-          ; let spec_id  = mkLocalId spec_name spec_ty 
-                           `setInlinePragma` inl_prag
-                           `setIdUnfolding`  spec_unf
-                inl_prag | isDefaultInlinePragma spec_inl = idInlinePragma poly_id
-                         | otherwise                      = spec_inl
-                     -- Get the INLINE pragma from SPECIALISE declaration, or,
-                      -- failing that, from the original Id
-
-                extra_dict_bndrs = [ mkLocalId (localiseName (idName d)) (idType d)
-                                            -- See Note [Constant rule dicts]
-                                   | d <- varSetElems (arg_fvs `delVarSetList` bndrs)
-                                   , isDictId d]
-
-                rule =  mkLocalRule (mkFastString ("SPEC " ++ showSDoc (ppr poly_name)))
-                               AlwaysActive poly_name
-                               (extra_dict_bndrs ++ bndrs) args
-                               (mkVarApps (Var spec_id) bndrs)
-
-                 spec_rhs  = wrap_fn poly_rhs
-                 spec_pair = makeCorePair spec_id False (dictArity bndrs) spec_rhs
-
-           ; return (Just (spec_pair `consOL` unf_pairs, rule))
-           } } }
-
+dsSpecs _ IsDefaultMethod = return (nilOL, [])
+dsSpecs poly_rhs (SpecPrags sps)
+  = do { pairs <- mapMaybeM (dsSpec (Just poly_rhs)) sps
+       ; let (spec_binds_s, rules) = unzip pairs
+       ; return (concatOL spec_binds_s, rules) }
+
+dsSpec :: Maybe CoreExpr       -- Just rhs => RULE is for a local binding
+                                       -- Nothing => RULE is for an imported Id
+                               --            rhs is in the Id's unfolding
+       -> Located TcSpecPrag
+       -> DsM (Maybe (OrdList (Id,CoreExpr), CoreRule))
+dsSpec mb_poly_rhs (L loc (SpecPrag poly_id spec_co spec_inl))
+  = putSrcSpanDs loc $ 
+    do { let poly_name = idName poly_id
+       ; spec_name <- newLocalName poly_name
+       ; wrap_fn   <- dsHsWrapper spec_co
+       ; let (bndrs, ds_lhs) = collectBinders (wrap_fn (Var poly_id))
+             spec_ty = mkPiTypes bndrs (exprType ds_lhs)
+       ; case decomposeRuleLhs ds_lhs of {
+          Nothing -> do { warnDs (decomp_msg spec_co)
+                        ; return Nothing } ;
+
+          Just (_fn, args) ->
+
+         -- Check for dead binders: Note [Unused spec binders]
+         let arg_fvs = exprsFreeVars args
+             bad_bndrs = filterOut (`elemVarSet` arg_fvs) bndrs
+         in if not (null bad_bndrs)
+            then do { warnDs (dead_msg bad_bndrs); return Nothing } 
+                   else do
+
+       { (spec_unf, unf_pairs) <- specUnfolding wrap_fn spec_ty (realIdUnfolding poly_id)
+
+       ; let spec_id  = mkLocalId spec_name spec_ty 
+                           `setInlinePragma` inl_prag
+                           `setIdUnfolding`  spec_unf
+             inl_prag | isDefaultInlinePragma spec_inl = idInlinePragma poly_id
+                      | otherwise                      = spec_inl
+                     -- Get the INLINE pragma from SPECIALISE declaration, or,
+              -- failing that, from the original Id
+
+             extra_dict_bndrs = [ mkLocalId (localiseName (idName d)) (idType d)
+                                       -- See Note [Constant rule dicts]
+                               | d <- varSetElems (arg_fvs `delVarSetList` bndrs)
+                               , isDictId d]
+
+             rule =  mkRule False {- Not auto -} is_local_id
+                        (mkFastString ("SPEC " ++ showSDoc (ppr poly_name)))
+                               AlwaysActive poly_name
+                               (extra_dict_bndrs ++ bndrs) args
+                               (mkVarApps (Var spec_id) bndrs)
+
+             spec_rhs  = wrap_fn poly_rhs
+             spec_pair = makeCorePair spec_id False (dictArity bndrs) spec_rhs
+
+       ; return (Just (spec_pair `consOL` unf_pairs, rule))
+       } } }
+  where
     dead_msg bs = vcat [ sep [ptext (sLit "Useless constraint") <> plural bs
                                 <+> ptext (sLit "in specialied type:"),
                             nest 2 (pprTheta (map get_pred bs))]
@@ -545,6 +547,15 @@ dsSpecs poly_id poly_rhs prags
         = hang (ptext (sLit "Specialisation too complicated to desugar; ignored"))
             2 (pprHsWrapper (ppr poly_id) spec_co)
             
+    is_local_id = isJust mb_poly_rhs
+    poly_rhs | Just rhs <-  mb_poly_rhs
+             = rhs
+             | Just unfolding <- maybeUnfoldingTemplate (idUnfolding poly_id)
+             = unfolding
+             | otherwise = pprPanic "dsImpSpecs" (ppr poly_id)
+       -- In the Nothing case the specialisation is for an imported Id
+       -- whose unfolding gives the RHS to be specialised
+        -- The type checker has checked that it has an unfolding
 
 specUnfolding :: (CoreExpr -> CoreExpr) -> Type 
               -> Unfolding -> DsM (Unfolding, OrdList (Id,CoreExpr))