[project @ 2003-12-10 14:15:16 by simonmar]
[ghc-hetmet.git] / ghc / compiler / stgSyn / StgLint.lhs
index 3692e06..0d1b7b5 100644 (file)
@@ -10,23 +10,25 @@ module StgLint ( lintStgBindings ) where
 
 import StgSyn
 
-import Bag             ( Bag, emptyBag, isEmptyBag, snocBag )
+import Bag             ( Bag, emptyBag, isEmptyBag, snocBag, bagToList )
 import Id              ( Id, idType, isLocalId )
 import VarSet
 import DataCon         ( DataCon, dataConArgTys, dataConRepType )
+import CoreSyn         ( AltCon(..) )
 import PrimOp          ( primOpType )
-import Literal         ( literalType, Literal )
+import Literal         ( literalType )
 import Maybes          ( catMaybes )
 import Name            ( getSrcLoc )
-import ErrUtils                ( ErrMsg, Message, addErrLocHdrLine, pprBagOfErrors, dontAddErrLoc )
+import ErrUtils                ( Message, mkLocMessage )
 import Type            ( mkFunTys, splitFunTys, splitTyConApp_maybe,
-                         isUnLiftedType, isTyVarTy, splitForAllTys, Type
+                         isUnLiftedType, isTyVarTy, dropForAlls, Type
                        )
-import TyCon           ( TyCon, isDataTyCon, tyConDataCons )
-import Util            ( zipEqual )
+import TyCon           ( isAlgTyCon, isNewTyCon, tyConDataCons )
+import Util            ( zipEqual, equalLength )
+import SrcLoc          ( srcLocSpan )
 import Outputable
 
-infixr 9 `thenL`, `thenL_`, `thenMaybeL`, `thenMaybeL_`
+infixr 9 `thenL`, `thenL_`, `thenMaybeL`
 \end{code}
 
 Checks for
@@ -89,11 +91,11 @@ lintStgVar v  = checkInScope v      `thenL_`
 
 \begin{code}
 lintStgBinds :: StgBinding -> LintM [Id]               -- Returns the binders
-lintStgBinds (StgNonRec _srt binder rhs)
+lintStgBinds (StgNonRec binder rhs)
   = lint_binds_help (binder,rhs)       `thenL_`
     returnL [binder]
 
-lintStgBinds (StgRec _srt pairs)
+lintStgBinds (StgRec pairs)
   = addInScopeVars binders (
        mapL lint_binds_help pairs `thenL_`
        returnL binders
@@ -127,10 +129,10 @@ lint_binds_help (binder, rhs)
 \begin{code}
 lintStgRhs :: StgRhs -> LintM (Maybe Type)
 
-lintStgRhs (StgRhsClosure _ _ _ _ [] expr)
+lintStgRhs (StgRhsClosure _ _ _ _ _ [] expr)
   = lintStgExpr expr
 
-lintStgRhs (StgRhsClosure _ _ _ _ binders expr)
+lintStgRhs (StgRhsClosure _ _ _ _ _ binders expr)
   = addLoc (LambdaBodyOf binders) (
     addInScopeVars binders (
        lintStgExpr expr   `thenMaybeL` \ body_ty ->
@@ -200,13 +202,14 @@ lintStgExpr (StgLetNoEscape _ _ binds body)
 
 lintStgExpr (StgSCC _ expr)    = lintStgExpr expr
 
-lintStgExpr e@(StgCase scrut _ _ bndr _ alts)
+lintStgExpr e@(StgCase scrut _ _ bndr _ alts_type alts)
   = lintStgExpr scrut          `thenMaybeL` \ _ ->
 
-    (case alts of
-       StgPrimAlts tc _ _       -> check_bndr tc
-       StgAlgAlts (Just tc) _ _ -> check_bndr tc
-       StgAlgAlts Nothing   _ _ -> returnL ()
+    (case alts_type of
+       AlgAlt tc    -> check_bndr tc
+       PrimAlt tc   -> check_bndr tc
+       UbxTupAlt tc -> check_bndr tc
+       PolyAlt      -> returnL ()
     )                                                  `thenL_`
        
     (trace (showSDoc (ppr e)) $ 
@@ -224,25 +227,15 @@ lintStgExpr e@(StgCase scrut _ _ bndr _ alts)
     check_bndr tc = case splitTyConApp_maybe scrut_ty of
                        Just (bndr_tc, _) -> checkL (tc == bndr_tc) bad_bndr
                        Nothing           -> addErrL bad_bndr
-\end{code}
 
-\begin{code}
-lintStgAlts :: StgCaseAlts
-            -> Type            -- Type of scrutinee
-            -> LintM (Maybe Type)      -- Type of alternatives
+
+lintStgAlts :: [StgAlt]
+           -> Type             -- Type of scrutinee
+           -> LintM (Maybe Type)       -- Type of alternatives
 
 lintStgAlts alts scrut_ty
-  = (case alts of
-        StgAlgAlts _ alg_alts deflt ->
-          mapL (lintAlgAlt scrut_ty) alg_alts  `thenL` \ maybe_alt_tys ->
-          lintDeflt deflt scrut_ty             `thenL` \ maybe_deflt_ty ->
-          returnL (maybe_deflt_ty : maybe_alt_tys)
-
-        StgPrimAlts _ prim_alts deflt ->
-          mapL (lintPrimAlt scrut_ty) prim_alts `thenL` \ maybe_alt_tys ->
-          lintDeflt deflt scrut_ty              `thenL` \ maybe_deflt_ty ->
-          returnL (maybe_deflt_ty : maybe_alt_tys)
-    )                                           `thenL` \ maybe_result_tys ->
+  = mapL (lintAlt scrut_ty) alts       `thenL` \ maybe_result_tys ->
+
         -- Check the result types
     case catMaybes (maybe_result_tys) of
       []            -> returnL Nothing
@@ -252,21 +245,29 @@ lintStgAlts alts scrut_ty
        where
          check ty = checkTys first_ty ty (mkCaseAltMsg alts)
 
-lintAlgAlt scrut_ty (con, args, _, rhs)
+lintAlt scrut_ty (DEFAULT, _, _, rhs)
+ = lintStgExpr rhs
+
+lintAlt scrut_ty (LitAlt lit, _, _, rhs)
+ = checkTys (literalType lit) scrut_ty (mkAltMsg1 scrut_ty)    `thenL_`
+   lintStgExpr rhs
+
+lintAlt scrut_ty (DataAlt con, args, _, rhs)
   = (case splitTyConApp_maybe scrut_ty of
-      Just (tycon, tys_applied) | isDataTyCon tycon ->
+      Just (tycon, tys_applied) | isAlgTyCon tycon && 
+                                 not (isNewTyCon tycon) ->
         let
           cons    = tyConDataCons tycon
           arg_tys = dataConArgTys con tys_applied
                -- This almost certainly does not work for existential constructors
         in
         checkL (con `elem` cons) (mkAlgAltMsg2 scrut_ty con) `thenL_`
-        checkL (length arg_tys == length args) (mkAlgAltMsg3 con args)
+        checkL (equalLength arg_tys args) (mkAlgAltMsg3 con args)
                                                                 `thenL_`
         mapL check (zipEqual "lintAlgAlt:stg" arg_tys args)     `thenL_`
         returnL ()
       other ->
-        addErrL (mkAlgAltMsg1 scrut_ty)
+        addErrL (mkAltMsg1 scrut_ty)
     )                                                           `thenL_`
     addInScopeVars args        (
         lintStgExpr rhs
@@ -279,13 +280,6 @@ lintAlgAlt scrut_ty (con, args, _, rhs)
     -- We give it its own copy, so it isn't overloaded.
     elem _ []      = False
     elem x (y:ys)   = x==y || elem x ys
-
-lintPrimAlt scrut_ty alt@(lit,rhs)
- = checkTys (literalType lit) scrut_ty (mkPrimAltMsg alt)      `thenL_`
-   lintStgExpr rhs
-
-lintDeflt StgNoDefault scrut_ty = returnL Nothing
-lintDeflt deflt@(StgBindDefault rhs) scrut_ty = lintStgExpr rhs
 \end{code}
 
 
@@ -298,8 +292,8 @@ lintDeflt deflt@(StgBindDefault rhs) scrut_ty = lintStgExpr rhs
 \begin{code}
 type LintM a = [LintLocInfo]   -- Locations
            -> IdSet            -- Local vars in scope
-           -> Bag ErrMsg       -- Error messages so far
-           -> (a, Bag ErrMsg)  -- Result and error messages (if any)
+           -> Bag Message      -- Error messages so far
+           -> (a, Bag Message) -- Result and error messages (if any)
 
 data LintLocInfo
   = RhsOf Id           -- The variable bound
@@ -307,12 +301,12 @@ data LintLocInfo
   | BodyOfLetRec [Id]  -- One of the binders
 
 dumpLoc (RhsOf v) =
-  (getSrcLoc v, ptext SLIT(" [RHS of ") <> pp_binders [v] <> char ']' )
+  (srcLocSpan (getSrcLoc v), ptext SLIT(" [RHS of ") <> pp_binders [v] <> char ']' )
 dumpLoc (LambdaBodyOf bs) =
-  (getSrcLoc (head bs), ptext SLIT(" [in body of lambda with binders ") <> pp_binders bs <> char ']' )
+  (srcLocSpan (getSrcLoc (head bs)), ptext SLIT(" [in body of lambda with binders ") <> pp_binders bs <> char ']' )
 
 dumpLoc (BodyOfLetRec bs) =
-  (getSrcLoc (head bs), ptext SLIT(" [in body of letrec with binders ") <> pp_binders bs <> char ']' )
+  (srcLocSpan (getSrcLoc (head bs)), ptext SLIT(" [in body of letrec with binders ") <> pp_binders bs <> char ']' )
 
 
 pp_binders :: [Id] -> SDoc
@@ -330,7 +324,7 @@ initL m
     if isEmptyBag errs then
        Nothing
     else
-       Just (pprBagOfErrors errs)
+       Just (vcat (punctuate (text "") (bagToList errs)))
     }
 
 returnL :: a -> LintM a
@@ -352,12 +346,6 @@ thenMaybeL m k loc scope errs
       (Nothing, errs2) -> (Nothing, errs2)
       (Just r,  errs2) -> k r loc scope errs2
 
-thenMaybeL_ :: LintM (Maybe a) -> LintM (Maybe b) -> LintM (Maybe b)
-thenMaybeL_ m k loc scope errs
-  = case m loc scope errs of
-      (Nothing, errs2) -> (Nothing, errs2)
-      (Just _,  errs2) -> k loc scope errs2
-
 mapL :: (a -> LintM b) -> [a] -> LintM [b]
 mapL f [] = returnL []
 mapL f (x:xs)
@@ -382,13 +370,14 @@ checkL False msg loc scope errs = ((), addErr errs msg loc)
 addErrL :: Message -> LintM ()
 addErrL msg loc scope errs = ((), addErr errs msg loc)
 
-addErr :: Bag ErrMsg -> Message -> [LintLocInfo] -> Bag ErrMsg
+addErr :: Bag Message -> Message -> [LintLocInfo] -> Bag Message
 
 addErr errs_so_far msg locs
   = errs_so_far `snocBag` mk_msg locs
   where
-    mk_msg (loc:_) = let (l,hdr) = dumpLoc loc in addErrLocHdrLine l hdr msg
-    mk_msg []      = dontAddErrLoc msg
+    mk_msg (loc:_) = let (l,hdr) = dumpLoc loc 
+                    in  mkLocMessage l (hdr $$ msg)
+    mk_msg []      = msg
 
 addLoc :: LintLocInfo -> LintM a -> LintM a
 addLoc extra_loc m loc scope errs
@@ -426,8 +415,7 @@ checkFunApp :: Type                     -- The function type
 checkFunApp fun_ty arg_tys msg loc scope errs
   = cfa res_ty expected_arg_tys arg_tys
   where
-    (_, de_forall_ty)         = splitForAllTys fun_ty
-    (expected_arg_tys, res_ty) = splitFunTys de_forall_ty
+    (expected_arg_tys, res_ty) = splitFunTys (dropForAlls fun_ty)
 
     cfa res_ty expected []     -- Args have run out; that's fine
       = (Just (mkFunTys expected res_ty), errs)
@@ -463,16 +451,11 @@ checkTys ty1 ty2 msg loc scope errs
 \end{code}
 
 \begin{code}
-mkCaseAltMsg :: StgCaseAlts -> Message
+mkCaseAltMsg :: [StgAlt] -> Message
 mkCaseAltMsg alts
   = ($$) (text "In some case alternatives, type of alternatives not all same:")
            (empty) -- LATER: ppr alts
 
-mkCaseAbstractMsg :: TyCon -> Message
-mkCaseAbstractMsg tycon
-  = ($$) (ptext SLIT("An algebraic case on an abstract type:"))
-           (ppr tycon)
-
 mkDefltMsg :: Id -> Message
 mkDefltMsg bndr
   = ($$) (ptext SLIT("Binder of a case expression doesn't match type of scrutinee:"))
@@ -491,16 +474,10 @@ mkRhsConMsg fun_ty arg_tys
              hang (ptext SLIT("Constructor type:")) 4 (ppr fun_ty),
              hang (ptext SLIT("Arg types:")) 4 (vcat (map (ppr) arg_tys))]
 
-mkUnappTyMsg :: Id -> Type -> Message
-mkUnappTyMsg var ty
-  = vcat [text "Variable has a for-all type, but isn't applied to any types.",
-             (<>) (ptext SLIT("Var:      ")) (ppr var),
-             (<>) (ptext SLIT("Its type: ")) (ppr ty)]
-
-mkAlgAltMsg1 :: Type -> Message
-mkAlgAltMsg1 ty
-  = ($$) (text "In some case statement, type of scrutinee is not a data type:")
-           (ppr ty)
+mkAltMsg1 :: Type -> Message
+mkAltMsg1 ty
+  = ($$) (text "In a case expression, type of scrutinee does not match patterns")
+        (ppr ty)
 
 mkAlgAltMsg2 :: Type -> DataCon -> Message
 mkAlgAltMsg2 ty con
@@ -526,11 +503,6 @@ mkAlgAltMsg4 ty arg
        ppr arg
     ]
 
-mkPrimAltMsg :: (Literal, StgExpr) -> Message
-mkPrimAltMsg alt
-  = text "In a primitive case alternative, type of literal doesn't match type of scrutinee:"
-    $$ ppr alt
-
 mkCaseOfCaseMsg :: StgExpr -> Message
 mkCaseOfCaseMsg e
   = text "Case of non-tail-call:" $$ ppr e