major refactoring of edu.berkeley.fleet.dataflow
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import java.io.*;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.fpga.*;
7 //import org.ibex.graphics.*;
8
9 public class MergeSort {
10
11     private final Fleet fleet;
12     private final Ship mem1;
13     private final Ship mem2;
14     private final int arity;
15     private final ShipPool pool;
16     private final Program program;
17     private final DataFlowGraph dfg;
18
19     private ParameterNode[] pn0;
20     private ParameterNode[] pn1;
21     private ParameterNode[] pn2;
22     private ParameterNode[] pn3;
23     private ParameterNode[] pn4;
24     private ParameterNode[] pn_base1;
25     private ParameterNode   pn_base2;
26     private ParameterNode   pn5;
27     private ParameterNode   pn6;
28     private ParameterNode   pn_end;
29
30     CodeBag cb2 = null;
31     Destination next_dest = null;
32
33     public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
34         this.fleet = fleet;
35         this.mem1 = mem1;
36         this.mem2 = mem2;
37         this.arity = arity;
38         this.pool = pool;
39         this.program = program;
40         this.dfg = new DataFlowGraph(fleet, pool);
41         /*
42         pool.allocateShip(mem1);
43         if (mem2 != mem1) pool.allocateShip(mem2);
44         */
45         next_dest = makeDfgCodeBag(dfg);
46         cb2 = dfg.build(new CodeBag(dfg.fleet, program));
47         cb2.seal();
48     }
49
50     public Destination makeDfgCodeBag(DataFlowGraph dfg) {
51
52         MemoryNode mem_read  = new MemoryNode(dfg, mem1);
53         MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
54
55         AluNode sm = new AluNode(dfg, "MAXMERGE");
56
57         pn0 = new ParameterNode[arity];
58         pn1 = new ParameterNode[arity];
59         pn2 = new ParameterNode[arity];
60         pn3 = new ParameterNode[arity];
61         pn4 = new ParameterNode[arity];
62         pn_base1 = new ParameterNode[arity];
63         pn_base2 = new ParameterNode(dfg, true);
64         pn_end = new ParameterNode(dfg);
65         pn5 = new ParameterNode(dfg);
66         pn6 = new ParameterNode(dfg, true);
67
68         // So far: we have four spare Counter ships; one can be used for resetting
69         for(int i=0; i<arity; i++) {
70
71             DownCounterNode c0 = new DownCounterNode(dfg);
72             DownCounterNode c1 = new DownCounterNode(dfg);
73
74             c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
75             c0.incr.connect(new ForeverNode(dfg, 1).out);
76             c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
77             c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
78
79             RepeatNode r1 = new RepeatNode(dfg);
80             r1.val.connect(c1.out);
81             r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
82
83             AluNode alu1 = new AluNode(dfg, "ADD");
84             AluNode alu2 = new AluNode(dfg, "ADD");
85             alu1.in1.connect(r1.out);
86             alu1.in2.connect(c0.out);
87             alu1.out.connect(alu2.in2);
88             alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
89             alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
90
91             PunctuatorNode punc = new PunctuatorNode(dfg, -1);
92             punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
93             punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
94             punc.out.connect(i==0 ? sm.in1 : sm.in2);
95         }
96
97         UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
98         unpunc.val.connect(sm.out);
99         unpunc.count.connect(pn6.out);
100
101         DownCounterNode cw = new DownCounterNode(dfg);
102         cw.start.connect(pn5.out);
103         cw.incr.connect(new OnceNode(dfg, 1).out);
104
105         AluNode alu = new AluNode(dfg, "ADD");
106         alu.in1.connect(pn_base2.out);
107         cw.out.connect(alu.in2);
108         mem_write.inAddrWrite.connect(alu.out);
109         mem_write.inDataWrite.connect(unpunc.out);
110
111         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
112         discard.count.connect(pn_end.out);
113         mem_write.outWrite.connect(discard.val);
114         DoneNode done = new DoneNode(dfg, program);
115         discard.out.connect(done.in);
116
117         return done.getDestinationToSendNextCodeBagDescriptorTo();
118     }
119
120     public CodeBag build(DataFlowGraph dfg,
121                          CodeBag ctx,
122                          int vals_length, int stride_length,
123                          long base1,
124                          long base2,
125                          CodeBag next) throws Exception {
126        
127         for(int i=0; i<arity; i++) {
128             if (pn0[i]!=null) pn0[i].set(ctx, stride_length);
129             if (pn1[i]!=null) pn1[i].set(ctx, vals_length + i*stride_length);
130             if (pn2[i]!=null) pn2[i].set(ctx, 2*stride_length);
131             if (pn3[i]!=null) pn3[i].set(ctx, stride_length);
132             if (pn4[i]!=null) pn4[i].set(ctx, stride_length);
133             if (pn_base1[i]!=null) pn_base1[i].set(ctx, base1);
134         }
135         pn5.set(ctx, vals_length);
136         pn6.set(ctx, 2*stride_length+1);
137         pn_base2.set(ctx, base2);
138         pn_end.set(ctx, vals_length);
139         ctx.sendWord(cb2.getDescriptor(), program.getCBDDestination());
140         ctx.sendWord(next.getDescriptor(), next_dest);
141         ctx.seal();
142
143         return ctx;
144     }
145
146     public static void main(String[] s) throws Exception {
147         //mergeSort(1024*64, 4, "Dvi",
148         mergeSort(1024*128, 4, "Dvi",
149                   4194304
150                   );
151                   //548*478);
152     }
153
154     /** demo */
155     public static void main0(String[] s) throws Exception {
156         PrintWriter pw = new PrintWriter(new OutputStreamWriter(new FileOutputStream("stats.txt")));
157         //int inflight = 1;
158         for(int inflight=1; inflight <= 8; inflight++)
159         for(int count = 32; count < 2097152; count *= 2) {
160             System.out.println("==============================================================================");
161             System.out.println("count="+count);
162             System.out.println("inflight="+inflight);
163             //long time = timeit(count, inflight);
164             long time = mergeSort(count, inflight, "DDR2", 0);
165             pw.println(inflight + ", " + count + ", " + time);
166             pw.flush();
167         }
168     }
169
170     public static long timeit(int count, int inflight) throws Exception {
171         Fleet fleet = new Fpga();
172         FleetProcess fp = fleet.run(new Instruction[0]);
173         ShipPool pool = new ShipPool(fleet);
174         //Program program = new Program(pool.allocateShip("Memory"));
175         CodeBag cb = new CodeBag(fleet);
176
177         Ship counter1 = pool.allocateShip("Counter");
178         Ship counter2 = pool.allocateShip("Counter");
179
180         Ship timer    = pool.allocateShip("Timer");
181         Ship debug    = pool.allocateShip("Debug");
182
183         LoopFactory lf;
184
185         lf = cb.loopFactory(debug.getDock("in"), 2);
186         lf.recvWord();
187         lf.deliver();
188
189
190         // First counter //////////////////////////////////////////////////////////////////////////////
191
192         lf = cb.loopFactory(counter1.getDock("in1"), 1);
193         lf.recvToken();
194         lf.literal(count);
195         lf.deliver();
196
197         lf = cb.loopFactory(counter1.getDock("in2"), 1);
198         lf.literal(1);
199         lf.deliver();
200
201         lf = cb.loopFactory(counter1.getDock("inOp"), 1);
202         lf.literal("COUNT");
203         lf.deliver();
204
205         lf = cb.loopFactory(counter1.getDock("out"), 0);
206         lf.recvToken();
207         lf.collectWord();
208         lf.sendWord(counter2.getDock("in2"));
209
210
211         // Second counter //////////////////////////////////////////////////////////////////////////////
212
213         lf = cb.loopFactory(counter2.getDock("in1"), 1);
214         lf.literal(count);
215         lf.deliver();
216         lf.literal(1);
217         lf.deliver();
218         lf.literal(1);
219         lf.deliver();
220
221         lf = cb.loopFactory(counter2.getDock("in2"), 1);
222         for(int i=0; i<inflight; i++)
223             lf.sendToken(counter1.getDock("out"));
224         lf = lf.makeNext(0);
225         lf.recvWord();
226         lf.sendToken(counter1.getDock("out"));
227         lf.deliver();
228
229         lf = cb.loopFactory(counter2.getDock("inOp"), 1);
230         lf.literal("DROP_C1_V2");
231         lf.deliver();
232         lf.literal("PASS_C1_V1");
233         lf.deliver();
234
235         lf = cb.loopFactory(counter2.getDock("out"), 1);
236         lf.collectWord();
237         lf.sendToken(timer.getDock("out"));
238
239
240         // Timer //////////////////////////////////////////////////////////////////////////////
241
242         lf = cb.loopFactory(timer.getDock("out"), 1);
243         lf.collectWord();
244         lf.sendToken(counter1.getDock("in1"));
245         lf.sendWord(debug.getDock("in"));
246         lf.recvToken();
247         lf.collectWord();
248         lf.sendWord(debug.getDock("in"));
249
250         FpgaDock out = (FpgaDock)counter1.getDock("out");
251         FpgaDock in  = (FpgaDock)counter2.getDock("in2");
252         System.out.println("distance is " + out.getPathLength((FpgaDestination)in.getDataDestination()));
253         System.out.println("reverse distance is " + in.getPathLength((FpgaDestination)out.getDataDestination()));
254
255         for(Instruction i : cb.emit()) System.out.println(i);
256         cb.dispatch(fp, true);
257         long time1 = fp.recvWord().toLong();
258         System.out.println("got " + time1);
259         long time2 = fp.recvWord().toLong();
260         System.out.println("got " + time2);
261         System.out.println("diff=" + (time2-time1));
262
263         fp.terminate();
264
265         return (time2-time1);
266     }
267
268     public static long mergeSort(int vals_length, int inflight, String shipType, int clearAmount) throws Exception {
269         Node.CAPACITY = inflight;
270
271         Fleet fleet = new Fpga();
272         FleetProcess fp = fleet.run(new Instruction[0]);
273         ShipPool pool = new ShipPool(fleet);
274         Ship mem1 = pool.allocateShip(shipType);
275
276
277         if (clearAmount > 0)
278             randomizeMemory(fp, pool, mem1, 0, clearAmount, false);
279
280         //randomizeMemory(fp, pool, mem1, 0, vals_length, true);
281
282         BitVector[] bvs = new BitVector[vals_length];
283
284         long index = 0;
285         /*
286         Picture p = new Picture(new FileInputStream("campus.png"));
287         for(int y=0; y<(478/2); y++)
288             for(int x=0; x<544; x++) {
289                 if (index >= vals_length) break;
290                 int pixel = (x>=p.width) ? 0 : p.data[p.width*y+x];
291                 long r = (pixel>>0)  & 0xff;
292                 long g = (pixel>>8)  & 0xff;
293                 long b = (pixel>>16) & 0xff;
294                 r >>= 2;
295                 g >>= 2;
296                 b >>= 2;
297                 //r = ~(-1L<<6);
298                 //g = ~(-1L<<6);
299                 //b = ~(-1L<<6);
300                 bvs[(int)index] = new BitVector(fleet.getWordWidth()).set( r | (g<<6) | (b<<12) | (index<<18) );
301                 index++;
302             }
303         */
304         for(; index<vals_length; index++) {
305             long tag = index<<18;
306             bvs[(int)index] = new BitVector(fleet.getWordWidth()).set( tag );
307         }
308
309         System.out.println("final index " + index);
310
311         Random random = new Random(System.currentTimeMillis());
312         for(int i=0; i<bvs.length*10; i++) {
313             int from = Math.abs(random.nextInt()) % bvs.length;
314             int to   = Math.abs(random.nextInt()) % bvs.length;
315             BitVector bv = bvs[from];
316             bvs[from] = bvs[to];
317             bvs[to] = bv;
318         }
319
320         MemoryUtils.writeMem(fp, pool, mem1, offset, bvs);
321
322         Program program = new Program(pool.allocateShip("Memory"));
323         long ret = new MergeSort(fleet, program, pool, 2, mem1, mem1).main(fp, vals_length);
324
325         //long ret = 0;
326         // verify the cleanup?
327         //CleanupUtils.verifyClean(fp);
328         //MemoryUtils.readMem(fp, new ShipPool(fp.getFleet()), mem1, 0, bvs);
329
330         BitVector[] bvx = new BitVector[1024];
331         MemoryUtils.readMem(fp, new ShipPool(fp.getFleet()), mem1, 0, bvx);
332         for(int i=0; i<bvx.length; i++)
333             System.out.println(bvx[i]);
334         /*
335         System.out.println("results:");
336         for(int i=0; i<vals_length-1; i++)
337             if ( (bvs[i].toLong() & ~(-1L<<18)) != ~(-1L<<18))
338                 System.out.println(bvs[i]);
339         */
340         /*
341         for(int i=0; i<vals_length-1; i++) {
342             if (bvs[i].toLong() > bvs[i+1].toLong())
343                 System.out.println("sort failure at "+i+":\n  "+bvs[i]+"\n  "+bvs[i+1]);
344         }
345         */
346         fp.terminate();
347         return ret;
348     }
349
350     static int offset = 40;
351
352     //static int offset = 32;
353     //static int offset = 544*2;
354     //static int offset = 544;
355
356     public long main(FleetProcess fp, int vals_length) throws Exception {
357
358         long base_read = offset;
359         //long base_write = offset+((vals_length/544)+1)*544;
360         //long base_write = offset;
361         long base_write = vals_length + offset;
362
363         int stride = 1;
364         ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
365         codeBags.add(new CodeBag(dfg.fleet, program));
366         for(; ;) {
367             CodeBag cb = codeBags.get(codeBags.size()-1);
368             //System.out.println("cb="+i+", stride="+stride);
369             boolean last = /*(base_write==offset) && */ (stride*2 >= vals_length);
370             CodeBag next;
371             if (last) {
372                 next = program.getEndProgramCodeBag();
373             } else {
374                 next = new CodeBag(dfg.fleet, program);
375                 codeBags.add(codeBags.size(), next);
376             }
377             build(dfg, cb, vals_length, stride, base_read, base_write, next);
378             cb.seal();
379             long base = base_read; base_read=base_write; base_write=base;
380             //i++;
381             if (last) break;
382             if ((stride*2) < vals_length) stride *= 2;
383         }
384         System.out.println("done building codebags; installing...");
385         System.out.println("cb.length=="+codeBags.size());
386
387         Ship button = pool.allocateShip("Button");
388         CodeBag cb = new CodeBag(dfg.fleet, program);
389         LoopFactory lf = cb.loopFactory(button.getDock("out"), 1);
390         lf.collectWord();
391         lf.literal(codeBags.get(0).getDescriptor());
392         lf.sendWord(program.getCBDDestination());
393         cb.seal();
394
395         // FIXME
396         return program.run(fp, cb, new ShipPool(fp.getFleet()));
397     }
398
399     public static void randomizeMemory(FleetProcess fp, ShipPool pool_, Ship memory, long start, long length, boolean randomize) {
400         ShipPool pool = new ShipPool(pool_);
401         Ship mem = pool.allocateShip("Memory");
402         Program prog = new Program(mem);
403
404         DataFlowGraph dfg = new DataFlowGraph(fp.getFleet(), pool);
405         DownCounterNode dcn = new DownCounterNode(dfg);
406         dcn.start.connectOnce(length);
407         dcn.incr.connectOnce(1);
408
409         AluNode alu = new AluNode(dfg, "ADD");
410         alu.in1.connectForever(start);
411         alu.in2.connect(dcn.out);
412
413         MemoryNode mn = new MemoryNode(dfg, memory);
414         mn.inAddrWrite.connect(alu.out);
415
416         AluNode aluAnd = new AluNode(dfg, "AND");
417         if (randomize) {
418             aluAnd.in1.connect(new RandomNode(dfg).out);
419         } else {
420             //aluAnd.in1.connectForever( ~(-1L<<36) );
421             aluAnd.in1.connectForever( 0 );
422         }
423         aluAnd.in2.connectForever( ~(-1<<18) );
424         mn.inDataWrite.connect(aluAnd.out);
425         
426         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
427         discard.count.connectOnce(length);
428         discard.val.connect(mn.outWrite);
429         DoneNode done = new DoneNode(dfg, prog);
430         discard.out.connect(done.in);
431
432         CodeBag cb = new CodeBag(fp.getFleet(), prog);
433         dfg.build(cb);
434         cb.seal();
435
436         CodeBag cb2 = new CodeBag(fp.getFleet(), prog);
437         Ship button = fp.getFleet().getShip("Button",0);
438
439         LoopFactory lf = cb2.loopFactory(button.getDock("out"), 1);
440         //lf.collectWord();
441         lf.literal(prog.getEndProgramCodeBag().getDescriptor());
442         lf.sendWord(done.getDestinationToSendNextCodeBagDescriptorTo());
443         lf.literal(cb.getDescriptor());
444         lf.sendWord(prog.getCBDDestination());
445         cb2.seal();
446
447         System.out.println("dispatching randomization codebag...");
448         prog.run(fp, cb2, pool);
449         System.out.println("  randomization done.");
450         pool.releaseAll();
451     }
452 }