Fix scoped type variables for expression type signatures
[ghc-hetmet.git] / compiler / parser / RdrHsSyn.lhs
index ca24070..ace6fd0 100644 (file)
@@ -8,7 +8,7 @@ module RdrHsSyn (
        extractHsTyRdrTyVars, 
        extractHsRhoRdrTyVars, extractGenericPatTyVars,
  
-       mkHsOpApp, mkClassDecl, 
+       mkHsOpApp, mkClassDecl,
        mkHsNegApp, mkHsIntegral, mkHsFractional,
        mkHsDo, mkHsSplice,
         mkTyData, mkPrefixCon, mkRecCon, mkInlineSpec, 
@@ -35,8 +35,10 @@ module RdrHsSyn (
        checkPrecP,           -- Int -> P Int
        checkContext,         -- HsType -> P HsContext
        checkPred,            -- HsType -> P HsPred
-       checkTyClHdr,         -- LHsContext RdrName -> LHsType RdrName -> P (LHsContext RdrName, Located RdrName, [LHsTyVarBndr RdrName])
-       checkSynHdr,          -- LHsType RdrName -> P (Located RdrName, [LHsTyVarBndr RdrName])
+       checkTyClHdr,         -- LHsContext RdrName -> LHsType RdrName -> P (LHsContext RdrName, Located RdrName, [LHsTyVarBndr RdrName], [LHsType RdrName])
+       checkTyVars,          -- [LHsType RdrName] -> P ()
+       checkSynHdr,          -- LHsType RdrName -> P (Located RdrName, [LHsTyVarBndr RdrName], [LHsType RdrName])
+       checkKindSigs,        -- [LTyClDecl RdrName] -> P ()
        checkInstType,        -- HsType -> P HsType
        checkPattern,         -- HsExp -> P HsPat
        checkPatterns,        -- SrcLoc -> [HsExp] -> P [HsPat]
@@ -68,6 +70,7 @@ import FastString
 import Panic
 
 import List            ( isSuffixOf, nubBy )
+import Monad           ( unless )
 \end{code}
 
 
@@ -151,16 +154,17 @@ Similarly for mkConDecl, mkClassOpSig and default-method names.
        *** See "THE NAMING STORY" in HsDecls ****
   
 \begin{code}
-mkClassDecl (cxt, cname, tyvars) fds sigs mbinds
+mkClassDecl (cxt, cname, tyvars) fds sigs mbinds ats
   = ClassDecl { tcdCtxt = cxt, tcdLName = cname, tcdTyVars = tyvars,
                tcdFDs = fds,  
                tcdSigs = sigs,
-               tcdMeths = mbinds
+               tcdMeths = mbinds,
+               tcdATs   = ats
                }
 
-mkTyData new_or_data (context, tname, tyvars) ksig data_cons maybe_deriv
+mkTyData new_or_data (context, tname, tyvars, typats) ksig data_cons maybe_deriv
   = TyData { tcdND = new_or_data, tcdCtxt = context, tcdLName = tname,
-            tcdTyVars = tyvars,  tcdCons = data_cons, 
+            tcdTyVars = tyvars, tcdTyPats = typats, tcdCons = data_cons, 
             tcdKindSig = ksig, tcdDerivs = maybe_deriv }
 \end{code}
 
@@ -198,23 +202,29 @@ cvTopDecls decls = go (fromOL decls)
                            where (L l' b', ds') = getMonoBind (L l b) ds
     go (d : ds)            = d : go ds
 
+-- Declaration list may only contain value bindings and signatures
+--
 cvBindGroup :: OrdList (LHsDecl RdrName) -> HsValBinds RdrName
 cvBindGroup binding
-  = case (cvBindsAndSigs binding) of { (mbs, sigs) ->
-    ValBindsIn mbs sigs
-    }
+  = case cvBindsAndSigs binding of
+      (mbs, sigs, []) ->                 -- list of type decls *always* empty
+        ValBindsIn mbs sigs
 
 cvBindsAndSigs :: OrdList (LHsDecl RdrName)
-  -> (Bag (LHsBind RdrName), [LSig RdrName])
+  -> (Bag (LHsBind RdrName), [LSig RdrName], [LTyClDecl RdrName])
 -- Input decls contain just value bindings and signatures
+-- and in case of class or instance declarations also
+-- associated type declarations
 cvBindsAndSigs  fb = go (fromOL fb)
   where
-    go []                 = (emptyBag, [])
-    go (L l (SigD s) : ds) = (bs, L l s : ss)
-                           where (bs,ss) = go ds
-    go (L l (ValD b) : ds) = (b' `consBag` bs, ss)
-                           where (b',ds') = getMonoBind (L l b) ds
-                                 (bs,ss)  = go ds'
+    go []                 = (emptyBag, [], [])
+    go (L l (SigD s) : ds) = (bs, L l s : ss, ts)
+                           where (bs, ss, ts) = go ds
+    go (L l (ValD b) : ds) = (b' `consBag` bs, ss, ts)
+                           where (b', ds')    = getMonoBind (L l b) ds
+                                 (bs, ss, ts) = go ds'
+    go (L l (TyClD t): ds) = (bs, ss, L l t : ts)
+                           where (bs, ss, ts) = go ds
 
 -----------------------------------------------------------------------------
 -- Group function bindings into equation groups
@@ -368,44 +378,76 @@ checkInstType (L l t)
        ty ->   do dict_ty <- checkDictTy (L l ty)
                   return (L l (HsForAllTy Implicit [] (noLoc []) dict_ty))
 
-checkTyVars :: [LHsType RdrName] -> P [LHsTyVarBndr RdrName]
-checkTyVars tvs 
-  = mapM chk tvs
+-- Check whether the given list of type parameters are all type variables
+-- (possibly with a kind signature).  If the second argument is `False',
+-- only type variables are allowed and we raise an error on encountering a
+-- non-variable; otherwise, we allow non-variable arguments and return the
+-- entire list of parameters.
+--
+checkTyVars :: [LHsType RdrName] -> P ()
+checkTyVars tparms = mapM_ chk tparms
   where
-       --  Check that the name space is correct!
+       -- Check that the name space is correct!
     chk (L l (HsKindSig (L _ (HsTyVar tv)) k))
-       | isRdrTyVar tv = return (L l (KindedTyVar tv k))
+       | isRdrTyVar tv    = return ()
     chk (L l (HsTyVar tv))
-        | isRdrTyVar tv = return (L l (UserTyVar tv))
-    chk (L l other)
-       = parseError l "Type found where type variable expected"
-
-checkSynHdr :: LHsType RdrName -> P (Located RdrName, [LHsTyVarBndr RdrName])
-checkSynHdr ty = do { (_, tc, tvs) <- checkTyClHdr (noLoc []) ty
-                   ; return (tc, tvs) }
-
+        | isRdrTyVar tv    = return ()
+    chk (L l other)        =
+         parseError l "Type found where type variable expected"
+
+-- Check whether the type arguments in a type synonym head are simply
+-- variables.  If not, we have a type equation of a type function and return
+-- all patterns.  If yes, we return 'Nothing' as the third component to
+-- indicate a vanilla type synonym.
+--
+checkSynHdr :: LHsType RdrName 
+           -> Bool                             -- is type instance?
+           -> P (Located RdrName,              -- head symbol
+                 [LHsTyVarBndr RdrName],       -- parameters
+                 [LHsType RdrName])            -- type patterns
+checkSynHdr ty isTyInst = 
+  do { (_, tc, tvs, tparms) <- checkTyClHdr (noLoc []) ty
+     ; unless isTyInst $ checkTyVars tparms
+     ; return (tc, tvs, tparms) }
+
+
+-- Well-formedness check and decomposition of type and class heads.
+--
 checkTyClHdr :: LHsContext RdrName -> LHsType RdrName
-  -> P (LHsContext RdrName, Located RdrName, [LHsTyVarBndr RdrName])
+  -> P (LHsContext RdrName,         -- the type context
+        Located RdrName,            -- the head symbol (type or class name)
+       [LHsTyVarBndr RdrName],      -- free variables of the non-context part
+       [LHsType RdrName])           -- parameters of head symbol
 -- The header of a type or class decl should look like
 --     (C a, D b) => T a b
 -- or  T a b
 -- or  a + b
 -- etc
+-- With associated types, we can also have non-variable parameters; ie,
+--      T Int [a]
+-- The unaltered parameter list is returned in the fourth component of the
+-- result.  Eg, for
+--      T Int [a]
+-- we return
+--      ('()', 'T', ['a'], ['Int', '[a]'])
 checkTyClHdr (L l cxt) ty
-  = do (tc, tvs) <- gol ty []
+  = do (tc, tvs, parms) <- gol ty []
        mapM_ chk_pred cxt
-       return (L l cxt, tc, tvs)
+       return (L l cxt, tc, tvs, parms)
   where
     gol (L l ty) acc = go l ty acc
 
     go l (HsTyVar tc)    acc 
-       | not (isRdrTyVar tc)   = checkTyVars acc               >>= \ tvs ->
-                                 return (L l tc, tvs)
-    go l (HsOpTy t1 tc t2) acc  = checkTyVars (t1:t2:acc)      >>= \ tvs ->
-                                 return (tc, tvs)
+       | not (isRdrTyVar tc)   = do
+                                   tvs <- extractTyVars acc
+                                   return (L l tc, tvs, acc)
+    go l (HsOpTy t1 tc t2) acc  = do
+                                   tvs <- extractTyVars (t1:t2:acc)
+                                   return (tc, tvs, acc)
     go l (HsParTy ty)    acc    = gol ty acc
     go l (HsAppTy t1 t2) acc    = gol t1 (t2:acc)
-    go l other          acc    = parseError l "Malformed LHS to type of class declaration"
+    go l other          acc    = 
+      parseError l "Malformed head of type or class declaration"
 
        -- The predicates in a type or class decl must all
        -- be HsClassPs.  They need not all be type variables,
@@ -414,7 +456,63 @@ checkTyClHdr (L l cxt) ty
     chk_pred (L l _)
        = parseError l "Malformed context in type or class declaration"
 
-  
+-- Extract the type variables of a list of type parameters.
+--
+-- * Type arguments can be complex type terms (needed for associated type
+--   declarations).
+--
+extractTyVars :: [LHsType RdrName] -> P [LHsTyVarBndr RdrName]
+extractTyVars tvs = collects [] tvs
+  where
+        -- Collect all variables (1st arg serves as an accumulator)
+    collect tvs (L l (HsForAllTy _ _ _ _)) =
+      parseError l "Forall type not allowed as type parameter"
+    collect tvs (L l (HsTyVar tv))
+      | isRdrTyVar tv                     = return $ L l (UserTyVar tv) : tvs
+      | otherwise                         = return tvs
+    collect tvs (L l (HsBangTy _ _      )) =
+      parseError l "Bang-style type annotations not allowed as type parameter"
+    collect tvs (L l (HsAppTy t1 t2     )) = do
+                                              tvs' <- collect tvs t2
+                                              collect tvs' t1
+    collect tvs (L l (HsFunTy t1 t2     )) = do
+                                              tvs' <- collect tvs t2
+                                              collect tvs' t1
+    collect tvs (L l (HsListTy t        )) = collect tvs t
+    collect tvs (L l (HsPArrTy t        )) = collect tvs t
+    collect tvs (L l (HsTupleTy _ ts    )) = collects tvs ts
+    collect tvs (L l (HsOpTy t1 _ t2    )) = do
+                                              tvs' <- collect tvs t2
+                                              collect tvs' t1
+    collect tvs (L l (HsParTy t         )) = collect tvs t
+    collect tvs (L l (HsNumTy t         )) = return tvs
+    collect tvs (L l (HsPredTy t        )) = 
+      parseError l "Predicate not allowed as type parameter"
+    collect tvs (L l (HsKindSig (L _ (HsTyVar tv)) k))
+       | isRdrTyVar tv                    = 
+         return $ L l (KindedTyVar tv k) : tvs
+       | otherwise                        =
+         parseError l "Kind signature only allowed for type variables"
+    collect tvs (L l (HsSpliceTy t      )) = 
+      parseError l "Splice not allowed as type parameter"
+
+        -- Collect all variables of a list of types
+    collects tvs []     = return tvs
+    collects tvs (t:ts) = do
+                           tvs' <- collects tvs ts
+                           collect tvs' t
+
+-- Check that associated type declarations of a class are all kind signatures.
+--
+checkKindSigs :: [LTyClDecl RdrName] -> P ()
+checkKindSigs = mapM_ check
+  where
+    check (L l tydecl) 
+      | isKindSigDecl tydecl
+        || isSynDecl tydecl  = return ()
+      | otherwise           = 
+       parseError l "Type declaration in a class must be a kind signature or synonym default"
+
 checkContext :: LHsType RdrName -> P (LHsContext RdrName)
 checkContext (L l t)
   = check t
@@ -622,7 +720,7 @@ makeFunBind :: Located id -> Bool -> [LMatch id] -> HsBind id
 -- Like HsUtils.mkFunBind, but we need to be able to set the fixity too
 makeFunBind fn is_infix ms 
   = FunBind { fun_id = fn, fun_infix = is_infix, fun_matches = mkMatchGroup ms,
-             fun_co_fn = idCoercion, bind_fvs = placeHolderNames }
+             fun_co_fn = idHsWrapper, bind_fvs = placeHolderNames }
 
 checkPatBind lhs (L _ grhss)
   = do { lhs <- checkPattern lhs