[project @ 1998-12-02 13:17:09 by simonm]
[ghc-hetmet.git] / ghc / interpreter / pmc.c
1 /* -*- mode: hugs-c; -*- */
2 /* --------------------------------------------------------------------------
3  * Pattern matching Compiler
4  *
5  * Copyright (c) The University of Nottingham and Yale University, 1994-1997.
6  * All rights reserved. See NOTICE for details and conditions of use etc...
7  * Hugs version 1.4, December 1997
8  *
9  * $RCSfile: pmc.c,v $
10  * $Revision: 1.2 $
11  * $Date: 1998/12/02 13:22:29 $
12  * ------------------------------------------------------------------------*/
13
14 #include "prelude.h"
15 #include "storage.h"
16 #include "connect.h"
17 #include "errors.h"
18 #include "link.h"
19
20 #include "desugar.h"
21 #include "pat.h"
22 #include "pmc.h"
23
24 /* --------------------------------------------------------------------------
25  * Eliminate pattern matching in function definitions -- pattern matching
26  * compiler:
27  *
28  * The original Gofer/Hugs pattern matching compiler was based on Wadler's
29  * algorithms described in `Implementation of functional programming
30  * languages'.  That should still provide a good starting point for anyone
31  * wanting to understand this part of the system.  However, the original
32  * algorithm has been generalized and restructured in order to implement
33  * new features added in Haskell 1.3.
34  *
35  * During the translation, in preparation for later stages of compilation,
36  * all local and bound variables are replaced by suitable offsets, and
37  * locally defined function symbols are given new names (which will
38  * eventually be their names when lifted to make top level definitions).
39  * ------------------------------------------------------------------------*/
40
41 /* --------------------------------------------------------------------------
42  * Local function prototypes:
43  * ------------------------------------------------------------------------*/
44
45 static Cell local pmcPair               Args((Int,List,Pair));
46 static Cell local pmcTriple             Args((Int,List,Triple));
47 static Cell local pmcVar                Args((List,Text));
48 static Void local pmcLetrec             Args((Int,List,Pair));
49 static Cell local pmcVarDef             Args((Int,List,List));
50 static Void local pmcFunDef             Args((Int,List,Triple));
51 static Cell local joinMas               Args((Int,List));
52 static Bool local canFail               Args((Cell));
53 static List local addConTable           Args((Cell,Cell,List));
54 static Void local advance               Args((Int,Int,Cell));
55 static Bool local emptyMatch            Args((Cell));
56 static Cell local maDiscr               Args((Cell));
57 static Bool local isNumDiscr            Args((Cell));
58 static Bool local eqNumDiscr            Args((Cell,Cell));
59 #if TREX
60 static Bool local isExtDiscr            Args((Cell));
61 static Bool local eqExtDiscr            Args((Cell,Cell));
62 #endif
63
64 /* --------------------------------------------------------------------------
65  * 
66  * ------------------------------------------------------------------------*/
67
68 Cell pmcTerm(co,sc,e)                  /* apply pattern matching compiler  */
69 Int  co;                               /* co = current offset              */
70 List sc;                               /* sc = scope                       */
71 Cell e;  {                             /* e  = expr to transform           */
72     switch (whatIs(e)) {
73         case GUARDED  : map2Over(pmcPair,co,sc,snd(e));
74                         break;
75
76         case LETREC   : pmcLetrec(co,sc,snd(e));
77                         break;
78
79         case VARIDCELL:
80         case VAROPCELL:
81         case DICTVAR  : return pmcVar(sc,textOf(e));
82
83         case COND     : return ap(COND,pmcTriple(co,sc,snd(e)));
84
85         case AP       : return pmcPair(co,sc,e);
86
87 #if NPLUSK
88         case ADDPAT   :
89 #endif
90 #if TREX
91         case EXT      :
92 #endif
93         case TUPLE    :
94         case NAME     :
95         case CHARCELL :
96         case INTCELL  :
97         case BIGCELL  :
98         case FLOATCELL:
99         case STRCELL  : break;
100
101         default       : internal("pmcTerm");
102                         break;
103     }
104     return e;
105 }
106
107 static Cell local pmcPair(co,sc,pr)    /* apply pattern matching compiler  */
108 Int  co;                               /* to a pair of exprs               */
109 List sc;
110 Pair pr; {
111     return pair(pmcTerm(co,sc,fst(pr)),
112                 pmcTerm(co,sc,snd(pr)));
113 }
114
115 static Cell local pmcTriple(co,sc,tr)  /* apply pattern matching compiler  */
116 Int    co;                             /* to a triple of exprs             */
117 List   sc;
118 Triple tr; {
119     return triple(pmcTerm(co,sc,fst3(tr)),
120                   pmcTerm(co,sc,snd3(tr)),
121                   pmcTerm(co,sc,thd3(tr)));
122 }
123
124 static Cell local pmcVar(sc,t)         /* find translation of variable     */
125 List sc;                               /* in current scope                 */
126 Text t; {
127     List xs;
128     Name n;
129
130     for (xs=sc; nonNull(xs); xs=tl(xs)) {
131         Cell x = hd(xs);
132         if (t==textOf(fst(x)))
133             if (isOffset(snd(x))) {                  /* local variable ... */
134                 return snd(x);
135             }
136             else {                                   /* local function ... */
137                 return fst3(snd(x));
138             }
139     }
140
141     n = findName(t);
142     assert(nonNull(n));
143     return n;
144 }
145
146 static Void local pmcLetrec(co,sc,e)   /* apply pattern matching compiler  */
147 Int  co;                               /* to LETREC, splitting decls into  */
148 List sc;                               /* two sections                     */
149 Pair e; {
150     List fs = NIL;                     /* local function definitions       */
151     List vs = NIL;                     /* local variable definitions       */
152     List ds;
153
154     for (ds=fst(e); nonNull(ds); ds=tl(ds)) {      /* Split decls into two */
155         Cell v     = fst(hd(ds));
156         Int  arity = length(fst(hd(snd(hd(ds)))));
157
158         if (arity==0) {                            /* Variable declaration */
159             vs = cons(snd(hd(ds)),vs);
160             sc = cons(pair(v,mkOffset(++co)),sc);
161         }
162         else {                                     /* Function declaration */
163             fs = cons(triple(inventVar(),mkInt(arity),snd(hd(ds))),fs);
164             sc = cons(pair(v,hd(fs)),sc);
165         }
166     }
167     vs       = rev(vs);                /* Put declaration lists back in    */
168     fs       = rev(fs);                /* original order                   */
169     fst(e)   = pair(vs,fs);            /* Store declaration lists          */
170     map2Over(pmcVarDef,co,sc,vs);      /* Translate variable definitions   */
171     map2Proc(pmcFunDef,co,sc,fs);      /* Translate function definitions   */
172     snd(e)   = pmcTerm(co,sc,snd(e));  /* Translate LETREC body            */
173 }
174
175 static Cell local pmcVarDef(co,sc,vd)  /* apply pattern matching compiler  */
176 Int  co;                               /* to variable definition           */
177 List sc;
178 List vd; {                             /* vd :: [ ([], rhs) ]              */
179     Cell d = snd(hd(vd));
180     if (nonNull(tl(vd)) && canFail(d))
181         return ap(FATBAR,pair(pmcTerm(co,sc,d),
182                               pmcVarDef(co,sc,tl(vd))));
183     return pmcTerm(co,sc,d);
184 }
185
186 static Void local pmcFunDef(co,sc,fd)  /* apply pattern matching compiler  */
187 Int    co;                             /* to function definition           */
188 List   sc;
189 Triple fd; {                           /* fd :: (Var, Arity, [Alt])        */
190     Int    arity         = intOf(snd3(fd));
191     Cell   temp          = altsMatch(co+1,arity,sc,thd3(fd));
192     Cell   xs;
193
194     temp      = match(co+arity,temp);
195     thd3(fd)  = triple(NIL,NIL,temp);  /* used to be freevar info */
196
197 }
198
199 /* ---------------------------------------------------------------------------
200  * Main part of pattern matching compiler: convert [Alt] to case constructs
201  *
202  * This section of Hugs has been almost completely rewritten to be more
203  * general, in particular, to allow pattern matching in orders other than the
204  * strictly left-to-right approach of the previous version.  This is needed
205  * for the implementation of the so-called Haskell 1.3 `record' syntax.
206  *
207  * At each stage, the different branches for the cases to be considered
208  * are represented by a list of values of type:
209  *   Match ::= { maPats :: [Pat],       patterns to match
210  *               maOffs :: [Offs],      offsets of corresponding values
211  *               maSc   :: Scope,       mapping from vars to offsets
212  *               maRhs  :: Rhs }        right hand side
213  * [Implementation uses nested pairs, ((pats,offs),(sc,rhs)).]
214  *
215  * The Scope component has type:
216  *   Scope  ::= [(Var,Expr)]
217  * and provides a mapping from variable names to offsets used in the matching
218  * process.
219  *
220  * Matches can be normalized by reducing them to a form in which the list
221  * of patterns is empty (in which case the match itself is described as an
222  * empty match), or in which the list is non-empty and the first pattern is
223  * one that requires either a CASE or NUMCASE (or EXTCASE) to decompose.  
224  * ------------------------------------------------------------------------*/
225
226 #define mkMatch(ps,os,sc,r)     pair(pair(ps,os),pair(sc,r))
227 #define maPats(ma)              fst(fst(ma))
228 #define maOffs(ma)              snd(fst(ma))
229 #define maSc(ma)                fst(snd(ma))
230 #define maRhs(ma)               snd(snd(ma))
231 #define extSc(v,o,ma)           maSc(ma) = cons(pair(v,o),maSc(ma))
232
233 List altsMatch(co,n,sc,as)              /* Make a list of matches from list*/
234 Int  co;                                /* of Alts, with initial offsets   */
235 Int  n;                                 /* reverse (take n [co..])         */
236 List sc;
237 List as; {
238     List mas = NIL;
239     List us  = NIL;
240     for (; n>0; n--)
241         us = cons(mkOffset(co++),us);
242     for (; nonNull(as); as=tl(as))      /* Each Alt is ([Pat], Rhs)        */
243         mas = cons(mkMatch(fst(hd(as)),us,sc,snd(hd(as))),mas);
244     return rev(mas);
245 }
246
247 Cell match(co,mas)              /* Generate case statement for Matches mas */
248 Int  co;                        /* at current offset co                    */
249 List mas; {                     /* N.B. Assumes nonNull(mas).              */
250     Cell srhs = NIL;            /* Rhs for selected matches                */
251     List smas = mas;            /* List of selected matches                */
252     mas       = tl(mas);
253     tl(smas)  = NIL;
254
255     if (emptyMatch(hd(smas))) {         /* The case for empty matches:     */
256         while (nonNull(mas) && emptyMatch(hd(mas))) {
257             List temp = tl(mas);
258             tl(mas)   = smas;
259             smas      = mas;
260             mas       = temp;
261         }
262         srhs = joinMas(co,rev(smas));
263     }
264     else {                              /* Non-empty match                 */
265         Int  o = offsetOf(hd(maOffs(hd(smas))));
266         Cell d = maDiscr(hd(smas));
267         if (isNumDiscr(d)) {            /* Numeric match                   */
268             Int  da = discrArity(d);
269             Cell d1 = pmcTerm(co,maSc(hd(smas)),d);
270             while (nonNull(mas) && !emptyMatch(hd(mas))
271                                 && o==offsetOf(hd(maOffs(hd(mas))))
272                                 && isNumDiscr(d=maDiscr(hd(mas)))
273                                 && eqNumDiscr(d,d1)) {
274                 List temp = tl(mas);
275                 tl(mas)   = smas;
276                 smas      = mas;
277                 mas       = temp;
278             }
279             smas = rev(smas);
280             map2Proc(advance,co,da,smas);
281             srhs = ap(NUMCASE,triple(mkOffset(o),d1,match(co+da,smas)));
282         }
283 #if TREX
284         else if (isExtDiscr(d)) {       /* Record match                    */
285             Int  da = discrArity(d);
286             Cell d1 = pmcTerm(co,maSc(hd(smas)),d);
287             while (nonNull(mas) && !emptyMatch(hd(mas))
288                                 && o==offsetOf(hd(maOffs(hd(mas))))
289                                 && isExtDiscr(d=maDiscr(hd(mas)))
290                                 && eqExtDiscr(d,d1)) {
291                 List temp = tl(mas);
292                 tl(mas)   = smas;
293                 smas      = mas;
294                 mas       = temp;
295             }
296             smas = rev(smas);
297             map2Proc(advance,co,da,smas);
298             srhs = ap(EXTCASE,triple(mkOffset(o),d1,match(co+da,smas)));
299         }
300 #endif
301         else {                          /* Constructor match               */
302             List tab = addConTable(d,hd(smas),NIL);
303             Int  da;
304             while (nonNull(mas) && !emptyMatch(hd(mas))
305                                 && o==offsetOf(hd(maOffs(hd(mas))))
306                                 && !isNumDiscr(d=maDiscr(hd(mas)))) {
307                 tab = addConTable(d,hd(mas),tab);
308                 mas = tl(mas);
309             }
310             for (tab=rev(tab); nonNull(tab); tab=tl(tab)) {
311                 d    = fst(hd(tab));
312                 smas = snd(hd(tab));
313                 da   = discrArity(d);
314                 map2Proc(advance,co,da,smas);
315                 srhs = cons(pair(d,match(co+da,smas)),srhs);
316             }
317             srhs = ap(CASE,pair(mkOffset(o),srhs));
318         }
319     }
320     return nonNull(mas) ? ap(FATBAR,pair(srhs,match(co,mas))) : srhs;
321 }
322
323 static Cell local joinMas(co,mas)       /* Combine list of matches into rhs*/
324 Int  co;                                /* using FATBARs as necessary      */
325 List mas; {                             /* Non-empty list of empty matches */
326     Cell ma  = hd(mas);
327     Cell rhs = pmcTerm(co,maSc(ma),maRhs(ma));
328     if (nonNull(tl(mas)) && canFail(rhs))
329         return ap(FATBAR,pair(rhs,joinMas(co,tl(mas))));
330     else
331         return rhs;
332 }
333
334 static Bool local canFail(rhs)         /* Determine if expression (as rhs) */
335 Cell rhs; {                            /* might ever be able to fail       */
336     switch (whatIs(rhs)) {
337         case LETREC  : return canFail(snd(snd(rhs)));
338         case GUARDED : return TRUE;    /* could get more sophisticated ..? */
339         default      : return FALSE;
340     }
341 }
342
343 /* type Table a b = [(a, [b])]
344  *
345  * addTable                 :: a -> b -> Table a b -> Table a b
346  * addTable x y []           = [(x,[y])]
347  * addTable x y (z@(n,sws):zs)
348  *              | n == x     = (n,sws++[y]):zs
349  *              | otherwise  = (n,sws):addTable x y zs
350  */
351
352 static List local addConTable(x,y,tab) /* add element (x,y) to table       */
353 Cell x, y;
354 List tab; {
355     if (isNull(tab))
356         return singleton(pair(x,singleton(y)));
357     else if (fst(hd(tab))==x)
358         snd(hd(tab)) = appendOnto(snd(hd(tab)),singleton(y));
359     else
360         tl(tab) = addConTable(x,y,tl(tab));
361
362     return tab;
363 }
364
365 static Void local advance(co,a,ma)      /* Advance non-empty match by      */
366 Int  co;                                /* processing head pattern         */
367 Int  a;                                 /* discriminator arity             */
368 Cell ma; {
369     Cell p  = hd(maPats(ma));
370     List ps = tl(maPats(ma));
371     List us = tl(maOffs(ma));
372     if (whatIs(p)==CONFLDS) {           /* Special case for record syntax  */
373         Name c  = fst(snd(p));
374         List fs = snd(snd(p));
375         List qs = NIL;
376         List vs = NIL;
377         for (; nonNull(fs); fs=tl(fs)) {
378             vs = cons(mkOffset(co+a+1-sfunPos(fst(hd(fs)),c)),vs);
379             qs = cons(snd(hd(fs)),qs);
380         }
381         ps = revOnto(qs,ps);
382         us = revOnto(vs,us);
383     }
384     else                                /* Normally just spool off patterns*/
385         for (; a>0; --a) {              /* and corresponding offsets ...   */
386             us = cons(mkOffset(++co),us);
387             ps = cons(arg(p),ps);
388             p  = fun(p);
389         }
390
391     maPats(ma) = ps;
392     maOffs(ma) = us;
393 }
394
395 /* --------------------------------------------------------------------------
396  * Normalize and test for empty match:
397  * ------------------------------------------------------------------------*/
398
399 static Bool local emptyMatch(ma)/* Normalize and test to see if a given    */
400 Cell ma; {                      /* match, ma, is empty.                    */
401
402     while (nonNull(maPats(ma))) {
403         Cell p;
404 tidyHd: switch (whatIs(p=hd(maPats(ma)))) {
405             case LAZYPAT   : {   Cell nv   = inventVar();
406                                  maRhs(ma) = ap(LETREC,
407                                                 pair(remPat(snd(p),nv,NIL),
408                                                      maRhs(ma)));
409                                  p         = nv;
410                              }
411                              /* intentional fall-thru */
412             case VARIDCELL :
413             case VAROPCELL :
414             case DICTVAR   : extSc(p,hd(maOffs(ma)),ma);
415             case WILDCARD  : maPats(ma) = tl(maPats(ma));
416                              maOffs(ma) = tl(maOffs(ma));
417                              continue;
418
419             /* So-called "as-patterns"are really just pattern intersections:
420              *    (p1@p2:ps, o:os, sc, e) ==> (p1:p2:ps, o:o:os, sc, e)
421              * (But the input grammar probably doesn't let us take
422              * advantage of this, so we stick with the special case
423              * when p1 is a variable.)
424              */
425             case ASPAT     : extSc(fst(snd(p)),hd(maOffs(ma)),ma);
426                              hd(maPats(ma)) = snd(snd(p));
427                              goto tidyHd;
428
429             case FINLIST   : hd(maPats(ma)) = mkConsList(snd(p));
430                              return FALSE;
431
432             case STRCELL   : {   String s = textToStr(textOf(p));
433                                  for (p=NIL; *s!='\0'; ++s) {
434                                      if (*s!='\\' || *++s=='\\') {
435                                          p = ap2(nameCons,mkChar(*s),p);
436                                      } else {
437                                          p = ap2(nameCons,mkChar('\0'),p);
438                                      }
439                                  }
440                                  hd(maPats(ma)) = revOnto(p,nameNil);
441                              }
442                              return FALSE;
443
444             case AP        : if (isName(fun(p)) && isCfun(fun(p))
445                                  && cfunOf(fun(p))==0
446                                  && name(fun(p)).defn==nameId) {
447                                   hd(maPats(ma)) = arg(p);
448                                   goto tidyHd;
449                              }
450                              /* intentional fall-thru */
451             case CHARCELL  :
452 #if !OVERLOADED_CONSTANTS
453             case INTCELL   :
454             case BIGCELL   :
455             case FLOATCELL :
456 #endif
457             case NAME      :
458             case CONFLDS   :
459                              return FALSE;
460
461             default        : internal("emptyMatch");
462         }
463     }
464     return TRUE;
465 }
466
467 /* --------------------------------------------------------------------------
468  * Discriminators:
469  * ------------------------------------------------------------------------*/
470
471 static Cell local maDiscr(ma)   /* Get the discriminator for a non-empty   */
472 Cell ma; {                      /* match, ma.                              */
473     Cell p = hd(maPats(ma));
474     Cell h = getHead(p);
475     switch (whatIs(h)) {
476         case CONFLDS : return fst(snd(p));
477 #if NPLUSK
478         case ADDPAT  : arg(fun(p)) = translate(arg(fun(p)));
479                        return fun(p);
480 #endif
481 #if TREX
482         case EXT     : h      = fun(fun(p));
483                        arg(h) = translate(arg(h));
484                        return h;
485 #endif
486 #if OVERLOADED_CONSTANTS
487         case NAME    : if (h==nameFromInt || h==nameFromInteger
488                                           || h==nameFromDouble) {
489                            if (argCount==2)
490                                arg(fun(p)) = translate(arg(fun(p)));
491                            return p;
492                         }
493 #endif
494     }
495     return h;
496 }
497
498 static Bool local isNumDiscr(d) /* TRUE => numeric discriminator           */
499 Cell d; {
500     switch (whatIs(d)) {
501         case NAME      :
502         case TUPLE     :
503         case CHARCELL  : return FALSE;
504 #if OVERLOADED_CONSTANTS
505 #if TREX
506         case AP        : return !isExt(fun(d));
507 #else
508         case AP        : return TRUE;   /* must be a literal or (n+k)      */
509 #endif
510 #else
511         case INTCELL  :
512         case BIGCELL  :
513         case FLOATCELL:
514                         return TRUE;
515 #endif
516     }
517     internal("isNumDiscr");
518     return 0;/*NOTREACHED*/
519 }
520
521 Int discrArity(d)                      /* Find arity of discriminator      */
522 Cell d; {
523     switch (whatIs(d)) {
524         case NAME      : return name(d).arity;
525         case TUPLE     : return tupleOf(d);
526         case CHARCELL  : return 0;
527 #if !OVERLOADED_CONSTANTS
528         case INTCELL   :
529         case BIGCELL   :
530         case FLOATCELL : return 0;
531 #endif /* !OVERLOADED_CONSTANTS */
532
533 #if TREX
534         case AP        : switch (whatIs(fun(d))) {
535 #if NPLUSK
536                              case ADDPAT : return 1;
537 #endif
538                              case EXT    : return 2;
539                              default     : return 0;
540                          }
541 #else
542 #if NPLUSK
543         case AP        : return (whatIs(fun(d))==ADDPAT) ? 1 : 0;
544 #else
545         case AP        : return 0;      /* must be an Int or Float lit     */
546 #endif
547 #endif
548     }
549     internal("discrArity");
550     return 0;/*NOTREACHED*/
551 }
552
553 static Bool local eqNumDiscr(d1,d2)     /* Determine whether two numeric   */
554 Cell d1, d2; {                          /* descriptors have same value     */
555 #if NPLUSK
556     if (whatIs(fun(d1))==ADDPAT)
557         return whatIs(fun(d2))==ADDPAT && bignumEq(snd(fun(d1)),snd(fun(d2)));
558 #endif
559 #if OVERLOADED_CONSTANTS
560     d1 = arg(d1);
561     d2 = arg(d2);
562 #endif
563     if (isInt(d1))
564         return isInt(d2) && intEq(d1,d2);
565     if (isFloat(d1))
566         return isFloat(d2) && floatEq(d1,d2);
567     if (isBignum(d1))
568         return isBignum(d2) && bignumEq(d1,d2);
569     internal("eqNumDiscr");
570     return FALSE;/*NOTREACHED*/
571 }
572
573 #if TREX
574 static Bool local isExtDiscr(d)         /* Test of extension discriminator */
575 Cell d; {
576     return isAp(d) && isExt(fun(d));
577 }
578
579 static Bool local eqExtDiscr(d1,d2)     /* Determine whether two extension */
580 Cell d1, d2; {                          /* discriminators have same label  */
581     return fun(d1)==fun(d2);
582 }
583 #endif
584
585 /*-------------------------------------------------------------------------*/