bc19d69715bdf7070579e4a07bd3187a51c4d236
[ghc-hetmet.git] / compiler / typecheck / TcArrows.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5 Typecheck arrow notation
6
7 \begin{code}
8 module TcArrows ( tcProc ) where
9
10 import {-# SOURCE #-}   TcExpr( tcMonoExpr, tcInferRho )
11
12 import HsSyn
13 import TcHsSyn
14
15 import TcMatches
16
17 import TcType
18 import TcMType
19 import TcBinds
20 import TcSimplify
21 import TcPat
22 import TcUnify
23 import TcRnMonad
24 import Coercion
25 import Inst
26 import Name
27 import TysWiredIn
28 import VarSet 
29 import TysPrim
30 import Type
31
32 import SrcLoc
33 import Outputable
34 import FastString
35 import Util
36
37 import Control.Monad
38 \end{code}
39
40 %************************************************************************
41 %*                                                                      *
42                 Proc    
43 %*                                                                      *
44 %************************************************************************
45
46 \begin{code}
47 tcProc :: InPat Name -> LHsCmdTop Name          -- proc pat -> expr
48        -> BoxyRhoType                           -- Expected type of whole proc expression
49        -> TcM (OutPat TcId, LHsCmdTop TcId, CoercionI)
50
51 tcProc pat cmd exp_ty
52   = newArrowScope $
53     do  { ((exp_ty1, res_ty), coi) <- boxySplitAppTy exp_ty 
54         ; ((arr_ty, arg_ty), coi1) <- boxySplitAppTy exp_ty1
55         ; let cmd_env = CmdEnv { cmd_arr = arr_ty }
56         ; (pat', cmd') <- tcProcPat pat arg_ty res_ty $
57                           tcCmdTop cmd_env cmd []
58         ; let res_coi = mkTransCoI coi (mkAppTyCoI exp_ty1 coi1 res_ty IdCo)
59         ; return (pat', cmd', res_coi) 
60         }
61 \end{code}
62
63
64 %************************************************************************
65 %*                                                                      *
66                 Commands
67 %*                                                                      *
68 %************************************************************************
69
70 \begin{code}
71 type CmdStack = [TcTauType]
72 data CmdEnv
73   = CmdEnv {
74         cmd_arr         :: TcType -- arrow type constructor, of kind *->*->*
75     }
76
77 mkCmdArrTy :: CmdEnv -> TcTauType -> TcTauType -> TcTauType
78 mkCmdArrTy env t1 t2 = mkAppTys (cmd_arr env) [t1, t2]
79
80 ---------------------------------------
81 tcCmdTop :: CmdEnv 
82          -> LHsCmdTop Name
83          -> CmdStack
84          -> TcTauType   -- Expected result type; always a monotype
85                              -- We know exactly how many cmd args are expected,
86                              -- albeit perhaps not their types; so we can pass 
87                              -- in a CmdStack
88         -> TcM (LHsCmdTop TcId)
89
90 tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk res_ty
91   = setSrcSpan loc $
92     do  { cmd'   <- tcGuardedCmd env cmd cmd_stk res_ty
93         ; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
94         ; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }
95
96
97 ----------------------------------------
98 tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
99              -> TcTauType -> TcM (LHsExpr TcId)
100 -- A wrapper that deals with the refinement (if any)
101 tcGuardedCmd env expr stk res_ty
102   = do  { body <- tcCmd env expr (stk, res_ty)
103         ; return body 
104         }
105
106 tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
107         -- The main recursive function
108 tcCmd env (L loc expr) res_ty
109   = setSrcSpan loc $ do
110         { expr' <- tc_cmd env expr res_ty
111         ; return (L loc expr') }
112
113 tc_cmd :: CmdEnv -> HsExpr Name -> (CmdStack, TcTauType) -> TcM (HsExpr TcId)
114 tc_cmd env (HsPar cmd) res_ty
115   = do  { cmd' <- tcCmd env cmd res_ty
116         ; return (HsPar cmd') }
117
118 tc_cmd env (HsLet binds (L body_loc body)) res_ty
119   = do  { (binds', body') <- tcLocalBinds binds         $
120                              setSrcSpan body_loc        $
121                              tc_cmd env body res_ty
122         ; return (HsLet binds' (L body_loc body')) }
123
124 tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
125   = addErrCtxt (cmdCtxt in_cmd) $ do
126       (scrut', scrut_ty) <- addErrCtxt (caseScrutCtxt scrut) $
127                               tcInferRho scrut 
128       matches' <- tcMatchesCase match_ctxt scrut_ty matches res_ty
129       return (HsCase scrut' matches')
130   where
131     match_ctxt = MC { mc_what = CaseAlt,
132                       mc_body = mc_body }
133     mc_body body res_ty' = tcGuardedCmd env body stk res_ty'
134
135 tc_cmd env (HsIf pred b1 b2) res_ty
136   = do  { pred' <- tcMonoExpr pred boolTy
137         ; b1'   <- tcCmd env b1 res_ty
138         ; b2'   <- tcCmd env b2 res_ty
139         ; return (HsIf pred' b1' b2')
140     }
141
142 -------------------------------------------
143 --              Arrow application
144 --          (f -< a)   or   (f -<< a)
145
146 tc_cmd env cmd@(HsArrApp fun arg _ ho_app lr) (cmd_stk, res_ty)
147   = addErrCtxt (cmdCtxt cmd)    $
148     do  { arg_ty <- newFlexiTyVarTy openTypeKind
149         ; let fun_ty = mkCmdArrTy env (foldl mkPairTy arg_ty cmd_stk) res_ty
150
151         ; fun' <- select_arrow_scope (tcMonoExpr fun fun_ty)
152
153         ; arg' <- tcMonoExpr arg arg_ty
154
155         ; return (HsArrApp fun' arg' fun_ty ho_app lr) }
156   where
157         -- Before type-checking f, use the environment of the enclosing
158         -- proc for the (-<) case.  
159         -- Local bindings, inside the enclosing proc, are not in scope 
160         -- inside f.  In the higher-order case (-<<), they are.
161     select_arrow_scope tc = case ho_app of
162         HsHigherOrderApp -> tc
163         HsFirstOrderApp  -> escapeArrowScope tc
164
165 -------------------------------------------
166 --              Command application
167
168 tc_cmd env cmd@(HsApp fun arg) (cmd_stk, res_ty)
169   = addErrCtxt (cmdCtxt cmd)    $
170     do  { arg_ty <- newFlexiTyVarTy openTypeKind
171
172         ; fun' <- tcCmd env fun (arg_ty:cmd_stk, res_ty)
173
174         ; arg' <- tcMonoExpr arg arg_ty
175
176         ; return (HsApp fun' arg') }
177
178 -------------------------------------------
179 --              Lambda
180
181 tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] _))
182        (cmd_stk, res_ty)
183   = addErrCtxt (matchCtxt match_ctxt match)     $
184
185     do  {       -- Check the cmd stack is big enough
186         ; checkTc (lengthAtLeast cmd_stk n_pats)
187                   (kappaUnderflow cmd)
188
189                 -- Check the patterns, and the GRHSs inside
190         ; (pats', grhss') <- setSrcSpan mtch_loc                $
191                              tcLamPats pats cmd_stk res_ty      $
192                              tc_grhss grhss
193
194         ; let match' = L mtch_loc (Match pats' Nothing grhss')
195         ; return (HsLam (MatchGroup [match'] res_ty))
196         }
197
198   where
199     n_pats     = length pats
200     stk'       = drop n_pats cmd_stk
201     match_ctxt = (LambdaExpr :: HsMatchContext Name)    -- Maybe KappaExpr?
202     pg_ctxt    = PatGuard match_ctxt
203
204     tc_grhss (GRHSs grhss binds) res_ty
205         = do { (binds', grhss') <- tcLocalBinds binds $
206                                    mapM (wrapLocM (tc_grhs res_ty)) grhss
207              ; return (GRHSs grhss' binds') }
208
209     tc_grhs res_ty (GRHS guards body)
210         = do { (guards', rhs') <- tcStmts pg_ctxt tcGuardStmt guards res_ty $
211                                   tcGuardedCmd env body stk'
212              ; return (GRHS guards' rhs') }
213
214 -------------------------------------------
215 --              Do notation
216
217 tc_cmd env cmd@(HsDo do_or_lc stmts body _ty) (cmd_stk, res_ty)
218   = do  { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
219         ; (stmts', body') <- tcStmts do_or_lc tc_stmt stmts res_ty $
220                              tcGuardedCmd env body []
221         ; return (HsDo do_or_lc stmts' body' res_ty) }
222   where
223     tc_stmt = tcMDoStmt tc_rhs
224     tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
225                     ; rhs' <- tcCmd env rhs ([], ty)
226                     ; return (rhs', ty) }
227
228
229 -----------------------------------------------------------------
230 --      Arrow ``forms''       (| e c1 .. cn |)
231 --
232 --      G      |-b  c : [s1 .. sm] s
233 --      pop(G) |-   e : forall w. b ((w,s1) .. sm) s
234 --                              -> a ((w,t1) .. tn) t
235 --      e \not\in (s, s1..sm, t, t1..tn)
236 --      ----------------------------------------------
237 --      G |-a  (| e c |)  :  [t1 .. tn] t
238
239 tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)       
240   = addErrCtxt (cmdCtxt cmd)    $
241     do  { cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
242         ; [w_tv]     <- tcInstSkolTyVars ArrowSkol [alphaTyVar]
243         ; let w_ty = mkTyVarTy w_tv     -- Just a convenient starting point
244
245                 --  a ((w,t1) .. tn) t
246         ; let e_res_ty = mkCmdArrTy env (foldl mkPairTy w_ty cmd_stk) res_ty
247
248                 --   b ((w,s1) .. sm) s
249                 --   -> a ((w,t1) .. tn) t
250         ; let e_ty = mkFunTys [mkAppTys b [tup,s] | (_,_,b,tup,s) <- cmds_w_tys] 
251                               e_res_ty
252
253                 -- Check expr
254         ; (expr', lie) <- escapeArrowScope (getLIE (tcMonoExpr expr e_ty))
255         ; loc <- getInstLoc (SigOrigin ArrowSkol)
256         ; inst_binds <- tcSimplifyCheck loc [w_tv] [] lie
257
258                 -- Check that the polymorphic variable hasn't been unified with anything
259                 -- and is not free in res_ty or the cmd_stk  (i.e.  t, t1..tn)
260         ; checkSigTyVarsWrt (tyVarsOfTypes (res_ty:cmd_stk)) [w_tv] 
261
262                 -- OK, now we are in a position to unscramble 
263                 -- the s1..sm and check each cmd
264         ; cmds' <- mapM (tc_cmd w_tv) cmds_w_tys
265
266         ; return (HsArrForm (noLoc $ HsWrap (WpTyLam w_tv) 
267                                                (unLoc $ mkHsDictLet inst_binds expr')) 
268                              fixity cmds')
269         }
270   where
271         -- Make the types       
272         --      b, ((e,s1) .. sm), s
273     new_cmd_ty :: LHsCmdTop Name -> Int
274                -> TcM (LHsCmdTop Name, Int, TcType, TcType, TcType)
275     new_cmd_ty cmd i
276           = do  { b_ty   <- newFlexiTyVarTy arrowTyConKind
277                 ; tup_ty <- newFlexiTyVarTy liftedTypeKind
278                         -- We actually make a type variable for the tuple
279                         -- because we don't know how deeply nested it is yet    
280                 ; s_ty   <- newFlexiTyVarTy liftedTypeKind
281                 ; return (cmd, i, b_ty, tup_ty, s_ty)
282                 }
283
284     tc_cmd w_tv (cmd, i, b, tup_ty, s)
285       = do { tup_ty' <- zonkTcType tup_ty
286            ; let (corner_ty, arg_tys) = unscramble tup_ty'
287
288                 -- Check that it has the right shape:
289                 --      ((w,s1) .. sn)
290                 -- where the si do not mention w
291            ; checkTc (corner_ty `tcEqType` mkTyVarTy w_tv && 
292                       not (w_tv `elemVarSet` tyVarsOfTypes arg_tys))
293                      (badFormFun i tup_ty')
294
295            ; tcCmdTop (env { cmd_arr = b }) cmd arg_tys s }
296
297     unscramble :: TcType -> (TcType, [TcType])
298     -- unscramble ((w,s1) .. sn)        =  (w, [s1..sn])
299     unscramble ty
300        = case tcSplitTyConApp_maybe ty of
301             Just (tc, [t,s]) | tc == pairTyCon 
302                ->  let 
303                       (w,ss) = unscramble t  
304                    in (w, s:ss)
305                                     
306             _ -> (ty, [])
307
308 -----------------------------------------------------------------
309 --              Base case for illegal commands
310 -- This is where expressions that aren't commands get rejected
311
312 tc_cmd _ cmd _
313   = failWithTc (vcat [ptext (sLit "The expression"), nest 2 (ppr cmd), 
314                       ptext (sLit "was found where an arrow command was expected")])
315 \end{code}
316
317
318 %************************************************************************
319 %*                                                                      *
320                 Helpers
321 %*                                                                      *
322 %************************************************************************
323
324
325 \begin{code}
326 mkPairTy :: Type -> Type -> Type
327 mkPairTy t1 t2 = mkTyConApp pairTyCon [t1,t2]
328
329 arrowTyConKind :: Kind          --  *->*->*
330 arrowTyConKind = mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind
331 \end{code}
332
333
334 %************************************************************************
335 %*                                                                      *
336                 Errors
337 %*                                                                      *
338 %************************************************************************
339
340 \begin{code}
341 cmdCtxt :: HsExpr Name -> SDoc
342 cmdCtxt cmd = ptext (sLit "In the command:") <+> ppr cmd
343
344 caseScrutCtxt :: LHsExpr Name -> SDoc
345 caseScrutCtxt cmd
346   = hang (ptext (sLit "In the scrutinee of a case command:")) 4 (ppr cmd)
347
348 nonEmptyCmdStkErr :: HsExpr Name -> SDoc
349 nonEmptyCmdStkErr cmd
350   = hang (ptext (sLit "Non-empty command stack at command:"))
351          4 (ppr cmd)
352
353 kappaUnderflow :: HsExpr Name -> SDoc
354 kappaUnderflow cmd
355   = hang (ptext (sLit "Command stack underflow at command:"))
356          4 (ppr cmd)
357
358 badFormFun :: Int -> TcType -> SDoc
359 badFormFun i tup_ty'
360  = hang (ptext (sLit "The type of the") <+> speakNth i <+> ptext (sLit "argument of a command form has the wrong shape"))
361         4 (ptext (sLit "Argument type:") <+> ppr tup_ty')
362 \end{code}