Major refactoring of the type inference engine
[ghc-hetmet.git] / compiler / typecheck / TcBinds.lhs
index 0db76d1..c9f2a2d 100644 (file)
@@ -25,6 +25,7 @@ import TcHsType
 import TcPat
 import TcMType
 import TcType
+import RnBinds( misplacedSigErr )
 import Coercion
 import TysPrim
 import Id
@@ -32,7 +33,6 @@ import Var
 import Name
 import NameSet
 import NameEnv
-import VarSet
 import SrcLoc
 import Bag
 import ErrUtils
@@ -43,7 +43,10 @@ import BasicTypes
 import Outputable
 import FastString
 
+import Data.List( partition )
 import Control.Monad
+
+#include "HsVersions.h"
 \end{code}
 
 
@@ -79,13 +82,19 @@ At the top-level the LIE is sure to contain nothing but constant
 dictionaries, which we resolve at the module level.
 
 \begin{code}
-tcTopBinds :: HsValBinds Name -> TcM (LHsBinds TcId, TcLclEnv)
+tcTopBinds :: HsValBinds Name 
+           -> TcM ( LHsBinds TcId      -- Typechecked bindings
+                  , [LTcSpecPrag]      -- SPECIALISE prags for imported Ids
+                  , TcLclEnv)          -- Augmented environment
+
         -- Note: returning the TcLclEnv is more than we really
         --       want.  The bit we care about is the local bindings
         --       and the free type variables thereof
 tcTopBinds binds
-  = do  { (ValBindsOut prs _, env) <- tcValBinds TopLevel binds getLclEnv
-        ; return (foldr (unionBags . snd) emptyBag prs, env) }
+  = do  { (ValBindsOut prs sigs, env) <- tcValBinds TopLevel binds getLclEnv
+        ; let binds = foldr (unionBags . snd) emptyBag prs
+        ; specs <- tcImpPrags sigs
+        ; return (binds, specs, env) }
         -- The top level bindings are flattened into a giant 
         -- implicitly-mutually-recursive LHsBinds
 
@@ -120,14 +129,12 @@ tcLocalBinds (HsValBinds binds) thing_inside
 
 tcLocalBinds (HsIPBinds (IPBinds ip_binds _)) thing_inside
   = do  { (given_ips, ip_binds') <- mapAndUnzipM (wrapLocSndM tc_ip_bind) ip_binds
-        ; let ip_tvs = foldr (unionVarSet . tyVarsOfType . idType) emptyVarSet given_ips
 
         -- If the binding binds ?x = E, we  must now 
         -- discharge any ?x constraints in expr_lie
+        -- See Note [Implicit parameter untouchables]
         ; (ev_binds, result) <- checkConstraints (IPSkol ips) 
-                                  ip_tvs  -- See Note [Implicit parameter untouchables]
-                                  [] given_ips $
-                                thing_inside
+                                  [] given_ips thing_inside
 
         ; return (HsIPBinds (IPBinds ip_binds' ev_binds), result) }
   where
@@ -154,6 +161,9 @@ doesn't float that solved constraint out (it's not an unsolved
 wanted.  Result disaster: the (Num alpha) is again solved, this
 time by defaulting.  No no no.
 
+However [Oct 10] this is all handled automatically by the 
+untouchable-range idea.
+
 \begin{code}
 tcValBinds :: TopLevelFlag 
            -> HsValBinds Name -> TcM thing
@@ -261,7 +271,7 @@ bindLocalInsts top_lvl thing_inside
         -- leave them to the tcSimplifyTop, and quite a bit faster too
 
   | otherwise   -- Nested case
-  = do  { ((binds, ids, thing), lie) <- getConstraints thing_inside
+  = do  { ((binds, ids, thing), lie) <- captureConstraints thing_inside
         ; lie_binds <- bindLocalMethods lie ids
         ; return (binds, lie_binds, thing) }
 -}
@@ -360,7 +370,7 @@ tcPolyNoGen tc_sig_fn prag_fn rec_tc bind_list
       = do { mono_ty' <- zonkTcTypeCarefully (idType mono_id)
             -- Zonk, mainly to expose unboxed types to checkStrictBinds
            ; let mono_id' = setIdType mono_id mono_ty'
-           ; _specs <- tcSpecPrags False mono_id' (prag_fn name)
+           ; _specs <- tcSpecPrags mono_id' (prag_fn name)
            ; return mono_id' }
           -- NB: tcPrags generates error messages for
           --     specialisation pragmas for non-overloaded sigs
@@ -377,13 +387,12 @@ tcPolyCheck :: TcSigInfo -> PragFun
 --   it binds a single variable,
 --   it has a signature,
 tcPolyCheck sig@(TcSigInfo { sig_id = id, sig_tvs = tvs, sig_scoped = scoped
-                           , sig_theta = theta, sig_loc = loc })
+                           , sig_theta = theta, sig_tau = tau, sig_loc = loc })
     prag_fn rec_tc bind_list
   = do { ev_vars <- newEvVars theta
-
-       ; let skol_info = SigSkol (FunSigCtxt (idName id))
+       ; let skol_info = SigSkol (FunSigCtxt (idName id)) (mkPhiTy theta tau)
        ; (ev_binds, (binds', [mono_info])) 
-            <- checkConstraints skol_info emptyVarSet tvs ev_vars $
+            <- checkConstraints skol_info tvs ev_vars $
                tcExtendTyVarEnv2 (scoped `zip` mkTyVarTys tvs)    $
                tcMonoBinds (\_ -> Just sig) LetLclBndr rec_tc bind_list
 
@@ -407,17 +416,13 @@ tcPolyInfer
   -> TcM (LHsBinds TcId, [TcId])
 tcPolyInfer top_lvl mono sig_fn prag_fn rec_tc bind_list
   = do { ((binds', mono_infos), wanted) 
-             <- getConstraints $
+             <- captureConstraints $
                 tcMonoBinds sig_fn LetLclBndr rec_tc bind_list
 
        ; unifyCtxts [sig | (_, Just sig, _) <- mono_infos] 
 
-       ; let get_tvs | isTopLevel top_lvl = tyVarsOfType  
-                     | otherwise          = exactTyVarsOfType
-                    -- See Note [Silly type synonym] in TcType
-             tau_tvs = foldr (unionVarSet . get_tvs . getMonoType) emptyVarSet mono_infos
-
-       ; (qtvs, givens, ev_binds) <- simplifyInfer mono tau_tvs wanted
+       ; let name_taus = [(name, idType mono_id) | (name, _, mono_id) <- mono_infos]
+       ; (qtvs, givens, ev_binds) <- simplifyInfer top_lvl mono name_taus wanted
 
        ; exports <- mapM (mkExport prag_fn qtvs (map evVarPred givens))
                     mono_infos
@@ -456,7 +461,7 @@ mkExport prag_fn inferred_tvs theta
 
         ; poly_id' <- addInlinePrags poly_id prag_sigs
 
-        ; spec_prags <- tcSpecPrags (notNull theta) poly_id prag_sigs
+        ; spec_prags <- tcSpecPrags poly_id prag_sigs
                 -- tcPrags requires a zonked poly_id
 
         ; return (tvs, poly_id', mono_id, SpecPrags spec_prags) }
@@ -485,7 +490,9 @@ mkPragFun sigs binds = \n -> lookupNameEnv prag_env n `orElse` []
     get_sig _                         = Nothing
 
     add_arity (L _ n) inl_prag   -- Adjust inl_sat field to match visible arity of function
-      | Just ar <- lookupNameEnv ar_env n = inl_prag { inl_sat = Just ar }
+      | Just ar <- lookupNameEnv ar_env n,
+        Inline <- inl_inline inl_prag     = inl_prag { inl_sat = Just ar }
+        -- add arity only for real INLINE pragmas, not INLINABLE
       | otherwise                         = inl_prag
 
     prag_env :: NameEnv [LSig Name]
@@ -502,43 +509,75 @@ lhsBindArity (L _ (FunBind { fun_id = id, fun_matches = ms })) env
 lhsBindArity _ env = env       -- PatBind/VarBind
 
 ------------------
-tcSpecPrags :: Bool     -- True <=> function is overloaded
-            -> Id -> [LSig Name]
-            -> TcM [Located TcSpecPrag]
+tcSpecPrags :: Id -> [LSig Name]
+            -> TcM [LTcSpecPrag]
 -- Add INLINE and SPECIALSE pragmas
 --    INLINE prags are added to the (polymorphic) Id directly
 --    SPECIALISE prags are passed to the desugarer via TcSpecPrags
 -- Pre-condition: the poly_id is zonked
 -- Reason: required by tcSubExp
-tcSpecPrags is_overloaded_id poly_id prag_sigs
-  = do { unless (null spec_sigs || is_overloaded_id) warn_discarded_spec
-       ; unless (null bad_sigs) warn_discarded_sigs
-       ; mapM (wrapLocM tc_spec) spec_sigs }
+tcSpecPrags poly_id prag_sigs
+  = do { unless (null bad_sigs) warn_discarded_sigs
+       ; mapAndRecoverM (wrapLocM (tcSpec poly_id)) spec_sigs }
   where
     spec_sigs = filter isSpecLSig prag_sigs
     bad_sigs  = filter is_bad_sig prag_sigs
     is_bad_sig s = not (isSpecLSig s || isInlineLSig s)
 
-    name      = idName poly_id
-    poly_ty   = idType poly_id
-    sig_ctxt  = FunSigCtxt name
-    origin    = SpecPragOrigin name
-    skol_info = SigSkol sig_ctxt
-
-    tc_spec prag@(SpecSig _ hs_ty inl) 
-      = addErrCtxt (spec_ctxt prag) $
-        do  { spec_ty <- tcHsSigType sig_ctxt hs_ty
-            ; wrap <- tcSubType origin skol_info poly_ty spec_ty
-            ; return (SpecPrag wrap inl) }
-    tc_spec sig = pprPanic "tcSpecPrag" (ppr sig)
-
-    warn_discarded_spec = warnPrags poly_id spec_sigs $
-                          ptext (sLit "SPECIALISE pragmas for non-overloaded function")
     warn_discarded_sigs = warnPrags poly_id bad_sigs $
                           ptext (sLit "Discarding unexpected pragmas for")
 
+
+--------------
+tcSpec :: TcId -> Sig Name -> TcM TcSpecPrag
+tcSpec poly_id prag@(SpecSig _ hs_ty inl) 
+  -- The Name in the SpecSig may not be the same as that of the poly_id
+  -- Example: SPECIALISE for a class method: the Name in the SpecSig is
+  --          for the selector Id, but the poly_id is something like $cop
+  = addErrCtxt (spec_ctxt prag) $
+    do  { spec_ty <- tcHsSigType sig_ctxt hs_ty
+        ; warnIf (not (isOverloadedTy poly_ty || isInlinePragma inl))
+                 (ptext (sLit "SPECIALISE pragma for non-overloaded function") <+> quotes (ppr poly_id))
+                 -- Note [SPECIALISE pragmas]
+        ; wrap <- tcSubType origin sig_ctxt (idType poly_id) spec_ty
+        ; return (SpecPrag poly_id wrap inl) }
+  where
+    name      = idName poly_id
+    poly_ty   = idType poly_id
+    origin    = SpecPragOrigin name
+    sig_ctxt  = FunSigCtxt name
     spec_ctxt prag = hang (ptext (sLit "In the SPECIALISE pragma")) 2 (ppr prag)
 
+tcSpec _ prag = pprPanic "tcSpec" (ppr prag)
+
+--------------
+tcImpPrags :: [LSig Name] -> TcM [LTcSpecPrag]
+tcImpPrags prags
+  = do { this_mod <- getModule
+       ; let is_imp prag 
+               = case sigName prag of
+                   Nothing   -> False
+                   Just name -> not (nameIsLocalOrFrom this_mod name)
+             (spec_prags, others) = partition isSpecLSig $
+                                   filter is_imp prags
+       ; mapM_ misplacedSigErr others 
+       -- Messy that this misplaced-sig error comes here
+       -- but the others come from the renamer
+       ; mapAndRecoverM (wrapLocM tcImpSpec) spec_prags }
+
+tcImpSpec :: Sig Name -> TcM TcSpecPrag
+tcImpSpec prag@(SpecSig (L _ name) _ _)
+ = do { id <- tcLookupId name
+      ; checkTc (isAnyInlinePragma (idInlinePragma id))
+                (impSpecErr name)
+      ; tcSpec id prag }
+tcImpSpec p = pprPanic "tcImpSpec" (ppr p)
+
+impSpecErr :: Name -> SDoc
+impSpecErr name
+  = hang (ptext (sLit "You cannot SPECIALISE") <+> quotes (ppr name))
+       2 (vcat [ ptext (sLit "because its definition has no INLINE/INLINABLE pragma")
+               , ptext (sLit "(or you compiled its definining module without -O)")])
 --------------
 -- If typechecking the binds fails, then return with each
 -- signature-less binder given type (forall a.a), to minimise 
@@ -557,6 +596,26 @@ forall_a_a :: TcType
 forall_a_a = mkForAllTy openAlphaTyVar (mkTyVarTy openAlphaTyVar)
 \end{code}
 
+Note [SPECIALISE pragmas]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+There is no point in a SPECIALISE pragma for a non-overloaded function:
+   reverse :: [a] -> [a]
+   {-# SPECIALISE reverse :: [Int] -> [Int] #-}
+
+But SPECIALISE INLINE *can* make sense for GADTS:
+   data Arr e where
+     ArrInt :: !Int -> ByteArray# -> Arr Int
+     ArrPair :: !Int -> Arr e1 -> Arr e2 -> Arr (e1, e2)
+
+   (!:) :: Arr e -> Int -> e
+   {-# SPECIALISE INLINE (!:) :: Arr Int -> Int -> Int #-}  
+   {-# SPECIALISE INLINE (!:) :: Arr (a, b) -> Int -> (a, b) #-}
+   (ArrInt _ ba)     !: (I# i) = I# (indexIntArray# ba i)
+   (ArrPair _ a1 a2) !: i      = (a1 !: i, a2 !: i)
+
+When (!:) is specialised it becomes non-recursive, and can usefully
+be inlined.  Scary!  So we only warn for SPECIALISE *without* INLINE
+for a non-overloaded function.
 
 %************************************************************************
 %*                                                                      *
@@ -634,9 +693,6 @@ type MonoBindInfo = (Name, Maybe TcSigInfo, TcId)
         -- Type signature (if any), and
         -- the monomorphic bound things
 
-getMonoType :: MonoBindInfo -> TcTauType
-getMonoType (_,_,mono_id) = idType mono_id
-
 tcLhs :: TcSigFun -> LetBndrSpec -> HsBind Name -> TcM TcMonoBind
 tcLhs sig_fn no_gen (FunBind { fun_id = L nm_loc name, fun_infix = inf, fun_matches = matches })
   | Just sig <- sig_fn name
@@ -983,7 +1039,10 @@ tcInstSig sig_fn use_skols name
   | Just (scoped_tvs, loc) <- sig_fn name
   = do  { poly_id <- tcLookupId name    -- Cannot fail; the poly ids are put into 
                                         -- scope when starting the binding group
-        ; (tvs, theta, tau) <- tcInstSigType use_skols name (idType poly_id)
+        ; let poly_ty = idType poly_id
+        ; (tvs, theta, tau) <- if use_skols
+                               then tcInstType tcInstSkolTyVars poly_ty
+                               else tcInstType tcInstSigTyVars  poly_ty
         ; let sig = TcSigInfo { sig_id = poly_id
                              , sig_scoped = scoped_tvs
                               , sig_tvs = tvs, sig_theta = theta, sig_tau = tau
@@ -1010,6 +1069,7 @@ instance Outputable GeneralisationPlan where
 decideGeneralisationPlan 
    :: DynFlags -> TopLevelFlag -> [Name] -> [LHsBind Name] -> TcSigFun -> GeneralisationPlan
 decideGeneralisationPlan dflags top_lvl _bndrs binds sig_fn
+  | bang_pat_binds                         = NoGen
   | mono_pat_binds                         = NoGen
   | Just sig <- one_funbind_with_sig binds = if null (sig_tvs sig) && null (sig_theta sig)
                                              then NoGen              -- Optimise common case
@@ -1019,7 +1079,12 @@ decideGeneralisationPlan dflags top_lvl _bndrs binds sig_fn
   | otherwise                              = InferGen mono_restriction
 
   where
-    mono_pat_binds = xopt Opt_MonoPatBinds dflags 
+    bang_pat_binds = any (isBangHsBind . unLoc) binds
+       -- Bang patterns must not be polymorphic,
+       -- because we are going to force them
+       -- See Trac #4498
+
+    mono_pat_binds = xopt Opt_MonoPatBinds dflags
                   && any (is_pat_bind . unLoc) binds
 
     mono_restriction = xopt Opt_MonomorphismRestriction dflags 
@@ -1063,24 +1128,30 @@ checkStrictBinds top_lvl rec_group binds poly_ids
         ; checkTc (isNonRec rec_group)
                   (strictBindErr "Recursive" unlifted binds)
         ; checkTc (isSingleton binds)
-                  (strictBindErr "Multiple" unlifted binds) 
+                  (strictBindErr "Multiple" unlifted binds)
         -- This should be a checkTc, not a warnTc, but as of GHC 6.11
         -- the versions of alex and happy available have non-conforming
         -- templates, so the GHC build fails if it's an error:
         ; warnUnlifted <- doptM Opt_WarnLazyUnliftedBindings
-        ; warnTc (warnUnlifted && not bang_pat)
+        ; warnTc (warnUnlifted && not bang_pat && lifted_pat)
+                 -- No outer bang, but it's a compound pattern
+                 -- E.g   (I# x#) = blah
+                 -- Warn about this, but not about
+                 --      x# = 4# +# 1#
+                 --      (# a, b #) = ...
                  (unliftedMustBeBang binds) }
   | otherwise
   = return ()
   where
-    unlifted = any is_unlifted poly_ids
-    bang_pat = any (isBangHsBind . unLoc) binds
+    unlifted    = any is_unlifted poly_ids
+    bang_pat    = any (isBangHsBind . unLoc) binds
+    lifted_pat  = any (isLiftedPatBind . unLoc) binds
     is_unlifted id = case tcSplitForAllTys (idType id) of
                        (_, rho) -> isUnLiftedType rho
 
 unliftedMustBeBang :: [LHsBind Name] -> SDoc
 unliftedMustBeBang binds
-  = hang (text "Bindings containing unlifted types should use an outermost bang pattern:")
+  = hang (text "Pattern bindings containing unlifted types should use an outermost bang pattern:")
        2 (pprBindList binds)
 
 strictBindErr :: String -> Bool -> [LHsBind Name] -> SDoc