1 package edu.berkeley.fleet.dataflow;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.interpreter.*;
7 import edu.berkeley.fleet.fpga.*;
8 import org.ibex.graphics.*;
10 public class MergeSort {
12 private final Fleet fleet;
13 private final Ship mem1;
14 private final Ship mem2;
15 private final int arity;
16 private final ShipPool pool;
17 private final Program program;
18 private final DataFlowGraph dfg;
20 private ParameterNode[] pn0;
21 private ParameterNode[] pn1;
22 private ParameterNode[] pn2;
23 private ParameterNode[] pn3;
24 private ParameterNode[] pn4;
25 private ParameterNode[] pn_base1;
26 private ParameterNode pn_base2;
27 private ParameterNode pn5;
28 private ParameterNode pn6;
29 private ParameterNode pn_end;
32 Destination next_dest = null;
34 public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
40 this.program = program;
41 this.dfg = new DataFlowGraph(fleet, pool);
42 next_dest = makeDfgCodeBag(dfg, program);
43 cb2 = dfg.build(new CodeBag(dfg.fleet, program));
47 public Destination makeDfgCodeBag(DataFlowGraph dfg, Program program) {
49 MemoryNode mem_read = new MemoryNode(dfg, mem1);
50 MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
52 pn0 = new ParameterNode[arity];
53 pn1 = new ParameterNode[arity];
54 pn2 = new ParameterNode[arity];
55 pn3 = new ParameterNode[arity];
56 pn4 = new ParameterNode[arity];
57 pn_base1 = new ParameterNode[arity];
58 pn_base2 = new ParameterNode(dfg, true);
59 pn_end = new ParameterNode(dfg);
60 pn5 = new ParameterNode(dfg);
61 pn6 = new ParameterNode(dfg, true);
63 AluNode sm = new AluNode(dfg, "MAXMERGE");
64 // So far: we have four spare Counter ships; one can be used for resetting
65 for(int i=0; i<arity; i++) {
67 DownCounterNode c0 = new DownCounterNode(dfg);
68 c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
69 c0.incr.connect(new ForeverNode(dfg, 1).out);
71 DownCounterNode c1 = new DownCounterNode(dfg);
72 c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
73 c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
75 RepeatNode r1 = new RepeatNode(dfg);
76 r1.val.connect(c1.out);
77 r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
79 AluNode alu1 = new AluNode(dfg, "ADD");
80 alu1.in1.connect(r1.out);
81 alu1.in2.connect(c0.out);
83 AluNode alu2 = new AluNode(dfg, "ADD");
84 alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
85 alu2.in2.connect(alu1.out);
86 alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
88 PunctuatorNode punc = new PunctuatorNode(dfg, -1);
89 punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
90 punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
91 punc.out.connect(i==0 ? sm.in1 : sm.in2);
94 UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
95 unpunc.val.connect(sm.out);
96 unpunc.count.connect(pn6.out);
98 DownCounterNode cw = new DownCounterNode(dfg);
99 cw.start.connect(pn5.out);
100 cw.incr.connect(new OnceNode(dfg, 1).out);
102 AluNode alu = new AluNode(dfg, "ADD");
103 alu.in1.connect(pn_base2.out);
104 cw.out.connect(alu.in2);
106 mem_write.inAddrWrite.connect(alu.out);
107 mem_write.inDataWrite.connect(unpunc.out);
109 UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
110 discard.count.connect(pn_end.out);
111 mem_write.outWrite.connect(discard.val);
112 DoneNode done = new DoneNode(dfg, program);
113 discard.out.connect(done.in);
115 return done.getDestinationToSendNextCodeBagDescriptorTo();
118 public CodeBag makeInstance(int offset, int length) throws Exception {
120 long base_read = offset;
121 long base_write = length + offset;
123 ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
124 codeBags.add(new CodeBag(dfg.fleet, program));
126 CodeBag cb = codeBags.get(codeBags.size()-1);
127 boolean last = stride*2 >= length;
130 next = program.getEndProgramCodeBag();
132 next = new CodeBag(dfg.fleet, program);
133 codeBags.add(codeBags.size(), next);
136 for(int i=0; i<arity; i++) {
137 if (pn0[i]!=null) pn0[i].set(cb, stride);
138 if (pn1[i]!=null) pn1[i].set(cb, length + i*stride);
139 if (pn2[i]!=null) pn2[i].set(cb, 2*stride);
140 if (pn3[i]!=null) pn3[i].set(cb, stride);
141 if (pn4[i]!=null) pn4[i].set(cb, stride);
142 if (pn_base1[i]!=null) pn_base1[i].set(cb, base_read);
144 if (pn5!=null) pn5.set(cb, length);
145 if (pn6!=null) pn6.set(cb, 2*stride+1);
146 if (pn_base2!=null) pn_base2.set(cb, base_write);
147 if (pn_end!=null) pn_end.set(cb, length);
148 cb.sendWord(cb2.getDescriptor(), program.getCBDDestination());
149 cb.sendWord(next.getDescriptor(), next_dest);
152 long base = base_read; base_read=base_write; base_write=base;
154 if ((stride*2) < length) stride *= 2;
157 return codeBags.get(0);
160 public static void main(String[] s) throws Exception {
162 System.err.println("usage: java " + MergeSort.class.getName() + " <target> <shipname> <base> <length>");
166 if (s[0].equals("fpga")) fleet = new Fpga();
167 else if (s[0].equals("interpreter")) {
168 fleet = new Interpreter();
171 ShipPool pool = new ShipPool(fleet);
172 Ship memory = pool.allocateShip(s[1]);
173 int base = Integer.parseInt(s[2]);
174 int length = Integer.parseInt(s[3]);
176 Random random = new Random(System.currentTimeMillis());
177 BitVector[] vals = new BitVector[length];
178 long[] longs = new long[length];
179 for(int i=0; i<vals.length; i++) {
180 vals[i] = new BitVector(fleet.getWordWidth()).set(random.nextLong());
181 for(int j=36; j<vals[i].length(); j++) vals[i].set(j, false);
182 longs[i] = vals[i].toLong();
186 Ship codemem = pool.allocateShip("Memory");
188 FleetProcess fp = fleet.run(new Instruction[0]);
189 ShipPool pool2 = new ShipPool(pool);
191 MemoryUtils.writeMem(fp, pool2, memory, base, vals);
194 Program program = new Program(codemem);
195 CodeBag cb = new MergeSort(fleet, program, pool2, 2, memory, memory).makeInstance(base, length);
198 ret = program.run(fp, cb, pool2);
201 MemoryUtils.readMem(fp, pool2, memory, base, vals);
204 System.out.println();
206 for(int i=0; i<vals.length; i++)
207 if (vals[i].toLong() != longs[i]) {
208 System.out.println("disagreement! on index " +
210 new BitVector(fleet.getWordWidth()).set(longs[i])+
214 System.out.println("done! ("+fails+" failures)");
215 if (fails>0) System.exit(-1);