Do dependency analysis when kind-checking type declarations
[ghc-hetmet.git] / compiler / hsSyn / HsUtils.lhs
index d5ff6f5..3ef4bff 100644 (file)
@@ -18,16 +18,17 @@ module HsUtils(
   -- Terms
   mkHsPar, mkHsApp, mkHsConApp, mkSimpleHsAlt,
   mkSimpleMatch, unguardedGRHSs, unguardedRHS, 
-  mkMatchGroup, mkMatch, mkHsLam,
-  mkHsWrap, mkLHsWrap, mkHsWrapCoI, coiToHsWrapper, mkHsDictLet,
-  mkHsOpApp, mkHsDo,
+  mkMatchGroup, mkMatch, mkHsLam, mkHsIf,
+  mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI,
+  coiToHsWrapper, mkHsDictLet,
+  mkHsOpApp, mkHsDo, mkHsWrapPat, mkHsWrapPatCoI,
 
   nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps, 
   nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList,
   mkLHsTupleExpr, mkLHsVarTuple, missingTupArg,
 
   -- Bindigns
-  mkFunBind, mkVarBind, mkHsVarBind, mk_easy_FunBind, mk_FunBind,
+  mkFunBind, mkVarBind, mkHsVarBind, mk_easy_FunBind, 
 
   -- Literals
   mkHsIntegral, mkHsFractional, mkHsIsString, mkHsString, 
@@ -52,14 +53,18 @@ module HsUtils(
   noRebindableInfo, 
 
   -- Collecting binders
-  collectLocalBinders, collectHsValBinders, 
+  collectLocalBinders, collectHsValBinders, collectHsBindListBinders,
   collectHsBindsBinders, collectHsBindBinders, collectMethodBinders,
   collectPatBinders, collectPatsBinders,
   collectLStmtsBinders, collectStmtsBinders,
   collectLStmtBinders, collectStmtBinders,
-  collectSigTysFromPats, collectSigTysFromPat
+  collectSigTysFromPats, collectSigTysFromPat,
+
+  hsTyClDeclBinders, hsTyClDeclsBinders, 
+  hsForeignDeclsBinders, hsGroupBinders
   ) where
 
+import HsDecls
 import HsBinds
 import HsExpr
 import HsPat
@@ -76,7 +81,6 @@ import NameSet
 import BasicTypes
 import SrcLoc
 import FastString
-import Outputable
 import Util
 import Bag
 \end{code}
@@ -128,13 +132,25 @@ mkHsWrap co_fn e | isIdHsWrapper co_fn = e
                 | otherwise           = HsWrap co_fn e
 
 mkHsWrapCoI :: CoercionI -> HsExpr id -> HsExpr id
-mkHsWrapCoI IdCo     e = e
+mkHsWrapCoI (IdCo _) e = e
 mkHsWrapCoI (ACo co) e = mkHsWrap (WpCast co) e
 
+mkLHsWrapCoI :: CoercionI -> LHsExpr id -> LHsExpr id
+mkLHsWrapCoI (IdCo _) e         = e
+mkLHsWrapCoI (ACo co) (L loc e) = L loc (mkHsWrap (WpCast co) e)
+
 coiToHsWrapper :: CoercionI -> HsWrapper
-coiToHsWrapper IdCo     = idHsWrapper
+coiToHsWrapper (IdCo _) = idHsWrapper
 coiToHsWrapper (ACo co) = WpCast co
 
+mkHsWrapPat :: HsWrapper -> Pat id -> Type -> Pat id
+mkHsWrapPat co_fn p ty | isIdHsWrapper co_fn = p
+                      | otherwise           = CoPat co_fn p ty
+
+mkHsWrapPatCoI :: CoercionI -> Pat id -> Type -> Pat id
+mkHsWrapPatCoI (IdCo _) pat _  = pat
+mkHsWrapPatCoI (ACo co) pat ty = CoPat (WpCast co) pat ty
+
 mkHsLam :: [LPat id] -> LHsExpr id -> LHsExpr id
 mkHsLam pats body = mkHsPar (L (getLoc body) (HsLam matches))
        where
@@ -143,14 +159,8 @@ mkHsLam pats body = mkHsPar (L (getLoc body) (HsLam matches))
 mkMatchGroup :: [LMatch id] -> MatchGroup id
 mkMatchGroup matches = MatchGroup matches placeHolderType
 
-mkHsDictLet :: LHsBinds Id -> LHsExpr Id -> LHsExpr Id
--- Used for the dictionary bindings gotten from TcSimplify
--- We make them recursive to be on the safe side
-mkHsDictLet binds expr 
-  | isEmptyLHsBinds binds = expr
-  | otherwise             = L (getLoc expr) (HsLet (HsValBinds val_binds) expr)
-                         where
-                           val_binds = ValBindsOut [(Recursive, binds)] []
+mkHsDictLet :: TcEvBinds -> LHsExpr Id -> LHsExpr Id
+mkHsDictLet ev_binds expr = mkLHsWrap (WpLet ev_binds) expr
 
 mkHsConApp :: DataCon -> [Type] -> [HsExpr Id] -> LHsExpr Id
 -- Used for constructing dictionary terms etc, so no locations 
@@ -195,6 +205,9 @@ noRebindableInfo = error "noRebindableInfo"         -- Just another placeholder;
 
 mkHsDo ctxt stmts body = HsDo ctxt stmts body placeHolderType
 
+mkHsIf :: LHsExpr id -> LHsExpr id -> LHsExpr id -> HsExpr id
+mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b
+
 mkNPat lit neg     = NPat lit neg noSyntaxExpr
 mkNPlusKPat id lit = NPlusKPat id lit noSyntaxExpr noSyntaxExpr
 
@@ -215,7 +228,7 @@ mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr
 emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = []
                        , recS_ret_fn = noSyntaxExpr, recS_mfix_fn = noSyntaxExpr
                       , recS_bind_fn = noSyntaxExpr
-                       , recS_rec_rets = [], recS_dicts = emptyLHsBinds }
+                       , recS_rec_rets = [] }
 
 mkRecStmt stmts = emptyRecStmt { recS_stmts = stmts }
 
@@ -319,7 +332,7 @@ nlList   :: [LHsExpr id] -> LHsExpr id
 
 nlHsLam        match           = noLoc (HsLam (mkMatchGroup [match]))
 nlHsPar e              = noLoc (HsPar e)
-nlHsIf cond true false = noLoc (HsIf cond true false)
+nlHsIf cond true false = noLoc (mkHsIf cond true false)
 nlHsCase expr matches  = noLoc (HsCase expr (mkMatchGroup matches))
 nlList exprs           = noLoc (ExplicitList placeHolderType exprs)
 
@@ -383,17 +396,6 @@ mk_easy_FunBind loc fun pats expr
   = L loc $ mkFunBind (L loc fun) [mkMatch pats expr emptyLocalBinds]
 
 ------------
-mk_FunBind :: SrcSpan -> id
-          -> [([LPat id], LHsExpr id)]
-          -> LHsBind id
-
-mk_FunBind _   _   [] = panic "TcGenDeriv:mk_FunBind"
-mk_FunBind loc fun pats_and_exprs
-  = L loc $ mkFunBind (L loc fun) matches
-  where
-    matches = [mkMatch p e emptyLocalBinds | (p,e) <-pats_and_exprs]
-
-------------
 mkMatch :: [LPat id] -> LHsExpr id -> HsLocalBinds id -> LMatch id
 mkMatch pats expr binds
   = noLoc (Match (map paren pats) Nothing 
@@ -423,7 +425,7 @@ it should return [x, y, f, a, b] (remember, order important).
 Note [Collect binders only after renaming]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 These functions should only be used on HsSyn *after* the renamer,
-to reuturn a [Name] or [Id].  Before renaming the record punning
+to return a [Name] or [Id].  Before renaming the record punning
 and wild-card mechanism makes it hard to know what is bound.
 So these functions should not be applied to (HsSyn RdrName)
 
@@ -457,6 +459,9 @@ collect_bind (AbsBinds { abs_exports = dbinds, abs_binds = _binds }) acc
 collectHsBindsBinders :: LHsBindsLR idL idR -> [idL]
 collectHsBindsBinders binds = collect_binds binds []
 
+collectHsBindListBinders :: [LHsBindLR idL idR] -> [idL]
+collectHsBindListBinders = foldr (collect_bind . unLoc) []
+
 collect_binds :: LHsBindsLR idL idR -> [idL] -> [idL]
 collect_binds binds acc = foldrBag (collect_bind . unLoc) acc binds
 
@@ -503,7 +508,6 @@ collect_lpat (L _ pat) bndrs
   = go pat
   where
     go (VarPat var)              = var : bndrs
-    go (VarPatOut var bs)        = var : collect_binds bs bndrs
     go (WildPat _)               = bndrs
     go (LazyPat pat)             = collect_lpat pat bndrs
     go (BangPat pat)             = collect_lpat pat bndrs
@@ -555,6 +559,59 @@ and *also* uses that dictionary to match the (n+1) pattern.  Yet, the
 variables bound by the lazy pattern are n,m, *not* the dictionary d.
 So in mkSelectorBinds in DsUtils, we want just m,n as the variables bound.
 
+\begin{code}
+hsGroupBinders :: HsGroup Name -> [Name]
+hsGroupBinders (HsGroup { hs_valds = val_decls, hs_tyclds = tycl_decls,
+                          hs_instds = inst_decls, hs_fords = foreign_decls })
+-- Collect the binders of a Group
+  =  collectHsValBinders val_decls
+  ++ hsTyClDeclsBinders tycl_decls inst_decls
+  ++ hsForeignDeclsBinders foreign_decls
+
+hsForeignDeclsBinders :: [LForeignDecl Name] -> [Name]
+hsForeignDeclsBinders foreign_decls
+  = [n | L _ (ForeignImport (L _ n) _ _) <- foreign_decls]
+
+hsTyClDeclsBinders :: [[LTyClDecl Name]] -> [Located (InstDecl Name)] -> [Name]
+hsTyClDeclsBinders tycl_decls inst_decls
+  = [n | d <- instDeclATs inst_decls ++ concat tycl_decls
+       , L _ n <- hsTyClDeclBinders d]
+
+hsTyClDeclBinders :: Eq name => Located (TyClDecl name) -> [Located name]
+-- ^ Returns all the /binding/ names of the decl, along with their SrcLocs.
+-- The first one is guaranteed to be the name of the decl. For record fields
+-- mentioned in multiple constructors, the SrcLoc will be from the first
+-- occurence.  We use the equality to filter out duplicate field names
+
+hsTyClDeclBinders (L _ (TyFamily    {tcdLName = name})) = [name]
+hsTyClDeclBinders (L _ (TySynonym   {tcdLName = name})) = [name]
+hsTyClDeclBinders (L _ (ForeignType {tcdLName = name})) = [name]
+
+hsTyClDeclBinders (L _ (ClassDecl {tcdLName = cls_name, tcdSigs = sigs, tcdATs = ats}))
+  = cls_name : 
+    concatMap hsTyClDeclBinders ats ++ [n | L _ (TypeSig n _) <- sigs]
+
+hsTyClDeclBinders (L _ (TyData {tcdLName = tc_name, tcdCons = cons}))
+  = tc_name : hsConDeclsBinders cons
+
+hsConDeclsBinders :: (Eq name) => [LConDecl name] -> [Located name]
+  -- See hsTyClDeclBinders for what this does
+  -- The function is boringly complicated because of the records
+  -- And since we only have equality, we have to be a little careful
+hsConDeclsBinders cons
+  = snd (foldl do_one ([], []) cons)
+  where
+    do_one (flds_seen, acc) (L _ (ConDecl { con_name = lname, con_details = RecCon flds }))
+       = (map unLoc new_flds ++ flds_seen, lname : new_flds ++ acc)
+       where
+         new_flds = filterOut (\f -> unLoc f `elem` flds_seen) 
+                              (map cd_fld_name flds)
+
+    do_one (flds_seen, acc) (L _ (ConDecl { con_name = lname }))
+       = (flds_seen, lname:acc)
+\end{code}
+
+
 %************************************************************************
 %*                                                                     *
        Collecting type signatures from patterns