More hacking on monad-comp; now works
[ghc-hetmet.git] / compiler / typecheck / TcArrows.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5 Typecheck arrow notation
6
7 \begin{code}
8 module TcArrows ( tcProc ) where
9
10 import {-# SOURCE #-}   TcExpr( tcMonoExpr, tcInferRho, tcSyntaxOp )
11
12 import HsSyn
13 import TcMatches
14 import TcType
15 import TcMType
16 import TcBinds
17 import TcPat
18 import TcUnify
19 import TcRnMonad
20 import Coercion
21 import Inst
22 import Name
23 import TysWiredIn
24 import VarSet 
25 import TysPrim
26
27 import SrcLoc
28 import Outputable
29 import FastString
30 import Util
31
32 import Control.Monad
33 \end{code}
34
35 %************************************************************************
36 %*                                                                      *
37                 Proc    
38 %*                                                                      *
39 %************************************************************************
40
41 \begin{code}
42 tcProc :: InPat Name -> LHsCmdTop Name          -- proc pat -> expr
43        -> TcRhoType                             -- Expected type of whole proc expression
44        -> TcM (OutPat TcId, LHsCmdTop TcId, CoercionI)
45
46 tcProc pat cmd exp_ty
47   = newArrowScope $
48     do  { (coi, (exp_ty1, res_ty)) <- matchExpectedAppTy exp_ty 
49         ; (coi1, (arr_ty, arg_ty)) <- matchExpectedAppTy exp_ty1
50         ; let cmd_env = CmdEnv { cmd_arr = arr_ty }
51         ; (pat', cmd') <- tcPat ProcExpr pat arg_ty $
52                           tcCmdTop cmd_env cmd [] res_ty
53         ; let res_coi = mkTransCoI coi (mkAppTyCoI coi1 (IdCo res_ty))
54         ; return (pat', cmd', res_coi) }
55 \end{code}
56
57
58 %************************************************************************
59 %*                                                                      *
60                 Commands
61 %*                                                                      *
62 %************************************************************************
63
64 \begin{code}
65 type CmdStack = [TcTauType]
66 data CmdEnv
67   = CmdEnv {
68         cmd_arr         :: TcType -- arrow type constructor, of kind *->*->*
69     }
70
71 mkCmdArrTy :: CmdEnv -> TcTauType -> TcTauType -> TcTauType
72 mkCmdArrTy env t1 t2 = mkAppTys (cmd_arr env) [t1, t2]
73
74 ---------------------------------------
75 tcCmdTop :: CmdEnv 
76          -> LHsCmdTop Name
77          -> CmdStack
78          -> TcTauType   -- Expected result type; always a monotype
79                              -- We know exactly how many cmd args are expected,
80                              -- albeit perhaps not their types; so we can pass 
81                              -- in a CmdStack
82         -> TcM (LHsCmdTop TcId)
83
84 tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk res_ty
85   = setSrcSpan loc $
86     do  { cmd'   <- tcGuardedCmd env cmd cmd_stk res_ty
87         ; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
88         ; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }
89
90
91 ----------------------------------------
92 tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
93              -> TcTauType -> TcM (LHsExpr TcId)
94 -- A wrapper that deals with the refinement (if any)
95 tcGuardedCmd env expr stk res_ty
96   = do  { body <- tcCmd env expr (stk, res_ty)
97         ; return body 
98         }
99
100 tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
101         -- The main recursive function
102 tcCmd env (L loc expr) res_ty
103   = setSrcSpan loc $ do
104         { expr' <- tc_cmd env expr res_ty
105         ; return (L loc expr') }
106
107 tc_cmd :: CmdEnv -> HsExpr Name -> (CmdStack, TcTauType) -> TcM (HsExpr TcId)
108 tc_cmd env (HsPar cmd) res_ty
109   = do  { cmd' <- tcCmd env cmd res_ty
110         ; return (HsPar cmd') }
111
112 tc_cmd env (HsLet binds (L body_loc body)) res_ty
113   = do  { (binds', body') <- tcLocalBinds binds         $
114                              setSrcSpan body_loc        $
115                              tc_cmd env body res_ty
116         ; return (HsLet binds' (L body_loc body')) }
117
118 tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
119   = addErrCtxt (cmdCtxt in_cmd) $ do
120       (scrut', scrut_ty) <- tcInferRho scrut 
121       matches' <- tcMatchesCase match_ctxt scrut_ty matches res_ty
122       return (HsCase scrut' matches')
123   where
124     match_ctxt = MC { mc_what = CaseAlt,
125                       mc_body = mc_body }
126     mc_body body res_ty' = tcGuardedCmd env body stk res_ty'
127
128 tc_cmd env (HsIf mb_fun pred b1 b2) (stack_ty,res_ty)
129   = do  { pred_ty <- newFlexiTyVarTy openTypeKind
130         ; b_ty <- newFlexiTyVarTy openTypeKind
131         ; let if_ty = mkFunTys [pred_ty, b_ty, b_ty] res_ty
132         ; mb_fun' <- case mb_fun of 
133               Nothing  -> return Nothing
134               Just fun -> liftM Just (tcSyntaxOp IfOrigin fun if_ty)
135         ; pred' <- tcMonoExpr pred pred_ty
136         ; b1'   <- tcCmd env b1 (stack_ty,b_ty)
137         ; b2'   <- tcCmd env b2 (stack_ty,b_ty)
138         ; return (HsIf mb_fun' pred' b1' b2')
139     }
140
141 -------------------------------------------
142 --              Arrow application
143 --          (f -< a)   or   (f -<< a)
144
145 tc_cmd env cmd@(HsArrApp fun arg _ ho_app lr) (cmd_stk, res_ty)
146   = addErrCtxt (cmdCtxt cmd)    $
147     do  { arg_ty <- newFlexiTyVarTy openTypeKind
148         ; let fun_ty = mkCmdArrTy env (foldl mkPairTy arg_ty cmd_stk) res_ty
149
150         ; fun' <- select_arrow_scope (tcMonoExpr fun fun_ty)
151
152         ; arg' <- tcMonoExpr arg arg_ty
153
154         ; return (HsArrApp fun' arg' fun_ty ho_app lr) }
155   where
156         -- Before type-checking f, use the environment of the enclosing
157         -- proc for the (-<) case.  
158         -- Local bindings, inside the enclosing proc, are not in scope 
159         -- inside f.  In the higher-order case (-<<), they are.
160     select_arrow_scope tc = case ho_app of
161         HsHigherOrderApp -> tc
162         HsFirstOrderApp  -> escapeArrowScope tc
163
164 -------------------------------------------
165 --              Command application
166
167 tc_cmd env cmd@(HsApp fun arg) (cmd_stk, res_ty)
168   = addErrCtxt (cmdCtxt cmd)    $
169     do  { arg_ty <- newFlexiTyVarTy openTypeKind
170
171         ; fun' <- tcCmd env fun (arg_ty:cmd_stk, res_ty)
172
173         ; arg' <- tcMonoExpr arg arg_ty
174
175         ; return (HsApp fun' arg') }
176
177 -------------------------------------------
178 --              Lambda
179
180 tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] _))
181        (cmd_stk, res_ty)
182   = addErrCtxt (pprMatchInCtxt match_ctxt match)        $
183
184     do  {       -- Check the cmd stack is big enough
185         ; checkTc (lengthAtLeast cmd_stk n_pats)
186                   (kappaUnderflow cmd)
187
188                 -- Check the patterns, and the GRHSs inside
189         ; (pats', grhss') <- setSrcSpan mtch_loc                $
190                              tcPats LambdaExpr pats cmd_stk     $
191                              tc_grhss grhss res_ty
192
193         ; let match' = L mtch_loc (Match pats' Nothing grhss')
194         ; return (HsLam (MatchGroup [match'] res_ty))
195         }
196
197   where
198     n_pats     = length pats
199     stk'       = drop n_pats cmd_stk
200     match_ctxt = (LambdaExpr :: HsMatchContext Name)    -- Maybe KappaExpr?
201     pg_ctxt    = PatGuard match_ctxt
202
203     tc_grhss (GRHSs grhss binds) res_ty
204         = do { (binds', grhss') <- tcLocalBinds binds $
205                                    mapM (wrapLocM (tc_grhs res_ty)) grhss
206              ; return (GRHSs grhss' binds') }
207
208     tc_grhs res_ty (GRHS guards body)
209         = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $
210                                   tcGuardedCmd env body stk'
211              ; return (GRHS guards' rhs') }
212
213 -------------------------------------------
214 --              Do notation
215
216 tc_cmd env cmd@(HsDo do_or_lc stmts _) (cmd_stk, res_ty)
217   = do  { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
218         ; stmts' <- tcStmts do_or_lc (tcMDoStmt tc_rhs) stmts res_ty 
219         ; return (HsDo do_or_lc stmts' res_ty) }
220   where
221     tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
222                     ; rhs' <- tcCmd env rhs ([], ty)
223                     ; return (rhs', ty) }
224
225
226 -----------------------------------------------------------------
227 --      Arrow ``forms''       (| e c1 .. cn |)
228 --
229 --      G      |-b  c : [s1 .. sm] s
230 --      pop(G) |-   e : forall w. b ((w,s1) .. sm) s
231 --                              -> a ((w,t1) .. tn) t
232 --      e \not\in (s, s1..sm, t, t1..tn)
233 --      ----------------------------------------------
234 --      G |-a  (| e c |)  :  [t1 .. tn] t
235
236 tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)       
237   = addErrCtxt (cmdCtxt cmd)    $
238     do  { cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
239         ; [w_tv]     <- tcInstSkolTyVars [alphaTyVar]
240         ; let w_ty = mkTyVarTy w_tv     -- Just a convenient starting point
241
242                 --  a ((w,t1) .. tn) t
243         ; let e_res_ty = mkCmdArrTy env (foldl mkPairTy w_ty cmd_stk) res_ty
244
245                 --   b ((w,s1) .. sm) s
246                 --   -> a ((w,t1) .. tn) t
247         ; let e_ty = mkFunTys [mkAppTys b [tup,s] | (_,_,b,tup,s) <- cmds_w_tys] 
248                               e_res_ty
249
250                 -- Check expr
251         ; (inst_binds, expr') <- checkConstraints ArrowSkol [w_tv] [] $
252                                  escapeArrowScope (tcMonoExpr expr e_ty)
253
254                 -- OK, now we are in a position to unscramble 
255                 -- the s1..sm and check each cmd
256         ; cmds' <- mapM (tc_cmd w_tv) cmds_w_tys
257
258         ; let wrap = WpTyLam w_tv <.> mkWpLet inst_binds
259         ; return (HsArrForm (mkLHsWrap wrap expr') fixity cmds') }
260   where
261         -- Make the types       
262         --      b, ((e,s1) .. sm), s
263     new_cmd_ty :: LHsCmdTop Name -> Int
264                -> TcM (LHsCmdTop Name, Int, TcType, TcType, TcType)
265     new_cmd_ty cmd i
266           = do  { b_ty   <- newFlexiTyVarTy arrowTyConKind
267                 ; tup_ty <- newFlexiTyVarTy liftedTypeKind
268                         -- We actually make a type variable for the tuple
269                         -- because we don't know how deeply nested it is yet    
270                 ; s_ty   <- newFlexiTyVarTy liftedTypeKind
271                 ; return (cmd, i, b_ty, tup_ty, s_ty)
272                 }
273
274     tc_cmd w_tv (cmd, i, b, tup_ty, s)
275       = do { tup_ty' <- zonkTcType tup_ty
276            ; let (corner_ty, arg_tys) = unscramble tup_ty'
277
278                 -- Check that it has the right shape:
279                 --      ((w,s1) .. sn)
280                 -- where the si do not mention w
281            ; checkTc (corner_ty `tcEqType` mkTyVarTy w_tv && 
282                       not (w_tv `elemVarSet` tyVarsOfTypes arg_tys))
283                      (badFormFun i tup_ty')
284
285            ; tcCmdTop (env { cmd_arr = b }) cmd arg_tys s }
286
287     unscramble :: TcType -> (TcType, [TcType])
288     -- unscramble ((w,s1) .. sn)        =  (w, [s1..sn])
289     unscramble ty = unscramble' ty []
290
291     unscramble' ty ss
292        = case tcSplitTyConApp_maybe ty of
293             Just (tc, [t,s]) | tc == pairTyCon 
294                ->  unscramble' t (s:ss)
295             _ -> (ty, ss)
296
297 -----------------------------------------------------------------
298 --              Base case for illegal commands
299 -- This is where expressions that aren't commands get rejected
300
301 tc_cmd _ cmd _
302   = failWithTc (vcat [ptext (sLit "The expression"), nest 2 (ppr cmd), 
303                       ptext (sLit "was found where an arrow command was expected")])
304 \end{code}
305
306
307 %************************************************************************
308 %*                                                                      *
309                 Helpers
310 %*                                                                      *
311 %************************************************************************
312
313
314 \begin{code}
315 mkPairTy :: Type -> Type -> Type
316 mkPairTy t1 t2 = mkTyConApp pairTyCon [t1,t2]
317
318 arrowTyConKind :: Kind          --  *->*->*
319 arrowTyConKind = mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind
320 \end{code}
321
322
323 %************************************************************************
324 %*                                                                      *
325                 Errors
326 %*                                                                      *
327 %************************************************************************
328
329 \begin{code}
330 cmdCtxt :: HsExpr Name -> SDoc
331 cmdCtxt cmd = ptext (sLit "In the command:") <+> ppr cmd
332
333 nonEmptyCmdStkErr :: HsExpr Name -> SDoc
334 nonEmptyCmdStkErr cmd
335   = hang (ptext (sLit "Non-empty command stack at command:"))
336        2 (ppr cmd)
337
338 kappaUnderflow :: HsExpr Name -> SDoc
339 kappaUnderflow cmd
340   = hang (ptext (sLit "Command stack underflow at command:"))
341        2 (ppr cmd)
342
343 badFormFun :: Int -> TcType -> SDoc
344 badFormFun i tup_ty'
345  = hang (ptext (sLit "The type of the") <+> speakNth i <+> ptext (sLit "argument of a command form has the wrong shape"))
346       2 (ptext (sLit "Argument type:") <+> ppr tup_ty')
347 \end{code}