Implement auto-specialisation of imported Ids
[ghc-hetmet.git] / compiler / typecheck / TcBinds.lhs
index 0db76d1..c918c9d 100644 (file)
@@ -25,6 +25,7 @@ import TcHsType
 import TcPat
 import TcMType
 import TcType
+import RnBinds( misplacedSigErr )
 import Coercion
 import TysPrim
 import Id
@@ -43,7 +44,10 @@ import BasicTypes
 import Outputable
 import FastString
 
+import Data.List( partition )
 import Control.Monad
+
+#include "HsVersions.h"
 \end{code}
 
 
@@ -79,13 +83,19 @@ At the top-level the LIE is sure to contain nothing but constant
 dictionaries, which we resolve at the module level.
 
 \begin{code}
-tcTopBinds :: HsValBinds Name -> TcM (LHsBinds TcId, TcLclEnv)
+tcTopBinds :: HsValBinds Name 
+           -> TcM ( LHsBinds TcId      -- Typechecked bindings
+                  , [LTcSpecPrag]      -- SPECIALISE prags for imported Ids
+                  , TcLclEnv)          -- Augmented environment
+
         -- Note: returning the TcLclEnv is more than we really
         --       want.  The bit we care about is the local bindings
         --       and the free type variables thereof
 tcTopBinds binds
-  = do  { (ValBindsOut prs _, env) <- tcValBinds TopLevel binds getLclEnv
-        ; return (foldr (unionBags . snd) emptyBag prs, env) }
+  = do  { (ValBindsOut prs sigs, env) <- tcValBinds TopLevel binds getLclEnv
+        ; let binds = foldr (unionBags . snd) emptyBag prs
+        ; specs <- tcImpPrags sigs
+        ; return (binds, specs, env) }
         -- The top level bindings are flattened into a giant 
         -- implicitly-mutually-recursive LHsBinds
 
@@ -360,7 +370,7 @@ tcPolyNoGen tc_sig_fn prag_fn rec_tc bind_list
       = do { mono_ty' <- zonkTcTypeCarefully (idType mono_id)
             -- Zonk, mainly to expose unboxed types to checkStrictBinds
            ; let mono_id' = setIdType mono_id mono_ty'
-           ; _specs <- tcSpecPrags False mono_id' (prag_fn name)
+           ; _specs <- tcSpecPrags mono_id' (prag_fn name)
            ; return mono_id' }
           -- NB: tcPrags generates error messages for
           --     specialisation pragmas for non-overloaded sigs
@@ -456,7 +466,7 @@ mkExport prag_fn inferred_tvs theta
 
         ; poly_id' <- addInlinePrags poly_id prag_sigs
 
-        ; spec_prags <- tcSpecPrags (notNull theta) poly_id prag_sigs
+        ; spec_prags <- tcSpecPrags poly_id prag_sigs
                 -- tcPrags requires a zonked poly_id
 
         ; return (tvs, poly_id', mono_id, SpecPrags spec_prags) }
@@ -502,42 +512,74 @@ lhsBindArity (L _ (FunBind { fun_id = id, fun_matches = ms })) env
 lhsBindArity _ env = env       -- PatBind/VarBind
 
 ------------------
-tcSpecPrags :: Bool     -- True <=> function is overloaded
-            -> Id -> [LSig Name]
-            -> TcM [Located TcSpecPrag]
+tcSpecPrags :: Id -> [LSig Name]
+            -> TcM [LTcSpecPrag]
 -- Add INLINE and SPECIALSE pragmas
 --    INLINE prags are added to the (polymorphic) Id directly
 --    SPECIALISE prags are passed to the desugarer via TcSpecPrags
 -- Pre-condition: the poly_id is zonked
 -- Reason: required by tcSubExp
-tcSpecPrags is_overloaded_id poly_id prag_sigs
-  = do { unless (null spec_sigs || is_overloaded_id) warn_discarded_spec
-       ; unless (null bad_sigs) warn_discarded_sigs
-       ; mapM (wrapLocM tc_spec) spec_sigs }
+tcSpecPrags poly_id prag_sigs
+  = do { unless (null bad_sigs) warn_discarded_sigs
+       ; mapAndRecoverM (wrapLocM (tcSpec poly_id)) spec_sigs }
   where
     spec_sigs = filter isSpecLSig prag_sigs
     bad_sigs  = filter is_bad_sig prag_sigs
     is_bad_sig s = not (isSpecLSig s || isInlineLSig s)
 
+    warn_discarded_sigs = warnPrags poly_id bad_sigs $
+                          ptext (sLit "Discarding unexpected pragmas for")
+
+
+--------------
+tcSpec :: TcId -> Sig Name -> TcM TcSpecPrag
+tcSpec poly_id prag@(SpecSig _ hs_ty inl) 
+  -- The Name in the SpecSig may not be the same as that of the poly_id
+  -- Example: SPECIALISE for a class method: the Name in the SpecSig is
+  --          for the selector Id, but the poly_id is something like $cop
+  = addErrCtxt (spec_ctxt prag) $
+    do  { spec_ty <- tcHsSigType sig_ctxt hs_ty
+        ; checkTc (isOverloadedTy poly_ty)
+                  (ptext (sLit "Discarding pragma for non-overloaded function") <+> quotes (ppr poly_id))
+        ; wrap <- tcSubType origin skol_info (idType poly_id) spec_ty
+        ; return (SpecPrag poly_id wrap inl) }
+  where
     name      = idName poly_id
     poly_ty   = idType poly_id
-    sig_ctxt  = FunSigCtxt name
     origin    = SpecPragOrigin name
+    sig_ctxt  = FunSigCtxt name
     skol_info = SigSkol sig_ctxt
+    spec_ctxt prag = hang (ptext (sLit "In the SPECIALISE pragma")) 2 (ppr prag)
 
-    tc_spec prag@(SpecSig _ hs_ty inl) 
-      = addErrCtxt (spec_ctxt prag) $
-        do  { spec_ty <- tcHsSigType sig_ctxt hs_ty
-            ; wrap <- tcSubType origin skol_info poly_ty spec_ty
-            ; return (SpecPrag wrap inl) }
-    tc_spec sig = pprPanic "tcSpecPrag" (ppr sig)
-
-    warn_discarded_spec = warnPrags poly_id spec_sigs $
-                          ptext (sLit "SPECIALISE pragmas for non-overloaded function")
-    warn_discarded_sigs = warnPrags poly_id bad_sigs $
-                          ptext (sLit "Discarding unexpected pragmas for")
+tcSpec _ prag = pprPanic "tcSpec" (ppr prag)
 
-    spec_ctxt prag = hang (ptext (sLit "In the SPECIALISE pragma")) 2 (ppr prag)
+--------------
+tcImpPrags :: [LSig Name] -> TcM [LTcSpecPrag]
+tcImpPrags prags
+  = do { this_mod <- getModule
+       ; let is_imp prag 
+               = case sigName prag of
+                   Nothing   -> False
+                   Just name -> not (nameIsLocalOrFrom this_mod name)
+             (spec_prags, others) = partition isSpecLSig $
+                                   filter is_imp prags
+       ; mapM_ misplacedSigErr others 
+       -- Messy that this misplaced-sig error comes here
+       -- but the others come from the renamer
+       ; mapAndRecoverM (wrapLocM tcImpSpec) spec_prags }
+
+tcImpSpec :: Sig Name -> TcM TcSpecPrag
+tcImpSpec prag@(SpecSig (L _ name) _ _)
+ = do { id <- tcLookupId name
+      ; checkTc (isInlinePragma (idInlinePragma id))
+                (impSpecErr name)
+      ; tcSpec id prag }
+tcImpSpec p = pprPanic "tcImpSpec" (ppr p)
+
+impSpecErr :: Name -> SDoc
+impSpecErr name
+  = hang (ptext (sLit "You cannot SPECIALISE") <+> quotes (ppr name))
+       2 (ptext (sLit "because its definition has no INLINE/INLINABLE pragma"))
 
 --------------
 -- If typechecking the binds fails, then return with each