[project @ 1999-04-27 10:06:47 by sewardj]
[ghc-hetmet.git] / ghc / interpreter / optimise.c
1
2 /* --------------------------------------------------------------------------
3  * Optimiser
4  *
5  * Copyright (c) The University of Nottingham and Yale University, 1994-1997.
6  * All rights reserved. See NOTICE for details and conditions of use etc...
7  * Hugs version 1.4, December 1997
8  *
9  * $RCSfile: optimise.c,v $
10  * $Revision: 1.5 $
11  * $Date: 1999/04/27 10:06:57 $
12  * ------------------------------------------------------------------------*/
13
14 #include "prelude.h"
15 #include "storage.h"
16 #include "backend.h"
17 #include "connect.h"
18 #include "errors.h"
19 #include "link.h"
20 #include "Assembler.h"
21
22 /* #define DEBUG_OPTIMISE */
23
24 /* --------------------------------------------------------------------------
25  * Local functions
26  * ------------------------------------------------------------------------*/
27
28 Int nLoopBreakersInlined;
29 Int nLetvarsInlined;
30 Int nTopvarsInlined;
31 Int nCaseOfLet;
32 Int nCaseOfCase;
33 Int nCaseOfPrimCase;
34 Int nCaseOfCon;
35 Int nCaseOfOuter;
36 Int nLetBindsDropped;
37 Int nLetrecGroupsDropped;
38 Int nLambdasMerged;
39 Int nCaseDefaultsDropped;
40 Int nAppsMerged;
41 Int nLetsFloatedOutOfFn;
42 Int nLetsFloatedIntoCase;
43 Int nCasesFloatedOutOfFn;
44 Int nBetaReductions;
45
46 Int nTotSizeIn;
47 Int nTotSizeOut;
48
49 Int  rDepth;
50 Bool copyInTopvar;
51 Bool inDBuilder;
52
53 static void local optimiseTopBind( StgVar v );
54
55 typedef
56    enum {
57       CTX_SCRUT,
58       CTX_OTHER
59    }
60    InlineCtx;
61
62 /* Exactly like whatIs except it avoids a fn call for STG tags */
63 #define whatIsStg(xx) ((isPair(xx) ? (isTag(fst(xx)) ? fst(xx) : AP) : whatIs(xx)))
64
65
66 /* --------------------------------------------------------------------------
67  * Transformation stats
68  * ------------------------------------------------------------------------*/
69
70 void initOptStats ( void )
71 {
72    nLoopBreakersInlined  = 0;
73    nLetvarsInlined       = 0;
74    nTopvarsInlined       = 0;
75    nCaseOfLet            = 0;
76    nCaseOfCase           = 0;
77    nCaseOfPrimCase       = 0;
78    nCaseOfCon            = 0;
79    nCaseOfOuter          = 0;
80    nLetBindsDropped      = 0;
81    nLetrecGroupsDropped  = 0;
82    nLambdasMerged        = 0;
83    nCaseDefaultsDropped  = 0;
84    nAppsMerged           = 0;
85    nLetsFloatedOutOfFn   = 0;
86    nLetsFloatedIntoCase  = 0;
87    nCasesFloatedOutOfFn  = 0;
88    nBetaReductions       = 0;
89    nTotSizeIn            = 0;
90    nTotSizeOut           = 0;
91 }
92
93 void printOptStats ( FILE* f )
94 {
95    fflush(stdout); fflush(stderr); fflush(f);
96    fprintf(f, "\n\n" );
97    fprintf(f, "Inlining:     topvar %-5d        letvar %-5d"
98               "      loopbrkr %-5d      betaredn %-5d\n",
99               nTopvarsInlined, nLetvarsInlined, nLoopBreakersInlined, 
100               nBetaReductions );
101    fprintf(f, "Case-of-:        let %-5d          case %-5d"
102               "           con %-5d         case# %-5d\n",
103               nCaseOfLet, nCaseOfCase, nCaseOfCon, nCaseOfPrimCase );
104    fprintf(f, "Dropped:     letbind %-5d      letgroup %-5d"
105               "       default %-5d\n",
106               nLetBindsDropped, nLetrecGroupsDropped, nCaseDefaultsDropped );
107    fprintf(f, "Merges:       lambda %-5d           app %-5d\n",
108               nLambdasMerged, nAppsMerged  );
109    fprintf(f, "Fn-float:        let %-5d          case %-5d\n",
110               nLetsFloatedOutOfFn, nCasesFloatedOutOfFn );
111    fprintf(f, "Misc:     case-outer %-5d let-into-case %-5d\n",
112               nCaseOfOuter, nLetsFloatedIntoCase );
113    fprintf(f, "total size:       in %-5d           out %-5d\n",
114               nTotSizeIn, nTotSizeOut );
115    fprintf(f, "\n" );
116 }
117
118
119 /* --------------------------------------------------------------------------
120  * How big is this STG tree (viz (primarily), do I want to inline it?)
121  * ------------------------------------------------------------------------*/
122
123 Int stgSize_list ( List es )
124 {
125    Int n = 0;
126    for (; nonNull(es); es=tl(es)) n += stgSize(hd(es));
127    return n;
128 }
129
130 Int stgSize ( StgExpr e )
131 {
132    List xs;
133    Int n = 1;
134
135    if (isNull(e)) return 0;
136
137    switch(whatIsStg(e)) {
138       case STGVAR:
139          break;
140       case LETREC:
141          for (xs = stgLetBinds(e); nonNull(xs);xs=tl(xs)) 
142             n += stgSize(stgVarBody(hd(xs)));
143          n += stgSize(stgLetBody(e));
144          break;
145       case LAMBDA:
146          n += stgSize(stgLambdaBody(e));
147          break;
148       case CASE:
149          n += stgSize_list(stgCaseAlts(e));
150          n += stgSize(stgCaseScrut(e));
151          break;
152       case PRIMCASE:
153          n += stgSize_list(stgPrimCaseAlts(e));
154          n += stgSize(stgPrimCaseScrut(e));
155          break;
156       case STGAPP:
157          n += stgSize_list(stgAppArgs(e));
158          n += stgSize(stgAppFun(e));
159          break;
160       case STGPRIM:
161          n += stgSize_list(stgPrimArgs(e));
162          n += stgSize(stgPrimOp(e));
163          break;
164       case STGCON:
165          n += stgSize_list(stgConArgs(e));
166          n += stgSize(stgConCon(e));
167          break;
168       case DEEFALT:
169          n  = stgSize(stgDefaultBody(e));
170          break;
171       case CASEALT:
172          n  = stgSize(stgCaseAltBody(e));
173          break;
174       case PRIMALT:
175          n  = stgSize(stgPrimAltBody(e));
176          break;
177       case INTCELL:
178       case STRCELL:
179       case PTRCELL:
180       case CHARCELL:
181       case FLOATCELL:
182       case BIGCELL:
183       case NAME:
184       case TUPLE:
185          break;
186       default:
187          fprintf(stderr, "sizeStg: unknown stuff %d\n",whatIsStg(e));
188          assert(0);
189    }
190    return n;
191 }
192
193
194 /* --------------------------------------------------------------------------
195  * Stacks of pairs of collectable things.  Used to implement associations.
196  * cloneStg() uses its stack to map old var names to new ones.
197  * ------------------------------------------------------------------------*/
198
199 #define M_PAIRS 400
200 #define SP_NOT_IN_USE (-123456789)
201
202 typedef
203    struct { Cell pfst; Cell psnd; } 
204    StgPair;
205
206 static Int     spClone;
207 static StgPair pairClone[M_PAIRS];
208
209 void markPairs ( void )
210 {
211    Int i;
212    if (spClone != SP_NOT_IN_USE) {
213       for (i = 0; i <= spClone; i++) {
214          mark(pairClone[i].pfst);
215          mark(pairClone[i].psnd);
216       }
217    }
218 }
219
220 void pushClone ( Cell a, Cell b )
221 {
222    spClone++;
223    if (spClone >= M_PAIRS) internal("pushClone -- M_PAIRS too small");
224    pairClone[spClone].pfst = a;
225    pairClone[spClone].psnd = b;
226 }
227
228 void dropClone ( void )
229 {
230    if (spClone < 0) internal("dropClone");
231    spClone--;
232 }
233
234 Cell findClone ( Cell x )
235 {
236    Int i;
237    for (i = spClone; i >= 0; i--)
238       if (pairClone[i].pfst == x)
239          return pairClone[i].psnd;
240    return NIL;
241 }
242
243
244 /* --------------------------------------------------------------------------
245  * Cloning of STG trees
246  * ------------------------------------------------------------------------*/
247
248 /* Clone v to create a new var.  Works for both StgVar and StgPrimVar. */
249 StgVar cloneStgVar ( StgVar v )
250 {
251   return ap(STGVAR,triple(stgVarBody(v),stgVarRep(v),NIL));
252 }
253
254
255 /* For each StgVar in origVars, make a new one with cloneStgVar,
256    and push the (old,new) pair on the clone pair stack.  Returns
257    the list of new vars.
258 */
259 List cloneStg_addVars ( List origVars )
260 {
261    List newVars = NIL;
262    while (nonNull(origVars)) {
263       StgVar newv = cloneStgVar(hd(origVars));
264       pushClone ( hd(origVars), newv );
265       newVars    = cons(newv,newVars);
266       origVars   = tl(origVars);
267    }
268    newVars = rev(newVars);
269    return newVars;
270 }
271
272
273 void cloneStg_dropVars ( List vs )
274 {
275    for (; nonNull(vs); vs=tl(vs)) 
276       dropClone();
277 }
278
279
280 /* Print the clone pair stack.  Just for debugging purposes. */
281 void ppCloneEnv ( char* s )
282 {
283    Int i;
284    fflush(stdout);fflush(stderr);
285    printf ( "\nenv-%s\n", s );
286    for (i = 0; i <= spClone; i++) {
287       printf ( "\t" ); 
288       ppStgExpr(pairClone[i].pfst);
289       ppStgExpr(pairClone[i].psnd);
290       printf ( "\n" );
291    };
292    printf ( "vne-%s\n", s );
293 }
294
295
296 StgExpr cloneStg ( StgExpr e )
297 {
298    List xs, newvs;
299    StgVar newv;
300    StgExpr t;
301
302    switch(whatIsStg(e)) {
303       case STGVAR:
304          newv = findClone(e);
305          if (nonNull(newv)) return newv; else return e;
306       case LETREC:
307          newvs = cloneStg_addVars ( stgLetBinds(e) );
308          for (xs = newvs; nonNull(xs);xs=tl(xs)) 
309             stgVarBody(hd(xs)) = cloneStg(stgVarBody(hd(xs)));
310          t = mkStgLet(newvs,cloneStg(stgLetBody(e)));
311          cloneStg_dropVars ( stgLetBinds(e) );
312          return t;
313       case LAMBDA:
314          newvs = cloneStg_addVars ( stgLambdaArgs(e) );
315          t = mkStgLambda(newvs, cloneStg(stgLambdaBody(e)));
316          cloneStg_dropVars ( stgLambdaArgs(e) );
317          return t;
318       case CASE:
319          xs = dupList(stgCaseAlts(e)); 
320          mapOver(cloneStg,xs);
321          return mkStgCase(cloneStg(stgCaseScrut(e)),xs);
322       case PRIMCASE:
323          xs = dupList(stgPrimCaseAlts(e));
324          mapOver(cloneStg,xs);
325          return mkStgPrimCase(cloneStg(stgPrimCaseScrut(e)),xs);
326       case STGAPP:
327          xs = dupList(stgAppArgs(e));
328          mapOver(cloneStg,xs);
329          return mkStgApp(cloneStg(stgAppFun(e)),xs);
330       case STGPRIM:
331          xs = dupList(stgPrimArgs(e));
332          mapOver(cloneStg,xs);
333          return mkStgPrim(cloneStg(stgPrimOp(e)),xs);
334       case STGCON:
335          xs = dupList(stgConArgs(e));
336          mapOver(cloneStg,xs);
337          return mkStgCon(cloneStg(stgConCon(e)),xs);
338       case DEEFALT:
339          newv = cloneStgVar(stgDefaultVar(e));
340          pushClone ( stgDefaultVar(e), newv );
341          t = mkStgDefault(newv,cloneStg(stgDefaultBody(e)));
342          dropClone();
343          return t;
344       case CASEALT:
345          newvs = cloneStg_addVars ( stgCaseAltVars(e) );
346          t = mkStgCaseAlt(stgCaseAltCon(e),newvs,
347                           cloneStg(stgCaseAltBody(e)));
348          cloneStg_dropVars ( stgCaseAltVars(e) );
349          return t;
350       case PRIMALT:
351          newvs = cloneStg_addVars ( stgPrimAltVars(e) );
352          t = mkStgPrimAlt(newvs, cloneStg(stgPrimAltBody(e)));
353          cloneStg_dropVars ( stgPrimAltVars(e) );
354          return t;
355       case INTCELL:
356       case STRCELL:
357       case PTRCELL:
358       case BIGCELL:
359       case CHARCELL:
360       case FLOATCELL:
361       case NAME:
362       case TUPLE:
363          return e;
364       default:
365          fprintf(stderr, "cloneStg: unknown stuff %d\n",whatIsStg(e));
366          assert(0);
367    }
368 }
369
370
371 /* Main entry point.  Checks against re-entrant use. */
372 StgExpr cloneStgTop ( StgExpr e )
373 {
374    StgExpr res;
375    if (spClone != SP_NOT_IN_USE) 
376       internal("cloneStgTop");
377    spClone = -1;
378    res = cloneStg ( e );
379    assert(spClone == -1);
380    spClone = SP_NOT_IN_USE;
381    return res;
382 }
383
384
385
386 /* --------------------------------------------------------------------------
387  * Sets of StgVars, used by the strongly-connected-components machinery.  
388  * Represented as an array of variables.  The vars
389  * must be in strictly nondecreasing order.  Each value may appear
390  * more than once, so as to make deletion relatively cheap.
391
392  * After a garbage collection happens, the values may have changed,
393  * so the array will need to be sorted.
394
395  * Using a binary search, membership costs O(log N).  Union and
396  * intersection cost O(N + M).  Deletion of a single element costs
397  * O(N) in the worst case, although if it happens infrequently
398  * compared to the other ops, it should asymptotically approach O(1).
399  * ------------------------------------------------------------------------*/
400
401 #define M_VAR_SETS 4000
402 #define MIN_VAR_SET_SIZE 4
403 #define M_UNION_TMP 20000
404
405 typedef
406    struct {
407       Int   nextfree;
408       Bool  inUse;
409       Int   size;
410       Int   used;
411       Cell* vs;
412    }
413    StgVarSetRec;
414
415 typedef Int StgVarSet;
416
417 StgVarSetRec varSet[M_VAR_SETS];
418 Int varSet_nfree;
419 Int varSet_nextfree;
420 Cell union_tmp[M_UNION_TMP];
421
422 #if 0 /* unused since unnecessary */
423 /* Shellsort set elems to restore representation invariants */
424 static Int shellCells_incs[10] 
425    = { 1, 4, 13, 40, 121, 364, 1093, 3280, 9841, 29524 };
426 static void shellCells ( Cell* a, Int lo, Int hi )
427 {
428    Int i, j, h, N, hp;
429    Cell v;
430
431    N = hi - lo + 1; if (N < 2) return;
432    hp = 0; 
433    while (hp < 10 && shellCells_incs[hp] < N) hp++; hp--;
434
435    for (; hp >= 0; hp--) {
436       h = shellCells_incs[hp];
437       i = lo + h;
438       while (1) {
439          if (i > hi) break;
440          v = a[i];
441          j = i;
442          while (a[j-h] > v) {
443             a[j] = a[j-h]; j = j - h;
444             if (j <= (lo + h - 1)) break;
445          }
446          a[j] = v; i++;
447       }
448    }
449 }
450 #endif
451
452 /* check that representation invariant still holds */
453 static void checkCells ( Cell* a, Int lo, Int hi )
454 {
455    Int i;
456    for (i = lo; i < hi; i++)
457       if (a[i] > a[i+1])
458          internal("checkCells");
459 }
460
461
462 /* Mark set contents for GC */
463 void markStgVarSets ( void )
464 {
465    Int i, j;
466    for (i = 0; i < M_VAR_SETS; i++)
467       if (varSet[i].inUse)
468          for (j = 0; j < varSet[i].used; j++)
469             mark(varSet[i].vs[j]);
470 }
471
472
473 /* Check representation invariants after GC */
474 void checkStgVarSets ( void )
475 {
476    Int i;
477    for (i = 0; i < M_VAR_SETS; i++)
478       if (varSet[i].inUse)
479          checkCells ( varSet[i].vs, 0, varSet[i].used-1 );
480 }
481
482
483 /* Allocate a set of a given size */
484 StgVarSet allocStgVarSet ( Int size )
485 {
486    Int i, j;
487    if (varSet_nextfree == -1)
488       internal("allocStgVarSet -- run out of var sets");
489    i = varSet_nextfree;
490    varSet_nextfree = varSet[i].nextfree;
491    varSet[i].inUse = TRUE;
492    j = MIN_VAR_SET_SIZE;
493    while (j <= size) j *= 2;
494    varSet[i].used = 0;
495    varSet[i].size = j;
496    varSet[i].vs = malloc(j * sizeof(StgVar) );
497    if (!varSet[i].vs) 
498       internal("allocStgVarSet -- can't malloc memory");
499    varSet_nfree--;
500    return i;
501 }
502
503
504 /* resize (upwards) */
505 void resizeStgVarSet ( StgVarSet s, Int size )
506 {
507    Cell* tmp;
508    Cell* tmp2;
509    Int i;
510    Int j = MIN_VAR_SET_SIZE;
511    while (j <= size) j *= 2;
512    if (j < varSet[s].size) return;
513    tmp = varSet[s].vs;
514    tmp2 = malloc( j * sizeof(StgVar) );
515    if (!tmp2) internal("resizeStgVarSet -- can't malloc memory");
516    varSet[s].vs = tmp2;
517    for (i = 0; i < varSet[s].used; i++)
518       tmp2[i] = tmp[i];
519    free(tmp);
520 }
521
522
523 /* Deallocation ... */
524 void freeStgVarSet ( StgVarSet s )
525 {
526    if (s < 0 || s >= M_VAR_SETS || 
527        !varSet[s].inUse || !varSet[s].vs)
528       internal("freeStgVarSet");
529    free(varSet[s].vs);
530    varSet[s].inUse = FALSE;
531    varSet[s].vs = NULL;
532    varSet[s].nextfree = varSet_nextfree;
533    varSet_nextfree = s;
534    varSet_nfree++;
535 }
536
537
538 /* Initialisation */
539 void initStgVarSets ( void )
540 {
541    Int i;
542    for (i = M_VAR_SETS-1; i >= 0; i--) {
543       varSet[i].inUse = FALSE;
544       varSet[i].vs = NULL;
545       varSet[i].nextfree = i+1;
546    }
547    varSet[M_VAR_SETS-1].nextfree = -1;
548    varSet_nextfree = 0;
549    varSet_nfree = M_VAR_SETS;
550 }
551
552
553 /* Find a var using binary search */
554 Int findInStgVarSet ( StgVarSet s, StgVar v )
555 {
556    Int lo, mid, hi;
557    lo = 0;
558    hi = varSet[s].used-1;
559    while (1) {
560       if (lo > hi) return -1;
561       mid = (hi+lo)/2;
562       if (varSet[s].vs[mid] == v) return mid;
563       if (varSet[s].vs[mid] < v) lo = mid+1; else hi = mid-1;
564    }
565 }
566
567
568 Bool elemStgVarSet ( StgVarSet s, StgVar v )
569 {
570    return findInStgVarSet(s,v) != -1;
571 }
572
573 void ppSet ( StgVarSet s )
574 {
575    Int i;
576    fprintf(stderr, "{ ");
577    for (i = 0; i < varSet[s].used; i++)
578       fprintf(stderr, "%d ", varSet[s].vs[i] );
579    fprintf(stderr, "}\n" );
580 }
581
582
583 void deleteFromStgVarSet ( StgVarSet s, StgVar v )
584 {
585    Int i, j;
586    i = findInStgVarSet(s,v);
587    if (i == -1) return;
588    j = varSet[s].used-1;
589    for (; i < j; i++) varSet[s].vs[i] = varSet[s].vs[i+1];
590    varSet[s].used--;
591 }
592
593
594 void singletonStgVarSet ( StgVarSet s, StgVar v )
595 {
596    varSet[s].used  = 1;
597    varSet[s].vs[0] = v;
598 }
599
600
601 void emptyStgVarSet ( StgVarSet s )
602 {
603    varSet[s].used = 0;
604 }
605
606
607 void copyStgVarSets ( StgVarSet dst, StgVarSet src )
608 {
609    Int i;
610    varSet[dst].used = varSet[src].used;
611    for (i = 0; i < varSet[dst].used; i++)
612       varSet[dst].vs[i] = varSet[src].vs[i];
613 }
614
615
616 Int sizeofVarSet ( StgVarSet s )
617 {
618    return varSet[s].used;
619 }
620
621
622 void unionStgVarSets ( StgVarSet dst, StgVarSet src )
623 {
624    StgVar v1;
625    Int pd, ps, i, res_used, tmp_used, dst_used, src_used;
626    StgVar* dst_vs;
627    StgVar* src_vs;
628    StgVar* tmp_vs;
629
630    dst_vs = varSet[dst].vs;
631
632    /* fast track a common (~ 50%) case */
633    if (varSet[src].used == 1) {
634       v1 = varSet[src].vs[0];
635       pd = findInStgVarSet(dst,v1);
636       if (pd != -1) return;
637       if (varSet[dst].used < varSet[dst].size) {
638          i = varSet[dst].used;
639          while (i > 0 && dst_vs[i-1] > v1) {
640             dst_vs[i] = dst_vs[i-1];
641             i--;
642          }
643          dst_vs[i] = v1;
644          varSet[dst].used++;
645          return;
646       }
647    }
648
649    res_used = varSet[dst].used + varSet[src].used;
650    if (res_used > M_UNION_TMP) 
651       internal("unionStgVarSets -- M_UNION_TMP too small");
652
653    resizeStgVarSet(dst,res_used);
654    dst_vs = varSet[dst].vs;
655    src_vs = varSet[src].vs;
656    tmp_vs = union_tmp;
657    tmp_used = 0;
658    dst_used = varSet[dst].used;
659    src_used = varSet[src].used;
660
661    /* merge the two sets into tmp */
662    pd = ps = 0;
663    while (pd < dst_used || ps < src_used) {
664       if (pd == dst_used)
665          tmp_vs[tmp_used++] = src_vs[ps++];
666       else
667       if (ps == src_used)
668          tmp_vs[tmp_used++] = dst_vs[pd++];
669       else {
670          StgVar vald = dst_vs[pd];
671          StgVar vals = src_vs[ps];
672          if (vald < vals)
673             tmp_vs[tmp_used++] = vald, pd++;
674          else
675          if (vald > vals)
676             tmp_vs[tmp_used++] = vals, ps++;
677          else
678             tmp_vs[tmp_used++] = vals, ps++, pd++;
679       }
680    }
681
682    /* copy setTmp back to dst */
683    varSet[dst].used = tmp_used;
684    for (i = 0; i < tmp_used; i++) {
685       dst_vs[i] = tmp_vs[i];
686    }
687 }
688
689
690
691 /* --------------------------------------------------------------------------
692  * Strongly-connected-components machinery for STG let bindings.
693  * Arranges let bindings in minimal mutually recursive groups, and
694  * then throws away any groups not referred to in the body of the let.
695  *
696  * How it works: does a bottom-up sweep of the tree.  Each call returns
697  * the set of variables free in the tree.  All nodes except LETREC are
698  * boring.  
699  * 
700  * When 'let v1=e1 .. vn=en in e' is encountered:
701  * -- recursively make a call on e.  This returns fvs(e) and scc-ifies
702  *    inside e as well.
703  * -- do recursive calls for e1 .. en too, giving fvs(e1) ... fvs(en).
704  *
705  * Then, using fvs(e1) ... fvs(en), the dependancy graph for v1 ... vn
706  * can be cheaply computed.  Using that, compute the strong components
707  * and rearrange the let binding accordingly.
708  * Finally, for each of the strong components, we can use fvs(en) to 
709  * cheaply determine if the component is used in the body of the let,
710  * and if not, it can be omitted.
711  *
712  * oaScc destructively modifies the tree -- when it gets to a let --
713  * we need to pass the address of the expression to scc, not the
714  * (more usual) heap index of it.
715  *
716  * The main requirement of this algorithm is an efficient implementation
717  * of sets of variables.  Because there is no name shadowing in these
718  * trees, either mentioned-sets or free-sets would be ok, although 
719  * free sets are presumably smaller.
720  * ------------------------------------------------------------------------*/
721
722
723 #define  SCC             stgScc          /* make scc algorithm for StgVars */
724 #define  LOWLINK         stgLowlink
725 #define  DEPENDS(t)      thd3(t)
726 #define  SETDEPENDS(c,v) thd3(c)=v
727 #include "scc.c"
728 #undef   SETDEPENDS
729 #undef   DEPENDS
730 #undef   LOWLINK
731 #undef   SCC
732
733
734 StgVarSet oaScc ( StgExpr* e_orig )
735 {
736    Bool grpUsed;
737    StgExpr e;
738    StgVarSet e_fvs, s1, s2;
739    List bs, bs2, bs3, bsFinal, augs, augsL;
740
741    bs=bs2=bs3=bsFinal=augs=augsL=e_fvs=s1=s2=e=NIL;
742    grpUsed=FALSE;
743
744    e = *e_orig;
745
746    //fprintf(stderr,"\n==================\n");
747    //ppStgExpr(*e_orig);
748    //fprintf(stderr,"\n\n");fflush(stderr);fflush(stdout);
749
750
751    switch(whatIsStg(e)) {
752       case LETREC:
753          /* first, recurse into the let body */
754          e_fvs = oaScc(&stgLetBody(*e_orig));
755
756          /* Make bs :: [StgVar] and e :: Stgexpr. */
757          bs = stgLetBinds(e);
758          e  = stgLetBody(e);
759
760          /* make augs :: [(StgVar,fvs(bindee),NIL)] */
761          augs = NIL;
762          for (; nonNull(bs); bs=tl(bs)) {
763             StgVarSet fvs_bindee = oaScc(&stgVarBody(hd(bs)));
764             augs = cons( triple(hd(bs),mkInt(fvs_bindee),NIL), augs );
765          }
766
767          bs2=bs3=bsFinal=augsL=s1=s2=NIL;
768
769          /* In each of the triples in aug, replace the NIL field with 
770             a list of the let-bound vars appearing in the bindee.
771             ie, construct the adjacency list for the graph. 
772             giving 
773             augs :: [(StgVar,fvs(bindee),[pointers-back-to-this-list-of-pairs])]
774          */
775          for (bs=augs;nonNull(bs);bs=tl(bs)) {
776             augsL = NIL;
777             for (bs2=augs;nonNull(bs2);bs2=tl(bs2))
778                if (elemStgVarSet( intOf(snd3(hd(bs))), fst3(hd(bs2)) ))
779                   augsL = cons(hd(bs2),augsL);
780             thd3(hd(bs)) = augsL;
781          }
782
783          bs2=bs3=bsFinal=augsL=s1=s2=NIL;
784
785          /* Do the Biz.  
786             augs becomes :: [[(StgVar,fvs(bindee),aux_info_field)]] */
787          augs = stgScc(augs);
788
789          /* work backwards through augs, reconstructing the expression,
790             dumping any unused groups as you go.
791          */
792          bsFinal = NIL;
793          for (augs=rev(augs); nonNull(augs); augs=tl(augs)) {
794             bs2 = NIL;
795             for (augsL=hd(augs);nonNull(augsL); augsL=tl(augsL))
796                bs2 = cons(fst3(hd(augsL)),bs2);
797             grpUsed = FALSE;
798             for (bs3=bs2;nonNull(bs3);bs3=tl(bs3))
799                if (elemStgVarSet(e_fvs,hd(bs3))) { grpUsed=TRUE; break; }
800             if (grpUsed) {
801                //e = mkStgLet(bs2,e);
802                bsFinal = dupOnto(bs2,bsFinal);
803                for (augsL=hd(augs);nonNull(augsL);augsL=tl(augsL)) {
804                   unionStgVarSets(e_fvs, intOf(snd3(hd(augsL))) );
805                   freeStgVarSet(intOf(snd3(hd(augsL))));
806                }
807             } else {
808                nLetrecGroupsDropped++;
809                for (augsL=hd(augs);nonNull(augsL);augsL=tl(augsL)) {
810                   freeStgVarSet(intOf(snd3(hd(augsL))));
811                }
812             }
813          }
814          //*e_orig = e;
815          *e_orig = mkStgLet(bsFinal,e);
816          return e_fvs;
817
818       case LAMBDA:
819          s1 = oaScc(&stgLambdaBody(e));
820          for (bs=stgLambdaArgs(e);nonNull(bs);bs=tl(bs))
821             deleteFromStgVarSet(s1,hd(bs));
822          return s1;
823       case CASE:
824          s1 = oaScc(&stgCaseScrut(e));
825          for (bs=stgCaseAlts(e);nonNull(bs);bs=tl(bs)) {
826             s2 = oaScc(&hd(bs));
827             unionStgVarSets(s1,s2);
828             freeStgVarSet(s2);
829          }
830          return s1;
831       case PRIMCASE:
832          s1 = oaScc(&stgPrimCaseScrut(e));
833          for (bs=stgPrimCaseAlts(e);nonNull(bs);bs=tl(bs)) {
834             s2 = oaScc(&hd(bs));
835             unionStgVarSets(s1,s2);
836             freeStgVarSet(s2);
837          }
838          return s1;
839       case STGAPP:
840          s1 = oaScc(&stgAppFun(e));
841          for (bs=stgAppArgs(e);nonNull(bs);bs=tl(bs)) {
842             s2 = oaScc(&hd(bs));
843             unionStgVarSets(s1,s2);
844             freeStgVarSet(s2);
845          }
846          return s1;
847       case STGPRIM:
848          s1 = oaScc(&stgPrimOp(e));
849          for (bs=stgPrimArgs(e);nonNull(bs);bs=tl(bs)) {
850             s2 = oaScc(&hd(bs));
851             unionStgVarSets(s1,s2);
852             freeStgVarSet(s2);
853          }
854          return s1;
855       case STGCON:
856          s1 = allocStgVarSet(0);
857          for (bs=stgPrimArgs(e);nonNull(bs);bs=tl(bs)) {
858             s2 = oaScc(&hd(bs));
859             unionStgVarSets(s1,s2);
860             freeStgVarSet(s2);
861          }
862          return s1;
863       case CASEALT:
864          s1 = oaScc(&stgCaseAltBody(e));
865          for (bs=stgCaseAltVars(e);nonNull(bs);bs=tl(bs))
866             deleteFromStgVarSet(s1,hd(bs));
867          return s1;
868       case DEEFALT:
869          s1 = oaScc(&stgDefaultBody(e));
870          deleteFromStgVarSet(s1,stgDefaultVar(e));
871          return s1;
872       case PRIMALT:
873          s1 = oaScc(&stgPrimAltBody(e));
874          for (bs=stgPrimAltVars(e);nonNull(bs);bs=tl(bs))
875             deleteFromStgVarSet(s1,hd(bs));
876          return s1;
877       case STGVAR:
878          s1 = allocStgVarSet(1);
879          singletonStgVarSet(s1,e);
880          return s1;
881       case NAME:
882       case INTCELL:
883       case STRCELL:
884       case PTRCELL:
885       case BIGCELL:
886       case CHARCELL:
887       case FLOATCELL:
888          return allocStgVarSet(0);
889          break;
890       default:
891          fprintf(stderr, "oaScc: unknown stuff %d\n",whatIsStg(e));
892          assert(0);
893    }
894 }
895
896
897
898 /* --------------------------------------------------------------------------
899  * Occurrence analyser.  Marks each let-bound var with the number of times
900  * it is used, or some number >= OCC_IN_LAMBDA if it is used inside a lambda.
901  *
902  * Firstly, oaPre traverses the tree, attaching a mutable INT cell to each
903  * let bound var, and NIL-ing the counts on all other vars.
904  *
905  * Then oaCount traveses the tree.  Because variables are represented by
906  * pointers in the heap, we can just increment the count field of each
907  * variable we see.  However, to deal with lambdas, the Hugs stack holds
908  * all let-bound variables currently in scope, and the uppermost portion
909  * of the stack, stack(spBase .. sp) inclusive, denotes the variables
910  * introduced into scope since the nearest enclosing lambda.  When a 
911  * let-bound var is seen, we search stack(spBase .. sp).  If it appears
912  * there, no lambda exists between the binding site and this usage of the
913  * var, so we can safely increment its use.  Otherwise, we must set it to
914  * OCC_IN_LAMBDA.
915  *
916  * When passing a lambda, spBase is set to sp+1, so as to effectively
917  * empty the set of vars-bound-since-the-latest-lambda.
918  * 
919  * Because oaPre pre-annotates the tree with mutable INT cells, oaCount
920  * doesn't allocate any heap at all.
921  * ------------------------------------------------------------------------*/
922
923 static int spBase;
924
925
926 #define OCC_IN_LAMBDA 50  /* any number > 1 will do */
927 #define nullCount(vv) stgVarInfo(vv)=NIL
928 #define nullCounts(vvs) { List tt=(vvs);for(;nonNull(tt);tt=tl(tt)) nullCount(hd(tt));}
929
930
931
932 void oaPre ( StgExpr e )
933 {
934    List bs;
935    switch(whatIsStg(e)) {
936       case LETREC:
937          for (bs = stgLetBinds(e);nonNull(bs);bs=tl(bs))
938             stgVarInfo(hd(bs)) = mkInt(0);
939          for (bs = stgLetBinds(e);nonNull(bs);bs=tl(bs))
940             oaPre(stgVarBody(hd(bs)));
941          oaPre(stgLetBody(e));
942          break;
943       case LAMBDA:
944          nullCounts(stgLambdaArgs(e));
945          oaPre(stgLambdaBody(e));
946          break;
947       case CASE:
948          oaPre(stgCaseScrut(e));
949          mapProc(oaPre,stgCaseAlts(e));
950          break;
951       case PRIMCASE:
952          oaPre(stgPrimCaseScrut(e));
953          mapProc(oaPre,stgPrimCaseAlts(e));
954          break;
955       case STGAPP:
956          oaPre(stgAppFun(e));
957          mapProc(oaPre,stgAppArgs(e));
958          break;
959       case STGPRIM:
960          mapProc(oaPre,stgPrimArgs(e));
961          break;
962       case STGCON:
963          mapProc(oaPre,stgConArgs(e));
964          break;
965       case CASEALT:
966          nullCounts(stgCaseAltVars(e));
967          oaPre(stgCaseAltBody(e));
968          break;
969       case DEEFALT:
970          nullCount(stgDefaultVar(e));
971          oaPre(stgDefaultBody(e));
972          break;
973       case PRIMALT:
974          nullCounts(stgPrimAltVars(e));
975          oaPre(stgPrimAltBody(e));
976          break;
977       case STGVAR:
978       case NAME:
979       case INTCELL:
980       case STRCELL:
981       case PTRCELL:
982       case BIGCELL:
983       case CHARCELL:
984       case FLOATCELL:
985          break;
986       default:
987          fprintf(stderr, "oaPre: unknown stuff %d\n",whatIsStg(e));
988          assert(0);
989    }
990 }
991
992
993 /* In oaCount:
994    -- the stack is always the set of let-bound vars currently
995       in scope.  viz, stack(0 .. sp) inclusive.
996    -- spBase is always >= 0 and <= sp.  
997       stack(spBase .. sp) inclusive will be the let vars bound
998       since the nearest enclosing lambda.  When entering a lambda,
999       we set spBase=sp+1 so as record this fact, and restore spBase
1000       afterwards.
1001 */
1002 void oaCount ( StgExpr e )
1003 {
1004    List bs;
1005    Int  spBase_saved;
1006
1007    switch(whatIsStg(e)) {
1008       case LETREC:
1009          for (bs = stgLetBinds(e);nonNull(bs);bs=tl(bs))
1010             push(hd(bs));
1011          for (bs = stgLetBinds(e);nonNull(bs);bs=tl(bs))
1012             oaCount(stgVarBody(hd(bs)));
1013          oaCount(stgLetBody(e));
1014          for (bs = stgLetBinds(e);nonNull(bs);bs=tl(bs))
1015             drop();
1016          break;
1017       case LAMBDA:
1018          spBase_saved = spBase;
1019          spBase = sp+1;
1020          oaCount(stgLambdaBody(e));
1021          spBase = spBase_saved;
1022          break;
1023       case CASE:
1024          oaCount(stgCaseScrut(e));
1025          mapProc(oaCount,stgCaseAlts(e));
1026          break;
1027       case PRIMCASE:
1028          oaCount(stgPrimCaseScrut(e));
1029          mapProc(oaCount,stgPrimCaseAlts(e));
1030          break;
1031       case STGAPP:
1032          oaCount(stgAppFun(e));
1033          mapProc(oaCount,stgAppArgs(e));
1034          break;
1035       case STGPRIM:
1036          mapProc(oaCount,stgPrimArgs(e));
1037          break;
1038       case STGCON:
1039          mapProc(oaCount,stgConArgs(e));
1040          break;
1041       case CASEALT:
1042          nullCounts(stgCaseAltVars(e));
1043          oaCount(stgCaseAltBody(e));
1044          break;
1045       case DEEFALT:
1046          nullCount(stgDefaultVar(e));
1047          oaCount(stgDefaultBody(e));
1048          break;
1049       case PRIMALT:
1050          nullCounts(stgPrimAltVars(e));
1051          oaCount(stgPrimAltBody(e));
1052          break;
1053       case STGVAR:
1054          if (isInt(stgVarInfo(e))) {
1055             Int i, j;
1056             j = -1;
1057             for (i = sp; i >= spBase; i--)
1058                if (stack(i) == e) { j = i; break; };
1059             if (j == -1)
1060                stgVarInfo(e) = mkInt(OCC_IN_LAMBDA); else
1061                stgVarInfo(e) = mkInt(1 + intOf(stgVarInfo(e)));
1062          }
1063          break;
1064       case NAME:
1065       case INTCELL:
1066       case STRCELL:
1067       case PTRCELL:
1068       case BIGCELL:
1069       case CHARCELL:
1070       case FLOATCELL:
1071          break;
1072       default:
1073          fprintf(stderr, "oaCount: unknown stuff %d\n",whatIsStg(e));
1074          assert(0);
1075    }
1076 }
1077
1078 void stgTopSanity ( char*, StgVar );
1079
1080 /* Top level entry point for the occurrence analyser. */
1081 void oaTop ( StgVar v )
1082 {
1083    assert (varSet_nfree == M_VAR_SETS);
1084    freeStgVarSet(oaScc(&stgVarBody(v)));
1085    assert (varSet_nfree == M_VAR_SETS);
1086    oaPre(stgVarBody(v));
1087    clearStack(); spBase = 0;
1088    oaCount(stgVarBody(v));
1089    assert(stackEmpty());
1090    stgTopSanity("oaTop",stgVarBody(v));
1091 }
1092
1093
1094 /* --------------------------------------------------------------------------
1095  * Transformation machinery proper
1096  * ------------------------------------------------------------------------*/
1097
1098 #define streq(aa,bb) (strcmp((aa),(bb))==0)
1099 /* Return TRUE if the non-default alts in the given list are exhaustive.
1100    If in doubt, return FALSE.
1101 */
1102 Bool stgAltsExhaustive ( List alts )
1103 {
1104    Int   nDefnCons;
1105    Name  con;
1106    Tycon t;
1107    List  cs;
1108    char* s;
1109    List  alts0 = alts;
1110    while (nonNull(alts) && isDefaultAlt(hd(alts))) alts=tl(alts);
1111    if (isNull(alts)) {
1112       return FALSE;
1113    } else {
1114       con = stgCaseAltCon(hd(alts));
1115       /* special case: dictionary constructor */
1116       if (strncmp("Make.",textToStr(name(con).text),5)==0)
1117          return TRUE;
1118       /* special case: constructor boxing an unboxed value. */
1119       if (isBoxingCon(con))
1120          return TRUE;
1121       /* some other special cases which are not boxingCons */
1122       s = textToStr(name(con).text);
1123       if (streq(s,"Integer#")
1124           || streq(s,"Ref#")
1125           || streq(s,"PrimMutableArray#")
1126           || streq(s,"PrimMutableByteArray#")
1127           || streq(s,"PrimByteArray#")
1128           || streq(s,"PrimArray#")
1129          )
1130          return TRUE;
1131       if (strcmp("Ref#",textToStr(name(con).text))==0)
1132          return TRUE;
1133       /* special case: Tuples */
1134       if (isTuple(con) || (isName(con) && con==nameUnit))
1135          return TRUE;
1136       if (isNull(name(con).parent)) internal("stgAltsExhaustive(1)");
1137       t = name(con).parent;
1138       cs = tycon(t).defn;
1139       if (tycon(t).what != DATATYPE) internal("stgAltsExhaustive(2)");
1140       nDefnCons = length(cs);
1141       for (; nonNull(alts0);alts0=tl(alts0)) {
1142          if (isDefaultAlt(hd(alts0))) continue;
1143          nDefnCons--;
1144       }
1145    }
1146    return nDefnCons == 0;
1147 }
1148 #undef streq
1149
1150
1151 /* If in doubt, return FALSE. 
1152 */
1153 Bool isManifestCon ( StgExpr e )
1154 {
1155    StgExpr altB;
1156    switch (whatIsStg(e)) {
1157       case STGCON: return TRUE;
1158       case LETREC: return isManifestCon(stgLetBody(e));
1159       case CASE:   if (length(stgCaseAlts(e))==1) {                      
1160                       if (isDefaultAlt(hd(stgCaseAlts(e))))
1161                          altB = stgDefaultBody(hd(stgCaseAlts(e))); else
1162                          altB = stgCaseAltBody(hd(stgCaseAlts(e)));
1163                          return isManifestCon(altB);
1164                    } else {
1165                       return FALSE;
1166                    }
1167       default:     return FALSE;
1168    }
1169 }
1170
1171
1172 /* Like isManifestCon, but doesn't give up at non-singular cases */
1173 Bool constructsCon ( StgExpr e )
1174 {
1175    List    as;
1176    switch (whatIsStg(e)) {
1177       case STGCON:   return TRUE;
1178       case LETREC:   return constructsCon(stgLetBody(e));
1179       case CASE:     for (as = stgCaseAlts(e); nonNull(as); as=tl(as))
1180                         if (!constructsCon(hd(as))) return FALSE;
1181                      return TRUE;
1182       case PRIMCASE: for (as = stgPrimCaseAlts(e); nonNull(as); as=tl(as))
1183                         if (!constructsCon(hd(as))) return FALSE;
1184                      return TRUE;
1185       case CASEALT:  return constructsCon(stgCaseAltBody(e));
1186       case DEEFALT:  return constructsCon(stgDefaultBody(e));
1187       case PRIMALT:  return constructsCon(stgPrimAltBody(e));
1188       default:       return FALSE;
1189    }
1190 }
1191
1192
1193 /* Inline v in the special case where expr is
1194    case v of C a1 ... an -> E
1195    and v's bindee returns a product constructed with C.
1196    and v does not appear in E
1197    and v does not appear in letDefs (ie, this expr isn't
1198        part of the definition of v.
1199 */
1200 void tryLoopbreakerHack ( List letDefs, StgExpr expr )
1201 {
1202    List       alts;
1203    StgExpr    scrut, ee, v_bindee;
1204    StgCaseAlt alt;
1205   
1206    assert (whatIsStg(expr)==CASE);
1207    alts      = stgCaseAlts(expr);
1208    scrut     = stgCaseScrut(expr);
1209    if (whatIsStg(scrut) != STGVAR || isNull(stgVarBody(scrut))) return;
1210    if (length(alts) != 1 || isDefaultAlt(hd(alts))) return;
1211    if (!stgAltsExhaustive(alts)) return;
1212    alt       = hd(alts);
1213    ee        = stgCaseAltBody(alt);
1214    if (nonNull(cellIsMember(scrut,letDefs))) return;
1215
1216    v_bindee  = stgVarBody(scrut);
1217    if (!isManifestCon(v_bindee)) return;
1218
1219    stgCaseScrut(expr) = cloneStgTop(v_bindee);
1220    nLoopBreakersInlined++;
1221 }
1222
1223
1224 /* Traverse a tree.  Replace let-bound vars marked as used-once
1225    by their definitions.  Replace references to top-level
1226    values marked inlineMe with their bodies.  Carry around a list
1227    of let-bound variables whose definitions we are currently in
1228    so as to know not to inline let-bound vars in their own
1229    definitions.
1230 */
1231 StgExpr copyIn ( List letDefs, InlineCtx ctx, StgExpr e )
1232 {
1233    List bs;
1234
1235    switch(whatIsStg(e)) {
1236       // these are the only two interesting cases
1237       case STGVAR:
1238          assert(isPtr(stgVarInfo(e)) || isNull(stgVarInfo(e)) || 
1239                 isInt(stgVarInfo(e)));
1240          if (isInt(stgVarInfo(e)) && intOf(stgVarInfo(e))==1) {
1241             nLetvarsInlined++;
1242             return cloneStgTop(stgVarBody(e)); 
1243          } else
1244             return e;
1245       case NAME:
1246          // if we're not inlining top vars on this round, do nothing
1247          if (!copyInTopvar) return e;
1248          // if it doesn't want to be inlined, do nothing
1249          if (!name(e).inlineMe) return e;
1250          // we decline to inline dictionary builders inside other builders
1251          if (inDBuilder && name(e).isDBuilder) {
1252            //fprintf(stderr, "decline to inline dbuilder %s\n", textToStr(name(e).text));
1253             return e;
1254          }
1255          // in fact, only inline dict builders into a case scrutinee
1256          if (name(e).isDBuilder && ctx != CTX_SCRUT)
1257             return e;
1258
1259 #if DEBUG_OPTIMISE
1260 assert( stgSize(stgVarBody(name(e).stgVar)) == name(e).stgSize );
1261 #endif
1262
1263          // only inline large dict builders if it returns a manifest con
1264          if (name(e).isDBuilder &&
1265              name(e).stgSize > 180 && 
1266              !isManifestCon(stgVarBody(name(e).stgVar)))
1267             return e;
1268 #if 0
1269          // if it's huge, don't inline into a boring place
1270          if (ctx != CTX_SCRUT &&
1271              name(e).stgSize > 270)
1272             return e;
1273 #endif
1274
1275          nTopvarsInlined++;
1276          return cloneStgTop(stgVarBody(name(e).stgVar));
1277
1278       // the rest are a boring recursive traversal of the tree      
1279       case LETREC:
1280          stgLetBody(e) = copyIn(letDefs,CTX_OTHER,stgLetBody(e));
1281          letDefs = dupOnto(stgLetBinds(e),letDefs);
1282          for (bs=stgLetBinds(e);nonNull(bs);bs=tl(bs))
1283             stgVarBody(hd(bs)) = copyIn(letDefs,CTX_OTHER,stgVarBody(hd(bs)));
1284          break;
1285       case LAMBDA:
1286          stgLambdaBody(e) = copyIn(letDefs,CTX_OTHER,stgLambdaBody(e));
1287          break;
1288       case CASE:
1289          stgCaseScrut(e) = copyIn(letDefs,CTX_SCRUT,stgCaseScrut(e));
1290          map2Over(copyIn,letDefs,CTX_OTHER,stgCaseAlts(e));
1291          if (copyInTopvar) tryLoopbreakerHack(letDefs,e);
1292          break;
1293       case PRIMCASE:
1294          stgPrimCaseScrut(e) = copyIn(letDefs,CTX_OTHER,stgPrimCaseScrut(e));
1295          map2Over(copyIn,letDefs,CTX_OTHER,stgPrimCaseAlts(e));
1296          break;
1297       case STGAPP:
1298          stgAppFun(e) = copyIn(letDefs,CTX_OTHER,stgAppFun(e));
1299          break;
1300       case CASEALT:
1301          stgCaseAltBody(e) = copyIn(letDefs,CTX_OTHER,stgCaseAltBody(e));
1302          break;
1303       case DEEFALT:
1304          stgDefaultBody(e) = copyIn(letDefs,CTX_OTHER,stgDefaultBody(e));
1305          break;
1306       case PRIMALT:
1307          stgPrimAltBody(e) = copyIn(letDefs,CTX_OTHER,stgPrimAltBody(e));
1308          break;
1309       case STGPRIM:
1310       case STGCON:
1311       case INTCELL:
1312       case STRCELL:
1313       case PTRCELL:
1314       case CHARCELL:
1315       case FLOATCELL:
1316          break;
1317       default:
1318          fprintf(stderr, "copyIn: unknown stuff %d\n",whatIsStg(e));
1319          ppStgExpr(e);
1320          printf("\n");
1321          print(e,1000);
1322          printf("\n");
1323          assert(0);
1324    }
1325    return e;
1326 }
1327
1328
1329
1330 /* case (C a1 ... an) of
1331       B ...       -> ...
1332       C v1 ... vn -> e
1333       D ...       -> ...
1334    ==>
1335    e with v1/a1 ... vn/an
1336 */
1337 StgExpr doCaseOfCon ( StgExpr expr, Bool* done )
1338 {
1339    StgExpr    scrut, e;
1340    StgVar     apC;
1341    StgCaseAlt theAlt;
1342    List       alts, altvs, as, sub;
1343
1344    *done  = FALSE;
1345    alts   = stgCaseAlts(expr);
1346    scrut  = stgCaseScrut(expr);
1347
1348    apC    = stgConCon(scrut);
1349
1350    theAlt = NIL;
1351    for (alts = stgCaseAlts(expr); nonNull(alts); alts=tl(alts))
1352       if (!isDefaultAlt(hd(alts)) && stgCaseAltCon(hd(alts)) == apC) {
1353          theAlt = hd(alts);
1354          break;
1355       }
1356
1357    if (isNull(theAlt)) return expr;
1358    altvs  = stgCaseAltVars(theAlt);
1359    e      = stgCaseAltBody(theAlt);
1360    as     = stgConArgs(scrut);
1361
1362    if (length(as)!=length(altvs)) return expr;
1363
1364    sub = NIL;
1365    while (nonNull(altvs)) {
1366       sub   = cons(pair(hd(altvs),hd(as)),sub);
1367       as    = tl(as);
1368       altvs = tl(altvs);
1369    }
1370    nCaseOfCon++;
1371    *done = TRUE;
1372    return zubstExpr(sub,e);
1373 }
1374
1375
1376 /* case (let binds in e) of alts
1377    ===>
1378    let binds in case e of alts
1379 */
1380 StgExpr doCaseOfLet ( StgExpr expr, Bool* done )
1381 {
1382    StgExpr letexpr, e;
1383    List    binds, alts;
1384
1385    letexpr = stgCaseScrut(expr);
1386    e       = stgLetBody(letexpr);
1387    binds   = stgLetBinds(letexpr);
1388    alts    = stgCaseAlts(expr);
1389    nCaseOfLet++;
1390    *done   = TRUE;
1391    return mkStgLet(binds,mkStgCase(e,alts));
1392 }
1393
1394
1395
1396 /* case (case e of p1 -> e1 ... pn -> en) of
1397       q1 -> h1
1398       ...
1399       qk -> hk
1400    ===>
1401    case e of 
1402       p1 -> case e1 of q1 -> h1 ... qk -> hk
1403       ...
1404       pn -> case en of q1 -> h1 ... qk -> kl
1405 */
1406 StgExpr doCaseOfCase ( StgExpr expr )
1407 {
1408    StgExpr innercase, e, tmpcase, protocase;
1409    List ps_n_es, qs_n_hs, newAlts;
1410    StgCaseAlt newAlt, p_n_e;
1411
1412    nCaseOfCase++;
1413
1414    innercase = stgCaseScrut(expr);
1415    e = stgCaseScrut(innercase);
1416    ps_n_es = stgCaseAlts(innercase);
1417    qs_n_hs = stgCaseAlts(expr);
1418
1419    /* protocase = case (hole-to-fill-in) of q1 -> h1 ... qk -> hk */
1420    protocase = mkStgCase( mkInt(0), qs_n_hs);
1421
1422    newAlts = NIL;
1423    for (;nonNull(ps_n_es);ps_n_es = tl(ps_n_es)) {
1424       tmpcase = cloneStgTop(protocase);
1425       p_n_e = hd(ps_n_es);
1426       if (isDefaultAlt(p_n_e)) {
1427          stgCaseScrut(tmpcase) = stgDefaultBody(p_n_e);
1428          newAlt = mkStgDefault(stgDefaultVar(p_n_e), tmpcase);
1429       } else {
1430          stgCaseScrut(tmpcase) = stgCaseAltBody(p_n_e);
1431          newAlt = mkStgCaseAlt(stgCaseAltCon(p_n_e),stgCaseAltVars(p_n_e),tmpcase);
1432       }
1433       newAlts = cons(newAlt,newAlts);
1434    }
1435    newAlts = rev(newAlts);
1436    return
1437       mkStgCase(e, newAlts);
1438 }
1439
1440
1441
1442 /* case (case# e of p1 -> e1 ... pn -> en) of
1443       q1 -> h1
1444       ...
1445       qk -> hk
1446    ===>
1447    case# e of 
1448       p1 -> case e1 of q1 -> h1 ... qk -> hk
1449       ...
1450       pn -> case en of q1 -> h1 ... qk -> kl
1451 */
1452 StgExpr doCaseOfPrimCase ( StgExpr expr )
1453 {
1454    StgExpr innercase, e, tmpcase, protocase;
1455    List ps_n_es, qs_n_hs, newAlts;
1456    StgCaseAlt newAlt, p_n_e;
1457
1458    nCaseOfPrimCase++;
1459
1460    innercase = stgCaseScrut(expr);
1461    e = stgPrimCaseScrut(innercase);
1462    ps_n_es = stgPrimCaseAlts(innercase);
1463    qs_n_hs = stgCaseAlts(expr);
1464
1465    /* protocase = case (hole-to-fill-in) of q1 -> h1 ... qk -> hk */
1466    protocase = mkStgCase( mkInt(0), qs_n_hs);
1467
1468    newAlts = NIL;
1469    for (;nonNull(ps_n_es);ps_n_es = tl(ps_n_es)) {
1470       tmpcase = cloneStgTop(protocase);
1471       p_n_e = hd(ps_n_es);
1472       stgPrimCaseScrut(tmpcase) = stgPrimAltBody(p_n_e);
1473       newAlt = mkStgPrimAlt(stgPrimAltVars(p_n_e),tmpcase);
1474       newAlts = cons(newAlt,newAlts);
1475    }
1476    newAlts = rev(newAlts);  
1477    return
1478       mkStgPrimCase(e, newAlts);
1479 }
1480
1481
1482 Bool isStgCaseWithSingleNonDefaultAlt ( StgExpr e )
1483 {
1484    return
1485       whatIsStg(e)==CASE &&
1486       length(stgCaseAlts(e))==1 &&
1487       !isDefaultAlt(hd(stgCaseAlts(e)));
1488 }
1489
1490
1491 /* Do simplifications on an Stg tree.  Invariant is that the
1492    input and output trees should have no name shadowing.
1493
1494    -- let { } in e
1495       ===>
1496       e
1497
1498    -- dump individual let-bindings with usage counts of zero
1499
1500    -- dump let-binding groups for which none of the bound vars
1501       occur in the let body
1502
1503    -- (\v1 ... vn -> e) a1 ... am
1504       ===>
1505       -- the usual beta reduction.  There are no constraints on n and m, so
1506          the result can be a lambda term (if n > m), or an application of e 
1507          to the unused args (if n < m).
1508
1509
1510   Scheme is: bottom-up traversal of the tree.  First simplify child
1511   trees.  Then try to do local transformations.  If a local transformation 
1512   succeeds, jump to the local-transformation code for whatever node
1513   is produced -- so as to try and maximise the amount of work which
1514   happens on each call to simplify.
1515 */
1516 StgExpr simplify ( List caseEnv, StgExpr e )
1517 {
1518    List bs, bs2;
1519    Bool done;
1520    Int  n;
1521
1522    restart:
1523    switch(whatIsStg(e)) {
1524       case STGVAR:
1525          return e;
1526       case NAME:
1527          return e;
1528
1529       case LETREC:
1530
1531          /* first dump dead binds, so as not to waste effort simplifying them */
1532          bs2=NIL;
1533          for (bs=stgLetBinds(e);nonNull(bs);bs=tl(bs))
1534             if (!isInt(stgVarInfo(hd(bs))) ||
1535                 intOf(stgVarInfo(hd(bs))) > 0) {
1536                bs2=cons(hd(bs),bs2);
1537             } else {
1538                nLetBindsDropped++;
1539             }
1540          if (isNull(bs2)) { e = stgLetBody(e); goto restart; };
1541          stgLetBinds(e) = rev(bs2);
1542
1543          for (bs=stgLetBinds(e);nonNull(bs);bs=tl(bs))
1544             stgVarBody(hd(bs)) = simplify(caseEnv,stgVarBody(hd(bs)));
1545          stgLetBody(e) = simplify(caseEnv,stgLetBody(e));
1546
1547          /* Merge let ... in let ... in e.  Grouping lets together
1548             sometimes reduces the number of iterations needed.
1549             oaScc should do this anyway, but this just to make sure.
1550          */
1551          while (whatIsStg(stgLetBody(e))==LETREC) {
1552             stgLetBinds(e) = dupOnto(stgLetBinds(stgLetBody(e)),stgLetBinds(e));
1553             stgLetBody(e) = stgLetBody(stgLetBody(e));
1554          }
1555
1556          let_local:
1557          /* let binds in case v-not-in-binds of singleAlt -> expr
1558             ===>
1559             case v-not-in-binds of singleAlt -> let binds in expr
1560          */
1561          if (isStgCaseWithSingleNonDefaultAlt(stgLetBody(e)) &&
1562              whatIsStg(stgCaseScrut(stgLetBody(e)))==STGVAR &&
1563              isNull(cellIsMember(stgCaseScrut(stgLetBody(e)),stgLetBinds(e)))) {
1564             StgVar     v = stgCaseScrut(stgLetBody(e));
1565             StgCaseAlt a = hd(stgCaseAlts(stgLetBody(e)));
1566             nLetsFloatedIntoCase++;
1567             e = mkStgCase( 
1568                    v, 
1569                    singleton( 
1570                       mkStgCaseAlt(
1571                          stgCaseAltCon(a),
1572                          stgCaseAltVars(a), 
1573                          mkStgLet(stgLetBinds(e),stgCaseAltBody(a))
1574                       )
1575                    )
1576                 );
1577             assert(whatIsStg(e)==CASE);
1578             goto case_local;
1579          }
1580           
1581          break;
1582
1583       case LAMBDA:
1584          stgLambdaBody(e) = simplify(caseEnv,stgLambdaBody(e));
1585
1586          lambda_local:
1587          while (whatIsStg(stgLambdaBody(e))==LAMBDA) {
1588             nLambdasMerged++;
1589             stgLambdaArgs(e) = appendOnto(stgLambdaArgs(e),
1590                                           stgLambdaArgs(stgLambdaBody(e)));
1591             stgLambdaBody(e) = stgLambdaBody(stgLambdaBody(e));
1592          }
1593          break;
1594
1595
1596       case CASE:
1597          stgCaseScrut(e) = simplify(caseEnv,stgCaseScrut(e));
1598          if (isStgCaseWithSingleNonDefaultAlt(e) &&
1599              (whatIsStg(stgCaseScrut(e))==STGVAR ||
1600               whatIsStg(stgCaseScrut(e))==NAME)) {
1601             List caseEnv2 = cons(
1602                                pair(stgCaseScrut(e),stgCaseAltVars(hd(stgCaseAlts(e)))),
1603                                caseEnv
1604                             );
1605             map1Over(simplify,caseEnv2,stgCaseAlts(e));
1606          } else {
1607             map1Over(simplify,caseEnv,stgCaseAlts(e));
1608          }
1609
1610          case_local:
1611          /* zap redundant default alternatives */
1612          if (stgAltsExhaustive(stgCaseAlts(e))) {
1613             Bool droppedDef = FALSE;
1614             bs2 = NIL;
1615             for (bs = dupList(stgCaseAlts(e));nonNull(bs);bs=tl(bs))
1616                if (!isDefaultAlt(hd(bs))) {
1617                   bs2=cons(hd(bs),bs2); 
1618                } else {
1619                   droppedDef = TRUE;
1620                }
1621             bs2 = rev(bs2);
1622             stgCaseAlts(e) = bs2;
1623             if (droppedDef) nCaseDefaultsDropped++;
1624          }
1625         
1626          switch (whatIsStg(stgCaseScrut(e))) {
1627             case CASE:
1628                /* attempt case-of-case */
1629                n = length(stgCaseAlts(e));
1630                if (n==1 || 
1631                            (n <= 3 && 
1632                             (stgSize(e)-stgSize(stgCaseScrut(e))) < 100 &&
1633                             constructsCon(stgCaseScrut(e)))
1634                   ) {
1635                   e = doCaseOfCase(e);
1636                   assert(whatIsStg(e)==CASE);
1637                   goto case_local;
1638                }
1639                break;
1640             case PRIMCASE:
1641                /* attempt case-of-case# */
1642                n = length(stgCaseAlts(e));
1643                if (n==1 || 
1644                            (n <= 3 && 
1645                             (stgSize(e)-stgSize(stgCaseScrut(e))) < 100 &&
1646                             constructsCon(stgCaseScrut(e)))
1647                   ) {
1648                   e = doCaseOfPrimCase(e);
1649                   assert(whatIsStg(e)==PRIMCASE);
1650                   goto primcase_local;
1651                }
1652                break;
1653             case LETREC:
1654                /* attempt case-of-let */
1655                e = doCaseOfLet(e,&done);
1656                if (done) { assert(whatIsStg(e)==LETREC); goto let_local; };
1657                break;
1658             case STGCON:
1659                /* attempt case-of-constructor */
1660                e = doCaseOfCon(e,&done);
1661                /* we don't know what the result is, so can't jump to local */
1662                break;
1663             case NAME:
1664             case STGVAR: {
1665                /* attempt to remove case on something already cased on */
1666                List outervs, innervs, sub;
1667                Cell lookupResult;
1668                if (!isStgCaseWithSingleNonDefaultAlt(e)) break;
1669                lookupResult = cellAssoc(stgCaseScrut(e),caseEnv);
1670                if (isNull(lookupResult)) break;
1671                outervs = snd(lookupResult);
1672                nCaseOfOuter++;
1673                sub = NIL;
1674                innervs = stgCaseAltVars(hd(stgCaseAlts(e)));
1675                for (; nonNull(outervs) && nonNull(innervs);
1676                       outervs=tl(outervs), innervs=tl(innervs))
1677                   sub = cons(pair(hd(innervs),hd(outervs)),sub);
1678                assert (isNull(outervs) && isNull(innervs));
1679                return zubstExpr(sub, stgCaseAltBody(hd(stgCaseAlts(e))));
1680                }
1681             default:
1682                break;
1683          }
1684          break;
1685       case CASEALT:
1686          stgCaseAltBody(e) = simplify(caseEnv,stgCaseAltBody(e));
1687          break;
1688       case DEEFALT:
1689          stgDefaultBody(e) = simplify(caseEnv,stgDefaultBody(e));
1690          break;
1691       case PRIMALT:
1692          stgPrimAltBody(e) = simplify(caseEnv,stgPrimAltBody(e));
1693          break;
1694       case PRIMCASE:
1695          stgPrimCaseScrut(e) = simplify(caseEnv,stgPrimCaseScrut(e));
1696          map1Over(simplify,caseEnv,stgPrimCaseAlts(e));
1697          primcase_local:
1698          break;
1699       case STGAPP: {
1700          List    sub, formals;
1701          StgExpr subd_body;
1702          StgExpr fun;
1703          List    args;
1704
1705          stgAppFun(e) = simplify(caseEnv,stgAppFun(e));
1706          map1Over(simplify,caseEnv,stgAppArgs(e));
1707
1708          fun  = stgAppFun(e);
1709          args = stgAppArgs(e);
1710
1711          switch (whatIsStg(fun)) {
1712             case STGAPP:
1713                nAppsMerged++;
1714                stgAppArgs(e) = appendOnto(stgAppArgs(fun),args);
1715                stgAppFun(e) = stgAppFun(fun);
1716                break;
1717             case LETREC:
1718                /* (let binds in f) args  ==> let binds in (f args) */
1719                nLetsFloatedOutOfFn++;
1720                e = mkStgLet(stgLetBinds(fun),mkStgApp(stgLetBody(fun),args));
1721                assert(whatIsStg(e)==LETREC);
1722                goto let_local;
1723                break;
1724             case CASE:
1725                if (length(stgCaseAlts(fun))==1 && 
1726                    !isDefaultAlt(hd(stgCaseAlts(fun)))) {
1727                   StgCaseAlt theAlt = hd(stgCaseAlts(fun));
1728                   /* (case e of alt -> f) args  ==> case e of alt -> f args */
1729                   e = mkStgCase(
1730                          stgCaseScrut(fun),
1731                          singleton(mkStgCaseAlt(stgCaseAltCon(theAlt),
1732                                                 stgCaseAltVars(theAlt),
1733                                                  mkStgApp(stgCaseAltBody(theAlt),args))
1734                          )
1735                       );
1736                   nCasesFloatedOutOfFn++;
1737                   assert(whatIsStg(e)==CASE);
1738                   goto case_local;
1739                }
1740                break;
1741             case LAMBDA: {
1742                sub      = NIL;
1743                formals  = stgLambdaArgs(fun);
1744                while (nonNull(formals) && nonNull(args)) {
1745                   sub     = cons(pair(hd(formals),hd(args)),sub);
1746                   formals = tl(formals);
1747                   args    = tl(args);
1748                }
1749                subd_body = zubstExpr(sub,stgLambdaBody(fun));
1750
1751                nBetaReductions++;
1752                assert(isNull(formals) || isNull(args));
1753                if (isNull(formals) && isNull(args)) {
1754                   /* fn and args match exactly */
1755                   e = subd_body;
1756                   return e;
1757                }
1758                else
1759                if (isNull(formals) && nonNull(args)) {
1760                   /* more args than we could deal with.  Build a new Ap. */
1761                   e = mkStgApp(subd_body,args);
1762                   return e;
1763                }
1764                else
1765                if (nonNull(formals) && isNull(args)) {
1766                   /* partial application.  We get a new Lambda */
1767                   e = mkStgLambda(formals,subd_body);
1768                   return e;
1769                }
1770                }
1771                break;
1772             default:
1773                break;
1774          }
1775          }
1776          break;
1777       case STGPRIM:
1778          break;
1779       case STGCON:
1780          break;
1781       case INTCELL:
1782       case STRCELL:
1783       case PTRCELL:
1784       case CHARCELL:
1785       case FLOATCELL:
1786          break;
1787       default:
1788          fprintf(stderr, "simplify: unknown stuff %d\n",whatIsStg(e));
1789          ppStgExpr(e);
1790          printf("\n");
1791          print(e,1000);
1792          printf("\n");
1793          assert(0);
1794    }
1795    return e;
1796 }
1797
1798
1799 /* Restore STG representation invariants broken by simplify.
1800    -- Let-bind any constructor applications which appear
1801       anywhere other than a let.
1802    -- Let-bind non-atomic case scrutinees (ToDo).
1803 */
1804 StgExpr restoreStg ( StgExpr e )
1805 {
1806    List bs;
1807    StgVar newv;
1808
1809    if (isNull(e)) return e;
1810
1811    switch(whatIsStg(e)) {
1812       case LETREC:
1813          for (bs=stgLetBinds(e); nonNull(bs); bs=tl(bs)) {
1814             if (whatIsStg(stgVarBody(hd(bs))) == STGCON) {
1815               /* do nothing */
1816             } 
1817             else
1818             if (whatIsStg(stgVarBody(hd(bs))) == LAMBDA) {
1819                stgLambdaBody(stgVarBody(hd(bs))) 
1820                   = restoreStg(stgLambdaBody(stgVarBody(hd(bs))));
1821             }
1822             else {
1823                stgVarBody(hd(bs)) = restoreStg(stgVarBody(hd(bs)));
1824             }
1825          }      
1826          stgLetBody(e) = restoreStg(stgLetBody(e));
1827          break;
1828       case LAMBDA:
1829          /* note that the check in LETREC above ensures we won't
1830             get here for legitimate (let-bound) lambdas. */
1831          stgLambdaBody(e) = restoreStg(stgLambdaBody(e));
1832          newv = mkStgVar(e,NIL);
1833          e = mkStgLet(singleton(newv),newv);
1834          break;
1835       case CASE:
1836          stgCaseScrut(e) = restoreStg(stgCaseScrut(e));
1837          mapOver(restoreStg,stgCaseAlts(e));
1838          if (!isAtomic(stgCaseScrut(e))) {
1839             newv = mkStgVar(stgCaseScrut(e),NIL);
1840             return mkStgLet(singleton(newv),mkStgCase(newv,stgCaseAlts(e)));
1841          }
1842          break;
1843       case PRIMCASE:
1844          stgPrimCaseScrut(e) = restoreStg(stgPrimCaseScrut(e));
1845          mapOver(restoreStg,stgPrimCaseAlts(e));
1846          break;
1847       case STGAPP:
1848          stgAppFun(e) = restoreStg(stgAppFun(e));
1849          mapOver(restoreStg,stgAppArgs(e)); /* probably incorrect */
1850          if (!isAtomic(stgAppFun(e))) {
1851             newv = mkStgVar(stgAppFun(e),NIL);
1852             e = mkStgLet(singleton(newv),mkStgApp(newv,stgAppArgs(e)));
1853          }
1854          break;
1855       case STGPRIM:
1856          mapOver(restoreStg,stgPrimArgs(e));
1857          break;
1858       case STGCON:
1859          /* note that the check in LETREC above ensures we won't
1860             get here for legitimate constructor applications. */
1861          mapOver(restoreStg,stgConArgs(e));
1862          newv = mkStgVar(e,NIL);
1863          return mkStgLet(singleton(newv),newv);
1864          break;
1865       case CASEALT:
1866          stgCaseAltBody(e) = restoreStg(stgCaseAltBody(e));
1867          if (whatIsStg(stgCaseAltBody(e))==LAMBDA) {
1868             newv = mkStgVar(stgCaseAltBody(e),NIL);
1869             stgCaseAltBody(e) = mkStgLet(singleton(newv),newv);
1870          }
1871          break;
1872       case DEEFALT:
1873          stgDefaultBody(e) = restoreStg(stgDefaultBody(e));
1874          if (whatIsStg(stgDefaultBody(e))==LAMBDA) {
1875             newv = mkStgVar(stgDefaultBody(e),NIL);
1876             stgDefaultBody(e) = mkStgLet(singleton(newv),newv);
1877          }
1878          break;
1879       case PRIMALT:
1880          stgPrimAltBody(e) = restoreStg(stgPrimAltBody(e));
1881          break;
1882       case STGVAR:
1883       case NAME:
1884       case INTCELL:
1885       case STRCELL:
1886       case PTRCELL:
1887       case CHARCELL:
1888       case FLOATCELL:
1889          break;
1890       default:
1891          fprintf(stderr, "restoreStg: unknown stuff %d\n",whatIsStg(e));
1892          ppStgExpr(e);
1893          printf("\n");
1894          assert(0);
1895    }
1896    return e;
1897 }
1898
1899
1900 StgExpr restoreStgTop ( StgExpr e )
1901 {
1902    if (whatIs(e)==LAMBDA)
1903       stgLambdaBody(e) = restoreStg(stgLambdaBody(e)); else
1904       e = restoreStg(e);
1905    return e;
1906 }
1907
1908
1909 void simplTopRefs ( StgExpr e )
1910 {
1911    List bs;
1912
1913    switch(whatIsStg(e)) {
1914      /* the only interesting case */
1915       case NAME:
1916          if (name(e).inlineMe && !name(e).simplified) {
1917             /* printf("\n((%d)) request for %s\n",rDepth, textToStr(name(e).text)); */
1918             name(e).simplified = TRUE;
1919             optimiseTopBind(name(e).stgVar);
1920             /* printf("((%d)) done    for %s\n",rDepth, textToStr(name(e).text)); */
1921          }
1922          break;
1923       case LETREC:
1924          simplTopRefs(stgLetBody(e));
1925          for (bs=stgLetBinds(e); nonNull(bs); bs=tl(bs))
1926             simplTopRefs(stgVarBody(hd(bs)));
1927          break;
1928       case LAMBDA:
1929          simplTopRefs(stgLambdaBody(e));
1930          break;
1931       case CASE:
1932          simplTopRefs(stgCaseScrut(e));
1933          mapProc(simplTopRefs,stgCaseAlts(e));
1934          break;
1935       case PRIMCASE:
1936          simplTopRefs(stgPrimCaseScrut(e));
1937          mapProc(simplTopRefs,stgPrimCaseAlts(e));
1938          break;
1939       case STGAPP:
1940          simplTopRefs(stgAppFun(e));
1941          mapProc(simplTopRefs,stgAppArgs(e));
1942          break;
1943       case STGCON:
1944          mapProc(simplTopRefs,stgConArgs(e));
1945          break;
1946       case STGPRIM:
1947          simplTopRefs(stgPrimOp(e));
1948          mapProc(simplTopRefs,stgPrimArgs(e));
1949          break;
1950       case CASEALT:
1951          simplTopRefs(stgCaseAltBody(e));
1952          break;
1953       case DEEFALT:
1954          simplTopRefs(stgDefaultBody(e));
1955          break;
1956       case PRIMALT:
1957          simplTopRefs(stgPrimAltBody(e));
1958          break;
1959       case INTCELL:
1960       case STRCELL:
1961       case PTRCELL:
1962       case BIGCELL:
1963       case CHARCELL:
1964       case FLOATCELL:
1965       case TUPLE:
1966       case STGVAR:
1967          break;
1968       default:
1969          fprintf(stderr, "simplTopRefs: unknown stuff %d\n",whatIsStg(e));
1970          ppStgExpr(e);
1971          printf("\n");
1972          print(e,1000);
1973          printf("\n");
1974          assert(0);
1975    }
1976 }
1977
1978 char* maybeName ( StgVar v )
1979 {
1980    Name n = nameFromStgVar(v);
1981    if (isNull(n)) return "(unknown)";
1982    return textToStr(name(n).text);
1983 }
1984
1985
1986 /* --------------------------------------------------------------------------
1987  * Sanity checking (weak :-(
1988  * ------------------------------------------------------------------------*/
1989
1990 Bool stgError;
1991
1992 int stgSanity_checkStack ( StgVar v )
1993 {
1994    int i, j;
1995    j = 0;
1996    for (i = 0; i <= sp; i++)
1997       if (stack(i)==v) j++;
1998    return j;
1999 }
2000
2001 void stgSanity_dropVar ( StgVar v )
2002 {
2003    drop();
2004 }
2005
2006 void stgSanity_pushVar ( StgVar v )
2007 {
2008    if (stgSanity_checkStack(v) != 0) stgError = TRUE;
2009    push(v);
2010 }
2011
2012
2013 void stgSanity ( StgExpr e )
2014 {
2015    List bs;
2016
2017    switch(whatIsStg(e)) {
2018       case LETREC:
2019          mapProc(stgSanity_pushVar,stgLetBinds(e));
2020          stgSanity(stgLetBody(e));
2021          for (bs=stgLetBinds(e); nonNull(bs); bs=tl(bs))
2022              stgSanity(stgVarBody(hd(bs)));
2023          mapProc(stgSanity_dropVar,stgLetBinds(e));
2024          break;
2025       case LAMBDA:
2026          mapProc(stgSanity_pushVar,stgLambdaArgs(e));
2027          stgSanity(stgLambdaBody(e));
2028          mapProc(stgSanity_dropVar,stgLambdaArgs(e));
2029          break;
2030       case CASE:
2031          stgSanity(stgCaseScrut(e));
2032          mapProc(stgSanity,stgCaseAlts(e));
2033          break;
2034       case PRIMCASE:
2035          stgSanity(stgPrimCaseScrut(e));
2036          mapProc(stgSanity,stgPrimCaseAlts(e));
2037          break;
2038       case STGAPP:
2039          stgSanity(stgAppFun(e));
2040          mapProc(stgSanity,stgAppArgs(e));
2041          break;
2042       case STGCON:
2043          stgSanity(stgConCon(e));
2044          mapProc(stgSanity,stgConArgs(e));
2045          break;
2046       case STGPRIM:
2047          stgSanity(stgPrimOp(e));
2048          mapProc(stgSanity,stgPrimArgs(e));
2049          break;
2050       case CASEALT:
2051          mapProc(stgSanity_pushVar,stgCaseAltVars(e));
2052          stgSanity(stgCaseAltBody(e));
2053          mapProc(stgSanity_dropVar,stgCaseAltVars(e));
2054          break;
2055       case DEEFALT:
2056          stgSanity_pushVar(stgDefaultVar(e));
2057          stgSanity(stgDefaultBody(e));
2058          stgSanity_dropVar(stgDefaultVar(e));
2059          break;
2060       case PRIMALT:
2061          mapProc(stgSanity_pushVar,stgPrimAltVars(e));
2062          stgSanity(stgPrimAltBody(e));
2063          mapProc(stgSanity_dropVar,stgPrimAltVars(e));
2064          break;
2065       case STGVAR:
2066          if (stgSanity_checkStack(e) == 1) break;
2067          if (nonNull(nameFromStgVar(e))) return;
2068          break;
2069       case NAME:
2070       case INTCELL:
2071       case STRCELL:
2072       case PTRCELL:
2073       case CHARCELL:
2074       case FLOATCELL:
2075       case TUPLE:
2076          break;
2077       default:
2078          fprintf(stderr, "stgSanity: unknown stuff %d\n",whatIsStg(e));
2079          ppStgExpr(e);
2080          printf("\n");
2081          print(e,1000);
2082          printf("\n");
2083          assert(0);
2084    }
2085 }
2086
2087
2088 void stgTopSanity ( char* caller, StgExpr e )
2089 {
2090 return;
2091    clearStack();
2092    assert(sp == -1);
2093    stgError = FALSE;
2094    stgSanity(e);
2095    assert(sp == -1);
2096    if (stgError) {
2097       fprintf(stderr, "\n\nstgTopSanity (caller = %s):\n\n", caller );
2098       ppStgExpr ( e );
2099       printf( "\n\n" );
2100       assert(0);
2101    }
2102 }
2103
2104
2105 /* Check if e is in a form which the code generator can deal with.
2106  * stgexpr-ness is what we need to enforce.  The extended version,
2107  * expr, may only occur as the rhs of a let binding.
2108  *
2109  * stgexpr ::= case atom of alts
2110  *           | case# primop{atom*} of primalts
2111  *           | let v_i = expr_i in stgexpr
2112  *           | var{atom*}
2113  *
2114  * expr ::= stgexpr
2115  *        | \v_i -> stgexpr
2116  *        | con{atoms}
2117  *
2118  *  alt ::= con vars -> stgexpr      (primalt and default similarly)
2119  *
2120  * atom ::= var | int | char etc     (unboxed, that is)
2121  */
2122 Bool isStgExpr     ( StgExpr e );
2123 Bool isStgFullExpr ( StgExpr e );
2124
2125 Bool isStgExpr ( StgExpr e )
2126 {
2127    List bs;
2128    switch (whatIs(e)) {
2129       case LAMBDA:
2130       case STGCON:
2131          return FALSE;
2132       case LETREC:
2133          for (bs=stgLetBinds(e); nonNull(bs); bs=tl(bs))
2134             if (!isStgFullExpr(stgVarBody(hd(bs))))
2135                return FALSE;
2136          return isStgExpr(stgLetBody(e));
2137       case CASE:
2138          for (bs=stgCaseAlts(e); nonNull(bs); bs=tl(bs))
2139             if (!isStgExpr(hd(bs))) return FALSE;
2140          return isAtomic(stgCaseScrut(e));
2141       case PRIMCASE:
2142          for (bs=stgPrimCaseAlts(e); nonNull(bs); bs=tl(bs))
2143             if (!isStgExpr(hd(bs))) return FALSE;
2144          if (isAtomic(stgPrimCaseScrut(e))) return TRUE;
2145          if (whatIs(stgPrimCaseScrut(e))==STGPRIM)
2146             return isStgExpr(stgPrimCaseScrut(e));
2147          return FALSE;
2148       case STGVAR:
2149       case NAME:
2150          return TRUE;
2151       case STGAPP:
2152          for (bs=stgAppArgs(e); nonNull(bs); bs=tl(bs))
2153             if (!isAtomic(hd(bs))) return FALSE;
2154          if (isStgVar(stgAppFun(e)) || isName(stgAppFun(e))) return TRUE;
2155          return FALSE;
2156       case STGPRIM:
2157          for (bs=stgPrimArgs(e); nonNull(bs); bs=tl(bs))
2158             if (!isAtomic(hd(bs))) return FALSE;
2159          if (isName(stgPrimOp(e))) return TRUE;
2160          return FALSE;
2161       case CASEALT:
2162          return isStgExpr(stgCaseAltBody(e));
2163       case DEEFALT:
2164          return isStgExpr(stgDefaultBody(e));
2165       case PRIMALT:
2166          return isStgExpr(stgPrimAltBody(e));
2167       default:
2168          return FALSE;
2169    }
2170 }
2171
2172
2173 Bool isStgFullExpr ( StgExpr e )
2174 {
2175    List bs;
2176    switch (whatIs(e)) {
2177       case LAMBDA:
2178          return isStgExpr(stgLambdaBody(e));
2179       case STGCON:
2180          for (bs=stgConArgs(e); nonNull(bs); bs=tl(bs))
2181             if (!isAtomic(hd(bs))) return FALSE;
2182          if (isName(stgConCon(e)) || isTuple(stgConCon(e)))
2183             return TRUE;
2184          return FALSE;
2185       default:
2186          return isStgExpr(e);
2187    }
2188 }
2189
2190
2191 /* --------------------------------------------------------------------------
2192  * Top level calls
2193  * ------------------------------------------------------------------------*/
2194
2195 /* Set ddumpSimpl to TRUE if you want to see simplified code. */
2196 static Bool ddumpSimpl = FALSE;
2197
2198 /* Leave this one alone ... */
2199 static Bool noisy;
2200
2201
2202 static void local optimiseTopBind( StgVar v )
2203 {
2204    Bool ppPrel = FALSE;
2205    Int  n, m;
2206    Name naam;
2207    Int  oldSize, newSize;
2208    Bool me;
2209
2210    /* printf( "[[%d]] looking at %s\n", rDepth, maybeName(v)); */
2211    assert(whatIsStg(v)==STGVAR);
2212
2213    rDepth++;
2214    if (nonNull(stgVarBody(v))) simplTopRefs(stgVarBody(v));
2215    rDepth--;
2216
2217    /* debugging ... */
2218    //me= 0&& 0==strcmp("tcUnify",maybeName(v));
2219    me= 0&& 0==strcmp("ttt",maybeName(v));
2220
2221    nTotSizeIn += stgSize(stgVarBody(v));
2222    if (noisy) {
2223       printf( "%28s: in %4d    ", maybeName(v),stgSize(stgVarBody(v))); 
2224       fflush(stdout);
2225    }
2226
2227    inDBuilder = FALSE;
2228    naam = nameFromStgVar(v);
2229    if (nonNull(naam) && name(naam).isDBuilder) inDBuilder = TRUE;
2230
2231 #if DEBUG_OPTIMISE
2232    if (nonNull(naam)) {
2233       assert(name(naam).stgSize == stgSize(stgVarBody(name(naam).stgVar)));
2234    }
2235 #endif
2236
2237    if (me) {
2238       fflush(stdout); fflush(stderr);
2239       fprintf ( stderr, "{{%d}}-----------------------------\n", -v );fflush(stderr);
2240       printStg ( stderr, v );
2241       fprintf(stderr, "\n" );
2242    }
2243
2244    stgTopSanity ( "initial", stgVarBody(v));
2245
2246    if (nonNull(stgVarBody(v))) {
2247       oldSize = -1;
2248
2249       for (n = 0; n < 8; n++) { // originally 7
2250          if (noisy) printf("%4d", stgSize(stgVarBody(v)));
2251          copyInTopvar = TRUE;
2252          stgTopSanity ( "outer-1", stgVarBody(v));
2253          oaTop ( v );
2254          stgTopSanity ( "outer-2", stgVarBody(v));
2255          stgVarBody(v) = copyIn ( NIL, CTX_OTHER, stgVarBody(v) );
2256          stgTopSanity ( "outer-3", stgVarBody(v));
2257          stgVarBody(v) = simplify ( NIL, stgVarBody(v) );
2258          stgTopSanity ( "outer-4", stgVarBody(v));
2259
2260          for (m = 0; m < 3; m++) { // oprignally 3
2261             if (noisy) printf("."); 
2262             fflush(stdout);
2263             copyInTopvar = FALSE;
2264             stgTopSanity ( "inner-1", stgVarBody(v));
2265             oaTop ( v );
2266             stgTopSanity ( "inner-2", stgVarBody(v));
2267             stgVarBody(v) = copyIn ( NIL, CTX_OTHER, stgVarBody(v) );
2268             stgTopSanity ( "inner-3", stgVarBody(v));
2269             stgVarBody(v) = simplify ( NIL, stgVarBody(v) );
2270
2271             if (me && 0) {
2272                fprintf(stderr,"\n-%d- - - - - - - - - - - - - -\n", n+1);
2273                printStg ( stderr,v );
2274             }
2275             stgTopSanity ( "inner-post", stgVarBody(v));
2276
2277          }
2278
2279          if (me && 1) {
2280             fprintf(stderr,"\n-%d-=-=-=-=-=-=-=-=-=-=-=-=-=-\n", n+1);
2281             printStg ( stderr,v );
2282          }
2283
2284          stgTopSanity ( "outer-post", stgVarBody(v));
2285
2286          newSize = stgSize ( stgVarBody(v) );
2287          if (newSize == oldSize) break;
2288          oldSize = newSize;
2289       }
2290       n++; for (; n < 8; n++) for (m = 0; m <= 3+3; m++) if (noisy) printf ( " " );
2291       if (noisy) printf(" --> %4d\n", stgSize(stgVarBody(v)) );
2292       stgVarBody(v) = restoreStgTop ( stgVarBody(v) );
2293
2294       if (nonNull(naam)) {
2295          assert(name(naam).stgVar == v);
2296          name(naam).stgSize = stgSize(stgVarBody(v));
2297       }
2298
2299 #if DEBUG_OPTIMISE
2300       /* debugging ... */
2301       if (!isStgFullExpr(stgVarBody(v))) {
2302          fprintf(stderr, "\n\nrestoreStg failed!\n\n" );
2303          printStg(stderr, v);
2304          fprintf(stderr, "\n" );
2305          exit(1);
2306       }
2307 #endif
2308    }
2309
2310    nTotSizeOut += stgSize(stgVarBody(v));
2311
2312    if (me) {
2313       fprintf(stderr,"\n=============================\n");
2314       printStg ( stderr,v );
2315       fprintf(stderr, "\n\n" );
2316       fflush(stderr);
2317       if (me) exit(1);
2318    }
2319 }
2320
2321
2322 void optimiseTopBinds ( List bs )
2323 {
2324    List t;
2325    Name n;
2326    Target ta = 0;
2327
2328    noisy = ddumpSimpl && (lastModule() != modulePrelude);
2329
2330    optimiser(RESET);
2331    if (noisy) printf("\n");
2332    initOptStats();
2333
2334    for (t = bs; nonNull(t); t=tl(t)) {
2335       n = nameFromStgVar(hd(t));
2336       if (isNull(n) || !name(n).simplified) {
2337          rDepth = 0;
2338          optimiseTopBind(hd(t));
2339       }
2340       soFar(ta++);
2341    }
2342    if (noisy) printOptStats ( stderr );
2343    optimiser(RESET);
2344 }
2345
2346
2347 /* --------------------------------------------------------------------------
2348  * Optimiser control:
2349  * ------------------------------------------------------------------------*/
2350
2351 Void optimiser(what)
2352 Int what; {
2353
2354     switch (what) {
2355         case INSTALL :
2356         case RESET   : spClone = SP_NOT_IN_USE;
2357                        initStgVarSets();
2358                        daSccs = NIL;
2359                        break;
2360
2361         case MARK    : markPairs();
2362                        markStgVarSets();
2363                        mark(daSccs);
2364                        break;
2365
2366         case GCDONE  : checkStgVarSets();
2367                        break;
2368     }
2369 }
2370
2371 /*-------------------------------------------------------------------------*/