move graphical sorting demo into SortingDemo, make a command-line sorting regression...
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import java.io.*;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.interpreter.*;
7 import edu.berkeley.fleet.fpga.*;
8 import org.ibex.graphics.*;
9
10 public class MergeSort {
11
12     private final Fleet fleet;
13     private final Ship mem1;
14     private final Ship mem2;
15     private final int arity;
16     private final ShipPool pool;
17     private final Program program;
18     private final DataFlowGraph dfg;
19
20     private ParameterNode[] pn0;
21     private ParameterNode[] pn1;
22     private ParameterNode[] pn2;
23     private ParameterNode[] pn3;
24     private ParameterNode[] pn4;
25     private ParameterNode[] pn_base1;
26     private ParameterNode   pn_base2;
27     private ParameterNode   pn5;
28     private ParameterNode   pn6;
29     private ParameterNode   pn_end;
30
31     CodeBag cb2 = null;
32     Destination next_dest = null;
33
34     public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
35         this.fleet = fleet;
36         this.mem1 = mem1;
37         this.mem2 = mem2;
38         this.arity = arity;
39         this.pool = pool;
40         this.program = program;
41         this.dfg = new DataFlowGraph(fleet, pool);
42         next_dest = makeDfgCodeBag(dfg, program);
43         cb2 = dfg.build(new CodeBag(dfg.fleet, program));
44         cb2.seal();
45     }
46
47     public Destination makeDfgCodeBag(DataFlowGraph dfg, Program program) {
48
49         MemoryNode mem_read  = new MemoryNode(dfg, mem1);
50         MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
51
52         AluNode sm = new AluNode(dfg, "MAXMERGE");
53
54         pn0 = new ParameterNode[arity];
55         pn1 = new ParameterNode[arity];
56         pn2 = new ParameterNode[arity];
57         pn3 = new ParameterNode[arity];
58         pn4 = new ParameterNode[arity];
59         pn_base1 = new ParameterNode[arity];
60         pn_base2 = new ParameterNode(dfg, true);
61         pn_end = new ParameterNode(dfg);
62         pn5 = new ParameterNode(dfg);
63         pn6 = new ParameterNode(dfg, true);
64
65         // So far: we have four spare Counter ships; one can be used for resetting
66         for(int i=0; i<arity; i++) {
67
68             DownCounterNode c0 = new DownCounterNode(dfg);
69             DownCounterNode c1 = new DownCounterNode(dfg);
70
71             c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
72             c0.incr.connect(new ForeverNode(dfg, 1).out);
73             c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
74             c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
75
76             RepeatNode r1 = new RepeatNode(dfg);
77             r1.val.connect(c1.out);
78             r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
79
80             AluNode alu1 = new AluNode(dfg, "ADD");
81             AluNode alu2 = new AluNode(dfg, "ADD");
82             alu1.in1.connect(r1.out);
83             alu1.in2.connect(c0.out);
84             alu1.out.connect(alu2.in2);
85             alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
86             alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
87
88             PunctuatorNode punc = new PunctuatorNode(dfg, -1);
89             punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
90             punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
91             punc.out.connect(i==0 ? sm.in1 : sm.in2);
92         }
93
94         UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
95         unpunc.val.connect(sm.out);
96         unpunc.count.connect(pn6.out);
97
98         DownCounterNode cw = new DownCounterNode(dfg);
99         cw.start.connect(pn5.out);
100         cw.incr.connect(new OnceNode(dfg, 1).out);
101
102         AluNode alu = new AluNode(dfg, "ADD");
103         alu.in1.connect(pn_base2.out);
104         cw.out.connect(alu.in2);
105         mem_write.inAddrWrite.connect(alu.out);
106         mem_write.inDataWrite.connect(unpunc.out);
107
108         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
109         discard.count.connect(pn_end.out);
110         mem_write.outWrite.connect(discard.val);
111         DoneNode done = new DoneNode(dfg, program);
112         discard.out.connect(done.in);
113
114         return done.getDestinationToSendNextCodeBagDescriptorTo();
115     }
116
117     public CodeBag makeInstance(int offset, int length) throws Exception {
118
119         long base_read = offset;
120         long base_write = length + offset;
121         int stride = 1;
122         ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
123         codeBags.add(new CodeBag(dfg.fleet, program));
124         for(;;) {
125             CodeBag cb = codeBags.get(codeBags.size()-1);
126             boolean last = stride*2 >= length;
127             CodeBag next;
128             if (last) {
129                 next = program.getEndProgramCodeBag();
130             } else {
131                 next = new CodeBag(dfg.fleet, program);
132                 codeBags.add(codeBags.size(), next);
133             }
134
135             for(int i=0; i<arity; i++) {
136                 if (pn0[i]!=null) pn0[i].set(cb, stride);
137                 if (pn1[i]!=null) pn1[i].set(cb, length + i*stride);
138                 if (pn2[i]!=null) pn2[i].set(cb, 2*stride);
139                 if (pn3[i]!=null) pn3[i].set(cb, stride);
140                 if (pn4[i]!=null) pn4[i].set(cb, stride);
141                 if (pn_base1[i]!=null) pn_base1[i].set(cb, base_read);
142             }
143             pn5.set(cb, length);
144             pn6.set(cb, 2*stride+1);
145             pn_base2.set(cb, base_write);
146             pn_end.set(cb, length);
147             cb.sendWord(cb2.getDescriptor(), program.getCBDDestination());
148             cb.sendWord(next.getDescriptor(), next_dest);
149
150             cb.seal();
151             long base = base_read; base_read=base_write; base_write=base;
152             if (last) break;
153             if ((stride*2) < length) stride *= 2;
154         }
155
156         return codeBags.get(0);
157     }
158
159     public static void main(String[] s) throws Exception {
160         if (s.length != 4) {
161             System.err.println("usage: java " + MergeSort.class.getName() + " <target> <shipname> <base> <length>");
162             System.exit(-1);
163         }
164         Fleet fleet = null;
165         if (s[0].equals("fpga"))             fleet = new Fpga();
166         else if (s[0].equals("interpreter")) {
167             fleet = new Interpreter();
168             Log.log = null;
169         }
170         ShipPool pool = new ShipPool(fleet);
171         Ship memory = pool.allocateShip(s[1]);
172         int base = Integer.parseInt(s[2]);
173         int length = Integer.parseInt(s[3]);
174
175         Random random = new Random(System.currentTimeMillis());        
176         BitVector[] vals  = new BitVector[length];
177         long[] longs = new long[length];
178         for(int i=0; i<vals.length; i++) {
179             vals[i] = new BitVector(fleet.getWordWidth()).set(random.nextLong());
180             for(int j=36; j<vals[i].length(); j++) vals[i].set(j, false);
181             longs[i] = vals[i].toLong();
182         }
183         Arrays.sort(longs);
184
185         Ship codemem = pool.allocateShip("Memory");
186
187         FleetProcess fp = fleet.run(new Instruction[0]);
188         ShipPool pool2 = new ShipPool(pool);
189
190         MemoryUtils.writeMem(fp, pool2, memory, base, vals);
191         pool2.releaseAll();
192
193         Program program = new Program(codemem);
194         CodeBag cb = new MergeSort(fleet, program, pool2, 2, memory, memory).makeInstance(base, length);
195         pool2.releaseAll();
196         long ret = program.run(fp, cb, pool2);
197         pool2.releaseAll();
198
199         MemoryUtils.readMem(fp, pool2, memory, base, vals);
200         pool2.releaseAll();
201
202         System.out.println();
203         int fails = 0;
204         for(int i=0; i<vals.length; i++)
205             if (vals[i].toLong() != longs[i]) {
206                 System.out.println("disagreement!  on index " +
207                                    i + "\n  expected="+
208                                    new BitVector(fleet.getWordWidth()).set(longs[i])+
209                                    "\n       got="+vals[i]);
210                 fails++;
211             }
212         System.out.println("done! ("+fails+" failures)");
213         if (fails>0) System.exit(-1);
214         
215     }
216     
217 }