876cf918304b1ab179d8a216400aba9f34ccb2c0
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import java.io.*;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.fpga.*;
7 import org.ibex.graphics.*;
8
9 public class MergeSort {
10
11     private final Fleet fleet;
12     private final Ship mem1;
13     private final Ship mem2;
14     private final int arity;
15     private final ShipPool pool;
16     private final Program program;
17     private final DataFlowGraph dfg;
18
19     private ParameterNode[] pn0;
20     private ParameterNode[] pn1;
21     private ParameterNode[] pn2;
22     private ParameterNode[] pn3;
23     private ParameterNode[] pn4;
24     private ParameterNode[] pn_base1;
25     private ParameterNode   pn_base2;
26     private ParameterNode   pn5;
27     private ParameterNode   pn6;
28     private ParameterNode   pn_end;
29
30     CodeBag cb2 = null;
31     Destination next_dest = null;
32
33     public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
34         this.fleet = fleet;
35         this.mem1 = mem1;
36         this.mem2 = mem2;
37         this.arity = arity;
38         this.pool = pool;
39         this.program = program;
40         this.dfg = new DataFlowGraph(fleet, pool);
41         /*
42         pool.allocateShip(mem1);
43         if (mem2 != mem1) pool.allocateShip(mem2);
44         */
45         next_dest = makeDfgCodeBag(dfg);
46         cb2 = dfg.build(new CodeBag(dfg.fleet, program));
47         cb2.seal();
48     }
49
50     public Destination makeDfgCodeBag(DataFlowGraph dfg) {
51
52         MemoryNode mem_read  = new MemoryNode(dfg, mem1);
53         MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
54
55         AluNode sm = new AluNode(dfg, "MAXMERGE");
56
57         pn0 = new ParameterNode[arity];
58         pn1 = new ParameterNode[arity];
59         pn2 = new ParameterNode[arity];
60         pn3 = new ParameterNode[arity];
61         pn4 = new ParameterNode[arity];
62         pn_base1 = new ParameterNode[arity];
63         pn_base2 = new ParameterNode(dfg, true);
64         pn_end = new ParameterNode(dfg);
65         pn5 = new ParameterNode(dfg);
66         pn6 = new ParameterNode(dfg, true);
67
68         // So far: we have four spare Counter ships; one can be used for resetting
69         for(int i=0; i<arity; i++) {
70
71             DownCounterNode c0 = new DownCounterNode(dfg);
72             DownCounterNode c1 = new DownCounterNode(dfg);
73
74             c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
75             c0.incr.connect(new ForeverNode(dfg, 1).out);
76             c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
77             c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
78
79             RepeatNode r1 = new RepeatNode(dfg);
80             r1.val.connect(c1.out);
81             r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
82
83             AluNode alu1 = new AluNode(dfg, "ADD");
84             AluNode alu2 = new AluNode(dfg, "ADD");
85             alu1.in1.connect(r1.out);
86             alu1.in2.connect(c0.out);
87             alu1.out.connect(alu2.in2);
88             alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
89             alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
90
91             PunctuatorNode punc = new PunctuatorNode(dfg, -1);
92             punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
93             punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
94             punc.out.connect(i==0 ? sm.in1 : sm.in2);
95         }
96
97         UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
98         unpunc.val.connect(sm.out);
99         unpunc.count.connect(pn6.out);
100
101         DownCounterNode cw = new DownCounterNode(dfg);
102         cw.start.connect(pn5.out);
103         cw.incr.connect(new OnceNode(dfg, 1).out);
104
105         AluNode alu = new AluNode(dfg, "ADD");
106         alu.in1.connect(pn_base2.out);
107         cw.out.connect(alu.in2);
108         mem_write.inAddrWrite.connect(alu.out);
109         mem_write.inDataWrite.connect(unpunc.out);
110
111         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
112         discard.count.connect(pn_end.out);
113         mem_write.outWrite.connect(discard.val);
114         DoneNode done = new DoneNode(dfg, program);
115         discard.out.connect(done.in);
116
117         return done.getDestinationToSendNextCodeBagDescriptorTo();
118     }
119
120     public CodeBag build(DataFlowGraph dfg,
121                          CodeBag ctx,
122                          int vals_length, int stride_length,
123                          long base1,
124                          long base2,
125                          CodeBag next) throws Exception {
126        
127         for(int i=0; i<arity; i++) {
128             if (pn0[i]!=null) pn0[i].set(ctx, stride_length);
129             if (pn1[i]!=null) pn1[i].set(ctx, vals_length + i*stride_length);
130             if (pn2[i]!=null) pn2[i].set(ctx, 2*stride_length);
131             if (pn3[i]!=null) pn3[i].set(ctx, stride_length);
132             if (pn4[i]!=null) pn4[i].set(ctx, stride_length);
133             if (pn_base1[i]!=null) pn_base1[i].set(ctx, base1);
134         }
135         pn5.set(ctx, vals_length);
136         pn6.set(ctx, 2*stride_length+1);
137         pn_base2.set(ctx, base2);
138         pn_end.set(ctx, vals_length);
139         ctx.sendWord(cb2.getDescriptor(), program.getCBDDestination());
140         ctx.sendWord(next.getDescriptor(), next_dest);
141         ctx.seal();
142
143         return ctx;
144     }
145
146     public static void main(String[] s) throws Exception {
147         //mergeSort(1024*64, 4, "Dvi",
148         mergeSort(1024*128, 4, "Dvi",
149                   4194304
150                   );
151                   //548*478);
152     }
153
154     /** demo */
155     public static void main0(String[] s) throws Exception {
156         PrintWriter pw = new PrintWriter(new OutputStreamWriter(new FileOutputStream("stats.txt")));
157         //int inflight = 1;
158         for(int inflight=1; inflight <= 8; inflight++)
159         for(int count = 32; count < 2097152; count *= 2) {
160             System.out.println("==============================================================================");
161             System.out.println("count="+count);
162             System.out.println("inflight="+inflight);
163             //long time = timeit(count, inflight);
164             long time = mergeSort(count, inflight, "DDR2", 0);
165             pw.println(inflight + ", " + count + ", " + time);
166             pw.flush();
167         }
168     }
169
170     public static long timeit(int count, int inflight) throws Exception {
171         Fleet fleet = new Fpga();
172         FleetProcess fp = fleet.run(new Instruction[0]);
173         ShipPool pool = new ShipPool(fleet);
174         //Program program = new Program(pool.allocateShip("Memory"));
175         CodeBag cb = new CodeBag(fleet);
176
177         Ship counter1 = pool.allocateShip("Counter");
178         Ship counter2 = pool.allocateShip("Counter");
179
180         Ship timer    = pool.allocateShip("Timer");
181         Ship debug    = pool.allocateShip("Debug");
182
183         LoopFactory lf;
184
185         lf = cb.loopFactory(debug.getDock("in"), 2);
186         lf.recvWord();
187         lf.deliver();
188
189
190         // First counter //////////////////////////////////////////////////////////////////////////////
191
192         lf = cb.loopFactory(counter1.getDock("in1"), 1);
193         lf.recvToken();
194         lf.literal(count);
195         lf.deliver();
196
197         lf = cb.loopFactory(counter1.getDock("in2"), 1);
198         lf.literal(1);
199         lf.deliver();
200
201         lf = cb.loopFactory(counter1.getDock("inOp"), 1);
202         lf.literal("COUNT");
203         lf.deliver();
204
205         lf = cb.loopFactory(counter1.getDock("out"), 0);
206         lf.recvToken();
207         lf.collectWord();
208         lf.sendWord(counter2.getDock("in2"));
209
210
211         // Second counter //////////////////////////////////////////////////////////////////////////////
212
213         lf = cb.loopFactory(counter2.getDock("in1"), 1);
214         lf.literal(count);
215         lf.deliver();
216         lf.literal(1);
217         lf.deliver();
218         lf.literal(1);
219         lf.deliver();
220
221         lf = cb.loopFactory(counter2.getDock("in2"), 1);
222         for(int i=0; i<inflight; i++)
223             lf.sendToken(counter1.getDock("out"));
224         lf = lf.makeNext(0);
225         lf.recvWord();
226         lf.sendToken(counter1.getDock("out"));
227         lf.deliver();
228
229         lf = cb.loopFactory(counter2.getDock("inOp"), 1);
230         lf.literal("DROP_C1_V2");
231         lf.deliver();
232         lf.literal("PASS_C1_V1");
233         lf.deliver();
234
235         lf = cb.loopFactory(counter2.getDock("out"), 1);
236         lf.collectWord();
237         lf.sendToken(timer.getDock("out"));
238
239
240         // Timer //////////////////////////////////////////////////////////////////////////////
241
242         lf = cb.loopFactory(timer.getDock("out"), 1);
243         lf.collectWord();
244         lf.sendToken(counter1.getDock("in1"));
245         lf.sendWord(debug.getDock("in"));
246         lf.recvToken();
247         lf.collectWord();
248         lf.sendWord(debug.getDock("in"));
249
250         FpgaDock out = (FpgaDock)counter1.getDock("out");
251         FpgaDock in  = (FpgaDock)counter2.getDock("in2");
252         System.out.println("distance is " + out.getPathLength((FpgaDestination)in.getDataDestination()));
253         System.out.println("reverse distance is " + in.getPathLength((FpgaDestination)out.getDataDestination()));
254
255         for(Instruction i : cb.emit()) System.out.println(i);
256         cb.dispatch(fp, true);
257         long time1 = fp.recvWord().toLong();
258         System.out.println("got " + time1);
259         long time2 = fp.recvWord().toLong();
260         System.out.println("got " + time2);
261         System.out.println("diff=" + (time2-time1));
262
263         fp.terminate();
264
265         return (time2-time1);
266     }
267
268     public static long mergeSort(int vals_length, int inflight, String shipType, int clearAmount) throws Exception {
269         Node.CAPACITY = inflight;
270
271         Fleet fleet = new Fpga();
272         FleetProcess fp = fleet.run(new Instruction[0]);
273         ShipPool pool = new ShipPool(fleet);
274         Ship mem1 = pool.allocateShip(shipType);
275
276
277         if (clearAmount > 0)
278             randomizeMemory(fp, pool, mem1, 0, clearAmount, false);
279
280         //randomizeMemory(fp, pool, mem1, 0, vals_length, true);
281
282         BitVector[] bvs = new BitVector[vals_length];
283
284         long index = 0;
285         Picture p = new Picture(new FileInputStream("campus.png"));
286         for(int y=0; y<(478/2); y++)
287             for(int x=0; x<544; x++) {
288                 if (index >= vals_length) break;
289                 int pixel = (x>=p.width) ? 0 : p.data[p.width*y+x];
290                 long r = (pixel>>0)  & 0xff;
291                 long g = (pixel>>8)  & 0xff;
292                 long b = (pixel>>16) & 0xff;
293                 r >>= 2;
294                 g >>= 2;
295                 b >>= 2;
296                 //r = ~(-1L<<6);
297                 //g = ~(-1L<<6);
298                 //b = ~(-1L<<6);
299                 bvs[(int)index] = new BitVector(fleet.getWordWidth()).set( r | (g<<6) | (b<<12) | (index<<18) );
300                 index++;
301             }
302
303         for(; index<vals_length; index++) {
304             long tag = index<<18;
305             bvs[(int)index] = new BitVector(fleet.getWordWidth()).set( tag );
306         }
307
308         System.out.println("final index " + index);
309
310         Random random = new Random(System.currentTimeMillis());
311         for(int i=0; i<bvs.length*10; i++) {
312             int from = Math.abs(random.nextInt()) % bvs.length;
313             int to   = Math.abs(random.nextInt()) % bvs.length;
314             BitVector bv = bvs[from];
315             bvs[from] = bvs[to];
316             bvs[to] = bv;
317         }
318
319         MemoryUtils.writeMem(fp, pool, mem1, offset, bvs);
320
321         Ship mem = pool.allocateShip("Memory");
322         Program program = new Program(mem);
323         long ret = new MergeSort(fleet, program, pool, 2, mem1, mem1).main(fp, vals_length);
324
325         //long ret = 0;
326         // verify the cleanup?
327         //CleanupUtils.verifyClean(fp);
328         //MemoryUtils.readMem(fp, new ShipPool(fp.getFleet()), mem1, 0, bvs);
329
330         BitVector[] bvx = new BitVector[1024];
331         pool.allocateShip(mem);
332         MemoryUtils.readMem(fp, new ShipPool(fp.getFleet()), mem1, 0, bvx);
333         for(int i=0; i<bvx.length; i++)
334             System.out.println(bvx[i]);
335         /*
336         System.out.println("results:");
337         for(int i=0; i<vals_length-1; i++)
338             if ( (bvs[i].toLong() & ~(-1L<<18)) != ~(-1L<<18))
339                 System.out.println(bvs[i]);
340         */
341         /*
342         for(int i=0; i<vals_length-1; i++) {
343             if (bvs[i].toLong() > bvs[i+1].toLong())
344                 System.out.println("sort failure at "+i+":\n  "+bvs[i]+"\n  "+bvs[i+1]);
345         }
346         */
347         fp.terminate();
348         return ret;
349     }
350
351     static int offset = 40;
352
353     //static int offset = 32;
354     //static int offset = 544*2;
355     //static int offset = 544;
356
357     public long main(FleetProcess fp, int vals_length) throws Exception {
358
359         long base_read = offset;
360         //long base_write = offset+((vals_length/544)+1)*544;
361         //long base_write = offset;
362         long base_write = vals_length + offset;
363
364         int stride = 1;
365         ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
366         codeBags.add(new CodeBag(dfg.fleet, program));
367         for(; ;) {
368             CodeBag cb = codeBags.get(codeBags.size()-1);
369             //System.out.println("cb="+i+", stride="+stride);
370             boolean last = /*(base_write==offset) && */ (stride*2 >= vals_length);
371             CodeBag next;
372             if (last) {
373                 next = program.getEndProgramCodeBag();
374             } else {
375                 next = new CodeBag(dfg.fleet, program);
376                 codeBags.add(codeBags.size(), next);
377             }
378             build(dfg, cb, vals_length, stride, base_read, base_write, next);
379             cb.seal();
380             long base = base_read; base_read=base_write; base_write=base;
381             //i++;
382             if (last) break;
383             if ((stride*2) < vals_length) stride *= 2;
384         }
385         System.out.println("done building codebags; installing...");
386         System.out.println("cb.length=="+codeBags.size());
387
388         Ship button = pool.allocateShip("Button");
389         CodeBag cb = new CodeBag(dfg.fleet, program);
390         LoopFactory lf = cb.loopFactory(button.getDock("out"), 1);
391         lf.collectWord();
392         lf.literal(codeBags.get(0).getDescriptor());
393         lf.sendWord(program.getCBDDestination());
394         cb.seal();
395
396         ShipPool pool = new ShipPool(fp.getFleet());
397         pool.allocateShip(program.memoryShip);
398         // FIXME
399         long ret = program.run(fp, cb, pool);
400         pool.releaseShip(program.memoryShip);
401         return ret;
402     }
403
404     public static void randomizeMemory(FleetProcess fp, ShipPool pool_, Ship memory, long start, long length, boolean randomize) {
405         ShipPool pool = new ShipPool(pool_);
406         Ship mem = pool.allocateShip("Memory");
407         Program prog = new Program(mem);
408
409         DataFlowGraph dfg = new DataFlowGraph(fp.getFleet(), pool);
410         DownCounterNode dcn = new DownCounterNode(dfg);
411         dcn.start.connectOnce(length);
412         dcn.incr.connectOnce(1);
413
414         AluNode alu = new AluNode(dfg, "ADD");
415         alu.in1.connectForever(start);
416         alu.in2.connect(dcn.out);
417
418         MemoryNode mn = new MemoryNode(dfg, memory);
419         mn.inAddrWrite.connect(alu.out);
420
421         AluNode aluAnd = new AluNode(dfg, "AND");
422         if (randomize) {
423             aluAnd.in1.connect(new RandomNode(dfg).out);
424         } else {
425             //aluAnd.in1.connectForever( ~(-1L<<36) );
426             aluAnd.in1.connectForever( 0 );
427         }
428         aluAnd.in2.connectForever( ~(-1<<18) );
429         mn.inDataWrite.connect(aluAnd.out);
430         
431         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
432         discard.count.connectOnce(length);
433         discard.val.connect(mn.outWrite);
434         DoneNode done = new DoneNode(dfg, prog);
435         discard.out.connect(done.in);
436
437         CodeBag cb = new CodeBag(fp.getFleet(), prog);
438         dfg.build(cb);
439         cb.seal();
440
441         CodeBag cb2 = new CodeBag(fp.getFleet(), prog);
442         Ship button = fp.getFleet().getShip("Button",0);
443
444         LoopFactory lf = cb2.loopFactory(button.getDock("out"), 1);
445         //lf.collectWord();
446         lf.literal(prog.getEndProgramCodeBag().getDescriptor());
447         lf.sendWord(done.getDestinationToSendNextCodeBagDescriptorTo());
448         lf.literal(cb.getDescriptor());
449         lf.sendWord(prog.getCBDDestination());
450         cb2.seal();
451
452         System.out.println("dispatching randomization codebag...");
453         prog.run(fp, cb2, pool);
454         System.out.println("  randomization done.");
455         pool.releaseAll();
456     }
457 }