Merge branch 'master' of http://darcs.haskell.org/ghc
[ghc-hetmet.git] / compiler / typecheck / TcBinds.lhs
index 638f692..b5bbeb1 100644 (file)
@@ -7,7 +7,7 @@
 \begin{code}
 module TcBinds ( tcLocalBinds, tcTopBinds, 
                  tcHsBootSigs, tcPolyBinds,
-                 PragFun, tcSpecPrags, mkPragFun, 
+                 PragFun, tcSpecPrags, tcVectDecls, mkPragFun, 
                  TcSigInfo(..), SigFun, mkSigFun,
                  badBootDeclErr ) where
 
@@ -25,7 +25,6 @@ import TcHsType
 import TcPat
 import TcMType
 import TcType
-import RnBinds( misplacedSigErr )
 import Coercion
 import TysPrim
 import Id
@@ -33,9 +32,9 @@ import Var
 import Name
 import NameSet
 import NameEnv
-import VarSet
 import SrcLoc
 import Bag
+import ListSetOps
 import ErrUtils
 import Digraph
 import Maybes
@@ -44,7 +43,6 @@ import BasicTypes
 import Outputable
 import FastString
 
-import Data.List( partition )
 import Control.Monad
 
 #include "HsVersions.h"
@@ -325,11 +323,13 @@ tcPolyBinds :: TopLevelFlag -> SigFun -> PragFun
 tcPolyBinds top_lvl sig_fn prag_fn rec_group rec_tc bind_list
   = setSrcSpan loc                              $
     recoverM (recoveryCode binder_names sig_fn) $ do 
-        -- Set up main recoer; take advantage of any type sigs
+        -- Set up main recover; take advantage of any type sigs
 
     { traceTc "------------------------------------------------" empty
     ; traceTc "Bindings for" (ppr binder_names)
 
+    -- Instantiate the polytypes of any binders that have signatures
+    -- (as determined by sig_fn), returning a TcSigInfo for each
     ; tc_sig_fn <- tcInstSigs sig_fn binder_names
 
     ; dflags <- getDOpts
@@ -348,9 +348,10 @@ tcPolyBinds top_lvl sig_fn prag_fn rec_group rec_tc bind_list
     ; return (binds, poly_ids) }
   where
     binder_names = collectHsBindListBinders bind_list
-    loc = getLoc (head bind_list)
-         -- TODO: location a bit awkward, but the mbinds have been
-         --       dependency analysed and may no longer be adjacent
+    loc = foldr1 combineSrcSpans (map getLoc bind_list)
+         -- The mbinds have been dependency analysed and 
+         -- may no longer be adjacent; so find the narrowest
+        -- span that includes them all
 
 ------------------
 tcPolyNoGen 
@@ -388,11 +389,10 @@ tcPolyCheck :: TcSigInfo -> PragFun
 --   it binds a single variable,
 --   it has a signature,
 tcPolyCheck sig@(TcSigInfo { sig_id = id, sig_tvs = tvs, sig_scoped = scoped
-                           , sig_theta = theta, sig_loc = loc })
+                           , sig_theta = theta, sig_tau = tau })
     prag_fn rec_tc bind_list
   = do { ev_vars <- newEvVars theta
-
-       ; let skol_info = SigSkol (FunSigCtxt (idName id))
+       ; let skol_info = SigSkol (FunSigCtxt (idName id)) (mkPhiTy theta tau)
        ; (ev_binds, (binds', [mono_info])) 
             <- checkConstraints skol_info tvs ev_vars $
                tcExtendTyVarEnv2 (scoped `zip` mkTyVarTys tvs)    $
@@ -400,6 +400,7 @@ tcPolyCheck sig@(TcSigInfo { sig_id = id, sig_tvs = tvs, sig_scoped = scoped
 
        ; export <- mkExport prag_fn tvs theta mono_info
 
+       ; loc <- getSrcSpanM
        ; let (_, poly_id, _, _) = export
              abs_bind = L loc $ AbsBinds 
                         { abs_tvs = tvs
@@ -416,19 +417,15 @@ tcPolyInfer
                    -- dependencies based on type signatures
   -> [LHsBind Name]
   -> TcM (LHsBinds TcId, [TcId])
-tcPolyInfer top_lvl mono sig_fn prag_fn rec_tc bind_list
+tcPolyInfer top_lvl mono tc_sig_fn prag_fn rec_tc bind_list
   = do { ((binds', mono_infos), wanted) 
              <- captureConstraints $
-                tcMonoBinds sig_fn LetLclBndr rec_tc bind_list
+                tcMonoBinds tc_sig_fn LetLclBndr rec_tc bind_list
 
        ; unifyCtxts [sig | (_, Just sig, _) <- mono_infos] 
 
-       ; let get_tvs | isTopLevel top_lvl = tyVarsOfType  
-                     | otherwise          = exactTyVarsOfType
-                    -- See Note [Silly type synonym] in TcType
-             tau_tvs = foldr (unionVarSet . get_tvs . getMonoType) emptyVarSet mono_infos
-
-       ; (qtvs, givens, ev_binds) <- simplifyInfer mono tau_tvs wanted
+       ; let name_taus = [(name, idType mono_id) | (name, _, mono_id) <- mono_infos]
+       ; (qtvs, givens, ev_binds) <- simplifyInfer top_lvl mono name_taus wanted
 
        ; exports <- mapM (mkExport prag_fn qtvs (map evVarPred givens))
                     mono_infos
@@ -542,47 +539,124 @@ tcSpec poly_id prag@(SpecSig _ hs_ty inl)
   --          for the selector Id, but the poly_id is something like $cop
   = addErrCtxt (spec_ctxt prag) $
     do  { spec_ty <- tcHsSigType sig_ctxt hs_ty
-        ; checkTc (isOverloadedTy poly_ty)
-                  (ptext (sLit "Discarding pragma for non-overloaded function") <+> quotes (ppr poly_id))
-        ; wrap <- tcSubType origin skol_info (idType poly_id) spec_ty
+        ; warnIf (not (isOverloadedTy poly_ty || isInlinePragma inl))
+                 (ptext (sLit "SPECIALISE pragma for non-overloaded function") <+> quotes (ppr poly_id))
+                 -- Note [SPECIALISE pragmas]
+        ; wrap <- tcSubType origin sig_ctxt (idType poly_id) spec_ty
         ; return (SpecPrag poly_id wrap inl) }
   where
     name      = idName poly_id
     poly_ty   = idType poly_id
     origin    = SpecPragOrigin name
     sig_ctxt  = FunSigCtxt name
-    skol_info = SigSkol sig_ctxt
     spec_ctxt prag = hang (ptext (sLit "In the SPECIALISE pragma")) 2 (ppr prag)
 
 tcSpec _ prag = pprPanic "tcSpec" (ppr prag)
 
 --------------
 tcImpPrags :: [LSig Name] -> TcM [LTcSpecPrag]
+-- SPECIALISE pragamas for imported things
 tcImpPrags prags
   = do { this_mod <- getModule
-       ; let is_imp prag 
-               = case sigName prag of
-                   Nothing   -> False
-                   Just name -> not (nameIsLocalOrFrom this_mod name)
-             (spec_prags, others) = partition isSpecLSig $
-                                   filter is_imp prags
-       ; mapM_ misplacedSigErr others 
-       -- Messy that this misplaced-sig error comes here
-       -- but the others come from the renamer
-       ; mapAndRecoverM (wrapLocM tcImpSpec) spec_prags }
-
-tcImpSpec :: Sig Name -> TcM TcSpecPrag
-tcImpSpec prag@(SpecSig (L _ name) _ _)
+       ; dflags <- getDOpts
+       ; if (not_specialising dflags) then
+            return []
+         else
+            mapAndRecoverM (wrapLocM tcImpSpec) 
+            [L loc (name,prag) | (L loc prag@(SpecSig (L _ name) _ _)) <- prags
+                               , not (nameIsLocalOrFrom this_mod name) ] }
+  where
+    -- Ignore SPECIALISE pragmas for imported things
+    -- when we aren't specialising, or when we aren't generating
+    -- code.  The latter happens when Haddocking the base library;
+    -- we don't wnat complaints about lack of INLINABLE pragmas 
+    not_specialising dflags
+      | not (dopt Opt_Specialise dflags) = True
+      | otherwise = case hscTarget dflags of
+                      HscNothing -> True
+                      HscInterpreted -> True
+                      _other         -> False
+
+tcImpSpec :: (Name, Sig Name) -> TcM TcSpecPrag
+tcImpSpec (name, prag)
  = do { id <- tcLookupId name
-      ; checkTc (isInlinePragma (idInlinePragma id))
-                (impSpecErr name)
+      ; unless (isAnyInlinePragma (idInlinePragma id))
+               (addWarnTc (impSpecErr name))
       ; tcSpec id prag }
-tcImpSpec p = pprPanic "tcImpSpec" (ppr p)
 
 impSpecErr :: Name -> SDoc
 impSpecErr name
   = hang (ptext (sLit "You cannot SPECIALISE") <+> quotes (ppr name))
-       2 (ptext (sLit "because its definition has no INLINE/INLINABLE pragma"))
+       2 (vcat [ ptext (sLit "because its definition has no INLINE/INLINABLE pragma")
+               , parens $ sep 
+                   [ ptext (sLit "or its defining module") <+> quotes (ppr mod)
+                   , ptext (sLit "was compiled without -O")]])
+  where
+    mod = nameModule name
+
+--------------
+tcVectDecls :: [LVectDecl Name] -> TcM ([LVectDecl TcId])
+tcVectDecls decls 
+  = do { decls' <- mapM (wrapLocM tcVect) decls
+       ; let ids  = map lvectDeclName decls'
+             dups = findDupsEq (==) ids
+       ; mapM_ reportVectDups dups
+       ; traceTcConstraints "End of tcVectDecls"
+       ; return decls'
+       }
+  where
+    reportVectDups (first:_second:_more) 
+      = addErrAt (getSrcSpan first) $
+          ptext (sLit "Duplicate vectorisation declarations for") <+> ppr first
+    reportVectDups _ = return ()
+
+--------------
+tcVect :: VectDecl Name -> TcM (VectDecl TcId)
+-- We can't typecheck the expression of a vectorisation declaration against the vectorised type
+-- of the original definition as this requires internals of the vectoriser not available during
+-- type checking.  Instead, we infer the type of the expression and leave it to the vectoriser
+-- to check the compatibility of the Core types.
+tcVect (HsVect name Nothing)
+  = addErrCtxt (vectCtxt name) $
+    do { id <- wrapLocM tcLookupId name
+       ; return $ HsVect id Nothing
+       }
+tcVect (HsVect name@(L loc _) (Just rhs))
+  = addErrCtxt (vectCtxt name) $
+    do { _id <- wrapLocM tcLookupId name     -- need to ensure that the name is already defined
+
+         -- turn the vectorisation declaration into a single non-recursive binding
+       ; let bind    = L loc $ mkFunBind name [mkSimpleMatch [] rhs] 
+             sigFun  = const Nothing
+             pragFun = mkPragFun [] (unitBag bind)
+
+         -- perform type inference (including generalisation)
+       ; (binds, [id']) <- tcPolyInfer TopLevel False sigFun pragFun NonRecursive [bind]
+
+       ; traceTc "tcVect inferred type" $ ppr (varType id')
+       ; traceTc "tcVect bindings"      $ ppr binds
+       
+         -- add all bindings, including the type variable and dictionary bindings produced by type
+         -- generalisation to the right-hand side of the vectorisation declaration
+       ; let [AbsBinds tvs evs _ evBinds actualBinds] = (map unLoc . bagToList) binds
+       ; let [bind']                                  = bagToList actualBinds
+             MatchGroup 
+               [L _ (Match _ _ (GRHSs [L _ (GRHS _ rhs')] _))]
+               _                                      = (fun_matches . unLoc) bind'
+             rhsWrapped                               = mkHsLams tvs evs (mkHsDictLet evBinds rhs')
+        
+        -- We return the type-checked 'Id', to propagate the inferred signature
+        -- to the vectoriser - see "Note [Typechecked vectorisation pragmas]" in HsDecls
+       ; return $ HsVect (L loc id') (Just rhsWrapped)
+       }
+tcVect (HsNoVect name)
+  = addErrCtxt (vectCtxt name) $
+    do { id <- wrapLocM tcLookupId name
+       ; return $ HsNoVect id
+       }
+
+vectCtxt :: Located Name -> SDoc
+vectCtxt name = ptext (sLit "When checking the vectorisation declaration for") <+> ppr name
 
 --------------
 -- If typechecking the binds fails, then return with each
@@ -602,6 +676,26 @@ forall_a_a :: TcType
 forall_a_a = mkForAllTy openAlphaTyVar (mkTyVarTy openAlphaTyVar)
 \end{code}
 
+Note [SPECIALISE pragmas]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+There is no point in a SPECIALISE pragma for a non-overloaded function:
+   reverse :: [a] -> [a]
+   {-# SPECIALISE reverse :: [Int] -> [Int] #-}
+
+But SPECIALISE INLINE *can* make sense for GADTS:
+   data Arr e where
+     ArrInt :: !Int -> ByteArray# -> Arr Int
+     ArrPair :: !Int -> Arr e1 -> Arr e2 -> Arr (e1, e2)
+
+   (!:) :: Arr e -> Int -> e
+   {-# SPECIALISE INLINE (!:) :: Arr Int -> Int -> Int #-}  
+   {-# SPECIALISE INLINE (!:) :: Arr (a, b) -> Int -> (a, b) #-}
+   (ArrInt _ ba)     !: (I# i) = I# (indexIntArray# ba i)
+   (ArrPair _ a1 a2) !: i      = (a1 !: i, a2 !: i)
+
+When (!:) is specialised it becomes non-recursive, and can usefully
+be inlined.  Scary!  So we only warn for SPECIALISE *without* INLINE
+for a non-overloaded function.
 
 %************************************************************************
 %*                                                                      *
@@ -679,9 +773,6 @@ type MonoBindInfo = (Name, Maybe TcSigInfo, TcId)
         -- Type signature (if any), and
         -- the monomorphic bound things
 
-getMonoType :: MonoBindInfo -> TcTauType
-getMonoType (_,_,mono_id) = idType mono_id
-
 tcLhs :: TcSigFun -> LetBndrSpec -> HsBind Name -> TcM TcMonoBind
 tcLhs sig_fn no_gen (FunBind { fun_id = L nm_loc name, fun_infix = inf, fun_matches = matches })
   | Just sig <- sig_fn name
@@ -780,7 +871,7 @@ unifyCtxts (sig1 : sigs)
                -- where F is a type function and (F a ~ [a])
                -- Then unification might succeed with a coercion.  But it's much
                -- much simpler to require that such signatures have identical contexts
-               checkTc (all isIdentityCoI cois)
+               checkTc (all isReflCo cois)
                        (ptext (sLit "Mutually dependent functions have syntactically distinct contexts"))
              }
 \end{code}
@@ -1028,7 +1119,10 @@ tcInstSig sig_fn use_skols name
   | Just (scoped_tvs, loc) <- sig_fn name
   = do  { poly_id <- tcLookupId name    -- Cannot fail; the poly ids are put into 
                                         -- scope when starting the binding group
-        ; (tvs, theta, tau) <- tcInstSigType use_skols name (idType poly_id)
+        ; let poly_ty = idType poly_id
+        ; (tvs, theta, tau) <- if use_skols
+                               then tcInstType tcInstSkolTyVars poly_ty
+                               else tcInstType tcInstSigTyVars  poly_ty
         ; let sig = TcSigInfo { sig_id = poly_id
                              , sig_scoped = scoped_tvs
                               , sig_tvs = tvs, sig_theta = theta, sig_tau = tau
@@ -1055,6 +1149,7 @@ instance Outputable GeneralisationPlan where
 decideGeneralisationPlan 
    :: DynFlags -> TopLevelFlag -> [Name] -> [LHsBind Name] -> TcSigFun -> GeneralisationPlan
 decideGeneralisationPlan dflags top_lvl _bndrs binds sig_fn
+  | bang_pat_binds                         = NoGen
   | mono_pat_binds                         = NoGen
   | Just sig <- one_funbind_with_sig binds = if null (sig_tvs sig) && null (sig_theta sig)
                                              then NoGen              -- Optimise common case
@@ -1064,7 +1159,12 @@ decideGeneralisationPlan dflags top_lvl _bndrs binds sig_fn
   | otherwise                              = InferGen mono_restriction
 
   where
-    mono_pat_binds = xopt Opt_MonoPatBinds dflags 
+    bang_pat_binds = any (isBangHsBind . unLoc) binds
+       -- Bang patterns must not be polymorphic,
+       -- because we are going to force them
+       -- See Trac #4498
+
+    mono_pat_binds = xopt Opt_MonoPatBinds dflags
                   && any (is_pat_bind . unLoc) binds
 
     mono_restriction = xopt Opt_MonomorphismRestriction dflags 
@@ -1108,24 +1208,30 @@ checkStrictBinds top_lvl rec_group binds poly_ids
         ; checkTc (isNonRec rec_group)
                   (strictBindErr "Recursive" unlifted binds)
         ; checkTc (isSingleton binds)
-                  (strictBindErr "Multiple" unlifted binds) 
+                  (strictBindErr "Multiple" unlifted binds)
         -- This should be a checkTc, not a warnTc, but as of GHC 6.11
         -- the versions of alex and happy available have non-conforming
         -- templates, so the GHC build fails if it's an error:
         ; warnUnlifted <- doptM Opt_WarnLazyUnliftedBindings
-        ; warnTc (warnUnlifted && not bang_pat)
+        ; warnTc (warnUnlifted && not bang_pat && lifted_pat)
+                 -- No outer bang, but it's a compound pattern
+                 -- E.g   (I# x#) = blah
+                 -- Warn about this, but not about
+                 --      x# = 4# +# 1#
+                 --      (# a, b #) = ...
                  (unliftedMustBeBang binds) }
   | otherwise
   = return ()
   where
-    unlifted = any is_unlifted poly_ids
-    bang_pat = any (isBangHsBind . unLoc) binds
+    unlifted    = any is_unlifted poly_ids
+    bang_pat    = any (isBangHsBind . unLoc) binds
+    lifted_pat  = any (isLiftedPatBind . unLoc) binds
     is_unlifted id = case tcSplitForAllTys (idType id) of
                        (_, rho) -> isUnLiftedType rho
 
 unliftedMustBeBang :: [LHsBind Name] -> SDoc
 unliftedMustBeBang binds
-  = hang (text "Bindings containing unlifted types should use an outermost bang pattern:")
+  = hang (text "Pattern bindings containing unlifted types should use an outermost bang pattern:")
        2 (pprBindList binds)
 
 strictBindErr :: String -> Bool -> [LHsBind Name] -> SDoc