1 package edu.berkeley.fleet.dataflow;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.interpreter.*;
7 import edu.berkeley.fleet.fpga.*;
8 import org.ibex.graphics.*;
10 public class MergeSort {
12 private final Fleet fleet;
13 private final Ship mem1;
14 private final Ship mem2;
15 private final int arity;
16 private final ShipPool pool;
17 private final Program program;
18 private final DataFlowGraph dfg;
20 private ParameterNode[] pn0;
21 private ParameterNode[] pn1;
22 private ParameterNode[] pn2;
23 private ParameterNode[] pn3;
24 private ParameterNode[] pn4;
25 private ParameterNode[] pn_base1;
26 private ParameterNode pn_base2;
27 private ParameterNode pn5;
28 private ParameterNode pn6;
29 private ParameterNode pn_end;
32 Destination next_dest = null;
34 public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
40 this.program = program;
41 this.dfg = new DataFlowGraph(fleet, pool);
42 next_dest = makeDfgCodeBag(dfg, program);
43 cb2 = dfg.build(new CodeBag(dfg.fleet, program));
47 public Destination makeDfgCodeBag(DataFlowGraph dfg, Program program) {
49 MemoryNode mem_read = new MemoryNode(dfg, mem1);
50 MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
52 AluNode sm = new AluNode(dfg, "MAXMERGE");
54 pn0 = new ParameterNode[arity];
55 pn1 = new ParameterNode[arity];
56 pn2 = new ParameterNode[arity];
57 pn3 = new ParameterNode[arity];
58 pn4 = new ParameterNode[arity];
59 pn_base1 = new ParameterNode[arity];
60 pn_base2 = new ParameterNode(dfg, true);
61 pn_end = new ParameterNode(dfg);
62 pn5 = new ParameterNode(dfg);
63 pn6 = new ParameterNode(dfg, true);
65 // So far: we have four spare Counter ships; one can be used for resetting
66 for(int i=0; i<arity; i++) {
68 DownCounterNode c0 = new DownCounterNode(dfg);
69 DownCounterNode c1 = new DownCounterNode(dfg);
71 c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
72 c0.incr.connect(new ForeverNode(dfg, 1).out);
73 c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
74 c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
76 RepeatNode r1 = new RepeatNode(dfg);
77 r1.val.connect(c1.out);
78 r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
80 AluNode alu1 = new AluNode(dfg, "ADD");
81 AluNode alu2 = new AluNode(dfg, "ADD");
82 alu1.in1.connect(r1.out);
83 alu1.in2.connect(c0.out);
84 alu1.out.connect(alu2.in2);
85 alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
86 alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
88 PunctuatorNode punc = new PunctuatorNode(dfg, -1);
89 punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
90 punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
91 punc.out.connect(i==0 ? sm.in1 : sm.in2);
94 UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
95 unpunc.val.connect(sm.out);
96 unpunc.count.connect(pn6.out);
98 DownCounterNode cw = new DownCounterNode(dfg);
99 cw.start.connect(pn5.out);
100 cw.incr.connect(new OnceNode(dfg, 1).out);
102 AluNode alu = new AluNode(dfg, "ADD");
103 alu.in1.connect(pn_base2.out);
104 cw.out.connect(alu.in2);
105 mem_write.inAddrWrite.connect(alu.out);
106 mem_write.inDataWrite.connect(unpunc.out);
108 UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
109 discard.count.connect(pn_end.out);
110 mem_write.outWrite.connect(discard.val);
111 DoneNode done = new DoneNode(dfg, program);
112 discard.out.connect(done.in);
114 return done.getDestinationToSendNextCodeBagDescriptorTo();
117 public CodeBag makeInstance(int offset, int length) throws Exception {
119 long base_read = offset;
120 long base_write = length + offset;
122 ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
123 codeBags.add(new CodeBag(dfg.fleet, program));
125 CodeBag cb = codeBags.get(codeBags.size()-1);
126 boolean last = stride*2 >= length;
129 next = program.getEndProgramCodeBag();
131 next = new CodeBag(dfg.fleet, program);
132 codeBags.add(codeBags.size(), next);
135 for(int i=0; i<arity; i++) {
136 if (pn0[i]!=null) pn0[i].set(cb, stride);
137 if (pn1[i]!=null) pn1[i].set(cb, length + i*stride);
138 if (pn2[i]!=null) pn2[i].set(cb, 2*stride);
139 if (pn3[i]!=null) pn3[i].set(cb, stride);
140 if (pn4[i]!=null) pn4[i].set(cb, stride);
141 if (pn_base1[i]!=null) pn_base1[i].set(cb, base_read);
144 pn6.set(cb, 2*stride+1);
145 pn_base2.set(cb, base_write);
146 pn_end.set(cb, length);
147 cb.sendWord(cb2.getDescriptor(), program.getCBDDestination());
148 cb.sendWord(next.getDescriptor(), next_dest);
151 long base = base_read; base_read=base_write; base_write=base;
153 if ((stride*2) < length) stride *= 2;
156 return codeBags.get(0);
159 public static void main(String[] s) throws Exception {
161 System.err.println("usage: java " + MergeSort.class.getName() + " <target> <shipname> <base> <length>");
165 if (s[0].equals("fpga")) fleet = new Fpga();
166 else if (s[0].equals("interpreter")) {
167 fleet = new Interpreter();
170 ShipPool pool = new ShipPool(fleet);
171 Ship memory = pool.allocateShip(s[1]);
172 int base = Integer.parseInt(s[2]);
173 int length = Integer.parseInt(s[3]);
175 Random random = new Random(System.currentTimeMillis());
176 BitVector[] vals = new BitVector[length];
177 long[] longs = new long[length];
178 for(int i=0; i<vals.length; i++) {
179 vals[i] = new BitVector(fleet.getWordWidth()).set(random.nextLong());
180 for(int j=36; j<vals[i].length(); j++) vals[i].set(j, false);
181 longs[i] = vals[i].toLong();
185 Ship codemem = pool.allocateShip("Memory");
187 FleetProcess fp = fleet.run(new Instruction[0]);
188 ShipPool pool2 = new ShipPool(pool);
190 MemoryUtils.writeMem(fp, pool2, memory, base, vals);
193 Program program = new Program(codemem);
194 CodeBag cb = new MergeSort(fleet, program, pool2, 2, memory, memory).makeInstance(base, length);
196 long ret = program.run(fp, cb, pool2);
199 MemoryUtils.readMem(fp, pool2, memory, base, vals);
202 System.out.println();
204 for(int i=0; i<vals.length; i++)
205 if (vals[i].toLong() != longs[i]) {
206 System.out.println("disagreement! on index " +
208 new BitVector(fleet.getWordWidth()).set(longs[i])+
212 System.out.println("done! ("+fails+" failures)");
213 if (fails>0) System.exit(-1);