[project @ 2001-11-23 12:01:34 by simonmar]
[ghc-hetmet.git] / ghc / compiler / deSugar / Check.lhs
index ef3bcf5..b679729 100644 (file)
@@ -11,46 +11,24 @@ module Check ( check , ExhaustivePat ) where
 
 
 import HsSyn           
-import TcHsSyn         ( TypecheckedPat )
-import DsHsSyn         ( outPatType ) 
-import CoreSyn         
-
-import DsUtils         ( EquationInfo(..),
-                         MatchResult(..),
-                         EqnSet,
-                         CanItFail(..)
+import TcHsSyn         ( TypecheckedPat, outPatType )
+import TcType          ( tcTyConAppTyCon, tcTyConAppArgs )
+import DsUtils         ( EquationInfo(..), MatchResult(..), EqnSet, 
+                         CanItFail(..),  tidyLitPat, tidyNPat, 
                        )
 import Id              ( idType )
-import DataCon         ( DataCon, isTupleCon, isUnboxedTupleCon,
+import DataCon         ( DataCon, dataConTyCon, dataConArgTys,
                          dataConSourceArity, dataConFieldLabels )
-import Name             ( Name, mkLocalName, getOccName, isDataSymOcc, getName, mkSrcVarOcc )
-import Type            ( Type, 
-                          isUnboxedType, 
-                          splitTyConApp_maybe
-                       )
-import TysPrim         ( intPrimTy, 
-                          charPrimTy, 
-                          floatPrimTy, 
-                          doublePrimTy,
-                         addrPrimTy, 
-                          wordPrimTy
-                       )
-import TysWiredIn      ( nilDataCon, consDataCon, 
-                          mkTupleTy, tupleCon,
-                         mkUnboxedTupleTy, unboxedTupleCon,
-                          mkListTy, 
-                          charTy, charDataCon, 
-                          intTy, intDataCon,
-                         floatTy, floatDataCon, 
-                          doubleTy, doubleDataCon, 
-                          addrTy, addrDataCon,
-                          wordTy, wordDataCon,
-                         stringTy
-                       )
-import Unique          ( unboundKey )
-import TyCon            ( tyConDataCons )
+import Name             ( Name, mkLocalName, getOccName, isDataSymOcc, getName, mkVarOcc )
+import TcType          ( mkTyVarTys )
+import TysPrim         ( charPrimTy )
+import TysWiredIn
+import PrelNames       ( unboundKey )
+import TyCon            ( tyConDataCons, tupleTyConBoxity, isTupleTyCon )
+import BasicTypes      ( Boxity(..) )
 import SrcLoc          ( noSrcLoc )
 import UniqSet
+import Util             ( takeList, splitAtList )
 import Outputable
 
 #include "HsVersions.h"
@@ -166,13 +144,7 @@ untidy b (ConOpPatIn pat1 name fixity pat2) =
 untidy _ (ListPatIn pats)  = ListPatIn (map untidy_no_pars pats) 
 untidy _ (TuplePatIn pats boxed) = TuplePatIn (map untidy_no_pars pats) boxed
 
-untidy _ (SigPatIn pat ty)      = panic "Check.untidy: SigPatIn"
-untidy _ (LazyPatIn pat)        = panic "Check.untidy: LazyPatIn"
-untidy _ (AsPatIn name pat)     = panic "Check.untidy: AsPatIn"
-untidy _ (NPlusKPatIn name lit) = panic "Check.untidy: NPlusKPatIn"
-untidy _ (NegPatIn ipat)        = panic "Check.untidy: NegPatIn"
-untidy _ (ParPatIn pat)         = panic "Check.untidy: ParPatIn"
-untidy _ (RecPatIn name fields) = panic "Check.untidy: RecPatIn"
+untidy _ pat = pprPanic "Check.untidy: SigPatIn" (ppr pat)
 
 pars :: NeedPars -> WarningPat -> WarningPat
 pars True p = ParPatIn p
@@ -216,7 +188,7 @@ check' :: [EquationInfo] -> ([ExhaustivePat],EqnSet)
 check' []                                              = ([([],[])],emptyUniqSet)
 
 check' [EqnInfo n ctx ps (MatchResult CanFail _)] 
-   | all_vars ps  = ([(take (length ps) (repeat new_wild_pat),[])],  unitUniqSet n)
+   | all_vars ps  = ([(takeList ps (repeat new_wild_pat),[])],  unitUniqSet n)
 
 check' qs@((EqnInfo n ctx ps (MatchResult CanFail _)):rs)
    | all_vars ps  = (pats,  addOneToUniqSet indexs n)
@@ -230,7 +202,7 @@ check' qs@((EqnInfo n ctx ps result):_)
    | literals     = split_by_literals qs
    | constructors = split_by_constructor qs
    | only_vars    = first_column_only_vars qs
-   | otherwise    = panic ("Check.check': Not implemented :-(")
+   | otherwise    = panic "Check.check': Not implemented :-("
   where
      -- Note: RecPats will have been simplified to ConPats
      --       at this stage.
@@ -273,8 +245,8 @@ must be one Variable to be complete.
 
 process_literals :: [HsLit] -> [EquationInfo] -> ([ExhaustivePat],EqnSet)
 process_literals used_lits qs 
-  | length default_eqns == 0 = ([make_row_vars used_lits (head qs)]++pats,indexs)
-  | otherwise                = (pats_default,indexs_default)
+  | null default_eqns  = ([make_row_vars used_lits (head qs)]++pats,indexs)
+  | otherwise          = (pats_default,indexs_default)
      where
        (pats,indexs)   = process_explicit_literals used_lits qs
        default_eqns    = (map remove_var (filter is_var qs))
@@ -312,8 +284,9 @@ same constructor.
 
 split_by_constructor :: [EquationInfo] -> ([ExhaustivePat],EqnSet)
 
-split_by_constructor qs | length unused_cons /= 0 = need_default_case used_cons unused_cons qs 
-                        | otherwise               = no_need_default_case used_cons qs 
+split_by_constructor qs 
+  | not (null unused_cons) = need_default_case used_cons unused_cons qs 
+  | otherwise              = no_need_default_case used_cons qs 
                        where 
                           used_cons   = get_used_cons qs 
                           unused_cons = get_unused_cons used_cons 
@@ -348,8 +321,8 @@ no_need_default_case cons qs = (concat pats, unionManyUniqSets indexs)
 
 need_default_case :: [TypecheckedPat] -> [DataCon] -> [EquationInfo] -> ([ExhaustivePat],EqnSet)
 need_default_case used_cons unused_cons qs 
-  | length default_eqns == 0 = (pats_default_no_eqns,indexs)
-  | otherwise                = (pats_default,indexs_default)
+  | null default_eqns  = (pats_default_no_eqns,indexs)
+  | otherwise          = (pats_default,indexs_default)
      where
        (pats,indexs)   = no_need_default_case used_cons qs
        default_eqns    = (map remove_var (filter is_var qs))
@@ -397,15 +370,15 @@ remove_first_column (ConPat con _ _ _ con_pats) qs =
 
 make_row_vars :: [HsLit] -> EquationInfo -> ExhaustivePat
 make_row_vars used_lits (EqnInfo _ _ pats _ ) = 
-   (VarPatIn new_var:take (length (tail pats)) (repeat new_wild_pat),[(new_var,used_lits)])
+   (VarPatIn new_var:takeList (tail pats) (repeat new_wild_pat),[(new_var,used_lits)])
   where new_var = hash_x
 
 hash_x = mkLocalName unboundKey {- doesn't matter much -}
-                    (mkSrcVarOcc SLIT("#x"))
+                    (mkVarOcc SLIT("#x"))
                     noSrcLoc
 
 make_row_vars_for_constructor :: EquationInfo -> [WarningPat]
-make_row_vars_for_constructor (EqnInfo _ _ pats _ ) = take (length (tail pats)) (repeat new_wild_pat)
+make_row_vars_for_constructor (EqnInfo _ _ pats _ ) = takeList (tail pats) (repeat new_wild_pat)
 
 compare_cons :: TypecheckedPat -> TypecheckedPat -> Bool
 compare_cons (ConPat id1 _ _ _ _) (ConPat id2 _ _ _ _) = id1 == id2  
@@ -442,13 +415,12 @@ get_unused_cons :: [TypecheckedPat] -> [DataCon]
 get_unused_cons used_cons = unused_cons
      where
        (ConPat _ ty _ _ _) = head used_cons
-       Just (ty_con,_)            = splitTyConApp_maybe ty
+       ty_con             = tcTyConAppTyCon ty         -- Newtype observable
        all_cons                   = tyConDataCons ty_con
        used_cons_as_id            = map (\ (ConPat d _ _ _ _) -> d) used_cons
        unused_cons                = uniqSetToList
                 (mkUniqSet all_cons `minusUniqSet` mkUniqSet used_cons_as_id) 
 
-
 all_vars :: [TypecheckedPat] -> Bool
 all_vars []              = True
 all_vars (WildPat _:ps)  = all_vars ps
@@ -552,13 +524,11 @@ make_con (ConPat id _ _ _ _) (p:q:ps, constraints)
           fixity = panic "Check.make_con: Guessing fixity"
 
 make_con (ConPat id _ _ _ pats) (ps,constraints) 
-      | isTupleCon id        = (TuplePatIn pats_con True : rest_pats,    constraints) 
-      | isUnboxedTupleCon id = (TuplePatIn pats_con False : rest_pats, constraints)
-      | otherwise     = (ConPatIn name pats_con : rest_pats, constraints)
-    where num_args  = length pats
-          name      = getName id
-          pats_con  = take num_args ps
-          rest_pats = drop num_args ps
+      | isTupleTyCon tc = (TuplePatIn pats_con (tupleTyConBoxity tc) : rest_pats, constraints) 
+      | otherwise       = (ConPatIn name pats_con                   : rest_pats, constraints)
+    where name      = getName id
+         (pats_con, rest_pats) = splitAtList pats ps
+         tc        = dataConTyCon id
          
 
 make_whole_con :: DataCon -> WarningPat
@@ -568,7 +538,7 @@ make_whole_con con | isInfixCon con = ConOpPatIn new_wild_pat name fixity new_wi
                   fixity = panic "Check.make_whole_con: Guessing fixity"
                   name   = getName con
                   arity  = dataConSourceArity con 
-                  pats   = take arity (repeat new_wild_pat)
+                  pats   = replicate arity new_wild_pat
 
 
 new_wild_pat :: WarningPat
@@ -605,27 +575,24 @@ simplify_pat (ListPat ty ps) = foldr (\ x -> \y -> ConPat consDataCon list_ty []
                              where list_ty = mkListTy ty
 
 
-simplify_pat (TuplePat ps True) = ConPat (tupleCon arity)
-                                   (mkTupleTy arity (map outPatType ps)) [] []
-                                   (map simplify_pat ps)
-                           where
-                              arity = length ps
-
-simplify_pat (TuplePat ps False) 
-  = ConPat (unboxedTupleCon arity)
-          (mkUnboxedTupleTy arity (map outPatType ps)) [] []
+simplify_pat (TuplePat ps boxity)
+  = ConPat (tupleCon boxity arity)
+          (mkTupleTy boxity arity (map outPatType ps)) [] []
           (map simplify_pat ps)
   where
     arity = length ps
 
-simplify_pat (RecPat dc ty tvs dicts [])   
-  = ConPat dc ty tvs dicts all_wild_pats
+simplify_pat (RecPat dc ty ex_tvs dicts [])   
+  = ConPat dc ty ex_tvs dicts all_wild_pats
   where
-    all_wild_pats = map (\ _ -> WildPat gt) (dataConFieldLabels dc)
-    gt = panic "Check.symplify_pat{RecPat-1}"
+    all_wild_pats = map WildPat con_arg_tys
 
-simplify_pat (RecPat dc ty tvs dicts idps) 
-  = ConPat dc ty tvs dicts pats
+      -- Identical to machinations in Match.tidy1:
+    inst_tys    = tcTyConAppArgs ty    -- Newtype is observable
+    con_arg_tys = dataConArgTys dc (inst_tys ++ mkTyVarTys ex_tvs)
+
+simplify_pat (RecPat dc ty ex_tvs dicts idps) 
+  = ConPat dc ty ex_tvs dicts pats
   where
     pats = map (simplify_pat.snd) all_pats
 
@@ -643,64 +610,18 @@ simplify_pat (RecPat dc ty tvs dicts idps)
       | nm == n    = (nm,p):xs
       | otherwise  = x : insertNm nm p xs
 
-simplify_pat pat@(LitPat lit lit_ty) 
-  | isUnboxedType lit_ty = pat
-
-  | lit_ty == charTy = ConPat charDataCon charTy [] [] [LitPat (mk_char lit) charPrimTy]
+simplify_pat pat@(LitPat lit lit_ty)        = tidyLitPat lit pat
 
-  | otherwise = pprPanic "Check.simplify_pat: LitPat:" (ppr pat)
+-- unpack string patterns fully, so we can see when they overlap with
+-- each other, or even explicit lists of Chars.
+simplify_pat pat@(NPat (HsString s) _ _) = 
+   foldr (\c pat -> ConPat consDataCon stringTy [] [] [mk_char_lit c,pat])
+       (ConPat nilDataCon stringTy [] [] []) (_UNPK_INT_ s)
   where
-    mk_char (HsChar c)    = HsCharPrim c
-
-simplify_pat (NPat lit lit_ty hsexpr) = better_pat
-  where
-    better_pat
-      | lit_ty == charTy   = ConPat charDataCon   lit_ty [] [] [LitPat (mk_char lit)   charPrimTy]
-      | lit_ty == intTy    = ConPat intDataCon    lit_ty [] [] [LitPat (mk_int lit)    intPrimTy]
-      | lit_ty == wordTy   = ConPat wordDataCon   lit_ty [] [] [LitPat (mk_word lit)   wordPrimTy]
-      | lit_ty == addrTy   = ConPat addrDataCon   lit_ty [] [] [LitPat (mk_addr lit)   addrPrimTy]
-      | lit_ty == floatTy  = ConPat floatDataCon  lit_ty [] [] [LitPat (mk_float lit)  floatPrimTy]
-      | lit_ty == doubleTy = ConPat doubleDataCon lit_ty [] [] [LitPat (mk_double lit) doublePrimTy]
-
-               -- Convert the literal pattern "" to the constructor pattern [].
-      | null_str_lit lit      = ConPat nilDataCon  lit_ty [] [] []
-      | lit_ty == stringTy = 
-            foldr (\ x -> \y -> ConPat consDataCon list_ty [] [] [x, y])
-                               (ConPat nilDataCon  list_ty [] [] [])
-                               (mk_string lit)
-      | otherwise             = NPat lit lit_ty hsexpr
-
-    list_ty = mkListTy lit_ty
-
-    mk_int    (HsInt i)      = HsIntPrim i
-    mk_int    l@(HsLitLit s) = l
-
-    mk_head_char (HsString s) = HsCharPrim (_HEAD_ s)
-    mk_string    (HsString s) = 
-       map (\ c -> ConPat charDataCon charTy [] []
-                        [LitPat (HsCharPrim c) charPrimTy]) 
-           (_UNPK_ s)
-
-    mk_char   (HsChar c)     = HsCharPrim c
-    mk_char   l@(HsLitLit s) = l
-
-    mk_word   l@(HsLitLit s) = l
-
-    mk_addr   l@(HsLitLit s) = l
-
-    mk_float  (HsInt i)      = HsFloatPrim (fromInteger i)
-    mk_float  (HsFrac f)     = HsFloatPrim f
-    mk_float  l@(HsLitLit s) = l
-
-    mk_double (HsInt i)      = HsDoublePrim (fromInteger i)
-    mk_double (HsFrac f)     = HsDoublePrim f
-    mk_double l@(HsLitLit s) = l
-
-    null_str_lit (HsString s) = _NULL_ s
-    null_str_lit other_lit    = False
+    mk_char_lit c = ConPat charDataCon charTy [] [] 
+                       [LitPat (HsCharPrim c) charPrimTy]
 
-    one_str_lit (HsString s) = _LENGTH_ s == (1::Int)
-    one_str_lit other_lit    = False
+simplify_pat pat@(NPat lit lit_ty hsexpr) = tidyNPat lit lit_ty pat
 
 simplify_pat (NPlusKPat        id hslit ty hsexpr1 hsexpr2) = 
      WildPat ty
@@ -708,9 +629,9 @@ simplify_pat (NPlusKPat     id hslit ty hsexpr1 hsexpr2) =
 
 simplify_pat (DictPat dicts methods) = 
     case num_of_d_and_ms of
-       0 -> simplify_pat (TuplePat [] True) 
+       0 -> simplify_pat (TuplePat [] Boxed) 
        1 -> simplify_pat (head dict_and_method_pats) 
-       _ -> simplify_pat (TuplePat dict_and_method_pats True)
+       _ -> simplify_pat (TuplePat dict_and_method_pats Boxed)
     where
        num_of_d_and_ms  = length dicts + length methods
        dict_and_method_pats = map VarPat (dicts ++ methods)