[project @ 2000-09-22 15:56:12 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcMatches.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[TcMatches]{Typecheck some @Matches@}
5
6 \begin{code}
7 module TcMatches ( tcMatchesFun, tcMatchesCase, tcMatchLambda, tcStmts, tcGRHSs ) where
8
9 #include "HsVersions.h"
10
11 import {-# SOURCE #-}   TcExpr( tcExpr )
12
13 import HsSyn            ( HsBinds(..), Match(..), GRHSs(..), GRHS(..),
14                           MonoBinds(..), StmtCtxt(..), Stmt(..),
15                           pprMatch, getMatchLoc, consLetStmt,
16                           mkMonoBind, collectSigTysFromPats
17                         )
18 import RnHsSyn          ( RenamedMatch, RenamedGRHSs, RenamedStmt )
19 import TcHsSyn          ( TcMatch, TcGRHSs, TcStmt )
20
21 import TcMonad
22 import TcMonoType       ( kcHsSigType, kcTyVarScope, newSigTyVars, checkSigTyVars, tcHsSigType, sigPatCtxt )
23 import Inst             ( LIE, plusLIE, emptyLIE, plusLIEs )
24 import TcEnv            ( tcExtendTyVarEnv, tcExtendLocalValEnv, tcExtendGlobalTyVars )
25 import TcPat            ( tcPat, tcPatBndr_NoSigs, polyPatSig )
26 import TcType           ( TcType, newTyVarTy )
27 import TcBinds          ( tcBindsAndThen )
28 import TcSimplify       ( tcSimplifyAndCheck, bindInstsOfLocalFuns )
29 import TcUnify          ( unifyFunTy, unifyTauTy )
30 import Name             ( Name )
31 import TysWiredIn       ( boolTy )
32
33 import BasicTypes       ( RecFlag(..) )
34 import Type             ( tyVarsOfType, isTauTy, mkFunTy, boxedTypeKind, openTypeKind )
35 import VarSet
36 import Var              ( Id )
37 import Bag
38 import Outputable
39 import List             ( nub )
40 \end{code}
41
42 %************************************************************************
43 %*                                                                      *
44 \subsection{tcMatchesFun, tcMatchesCase}
45 %*                                                                      *
46 %************************************************************************
47
48 @tcMatchesFun@ typechecks a @[Match]@ list which occurs in a
49 @FunMonoBind@.  The second argument is the name of the function, which
50 is used in error messages.  It checks that all the equations have the
51 same number of arguments before using @tcMatches@ to do the work.
52
53 \begin{code}
54 tcMatchesFun :: [(Name,Id)]     -- Bindings for the variables bound in this group
55              -> Name
56              -> TcType          -- Expected type
57              -> [RenamedMatch]
58              -> TcM s ([TcMatch], LIE)
59
60 tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
61   =      -- Check that they all have the same no of arguments
62          -- Set the location to that of the first equation, so that
63          -- any inter-equation error messages get some vaguely
64          -- sensible location.  Note: we have to do this odd
65          -- ann-grabbing, because we don't always have annotations in
66          -- hand when we call tcMatchesFun...
67     tcAddSrcLoc (getMatchLoc first_match)        (
68             checkTc (sameNoOfArgs matches)
69                     (varyingArgsErr fun_name matches)
70     )                                            `thenTc_`
71
72         -- ToDo: Don't use "expected" stuff if there ain't a type signature
73         -- because inconsistency between branches
74         -- may show up as something wrong with the (non-existent) type signature
75
76         -- No need to zonk expected_ty, because unifyFunTy does that on the fly
77     tcMatches xve matches expected_ty (FunRhs fun_name)
78 \end{code}
79
80 @tcMatchesCase@ doesn't do the argument-count check because the
81 parser guarantees that each equation has exactly one argument.
82
83 \begin{code}
84 tcMatchesCase :: [RenamedMatch]         -- The case alternatives
85               -> TcType                 -- Type of whole case expressions
86               -> TcM s (TcType,         -- Inferred type of the scrutinee
87                         [TcMatch],      -- Translated alternatives
88                         LIE)
89
90 tcMatchesCase matches expr_ty
91   = newTyVarTy openTypeKind                                     `thenNF_Tc` \ scrut_ty ->
92     tcMatches [] matches (mkFunTy scrut_ty expr_ty) CaseAlt     `thenTc` \ (matches', lie) ->
93     returnTc (scrut_ty, matches', lie)
94
95 tcMatchLambda :: RenamedMatch -> TcType -> TcM s (TcMatch, LIE)
96 tcMatchLambda match res_ty = tcMatch [] match res_ty LambdaBody
97 \end{code}
98
99
100 \begin{code}
101 tcMatches :: [(Name,Id)]
102           -> [RenamedMatch]
103           -> TcType
104           -> StmtCtxt
105           -> TcM s ([TcMatch], LIE)
106
107 tcMatches xve matches expected_ty fun_or_case
108   = mapAndUnzipTc tc_match matches      `thenTc` \ (matches, lies) ->
109     returnTc (matches, plusLIEs lies)
110   where
111     tc_match match = tcMatch xve match expected_ty fun_or_case
112 \end{code}
113
114
115 %************************************************************************
116 %*                                                                      *
117 \subsection{tcMatch}
118 %*                                                                      *
119 %************************************************************************
120
121 \begin{code}
122 tcMatch :: [(Name,Id)]
123         -> RenamedMatch
124         -> TcType               -- Expected result-type of the Match.
125                                 -- Early unification with this guy gives better error messages
126         -> StmtCtxt
127         -> TcM s (TcMatch, LIE)
128
129 tcMatch xve1 match@(Match sig_tvs pats maybe_rhs_sig grhss) expected_ty ctxt
130   = tcAddSrcLoc (getMatchLoc match)             $
131     tcAddErrCtxt (matchCtxt ctxt match)         $
132
133     if null sig_tvs then        -- The common case
134         tc_match expected_ty    `thenTc` \ (_, match_and_lie) ->
135         returnTc match_and_lie
136
137     else
138         -- If there are sig tvs we must be careful *not* to use
139         -- expected_ty right away, else we'll unify with tyvars free
140         -- in the envt.  So invent a fresh tyvar and use that instead
141         newTyVarTy openTypeKind         `thenNF_Tc` \ tyvar_ty ->
142
143         -- Extend the tyvar env and check the match itself
144         kcTyVarScope sig_tvs (mapTc_ kcHsSigType sig_tys)       `thenTc` \ sig_tv_kinds ->
145         newSigTyVars sig_tv_kinds                               `thenNF_Tc` \ sig_tyvars ->
146         tcExtendTyVarEnv sig_tyvars (tc_match tyvar_ty)         `thenTc` \ (pat_ids, match_and_lie) ->
147
148         -- Check that the scoped type variables from the patterns
149         -- have not been constrained
150         tcAddErrCtxtM (sigPatCtxt sig_tyvars pat_ids)           (
151                 checkSigTyVars sig_tyvars emptyVarSet
152         )                                                       `thenTc_`
153
154         -- *Now* we're free to unify with expected_ty
155         unifyTauTy expected_ty tyvar_ty `thenTc_`
156
157         returnTc match_and_lie
158
159   where
160     sig_tys = case maybe_rhs_sig of { Just t -> [t]; Nothing -> [] }
161               ++ collectSigTysFromPats pats
162               
163     tc_match expected_ty        -- Any sig tyvars are in scope by now
164       = -- STEP 1: Typecheck the patterns
165         tcMatchPats pats expected_ty    `thenTc` \ (rhs_ty, pats', lie_req1, ex_tvs, pat_bndrs, lie_avail) ->
166         let
167           xve2       = bagToList pat_bndrs
168           pat_ids    = map snd xve2
169           ex_tv_list = bagToList ex_tvs
170         in
171
172         -- STEP 2: Check that the remaining "expected type" is not a rank-2 type
173         -- If it is it'll mess up the unifier when checking the RHS
174         checkTc (isTauTy rhs_ty) lurkingRank2SigErr             `thenTc_`
175
176         -- STEP 3: Unify with the rhs type signature if any
177         (case maybe_rhs_sig of
178             Nothing  -> returnTc ()
179             Just sig -> tcHsSigType sig         `thenTc` \ sig_ty ->
180
181                         -- Check that the signature isn't a polymorphic one, which
182                         -- we don't permit (at present, anyway)
183                         checkTc (isTauTy sig_ty) (polyPatSig sig_ty)    `thenTc_`
184                         unifyTauTy rhs_ty sig_ty
185         )                                               `thenTc_`
186
187         -- STEP 4: Typecheck the guarded RHSs and the associated where clause
188         tcExtendLocalValEnv xve1 (tcExtendLocalValEnv xve2 (
189             tcGRHSs grhss rhs_ty ctxt
190         ))                                      `thenTc` \ (grhss', lie_req2) ->
191
192         -- STEP 5: Check for existentially bound type variables
193         tcExtendGlobalTyVars (tyVarsOfType rhs_ty)      (
194             tcAddErrCtxtM (sigPatCtxt ex_tv_list pat_ids)       $
195             checkSigTyVars ex_tv_list emptyVarSet               `thenTc` \ zonked_ex_tvs ->
196             tcSimplifyAndCheck 
197                 (text ("the existential context of a data constructor"))
198                 (mkVarSet zonked_ex_tvs)
199                 lie_avail (lie_req1 `plusLIE` lie_req2)
200         )                                                       `thenTc` \ (lie_req', ex_binds) ->
201
202         -- STEP 6 In case there are any polymorpic, overloaded binders in the pattern
203         -- (which can happen in the case of rank-2 type signatures, or data constructors
204         -- with polymorphic arguments), we must do a bindInstsOfLocalFns here
205         bindInstsOfLocalFuns lie_req' pat_ids           `thenTc` \ (lie_req'', inst_binds) ->
206
207         -- Phew!  All done.
208         let
209             grhss'' = glue_on Recursive ex_binds $
210                       glue_on Recursive inst_binds grhss'
211         in
212         returnTc (pat_ids, (Match [] pats' Nothing grhss'', lie_req''))
213
214         -- glue_on just avoids stupid dross
215 glue_on _ EmptyMonoBinds grhss = grhss          -- The common case
216 glue_on is_rec mbinds (GRHSs grhss binds ty)
217   = GRHSs grhss (mkMonoBind mbinds [] is_rec `ThenBinds` binds) ty
218
219 tcGRHSs :: RenamedGRHSs
220         -> TcType -> StmtCtxt
221         -> TcM s (TcGRHSs, LIE)
222
223 tcGRHSs (GRHSs grhss binds _) expected_ty ctxt
224   = tcBindsAndThen glue_on binds (tc_grhss grhss)
225   where
226     tc_grhss grhss
227         = mapAndUnzipTc tc_grhs grhss           `thenTc` \ (grhss', lies) ->
228           returnTc (GRHSs grhss' EmptyBinds (Just expected_ty), plusLIEs lies)
229
230     tc_grhs (GRHS guarded locn)
231         = tcAddSrcLoc locn                              $
232           tcStmts ctxt (\ty -> ty) guarded expected_ty  `thenTc` \ (guarded', lie) ->
233           returnTc (GRHS guarded' locn, lie)
234 \end{code}
235
236
237 %************************************************************************
238 %*                                                                      *
239 \subsection{tcMatchPats}
240 %*                                                                      *
241 %************************************************************************
242
243 \begin{code}
244 tcMatchPats [] expected_ty
245   = returnTc (expected_ty, [], emptyLIE, emptyBag, emptyBag, emptyLIE)
246
247 tcMatchPats (pat:pats) expected_ty
248   = unifyFunTy expected_ty              `thenTc` \ (arg_ty, rest_ty) ->
249     tcPat tcPatBndr_NoSigs pat arg_ty   `thenTc` \ (pat', lie_req, pat_tvs, pat_ids, lie_avail) ->
250     tcMatchPats pats rest_ty            `thenTc` \ (rhs_ty, pats', lie_reqs, pats_tvs, pats_ids, lie_avails) ->
251     returnTc (  rhs_ty, 
252                 pat':pats',
253                 lie_req `plusLIE` lie_reqs,
254                 pat_tvs `unionBags` pats_tvs,
255                 pat_ids `unionBags` pats_ids,
256                 lie_avail `plusLIE` lie_avails
257     )
258 \end{code}
259
260
261 %************************************************************************
262 %*                                                                      *
263 \subsection{tcStmts}
264 %*                                                                      *
265 %************************************************************************
266
267
268 \begin{code}
269 tcStmts :: StmtCtxt
270         -> (TcType -> TcType)   -- m, the relationship type of pat and rhs in pat <- rhs
271         -> [RenamedStmt]
272         -> TcType                       -- elt_ty, where type of the comprehension is (m elt_ty)
273         -> TcM s ([TcStmt], LIE)
274
275 tcStmts do_or_lc m (stmt@(ReturnStmt exp) : stmts) elt_ty
276   = ASSERT( null stmts )
277     tcSetErrCtxt (stmtCtxt do_or_lc stmt)       $
278     tcExpr exp elt_ty                           `thenTc`    \ (exp', exp_lie) ->
279     returnTc ([ReturnStmt exp'], exp_lie)
280
281         -- ExprStmt at the end
282 tcStmts do_or_lc m [stmt@(ExprStmt exp src_loc)] elt_ty
283   = tcSetErrCtxt (stmtCtxt do_or_lc stmt)       $
284     tcExpr exp (m elt_ty)                       `thenTc`    \ (exp', exp_lie) ->
285     returnTc ([ExprStmt exp' src_loc], exp_lie)
286
287         -- ExprStmt not at the end
288 tcStmts do_or_lc m (stmt@(ExprStmt exp src_loc) : stmts) elt_ty
289   = ASSERT( isDoStmt do_or_lc )
290     tcAddSrcLoc src_loc                 (
291         tcSetErrCtxt (stmtCtxt do_or_lc stmt)   $
292             -- exp has type (m tau) for some tau (doesn't matter what)
293         newTyVarTy openTypeKind         `thenNF_Tc` \ any_ty ->
294         tcExpr exp (m any_ty)
295     )                                   `thenTc` \ (exp', exp_lie) ->
296     tcStmts do_or_lc m stmts elt_ty     `thenTc` \ (stmts', stmts_lie) ->
297     returnTc (ExprStmt exp' src_loc : stmts',
298               exp_lie `plusLIE` stmts_lie)
299
300 tcStmts do_or_lc m (stmt@(GuardStmt exp src_loc) : stmts) elt_ty
301   = ASSERT( not (isDoStmt do_or_lc) )
302     tcSetErrCtxt (stmtCtxt do_or_lc stmt) (
303         tcAddSrcLoc src_loc             $
304         tcExpr exp boolTy
305     )                                   `thenTc` \ (exp', exp_lie) ->
306     tcStmts do_or_lc m stmts elt_ty     `thenTc` \ (stmts', stmts_lie) ->
307     returnTc (GuardStmt exp' src_loc : stmts',
308               exp_lie `plusLIE` stmts_lie)
309
310 tcStmts do_or_lc m (stmt@(BindStmt pat exp src_loc) : stmts) elt_ty
311   = tcAddSrcLoc src_loc         (
312         tcSetErrCtxt (stmtCtxt do_or_lc stmt)   $
313         newTyVarTy boxedTypeKind                `thenNF_Tc` \ pat_ty ->
314         tcPat tcPatBndr_NoSigs pat pat_ty       `thenTc` \ (pat', pat_lie, pat_tvs, pat_ids, avail) ->  
315         tcExpr exp (m pat_ty)                   `thenTc` \ (exp', exp_lie) ->
316         returnTc (pat', exp',
317                   pat_lie `plusLIE` exp_lie,
318                   pat_tvs, pat_ids, avail)
319     )                                   `thenTc` \ (pat', exp', lie_req, pat_tvs, pat_bndrs, lie_avail) ->
320     let
321         new_val_env = bagToList pat_bndrs
322         pat_ids     = map snd new_val_env
323         pat_tv_list = bagToList pat_tvs
324     in
325
326         -- Do the rest; we don't need to add the pat_tvs to the envt
327         -- because they all appear in the pat_ids's types
328     tcExtendLocalValEnv new_val_env (
329        tcStmts do_or_lc m stmts elt_ty
330     )                                           `thenTc` \ (stmts', stmts_lie) ->
331
332
333         -- Reinstate context for existential checks
334     tcSetErrCtxt (stmtCtxt do_or_lc stmt)               $
335     tcExtendGlobalTyVars (tyVarsOfType (m elt_ty))      $
336     tcAddErrCtxtM (sigPatCtxt pat_tv_list pat_ids)      $
337
338     checkSigTyVars pat_tv_list emptyVarSet              `thenTc` \ zonked_pat_tvs ->
339
340     tcSimplifyAndCheck 
341         (text ("the existential context of a data constructor"))
342         (mkVarSet zonked_pat_tvs)
343         lie_avail stmts_lie                     `thenTc` \ (final_lie, dict_binds) ->
344
345     returnTc (BindStmt pat' exp' src_loc : 
346                 consLetStmt (mkMonoBind dict_binds [] Recursive) stmts',
347               lie_req `plusLIE` final_lie)
348
349 tcStmts do_or_lc m (LetStmt binds : stmts) elt_ty
350      = tcBindsAndThen           -- No error context, but a binding group is
351         combine                 -- rather a large thing for an error context anyway
352         binds
353         (tcStmts do_or_lc m stmts elt_ty)
354      where
355         combine is_rec binds' stmts' = consLetStmt (mkMonoBind binds' [] is_rec) stmts'
356
357
358 isDoStmt DoStmt = True
359 isDoStmt other  = False
360 \end{code}
361
362
363 %************************************************************************
364 %*                                                                      *
365 \subsection{Errors and contexts}
366 %*                                                                      *
367 %************************************************************************
368
369 @sameNoOfArgs@ takes a @[RenamedMatch]@ and decides whether the same
370 number of args are used in each equation.
371
372 \begin{code}
373 sameNoOfArgs :: [RenamedMatch] -> Bool
374 sameNoOfArgs matches = length (nub (map args_in_match matches)) == 1
375   where
376     args_in_match :: RenamedMatch -> Int
377     args_in_match (Match _ pats _ _) = length pats
378 \end{code}
379
380 \begin{code}
381 matchCtxt CaseAlt match
382   = hang (ptext SLIT("In a case alternative:"))
383          4 (pprMatch (True,empty) {-is_case-} match)
384
385 matchCtxt (FunRhs fun) match
386   = hang (hcat [ptext SLIT("In an equation for function "), quotes (ppr_fun), char ':'])
387          4 (pprMatch (False, ppr_fun) {-not case-} match)
388   where
389     ppr_fun = ppr fun
390
391 matchCtxt LambdaBody match
392   = hang (ptext SLIT("In the lambda expression"))
393          4 (pprMatch (True, empty) match)
394
395 varyingArgsErr name matches
396   = sep [ptext SLIT("Varying number of arguments for function"), quotes (ppr name)]
397
398 lurkingRank2SigErr
399   = ptext SLIT("Too few explicit arguments when defining a function with a rank-2 type")
400
401 stmtCtxt do_or_lc stmt
402   = hang (ptext SLIT("In") <+> what <> colon)
403          4 (ppr stmt)
404   where
405     what = case do_or_lc of
406                 ListComp -> ptext SLIT("a list-comprehension qualifier")
407                 DoStmt   -> ptext SLIT("a do statement")
408                 PatBindRhs -> thing <+> ptext SLIT("a pattern binding")
409                 FunRhs f   -> thing <+> ptext SLIT("an equation for") <+> quotes (ppr f)
410                 CaseAlt    -> thing <+> ptext SLIT("a case alternative")
411                 LambdaBody -> thing <+> ptext SLIT("a lambda abstraction")
412     thing = case stmt of
413                 BindStmt _ _ _ -> ptext SLIT("a pattern guard for")
414                 GuardStmt _ _  -> ptext SLIT("a guard for")
415                 ExprStmt _ _   -> ptext SLIT("the right-hand side of")
416 \end{code}