Optimise desugaring of parallel array comprehensions
[ghc-hetmet.git] / compiler / deSugar / DsListComp.lhs
1 %
2 % (c) The University of Glasgow 2006
3 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 %
5
6 Desugaring list comprehensions and array comprehensions
7
8 \begin{code}
9 {-# OPTIONS -w #-}
10 -- The above warning supression flag is a temporary kludge.
11 -- While working on this module you are encouraged to remove it and fix
12 -- any warnings in the module. See
13 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
14 -- for details
15
16 module DsListComp ( dsListComp, dsPArrComp ) where
17
18 #include "HsVersions.h"
19
20 import {-# SOURCE #-} DsExpr ( dsLExpr, dsLocalBinds )
21
22 import BasicTypes
23 import HsSyn
24 import TcHsSyn
25 import CoreSyn
26
27 import DsMonad          -- the monadery used in the desugarer
28 import DsUtils
29
30 import DynFlags
31 import CoreUtils
32 import Var
33 import Type
34 import TysPrim
35 import TysWiredIn
36 import Match
37 import PrelNames
38 import PrelInfo
39 import SrcLoc
40 import Panic
41 \end{code}
42
43 List comprehensions may be desugared in one of two ways: ``ordinary''
44 (as you would expect if you read SLPJ's book) and ``with foldr/build
45 turned on'' (if you read Gill {\em et al.}'s paper on the subject).
46
47 There will be at least one ``qualifier'' in the input.
48
49 \begin{code}
50 dsListComp :: [LStmt Id] 
51            -> LHsExpr Id
52            -> Type              -- Type of list elements
53            -> DsM CoreExpr
54 dsListComp lquals body elt_ty
55   = getDOptsDs  `thenDs` \dflags ->
56     let
57         quals = map unLoc lquals
58     in
59     if not (dopt Opt_RewriteRules dflags) || dopt Opt_IgnoreInterfacePragmas dflags
60         -- Either rules are switched off, or we are ignoring what there are;
61         -- Either way foldr/build won't happen, so use the more efficient
62         -- Wadler-style desugaring
63         || isParallelComp quals
64                 -- Foldr-style desugaring can't handle
65                 -- parallel list comprehensions
66         then deListComp quals body (mkNilExpr elt_ty)
67
68    else         -- Foldr/build should be enabled, so desugar 
69                 -- into foldrs and builds
70     newTyVarsDs [alphaTyVar]    `thenDs` \ [n_tyvar] ->
71     let
72         n_ty = mkTyVarTy n_tyvar
73         c_ty = mkFunTys [elt_ty, n_ty] n_ty
74     in
75     newSysLocalsDs [c_ty,n_ty]          `thenDs` \ [c, n] ->
76     dfListComp c n quals body           `thenDs` \ result ->
77     dsLookupGlobalId buildName  `thenDs` \ build_id ->
78     returnDs (Var build_id `App` Type elt_ty 
79                            `App` mkLams [n_tyvar, c, n] result)
80
81   where isParallelComp (ParStmt bndrstmtss : _) = True
82         isParallelComp _                        = False
83 \end{code}
84
85 %************************************************************************
86 %*                                                                      *
87 \subsection[DsListComp-ordinary]{Ordinary desugaring of list comprehensions}
88 %*                                                                      *
89 %************************************************************************
90
91 Just as in Phil's chapter~7 in SLPJ, using the rules for
92 optimally-compiled list comprehensions.  This is what Kevin followed
93 as well, and I quite happily do the same.  The TQ translation scheme
94 transforms a list of qualifiers (either boolean expressions or
95 generators) into a single expression which implements the list
96 comprehension.  Because we are generating 2nd-order polymorphic
97 lambda-calculus, calls to NIL and CONS must be applied to a type
98 argument, as well as their usual value arguments.
99 \begin{verbatim}
100 TE << [ e | qs ] >>  =  TQ << [ e | qs ] ++ Nil (typeOf e) >>
101
102 (Rule C)
103 TQ << [ e | ] ++ L >> = Cons (typeOf e) TE <<e>> TE <<L>>
104
105 (Rule B)
106 TQ << [ e | b , qs ] ++ L >> =
107     if TE << b >> then TQ << [ e | qs ] ++ L >> else TE << L >>
108
109 (Rule A')
110 TQ << [ e | p <- L1, qs ]  ++  L2 >> =
111   letrec
112     h = \ u1 ->
113           case u1 of
114             []        ->  TE << L2 >>
115             (u2 : u3) ->
116                   (( \ TE << p >> -> ( TQ << [e | qs]  ++  (h u3) >> )) u2)
117                     [] (h u3)
118   in
119     h ( TE << L1 >> )
120
121 "h", "u1", "u2", and "u3" are new variables.
122 \end{verbatim}
123
124 @deListComp@ is the TQ translation scheme.  Roughly speaking, @dsExpr@
125 is the TE translation scheme.  Note that we carry around the @L@ list
126 already desugared.  @dsListComp@ does the top TE rule mentioned above.
127
128 To the above, we add an additional rule to deal with parallel list
129 comprehensions.  The translation goes roughly as follows:
130      [ e | p1 <- e11, let v1 = e12, p2 <- e13
131          | q1 <- e21, let v2 = e22, q2 <- e23]
132      =>
133      [ e | ((x1, .., xn), (y1, ..., ym)) <-
134                zip [(x1,..,xn) | p1 <- e11, let v1 = e12, p2 <- e13]
135                    [(y1,..,ym) | q1 <- e21, let v2 = e22, q2 <- e23]]
136 where (x1, .., xn) are the variables bound in p1, v1, p2
137       (y1, .., ym) are the variables bound in q1, v2, q2
138
139 In the translation below, the ParStmt branch translates each parallel branch
140 into a sub-comprehension, and desugars each independently.  The resulting lists
141 are fed to a zip function, we create a binding for all the variables bound in all
142 the comprehensions, and then we hand things off the the desugarer for bindings.
143 The zip function is generated here a) because it's small, and b) because then we
144 don't have to deal with arbitrary limits on the number of zip functions in the
145 prelude, nor which library the zip function came from.
146 The introduced tuples are Boxed, but only because I couldn't get it to work
147 with the Unboxed variety.
148
149 \begin{code}
150 deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr
151
152 deListComp (ParStmt stmtss_w_bndrs : quals) body list
153   = mappM do_list_comp stmtss_w_bndrs   `thenDs` \ exps ->
154     mkZipBind qual_tys                  `thenDs` \ (zip_fn, zip_rhs) ->
155
156         -- Deal with [e | pat <- zip l1 .. ln] in example above
157     deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)) 
158                    quals body list
159
160   where 
161         bndrs_s = map snd stmtss_w_bndrs
162
163         -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
164         pat      = mkTuplePat pats
165         pats     = map mk_hs_tuple_pat bndrs_s
166
167         -- Types of (x1,..,xn), (y1,..,yn) etc
168         qual_tys = map mk_bndrs_tys bndrs_s
169
170         do_list_comp (stmts, bndrs)
171           = dsListComp stmts (mk_hs_tuple_expr bndrs)
172                        (mk_bndrs_tys bndrs)
173
174         mk_bndrs_tys bndrs = mkCoreTupTy (map idType bndrs)
175
176         -- Last: the one to return
177 deListComp [] body list         -- Figure 7.4, SLPJ, p 135, rule C above
178   = dsLExpr body                `thenDs` \ core_body ->
179     returnDs (mkConsExpr (exprType core_body) core_body list)
180
181         -- Non-last: must be a guard
182 deListComp (ExprStmt guard _ _ : quals) body list       -- rule B above
183   = dsLExpr guard               `thenDs` \ core_guard ->
184     deListComp quals body list  `thenDs` \ core_rest ->
185     returnDs (mkIfThenElse core_guard core_rest list)
186
187 -- [e | let B, qs] = let B in [e | qs]
188 deListComp (LetStmt binds : quals) body list
189   = deListComp quals body list  `thenDs` \ core_rest ->
190     dsLocalBinds binds core_rest
191
192 deListComp (BindStmt pat list1 _ _ : quals) body core_list2 -- rule A' above
193   = dsLExpr list1                   `thenDs` \ core_list1 ->
194     deBindComp pat core_list1 quals body core_list2
195 \end{code}
196
197
198 \begin{code}
199 deBindComp pat core_list1 quals body core_list2
200   = let
201         u3_ty@u1_ty = exprType core_list1       -- two names, same thing
202
203         -- u1_ty is a [alpha] type, and u2_ty = alpha
204         u2_ty = hsLPatType pat
205
206         res_ty = exprType core_list2
207         h_ty   = u1_ty `mkFunTy` res_ty
208     in
209     newSysLocalsDs [h_ty, u1_ty, u2_ty, u3_ty]  `thenDs` \ [h, u1, u2, u3] ->
210
211     -- the "fail" value ...
212     let
213         core_fail   = App (Var h) (Var u3)
214         letrec_body = App (Var h) core_list1
215     in
216     deListComp quals body core_fail             `thenDs` \ rest_expr ->
217     matchSimply (Var u2) (StmtCtxt ListComp) pat
218                 rest_expr core_fail             `thenDs` \ core_match ->
219     let
220         rhs = Lam u1 $
221               Case (Var u1) u1 res_ty
222                    [(DataAlt nilDataCon,  [],       core_list2),
223                     (DataAlt consDataCon, [u2, u3], core_match)]
224                         -- Increasing order of tag
225     in
226     returnDs (Let (Rec [(h, rhs)]) letrec_body)
227 \end{code}
228
229
230 \begin{code}
231 mkZipBind :: [Type] -> DsM (Id, CoreExpr)
232 -- mkZipBind [t1, t2] 
233 -- = (zip, \as1:[t1] as2:[t2] 
234 --         -> case as1 of 
235 --              [] -> []
236 --              (a1:as'1) -> case as2 of
237 --                              [] -> []
238 --                              (a2:as'2) -> (a2,a2) : zip as'1 as'2)]
239
240 mkZipBind elt_tys 
241   = mappM newSysLocalDs  list_tys       `thenDs` \ ass ->
242     mappM newSysLocalDs  elt_tys        `thenDs` \ as' ->
243     mappM newSysLocalDs  list_tys       `thenDs` \ as's ->
244     newSysLocalDs zip_fn_ty             `thenDs` \ zip_fn ->
245     let 
246         inner_rhs = mkConsExpr ret_elt_ty 
247                         (mkCoreTup (map Var as'))
248                         (mkVarApps (Var zip_fn) as's)
249         zip_body  = foldr mk_case inner_rhs (zip3 ass as' as's)
250     in
251     returnDs (zip_fn, mkLams ass zip_body)
252   where
253     list_tys    = map mkListTy elt_tys
254     ret_elt_ty  = mkCoreTupTy elt_tys
255     list_ret_ty = mkListTy ret_elt_ty
256     zip_fn_ty   = mkFunTys list_tys list_ret_ty
257
258     mk_case (as, a', as') rest
259           = Case (Var as) as list_ret_ty
260                   [(DataAlt nilDataCon,  [],        mkNilExpr ret_elt_ty),
261                    (DataAlt consDataCon, [a', as'], rest)]
262                         -- Increasing order of tag
263 -- Helper functions that makes an HsTuple only for non-1-sized tuples
264 mk_hs_tuple_expr :: [Id] -> LHsExpr Id
265 mk_hs_tuple_expr []   = nlHsVar unitDataConId
266 mk_hs_tuple_expr [id] = nlHsVar id
267 mk_hs_tuple_expr ids  = noLoc $ ExplicitTuple [ nlHsVar i | i <- ids ] Boxed
268
269 mk_hs_tuple_pat :: [Id] -> LPat Id
270 mk_hs_tuple_pat bs  = mkTuplePat (map nlVarPat bs)
271 \end{code}
272
273
274 %************************************************************************
275 %*                                                                      *
276 \subsection[DsListComp-foldr-build]{Foldr/Build desugaring of list comprehensions}
277 %*                                                                      *
278 %************************************************************************
279
280 @dfListComp@ are the rules used with foldr/build turned on:
281
282 \begin{verbatim}
283 TE[ e | ]            c n = c e n
284 TE[ e | b , q ]      c n = if b then TE[ e | q ] c n else n
285 TE[ e | p <- l , q ] c n = let 
286                                 f = \ x b -> case x of
287                                                   p -> TE[ e | q ] c b
288                                                   _ -> b
289                            in
290                            foldr f n l
291 \end{verbatim}
292
293 \begin{code}
294 dfListComp :: Id -> Id                  -- 'c' and 'n'
295            -> [Stmt Id]         -- the rest of the qual's
296            -> LHsExpr Id
297            -> DsM CoreExpr
298
299         -- Last: the one to return
300 dfListComp c_id n_id [] body
301   = dsLExpr body                `thenDs` \ core_body ->
302     returnDs (mkApps (Var c_id) [core_body, Var n_id])
303
304         -- Non-last: must be a guard
305 dfListComp c_id n_id (ExprStmt guard _ _  : quals) body
306   = dsLExpr guard                       `thenDs` \ core_guard ->
307     dfListComp c_id n_id quals body     `thenDs` \ core_rest ->
308     returnDs (mkIfThenElse core_guard core_rest (Var n_id))
309
310 dfListComp c_id n_id (LetStmt binds : quals) body
311   -- new in 1.3, local bindings
312   = dfListComp c_id n_id quals body     `thenDs` \ core_rest ->
313     dsLocalBinds binds core_rest
314
315 dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) body
316     -- evaluate the two lists
317   = dsLExpr list1                       `thenDs` \ core_list1 ->
318
319     -- find the required type
320     let x_ty   = hsLPatType pat
321         b_ty   = idType n_id
322     in
323
324     -- create some new local id's
325     newSysLocalsDs [b_ty,x_ty]                  `thenDs` \ [b,x] ->
326
327     -- build rest of the comprehesion
328     dfListComp c_id b quals body                `thenDs` \ core_rest ->
329
330     -- build the pattern match
331     matchSimply (Var x) (StmtCtxt ListComp)
332                 pat core_rest (Var b)           `thenDs` \ core_expr ->
333
334     -- now build the outermost foldr, and return
335     dsLookupGlobalId foldrName          `thenDs` \ foldr_id ->
336     returnDs (
337       Var foldr_id `App` Type x_ty 
338                    `App` Type b_ty
339                    `App` mkLams [x, b] core_expr
340                    `App` Var n_id
341                    `App` core_list1
342     )
343 \end{code}
344
345 %************************************************************************
346 %*                                                                      *
347 \subsection[DsPArrComp]{Desugaring of array comprehensions}
348 %*                                                                      *
349 %************************************************************************
350
351 \begin{code}
352
353 -- entry point for desugaring a parallel array comprehension
354 --
355 --   [:e | qss:] = <<[:e | qss:]>> () [:():]
356 --
357 dsPArrComp      :: [Stmt Id] 
358                 -> LHsExpr Id
359                 -> Type             -- Don't use; called with `undefined' below
360                 -> DsM CoreExpr
361 dsPArrComp [ParStmt qss] body _  =  -- parallel comprehension
362   dePArrParComp qss body
363 dsPArrComp qs            body _  =  -- no ParStmt in `qs'
364   dsLookupGlobalId singletonPName                         `thenDs` \sglP ->
365   let unitArray = mkApps (Var sglP) [Type unitTy, 
366                                      mkCoreTup []]
367   in
368   dePArrComp qs body (mkTuplePat []) unitArray
369
370
371
372 -- the work horse
373 --
374 dePArrComp :: [Stmt Id] 
375            -> LHsExpr Id
376            -> LPat Id           -- the current generator pattern
377            -> CoreExpr          -- the current generator expression
378            -> DsM CoreExpr
379 --
380 --  <<[:e' | :]>> pa ea = mapP (\pa -> e') ea
381 --
382 dePArrComp [] e' pa cea =
383   dsLookupGlobalId mapPName                               `thenDs` \mapP    ->
384   let ty = parrElemType cea
385   in
386   deLambda ty pa e'                                       `thenDs` \(clam, 
387                                                                      ty'e') ->
388   returnDs $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea]
389 --
390 --  <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea)
391 --
392 dePArrComp (ExprStmt b _ _ : qs) body pa cea =
393   dsLookupGlobalId filterPName                    `thenDs` \filterP  ->
394   let ty = parrElemType cea
395   in
396   deLambda ty pa b                                `thenDs` \(clam,_) ->
397   dePArrComp qs body pa (mkApps (Var filterP) [Type ty, clam, cea])
398
399 --
400 --  <<[:e' | p <- e, qs:]>> pa ea =
401 --    let ef = \pa -> e
402 --    in
403 --    <<[:e' | qs:]>> (pa, p) (crossMap ea ef)
404 --
405 -- if matching again p cannot fail, or else
406 --
407 --  <<[:e' | p <- e, qs:]>> pa ea = 
408 --    let ef = \pa -> filterP (\x -> case x of {p -> True; _ -> False}) e
409 --    in
410 --    <<[:e' | qs:]>> (pa, p) (crossMapP ea ef)
411 --
412 dePArrComp (BindStmt p e _ _ : qs) body pa cea =
413   dsLookupGlobalId filterPName                    `thenDs` \filterP    ->
414   dsLookupGlobalId crossMapPName                  `thenDs` \crossMapP  ->
415   dsLExpr e                                       `thenDs` \ce         ->
416   let ety'cea = parrElemType cea
417       ety'ce  = parrElemType ce
418       false   = Var falseDataConId
419       true    = Var trueDataConId
420   in
421   newSysLocalDs ety'ce                                    `thenDs` \v       ->
422   matchSimply (Var v) (StmtCtxt PArrComp) p true false    `thenDs` \pred    ->
423   let cef | isIrrefutableHsPat p = ce
424           | otherwise            = mkApps (Var filterP) [Type ety'ce, mkLams [v] pred, ce]
425   in
426   mkLambda ety'cea pa cef                                 `thenDs` \(clam, 
427                                                                      _    ) ->
428   let ety'cef = ety'ce              -- filter doesn't change the element type
429       pa'     = mkTuplePat [pa, p]
430   in
431   dePArrComp qs body pa' (mkApps (Var crossMapP) 
432                                  [Type ety'cea, Type ety'cef, cea, clam])
433 --
434 --  <<[:e' | let ds, qs:]>> pa ea = 
435 --    <<[:e' | qs:]>> (pa, (x_1, ..., x_n)) 
436 --                    (mapP (\v@pa -> let ds in (v, (x_1, ..., x_n))) ea)
437 --  where
438 --    {x_1, ..., x_n} = DV (ds)         -- Defined Variables
439 --
440 dePArrComp (LetStmt ds : qs) body pa cea =
441   dsLookupGlobalId mapPName                               `thenDs` \mapP    ->
442   let xs     = map unLoc (collectLocalBinders ds)
443       ty'cea = parrElemType cea
444   in
445   newSysLocalDs ty'cea                                    `thenDs` \v       ->
446   dsLocalBinds ds (mkCoreTup (map Var xs))                `thenDs` \clet    ->
447   newSysLocalDs (exprType clet)                           `thenDs` \let'v   ->
448   let projBody = mkDsLet (NonRec let'v clet) $ 
449                  mkCoreTup [Var v, Var let'v]
450       errTy    = exprType projBody
451       errMsg   = "DsListComp.dePArrComp: internal error!"
452   in
453   mkErrorAppDs pAT_ERROR_ID errTy errMsg                  `thenDs` \cerr    ->
454   matchSimply (Var v) (StmtCtxt PArrComp) pa projBody cerr`thenDs` \ccase   ->
455   let pa'    = mkTuplePat [pa, mkTuplePat (map nlVarPat xs)]
456       proj   = mkLams [v] ccase
457   in
458   dePArrComp qs body pa' (mkApps (Var mapP) 
459                                  [Type ty'cea, Type errTy, proj, cea])
460 --
461 -- The parser guarantees that parallel comprehensions can only appear as
462 -- singeltons qualifier lists, which we already special case in the caller.
463 -- So, encountering one here is a bug.
464 --
465 dePArrComp (ParStmt _ : _) _ _ _ = 
466   panic "DsListComp.dePArrComp: malformed comprehension AST"
467
468 --  <<[:e' | qs | qss:]>> pa ea = 
469 --    <<[:e' | qss:]>> (pa, (x_1, ..., x_n)) 
470 --                     (zipP ea <<[:(x_1, ..., x_n) | qs:]>>)
471 --    where
472 --      {x_1, ..., x_n} = DV (qs)
473 --
474 dePArrParComp qss body = 
475   deParStmt qss                                         `thenDs` \(pQss, 
476                                                                    ceQss) ->
477   dePArrComp [] body pQss ceQss
478   where
479     deParStmt []             =
480       -- empty parallel statement lists have no source representation
481       panic "DsListComp.dePArrComp: Empty parallel list comprehension"
482     deParStmt ((qs, xs):qss) =          -- first statement
483       let res_expr = mkExplicitTuple (map nlHsVar xs)
484       in
485       dsPArrComp (map unLoc qs) res_expr undefined        `thenDs` \cqs     ->
486       parStmts qss (mkTuplePat (map nlVarPat xs)) cqs
487     ---
488     parStmts []             pa cea = return (pa, cea)
489     parStmts ((qs, xs):qss) pa cea =    -- subsequent statements (zip'ed)
490       dsLookupGlobalId zipPName                           `thenDs` \zipP    ->
491       let pa'      = mkTuplePat [pa, mkTuplePat (map nlVarPat xs)]
492           ty'cea   = parrElemType cea
493           res_expr = mkExplicitTuple (map nlHsVar xs)
494       in
495       dsPArrComp (map unLoc qs) res_expr undefined        `thenDs` \cqs     ->
496       let ty'cqs = parrElemType cqs
497           cea'   = mkApps (Var zipP) [Type ty'cea, Type ty'cqs, cea, cqs]
498       in
499       parStmts qss pa' cea'
500
501 -- generate Core corresponding to `\p -> e'
502 --
503 deLambda :: Type                        -- type of the argument
504           -> LPat Id                    -- argument pattern
505           -> LHsExpr Id                 -- body
506           -> DsM (CoreExpr, Type)
507 deLambda ty p e =
508   dsLExpr e                                               `thenDs` \ce      ->
509   mkLambda ty p ce
510
511 -- generate Core for a lambda pattern match, where the body is already in Core
512 --
513 mkLambda :: Type                        -- type of the argument
514          -> LPat Id                     -- argument pattern
515          -> CoreExpr                    -- desugared body
516          -> DsM (CoreExpr, Type)
517 mkLambda ty p ce =
518   newSysLocalDs ty                                        `thenDs` \v       ->
519   let errMsg = "DsListComp.deLambda: internal error!"
520       ce'ty  = exprType ce
521   in
522   mkErrorAppDs pAT_ERROR_ID ce'ty errMsg                  `thenDs` \cerr    -> 
523   matchSimply (Var v) (StmtCtxt PArrComp) p ce cerr       `thenDs` \res     ->
524   returnDs (mkLams [v] res, ce'ty)
525
526 -- obtain the element type of the parallel array produced by the given Core
527 -- expression
528 --
529 parrElemType   :: CoreExpr -> Type
530 parrElemType e  = 
531   case splitTyConApp_maybe (exprType e) of
532     Just (tycon, [ty]) | tycon == parrTyCon -> ty
533     _                                                     -> panic
534       "DsListComp.parrElemType: not a parallel array type"
535
536 -- Smart constructor for source tuple patterns
537 --
538 mkTuplePat :: [LPat Id] -> LPat Id
539 mkTuplePat [lpat] = lpat
540 mkTuplePat lpats  = noLoc $ mkVanillaTuplePat lpats Boxed
541
542 -- Smart constructor for source tuple expressions
543 --
544 mkExplicitTuple :: [LHsExpr id] -> LHsExpr id
545 mkExplicitTuple [lexp] = lexp
546 mkExplicitTuple lexps  = noLoc $ ExplicitTuple lexps Boxed
547 \end{code}