b478462bc138e7d8f20665fddff9b33e75655a59
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import java.io.*;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.fpga.*;
7 import org.ibex.graphics.*;
8
9 public class MergeSort {
10
11     private final Fleet fleet;
12     private final Ship mem1;
13     private final Ship mem2;
14     private final int arity;
15     private final ShipPool pool;
16     private final Program program;
17     private final DataFlowGraph dfg;
18
19     private ParameterNode[] pn0;
20     private ParameterNode[] pn1;
21     private ParameterNode[] pn2;
22     private ParameterNode[] pn3;
23     private ParameterNode[] pn4;
24     private ParameterNode[] pn_base1;
25     private ParameterNode   pn_base2;
26     private ParameterNode   pn5;
27     private ParameterNode   pn6;
28     private ParameterNode   pn_end;
29
30     CodeBag cb2 = null;
31     Destination next_dest = null;
32
33     public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
34         this.fleet = fleet;
35         this.mem1 = mem1;
36         this.mem2 = mem2;
37         this.arity = arity;
38         this.pool = pool;
39         this.program = program;
40         this.dfg = new DataFlowGraph(fleet, pool);
41         /*
42         pool.allocateShip(mem1);
43         if (mem2 != mem1) pool.allocateShip(mem2);
44         */
45         next_dest = makeDfgCodeBag(dfg);
46         cb2 = dfg.build(new CodeBag(dfg.fleet, program));
47         cb2.seal();
48     }
49
50     public Destination makeDfgCodeBag(DataFlowGraph dfg) {
51
52         MemoryNode mem_read  = new MemoryNode(dfg, mem1);
53         MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
54
55         AluNode sm = new AluNode(dfg, "MAXMERGE");
56
57         pn0 = new ParameterNode[arity];
58         pn1 = new ParameterNode[arity];
59         pn2 = new ParameterNode[arity];
60         pn3 = new ParameterNode[arity];
61         pn4 = new ParameterNode[arity];
62         pn_base1 = new ParameterNode[arity];
63         pn_base2 = new ParameterNode(dfg, true);
64         pn_end = new ParameterNode(dfg);
65         pn5 = new ParameterNode(dfg);
66         pn6 = new ParameterNode(dfg, true);
67
68         // So far: we have four spare Counter ships; one can be used for resetting
69         for(int i=0; i<arity; i++) {
70
71             DownCounterNode c0 = new DownCounterNode(dfg);
72             DownCounterNode c1 = new DownCounterNode(dfg);
73
74             c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
75             c0.incr.connect(new ForeverNode(dfg, 1).out);
76             c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
77             c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
78
79             RepeatNode r1 = new RepeatNode(dfg);
80             r1.val.connect(c1.out);
81             r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
82
83             AluNode alu1 = new AluNode(dfg, "ADD");
84             AluNode alu2 = new AluNode(dfg, "ADD");
85             alu1.in1.connect(r1.out);
86             alu1.in2.connect(c0.out);
87             alu1.out.connect(alu2.in2);
88             alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
89             alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
90
91             PunctuatorNode punc = new PunctuatorNode(dfg, -1);
92             punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
93             punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
94             punc.out.connect(i==0 ? sm.in1 : sm.in2);
95         }
96
97         UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
98         unpunc.val.connect(sm.out);
99         unpunc.count.connect(pn6.out);
100
101         DownCounterNode cw = new DownCounterNode(dfg);
102         cw.start.connect(pn5.out);
103         cw.incr.connect(new OnceNode(dfg, 1).out);
104
105         AluNode alu = new AluNode(dfg, "ADD");
106         alu.in1.connect(pn_base2.out);
107         cw.out.connect(alu.in2);
108         mem_write.inAddrWrite.connect(alu.out);
109         mem_write.inDataWrite.connect(unpunc.out);
110
111         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
112         discard.count.connect(pn_end.out);
113         mem_write.outWrite.connect(discard.val);
114         DoneNode done = new DoneNode(dfg, program);
115         discard.out.connect(done.in);
116
117         return done.getDestinationToSendNextCodeBagDescriptorTo();
118     }
119
120     public CodeBag build(DataFlowGraph dfg,
121                          CodeBag ctx,
122                          int vals_length, int stride_length,
123                          long base1,
124                          long base2,
125                          CodeBag next) throws Exception {
126        
127         for(int i=0; i<arity; i++) {
128             if (pn0[i]!=null) pn0[i].set(ctx, stride_length);
129             if (pn1[i]!=null) pn1[i].set(ctx, vals_length + i*stride_length);
130             if (pn2[i]!=null) pn2[i].set(ctx, 2*stride_length);
131             if (pn3[i]!=null) pn3[i].set(ctx, stride_length);
132             if (pn4[i]!=null) pn4[i].set(ctx, stride_length);
133             if (pn_base1[i]!=null) pn_base1[i].set(ctx, base1);
134         }
135         pn5.set(ctx, vals_length);
136         pn6.set(ctx, 2*stride_length+1);
137         pn_base2.set(ctx, base2);
138         pn_end.set(ctx, vals_length);
139         ctx.sendWord(cb2.getDescriptor(), program.getCBDDestination());
140         ctx.sendWord(next.getDescriptor(), next_dest);
141         ctx.seal();
142
143         return ctx;
144     }
145
146     public static void main(String[] s) throws Exception {
147         //mergeSort(1024*64, 4, "Dvi",
148         mergeSort(1024*128, 4, "Dvi",
149                   4194304
150                   );
151                   //548*478);
152     }
153
154     /** demo */
155     public static void main0(String[] s) throws Exception {
156         PrintWriter pw = new PrintWriter(new OutputStreamWriter(new FileOutputStream("stats.txt")));
157         //int inflight = 1;
158         for(int inflight=1; inflight <= 8; inflight++)
159         for(int count = 32; count < 2097152; count *= 2) {
160             System.out.println("==============================================================================");
161             System.out.println("count="+count);
162             System.out.println("inflight="+inflight);
163             //long time = timeit(count, inflight);
164             long time = mergeSort(count, inflight, "DDR2", 0);
165             pw.println(inflight + ", " + count + ", " + time);
166             pw.flush();
167         }
168     }
169
170     public static long timeit(int count, int inflight) throws Exception {
171         Fleet fleet = new Fpga();
172         FleetProcess fp = fleet.run(new Instruction[0]);
173         ShipPool pool = new ShipPool(fleet);
174         //Program program = new Program(pool.allocateShip("Memory"));
175         CodeBag cb = new CodeBag(fleet);
176
177         Ship counter1 = pool.allocateShip("Counter");
178         Ship counter2 = pool.allocateShip("Counter");
179
180         Ship timer    = pool.allocateShip("Timer");
181         Ship debug    = pool.allocateShip("Debug");
182
183         LoopFactory lf;
184
185         lf = cb.loopFactory(debug.getDock("in"), 2);
186         lf.recvWord();
187         lf.deliver();
188
189
190         // First counter //////////////////////////////////////////////////////////////////////////////
191
192         lf = cb.loopFactory(counter1.getDock("in1"), 1);
193         lf.recvToken();
194         lf.literal(count);
195         lf.deliver();
196
197         lf = cb.loopFactory(counter1.getDock("in2"), 1);
198         lf.literal(1);
199         lf.deliver();
200
201         lf = cb.loopFactory(counter1.getDock("inOp"), 1);
202         lf.literal("COUNT");
203         lf.deliver();
204
205         lf = cb.loopFactory(counter1.getDock("out"), 0);
206         lf.recvToken();
207         lf.collectWord();
208         lf.sendWord(counter2.getDock("in2"));
209
210
211         // Second counter //////////////////////////////////////////////////////////////////////////////
212
213         lf = cb.loopFactory(counter2.getDock("in1"), 1);
214         lf.literal(count);
215         lf.deliver();
216         lf.literal(1);
217         lf.deliver();
218         lf.literal(1);
219         lf.deliver();
220
221         lf = cb.loopFactory(counter2.getDock("in2"), 1);
222         for(int i=0; i<inflight; i++)
223             lf.sendToken(counter1.getDock("out"));
224         lf = lf.makeNext(0);
225         lf.recvWord();
226         lf.sendToken(counter1.getDock("out"));
227         lf.deliver();
228
229         lf = cb.loopFactory(counter2.getDock("inOp"), 1);
230         lf.literal("DROP_C1_V2");
231         lf.deliver();
232         lf.literal("PASS_C1_V1");
233         lf.deliver();
234
235         lf = cb.loopFactory(counter2.getDock("out"), 1);
236         lf.collectWord();
237         lf.sendToken(timer.getDock("out"));
238
239
240         // Timer //////////////////////////////////////////////////////////////////////////////
241
242         lf = cb.loopFactory(timer.getDock("out"), 1);
243         lf.collectWord();
244         lf.sendToken(counter1.getDock("in1"));
245         lf.sendWord(debug.getDock("in"));
246         lf.recvToken();
247         lf.collectWord();
248         lf.sendWord(debug.getDock("in"));
249
250         FpgaDock out = (FpgaDock)counter1.getDock("out");
251         FpgaDock in  = (FpgaDock)counter2.getDock("in2");
252         System.out.println("distance is " + out.getPathLength((FpgaDestination)in.getDataDestination()));
253         System.out.println("reverse distance is " + in.getPathLength((FpgaDestination)out.getDataDestination()));
254
255         for(Instruction i : cb.emit()) System.out.println(i);
256         cb.dispatch(fp, true);
257         long time1 = fp.recvWord().toLong();
258         System.out.println("got " + time1);
259         long time2 = fp.recvWord().toLong();
260         System.out.println("got " + time2);
261         System.out.println("diff=" + (time2-time1));
262
263         fp.terminate();
264
265         return (time2-time1);
266     }
267
268     public static long mergeSort(int vals_length, int inflight, String shipType, int clearAmount) throws Exception {
269         Node.CAPACITY = inflight;
270
271         Fleet fleet = new Fpga();
272         FleetProcess fp = fleet.run(new Instruction[0]);
273         ShipPool pool = new ShipPool(fleet);
274         Ship mem1 = pool.allocateShip(shipType);
275
276
277         if (clearAmount > 0)
278             randomizeMemory(fp, pool, mem1, 0, clearAmount, false);
279
280         //randomizeMemory(fp, pool, mem1, 0, vals_length, true);
281
282         BitVector[] bvs = new BitVector[vals_length];
283
284         long index = 0;
285         Picture p = new Picture(new FileInputStream("campus.png"));
286         for(int y=0; y<(478/2); y++)
287             for(int x=0; x<544; x++) {
288                 if (index >= vals_length) break;
289                 int pixel = (x>=p.width) ? 0 : p.data[p.width*y+x];
290                 long r = (pixel>>0)  & 0xff;
291                 long g = (pixel>>8)  & 0xff;
292                 long b = (pixel>>16) & 0xff;
293                 r >>= 2;
294                 g >>= 2;
295                 b >>= 2;
296                 //r = ~(-1L<<6);
297                 //g = ~(-1L<<6);
298                 //b = ~(-1L<<6);
299                 bvs[(int)index] = new BitVector(fleet.getWordWidth()).set( r | (g<<6) | (b<<12) | (index<<18) );
300                 index++;
301             }
302
303         for(; index<vals_length; index++) {
304             long tag = index<<18;
305             bvs[(int)index] = new BitVector(fleet.getWordWidth()).set( tag );
306         }
307
308         System.out.println("final index " + index);
309
310         Random random = new Random(System.currentTimeMillis());
311         for(int i=0; i<bvs.length*10; i++) {
312             int from = Math.abs(random.nextInt()) % bvs.length;
313             int to   = Math.abs(random.nextInt()) % bvs.length;
314             BitVector bv = bvs[from];
315             bvs[from] = bvs[to];
316             bvs[to] = bv;
317         }
318
319         MemoryUtils.writeMem(fp, pool, mem1, offset, bvs);
320
321         Program program = new Program(pool.allocateShip("Memory"));
322         long ret = new MergeSort(fleet, program, pool, 2, mem1, mem1).main(fp, vals_length);
323
324         //long ret = 0;
325         // verify the cleanup?
326         //CleanupUtils.verifyClean(fp);
327         //MemoryUtils.readMem(fp, new ShipPool(fp.getFleet()), mem1, 0, bvs);
328
329         BitVector[] bvx = new BitVector[1024];
330         MemoryUtils.readMem(fp, new ShipPool(fp.getFleet()), mem1, 0, bvx);
331         for(int i=0; i<bvx.length; i++)
332             System.out.println(bvx[i]);
333         /*
334         System.out.println("results:");
335         for(int i=0; i<vals_length-1; i++)
336             if ( (bvs[i].toLong() & ~(-1L<<18)) != ~(-1L<<18))
337                 System.out.println(bvs[i]);
338         */
339         /*
340         for(int i=0; i<vals_length-1; i++) {
341             if (bvs[i].toLong() > bvs[i+1].toLong())
342                 System.out.println("sort failure at "+i+":\n  "+bvs[i]+"\n  "+bvs[i+1]);
343         }
344         */
345         fp.terminate();
346         return ret;
347     }
348
349     static int offset = 40;
350
351     //static int offset = 32;
352     //static int offset = 544*2;
353     //static int offset = 544;
354
355     public long main(FleetProcess fp, int vals_length) throws Exception {
356
357         long base_read = offset;
358         //long base_write = offset+((vals_length/544)+1)*544;
359         //long base_write = offset;
360         long base_write = vals_length + offset;
361
362         int stride = 1;
363         ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
364         codeBags.add(new CodeBag(dfg.fleet, program));
365         for(; ;) {
366             CodeBag cb = codeBags.get(codeBags.size()-1);
367             //System.out.println("cb="+i+", stride="+stride);
368             boolean last = /*(base_write==offset) && */ (stride*2 >= vals_length);
369             CodeBag next;
370             if (last) {
371                 next = program.getEndProgramCodeBag();
372             } else {
373                 next = new CodeBag(dfg.fleet, program);
374                 codeBags.add(codeBags.size(), next);
375             }
376             build(dfg, cb, vals_length, stride, base_read, base_write, next);
377             cb.seal();
378             long base = base_read; base_read=base_write; base_write=base;
379             //i++;
380             if (last) break;
381             if ((stride*2) < vals_length) stride *= 2;
382         }
383         System.out.println("done building codebags; installing...");
384         System.out.println("cb.length=="+codeBags.size());
385
386         Ship button = pool.allocateShip("Button");
387         CodeBag cb = new CodeBag(dfg.fleet, program);
388         LoopFactory lf = cb.loopFactory(button.getDock("out"), 1);
389         lf.collectWord();
390         lf.literal(codeBags.get(0).getDescriptor());
391         lf.sendWord(program.getCBDDestination());
392         cb.seal();
393
394         // FIXME
395         return program.run(fp, cb, new ShipPool(fp.getFleet()));
396     }
397
398     public static void randomizeMemory(FleetProcess fp, ShipPool pool_, Ship memory, long start, long length, boolean randomize) {
399         ShipPool pool = new ShipPool(pool_);
400         Ship mem = pool.allocateShip("Memory");
401         Program prog = new Program(mem);
402
403         DataFlowGraph dfg = new DataFlowGraph(fp.getFleet(), pool);
404         DownCounterNode dcn = new DownCounterNode(dfg);
405         dcn.start.connectOnce(length);
406         dcn.incr.connectOnce(1);
407
408         AluNode alu = new AluNode(dfg, "ADD");
409         alu.in1.connectForever(start);
410         alu.in2.connect(dcn.out);
411
412         MemoryNode mn = new MemoryNode(dfg, memory);
413         mn.inAddrWrite.connect(alu.out);
414
415         AluNode aluAnd = new AluNode(dfg, "AND");
416         if (randomize) {
417             aluAnd.in1.connect(new RandomNode(dfg).out);
418         } else {
419             //aluAnd.in1.connectForever( ~(-1L<<36) );
420             aluAnd.in1.connectForever( 0 );
421         }
422         aluAnd.in2.connectForever( ~(-1<<18) );
423         mn.inDataWrite.connect(aluAnd.out);
424         
425         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
426         discard.count.connectOnce(length);
427         discard.val.connect(mn.outWrite);
428         DoneNode done = new DoneNode(dfg, prog);
429         discard.out.connect(done.in);
430
431         CodeBag cb = new CodeBag(fp.getFleet(), prog);
432         dfg.build(cb);
433         cb.seal();
434
435         CodeBag cb2 = new CodeBag(fp.getFleet(), prog);
436         Ship button = fp.getFleet().getShip("Button",0);
437
438         LoopFactory lf = cb2.loopFactory(button.getDock("out"), 1);
439         //lf.collectWord();
440         lf.literal(prog.getEndProgramCodeBag().getDescriptor());
441         lf.sendWord(done.getDestinationToSendNextCodeBagDescriptorTo());
442         lf.literal(cb.getDescriptor());
443         lf.sendWord(prog.getCBDDestination());
444         cb2.seal();
445
446         System.out.println("dispatching randomization codebag...");
447         prog.run(fp, cb2, pool);
448         System.out.println("  randomization done.");
449         pool.releaseAll();
450     }
451 }