Major refactoring of the type inference engine
[ghc-hetmet.git] / compiler / typecheck / TcArrows.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5 Typecheck arrow notation
6
7 \begin{code}
8 module TcArrows ( tcProc ) where
9
10 import {-# SOURCE #-}   TcExpr( tcMonoExpr, tcInferRho, tcSyntaxOp )
11
12 import HsSyn
13 import TcMatches
14 import TcType
15 import TcMType
16 import TcBinds
17 import TcPat
18 import TcUnify
19 import TcRnMonad
20 import Coercion
21 import Inst
22 import Name
23 import TysWiredIn
24 import VarSet 
25 import TysPrim
26
27 import SrcLoc
28 import Outputable
29 import FastString
30 import Util
31
32 import Control.Monad
33 \end{code}
34
35 %************************************************************************
36 %*                                                                      *
37                 Proc    
38 %*                                                                      *
39 %************************************************************************
40
41 \begin{code}
42 tcProc :: InPat Name -> LHsCmdTop Name          -- proc pat -> expr
43        -> TcRhoType                             -- Expected type of whole proc expression
44        -> TcM (OutPat TcId, LHsCmdTop TcId, CoercionI)
45
46 tcProc pat cmd exp_ty
47   = newArrowScope $
48     do  { (coi, (exp_ty1, res_ty)) <- matchExpectedAppTy exp_ty 
49         ; (coi1, (arr_ty, arg_ty)) <- matchExpectedAppTy exp_ty1
50         ; let cmd_env = CmdEnv { cmd_arr = arr_ty }
51         ; (pat', cmd') <- tcPat ProcExpr pat arg_ty $
52                           tcCmdTop cmd_env cmd [] res_ty
53         ; let res_coi = mkTransCoI coi (mkAppTyCoI coi1 (IdCo res_ty))
54         ; return (pat', cmd', res_coi) }
55 \end{code}
56
57
58 %************************************************************************
59 %*                                                                      *
60                 Commands
61 %*                                                                      *
62 %************************************************************************
63
64 \begin{code}
65 type CmdStack = [TcTauType]
66 data CmdEnv
67   = CmdEnv {
68         cmd_arr         :: TcType -- arrow type constructor, of kind *->*->*
69     }
70
71 mkCmdArrTy :: CmdEnv -> TcTauType -> TcTauType -> TcTauType
72 mkCmdArrTy env t1 t2 = mkAppTys (cmd_arr env) [t1, t2]
73
74 ---------------------------------------
75 tcCmdTop :: CmdEnv 
76          -> LHsCmdTop Name
77          -> CmdStack
78          -> TcTauType   -- Expected result type; always a monotype
79                              -- We know exactly how many cmd args are expected,
80                              -- albeit perhaps not their types; so we can pass 
81                              -- in a CmdStack
82         -> TcM (LHsCmdTop TcId)
83
84 tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk res_ty
85   = setSrcSpan loc $
86     do  { cmd'   <- tcGuardedCmd env cmd cmd_stk res_ty
87         ; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
88         ; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }
89
90
91 ----------------------------------------
92 tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
93              -> TcTauType -> TcM (LHsExpr TcId)
94 -- A wrapper that deals with the refinement (if any)
95 tcGuardedCmd env expr stk res_ty
96   = do  { body <- tcCmd env expr (stk, res_ty)
97         ; return body 
98         }
99
100 tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
101         -- The main recursive function
102 tcCmd env (L loc expr) res_ty
103   = setSrcSpan loc $ do
104         { expr' <- tc_cmd env expr res_ty
105         ; return (L loc expr') }
106
107 tc_cmd :: CmdEnv -> HsExpr Name -> (CmdStack, TcTauType) -> TcM (HsExpr TcId)
108 tc_cmd env (HsPar cmd) res_ty
109   = do  { cmd' <- tcCmd env cmd res_ty
110         ; return (HsPar cmd') }
111
112 tc_cmd env (HsLet binds (L body_loc body)) res_ty
113   = do  { (binds', body') <- tcLocalBinds binds         $
114                              setSrcSpan body_loc        $
115                              tc_cmd env body res_ty
116         ; return (HsLet binds' (L body_loc body')) }
117
118 tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
119   = addErrCtxt (cmdCtxt in_cmd) $ do
120       (scrut', scrut_ty) <- tcInferRho scrut 
121       matches' <- tcMatchesCase match_ctxt scrut_ty matches res_ty
122       return (HsCase scrut' matches')
123   where
124     match_ctxt = MC { mc_what = CaseAlt,
125                       mc_body = mc_body }
126     mc_body body res_ty' = tcGuardedCmd env body stk res_ty'
127
128 tc_cmd env (HsIf mb_fun pred b1 b2) (stack_ty,res_ty)
129   = do  { pred_ty <- newFlexiTyVarTy openTypeKind
130         ; b_ty <- newFlexiTyVarTy openTypeKind
131         ; let if_ty = mkFunTys [pred_ty, b_ty, b_ty] res_ty
132         ; mb_fun' <- case mb_fun of 
133               Nothing  -> return Nothing
134               Just fun -> liftM Just (tcSyntaxOp IfOrigin fun if_ty)
135         ; pred' <- tcMonoExpr pred pred_ty
136         ; b1'   <- tcCmd env b1 (stack_ty,b_ty)
137         ; b2'   <- tcCmd env b2 (stack_ty,b_ty)
138         ; return (HsIf mb_fun' pred' b1' b2')
139     }
140
141 -------------------------------------------
142 --              Arrow application
143 --          (f -< a)   or   (f -<< a)
144
145 tc_cmd env cmd@(HsArrApp fun arg _ ho_app lr) (cmd_stk, res_ty)
146   = addErrCtxt (cmdCtxt cmd)    $
147     do  { arg_ty <- newFlexiTyVarTy openTypeKind
148         ; let fun_ty = mkCmdArrTy env (foldl mkPairTy arg_ty cmd_stk) res_ty
149
150         ; fun' <- select_arrow_scope (tcMonoExpr fun fun_ty)
151
152         ; arg' <- tcMonoExpr arg arg_ty
153
154         ; return (HsArrApp fun' arg' fun_ty ho_app lr) }
155   where
156         -- Before type-checking f, use the environment of the enclosing
157         -- proc for the (-<) case.  
158         -- Local bindings, inside the enclosing proc, are not in scope 
159         -- inside f.  In the higher-order case (-<<), they are.
160     select_arrow_scope tc = case ho_app of
161         HsHigherOrderApp -> tc
162         HsFirstOrderApp  -> escapeArrowScope tc
163
164 -------------------------------------------
165 --              Command application
166
167 tc_cmd env cmd@(HsApp fun arg) (cmd_stk, res_ty)
168   = addErrCtxt (cmdCtxt cmd)    $
169     do  { arg_ty <- newFlexiTyVarTy openTypeKind
170
171         ; fun' <- tcCmd env fun (arg_ty:cmd_stk, res_ty)
172
173         ; arg' <- tcMonoExpr arg arg_ty
174
175         ; return (HsApp fun' arg') }
176
177 -------------------------------------------
178 --              Lambda
179
180 tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] _))
181        (cmd_stk, res_ty)
182   = addErrCtxt (pprMatchInCtxt match_ctxt match)        $
183
184     do  {       -- Check the cmd stack is big enough
185         ; checkTc (lengthAtLeast cmd_stk n_pats)
186                   (kappaUnderflow cmd)
187
188                 -- Check the patterns, and the GRHSs inside
189         ; (pats', grhss') <- setSrcSpan mtch_loc                $
190                              tcPats LambdaExpr pats cmd_stk     $
191                              tc_grhss grhss res_ty
192
193         ; let match' = L mtch_loc (Match pats' Nothing grhss')
194         ; return (HsLam (MatchGroup [match'] res_ty))
195         }
196
197   where
198     n_pats     = length pats
199     stk'       = drop n_pats cmd_stk
200     match_ctxt = (LambdaExpr :: HsMatchContext Name)    -- Maybe KappaExpr?
201     pg_ctxt    = PatGuard match_ctxt
202
203     tc_grhss (GRHSs grhss binds) res_ty
204         = do { (binds', grhss') <- tcLocalBinds binds $
205                                    mapM (wrapLocM (tc_grhs res_ty)) grhss
206              ; return (GRHSs grhss' binds') }
207
208     tc_grhs res_ty (GRHS guards body)
209         = do { (guards', rhs') <- tcStmts pg_ctxt tcGuardStmt guards res_ty $
210                                   tcGuardedCmd env body stk'
211              ; return (GRHS guards' rhs') }
212
213 -------------------------------------------
214 --              Do notation
215
216 tc_cmd env cmd@(HsDo do_or_lc stmts body _ty) (cmd_stk, res_ty)
217   = do  { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
218         ; (stmts', body') <- tcStmts do_or_lc (tcMDoStmt tc_rhs) stmts res_ty $
219                              tcGuardedCmd env body []
220         ; return (HsDo do_or_lc stmts' body' res_ty) }
221   where
222     tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
223                     ; rhs' <- tcCmd env rhs ([], ty)
224                     ; return (rhs', ty) }
225
226
227 -----------------------------------------------------------------
228 --      Arrow ``forms''       (| e c1 .. cn |)
229 --
230 --      G      |-b  c : [s1 .. sm] s
231 --      pop(G) |-   e : forall w. b ((w,s1) .. sm) s
232 --                              -> a ((w,t1) .. tn) t
233 --      e \not\in (s, s1..sm, t, t1..tn)
234 --      ----------------------------------------------
235 --      G |-a  (| e c |)  :  [t1 .. tn] t
236
237 tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)       
238   = addErrCtxt (cmdCtxt cmd)    $
239     do  { cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
240         ; [w_tv]     <- tcInstSkolTyVars [alphaTyVar]
241         ; let w_ty = mkTyVarTy w_tv     -- Just a convenient starting point
242
243                 --  a ((w,t1) .. tn) t
244         ; let e_res_ty = mkCmdArrTy env (foldl mkPairTy w_ty cmd_stk) res_ty
245
246                 --   b ((w,s1) .. sm) s
247                 --   -> a ((w,t1) .. tn) t
248         ; let e_ty = mkFunTys [mkAppTys b [tup,s] | (_,_,b,tup,s) <- cmds_w_tys] 
249                               e_res_ty
250
251                 -- Check expr
252         ; (inst_binds, expr') <- checkConstraints ArrowSkol [w_tv] [] $
253                                  escapeArrowScope (tcMonoExpr expr e_ty)
254
255                 -- OK, now we are in a position to unscramble 
256                 -- the s1..sm and check each cmd
257         ; cmds' <- mapM (tc_cmd w_tv) cmds_w_tys
258
259         ; let wrap = WpTyLam w_tv <.> mkWpLet inst_binds
260         ; return (HsArrForm (mkLHsWrap wrap expr') fixity cmds') }
261   where
262         -- Make the types       
263         --      b, ((e,s1) .. sm), s
264     new_cmd_ty :: LHsCmdTop Name -> Int
265                -> TcM (LHsCmdTop Name, Int, TcType, TcType, TcType)
266     new_cmd_ty cmd i
267           = do  { b_ty   <- newFlexiTyVarTy arrowTyConKind
268                 ; tup_ty <- newFlexiTyVarTy liftedTypeKind
269                         -- We actually make a type variable for the tuple
270                         -- because we don't know how deeply nested it is yet    
271                 ; s_ty   <- newFlexiTyVarTy liftedTypeKind
272                 ; return (cmd, i, b_ty, tup_ty, s_ty)
273                 }
274
275     tc_cmd w_tv (cmd, i, b, tup_ty, s)
276       = do { tup_ty' <- zonkTcType tup_ty
277            ; let (corner_ty, arg_tys) = unscramble tup_ty'
278
279                 -- Check that it has the right shape:
280                 --      ((w,s1) .. sn)
281                 -- where the si do not mention w
282            ; checkTc (corner_ty `tcEqType` mkTyVarTy w_tv && 
283                       not (w_tv `elemVarSet` tyVarsOfTypes arg_tys))
284                      (badFormFun i tup_ty')
285
286            ; tcCmdTop (env { cmd_arr = b }) cmd arg_tys s }
287
288     unscramble :: TcType -> (TcType, [TcType])
289     -- unscramble ((w,s1) .. sn)        =  (w, [s1..sn])
290     unscramble ty = unscramble' ty []
291
292     unscramble' ty ss
293        = case tcSplitTyConApp_maybe ty of
294             Just (tc, [t,s]) | tc == pairTyCon 
295                ->  unscramble' t (s:ss)
296             _ -> (ty, ss)
297
298 -----------------------------------------------------------------
299 --              Base case for illegal commands
300 -- This is where expressions that aren't commands get rejected
301
302 tc_cmd _ cmd _
303   = failWithTc (vcat [ptext (sLit "The expression"), nest 2 (ppr cmd), 
304                       ptext (sLit "was found where an arrow command was expected")])
305 \end{code}
306
307
308 %************************************************************************
309 %*                                                                      *
310                 Helpers
311 %*                                                                      *
312 %************************************************************************
313
314
315 \begin{code}
316 mkPairTy :: Type -> Type -> Type
317 mkPairTy t1 t2 = mkTyConApp pairTyCon [t1,t2]
318
319 arrowTyConKind :: Kind          --  *->*->*
320 arrowTyConKind = mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind
321 \end{code}
322
323
324 %************************************************************************
325 %*                                                                      *
326                 Errors
327 %*                                                                      *
328 %************************************************************************
329
330 \begin{code}
331 cmdCtxt :: HsExpr Name -> SDoc
332 cmdCtxt cmd = ptext (sLit "In the command:") <+> ppr cmd
333
334 nonEmptyCmdStkErr :: HsExpr Name -> SDoc
335 nonEmptyCmdStkErr cmd
336   = hang (ptext (sLit "Non-empty command stack at command:"))
337        2 (ppr cmd)
338
339 kappaUnderflow :: HsExpr Name -> SDoc
340 kappaUnderflow cmd
341   = hang (ptext (sLit "Command stack underflow at command:"))
342        2 (ppr cmd)
343
344 badFormFun :: Int -> TcType -> SDoc
345 badFormFun i tup_ty'
346  = hang (ptext (sLit "The type of the") <+> speakNth i <+> ptext (sLit "argument of a command form has the wrong shape"))
347       2 (ptext (sLit "Argument type:") <+> ppr tup_ty')
348 \end{code}