package edu.berkeley.fleet.dataflow;
+import java.util.*;
+import java.io.*;
import edu.berkeley.fleet.loops.*;
import edu.berkeley.fleet.api.*;
+import edu.berkeley.fleet.interpreter.*;
import edu.berkeley.fleet.fpga.*;
-import java.util.*;
+import org.ibex.graphics.*;
public class MergeSort {
- public static long[] mergeSort(FleetProcess fp, Fleet fleet,
- long[] vals, int vals_length, int stride_length,
- Ship memoryShip1, Ship memoryShip2) throws Exception {
-
- if (vals != null) {
- BitVector[] mem = new BitVector[vals_length];
- for(int i=0; i<mem.length; i++) mem[i] = new BitVector(fleet.getWordWidth()).set(vals[i]);
- MemoryUtils.writeMem(fp, memoryShip1, 0, mem);
- }
- //////////////////////////////////////////////////////////////////////////////
+ private final Fleet fleet;
+ private final Ship mem1;
+ private final Ship mem2;
+ private final int arity;
+ private final ShipPool pool;
+ private final Program program;
+ private final DataFlowGraph dfg;
+
+ private ParameterNode[] pn0;
+ private ParameterNode[] pn1;
+ private ParameterNode[] pn2;
+ private ParameterNode[] pn3;
+ private ParameterNode[] pn4;
+ private ParameterNode[] pn_base1;
+ private ParameterNode pn_base2;
+ private ParameterNode pn5;
+ private ParameterNode pn6;
+ private ParameterNode pn_end;
+
+ CodeBag cb2 = null;
+ Destination next_dest = null;
+
+ public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
+ this.fleet = fleet;
+ this.mem1 = mem1;
+ this.mem2 = mem2;
+ this.arity = arity;
+ this.pool = pool;
+ this.program = program;
+ this.dfg = new DataFlowGraph(fleet, pool);
+ next_dest = makeDfgCodeBag(dfg, program);
+ cb2 = dfg.build(new CodeBag(dfg.fleet, program));
+ cb2.seal();
+ }
- DataFlowGraph proc = new DataFlowGraph(fleet);
- DebugNode dm = new DebugNode(proc);
+ public Destination makeDfgCodeBag(DataFlowGraph dfg, Program program) {
- int end_of_data = vals_length;
- int num_strides = end_of_data / (stride_length * 2);
+ MemoryNode mem_read = new MemoryNode(dfg, mem1);
+ MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
- MemoryNode mm = new MemoryNode(proc, memoryShip1);
- SortedMergeNode sm = new SortedMergeNode(proc);
+ pn0 = new ParameterNode[arity];
+ pn1 = new ParameterNode[arity];
+ pn2 = new ParameterNode[arity];
+ pn3 = new ParameterNode[arity];
+ pn4 = new ParameterNode[arity];
+ pn_base1 = new ParameterNode[arity];
+ pn_base2 = new ParameterNode(dfg, true);
+ pn_end = new ParameterNode(dfg);
+ pn5 = new ParameterNode(dfg);
+ pn6 = new ParameterNode(dfg, true);
+ AluNode sm = new AluNode(dfg, "MAXMERGE");
// So far: we have four spare Counter ships; one can be used for resetting
- for(int i=0; i<2; i++) {
+ for(int i=0; i<arity; i++) {
- DownCounterNode c0 = new DownCounterNode(proc);
- DownCounterNode c1 = new DownCounterNode(proc);
+ DownCounterNode c0 = new DownCounterNode(dfg);
+ c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
+ c0.incr.connect(new ForeverNode(dfg, 1).out);
- c0.start.connect(new ForeverNode(proc, stride_length).out);
- c0.incr.connect(new ForeverNode(proc, 1).out);
+ DownCounterNode c1 = new DownCounterNode(dfg);
+ c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
+ c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
- c1.start.connect(new OnceNode(proc, end_of_data + i*stride_length).out);
- c1.incr.connect(new OnceNode(proc, stride_length*2).out);
-
- RepeatNode r1 = new RepeatNode(proc);
+ RepeatNode r1 = new RepeatNode(dfg);
r1.val.connect(c1.out);
- r1.count.connect(new ForeverNode(proc, stride_length).out);
+ r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
+
+ AluNode alu1 = new AluNode(dfg, "ADD");
+ alu1.in1.connect(r1.out);
+ alu1.in2.connect(c0.out);
- AluNode alu = new AluNode(proc);
- alu.in1.connect(r1.out);
- alu.in2.connect(c0.out);
- alu.inOp.connect(new ForeverNode(proc, ((Node.DockInPort)alu.inOp).getConstant("ADD")).out);
- alu.out.connect(i==0 ? mm.inAddrRead1 : mm.inAddrRead2);
+ AluNode alu2 = new AluNode(dfg, "ADD");
+ alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
+ alu2.in2.connect(alu1.out);
+ alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
- PunctuatorNode punc = new PunctuatorNode(proc, -1);
- punc.count.connect(new ForeverNode(proc, stride_length).out);
- punc.val.connect(i==0 ? mm.outRead1 : mm.outRead2);
+ PunctuatorNode punc = new PunctuatorNode(dfg, -1);
+ punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
+ punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
punc.out.connect(i==0 ? sm.in1 : sm.in2);
}
- UnPunctuatorNode unpunc = new UnPunctuatorNode(proc);
+ UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
unpunc.val.connect(sm.out);
- unpunc.count.connect(new ForeverNode(proc, 2*stride_length).out);
-
- DownCounterNode cw = new DownCounterNode(proc);
- cw.start.connect(new OnceNode(proc, end_of_data).out);
- cw.incr.connect(new OnceNode(proc, 1).out);
-
- MemoryNode mm2 = new MemoryNode(proc, memoryShip2);
- mm2.inAddrWrite.connect(cw.out);
- mm2.inDataWrite.connect(unpunc.out);
- mm2.outWrite.connect(dm.in);
+ unpunc.count.connect(pn6.out);
- //////////////////////////////////////////////////////////////////////////////
+ DownCounterNode cw = new DownCounterNode(dfg);
+ cw.start.connect(pn5.out);
+ cw.incr.connect(new OnceNode(dfg, 1).out);
- Context ctx = new Context(fp.getFleet());
- ctx.setAutoflush(true);
+ AluNode alu = new AluNode(dfg, "ADD");
+ alu.in1.connect(pn_base2.out);
+ cw.out.connect(alu.in2);
- ArrayList<Instruction> ai = new ArrayList<Instruction>();
- proc.build(ctx);
- ctx.emit(ai);
- for(Instruction ins : ai) {
- //System.out.println(ins);
- fp.sendInstruction(ins);
- }
- fp.flush();
+ mem_write.inAddrWrite.connect(alu.out);
+ mem_write.inDataWrite.connect(unpunc.out);
- for(int i=0; i<vals_length; i++) {
- System.out.print("\rreading back... " + i+"/"+vals_length+" ");
- BitVector rec = fp.recvWord();
- System.out.print(" (prev result: " + rec + " = " + rec.toLong() + ")");
- }
- System.out.println("\rdone. ");
-
- //if (true) return ret;
-
- Context ctx2 = new Context(fp.getFleet());
- Dock debugIn = fleet.getShip("Debug",0).getDock("in");
- Dock fred = debugIn;
- fp.sendToken(debugIn.getInstructionDestination());
- fp.flush();
-
- LoopFactory lf = new LoopFactory(ctx2, debugIn, 0);
- lf.literal(0);
- lf.abortLoopIfTorpedoPresent();
- lf.recvToken();
- lf.deliver();
-
- ctx2.dispatch(fp);
- fp.flush();
-
- int count = 0;
-
- Ship counter = proc.pool.allocateShip("Counter");
-
- for(int phase=0; phase<=3; phase++) {
- System.out.println("== phase "+phase+" ==================================================================");
- ctx2 = new Context(fp.getFleet());
-
- Destination ackDestination = counter.getDock("in2").getDataDestination();
- proc.reset(ctx2, phase, ackDestination);
-
- Context ctx3 = new Context(fp.getFleet());
- lf = new LoopFactory(ctx3, counter.getDock("inOp"), 1);
- lf.literal("DROP_C1_V2");
- lf.deliver();
- lf.literal(5);
- lf.deliver();
- lf = new LoopFactory(ctx3, counter.getDock("in1"), 1);
- lf.literal(DataFlowGraph.reset_count-1);
- lf.deliver();
- lf.literal(1);
- lf.deliver();
- lf = new LoopFactory(ctx3, counter.getDock("in2"), 0);
- lf.abortLoopIfTorpedoPresent();
- lf.recvWord();
- lf.deliver();
- lf = new LoopFactory(ctx3, counter.getDock("out"), 1);
- lf.collectWord();
- lf.sendToken(counter.getDock("in2").getInstructionDestination()); // HACK: we don't check to make sure this hits
- lf.sendToken(debugIn.getDataDestination());
- ctx3.dispatch(fp); // HACK: we don't check to make sure that this is "firmly in place"
-
- for(Dock dock : DataFlowGraph.torpedoes) fp.sendToken(dock.getInstructionDestination());
- ctx2.dispatch(fp);
- fp.flush();
- System.out.println("flushed");
-
- fp.recvWord();
- System.out.println("phase done");
-
- System.out.println();
- }
+ UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
+ discard.count.connect(pn_end.out);
+ mem_write.outWrite.connect(discard.val);
+ DoneNode done = new DoneNode(dfg, program);
+ discard.out.connect(done.in);
- fp.sendToken(debugIn.getInstructionDestination());
- fp.flush();
+ return done.getDestinationToSendNextCodeBagDescriptorTo();
+ }
- //System.out.println("verifying cleanup:");
- //CleanupUtils.verifyClean(fp);
+ public CodeBag makeInstance(int offset, int length) throws Exception {
- System.out.println("reading back:");
- long[] ret = null;
- if (vals != null) {
- ret = new long[vals_length];
- BitVector[] mem = new BitVector[vals_length];
- MemoryUtils.readMem(fp, memoryShip2, 0, mem);
- for(int i=0; i<ret.length; i++) ret[i] = mem[i].toLong();
+ long base_read = offset;
+ long base_write = length + offset;
+ int stride = 1;
+ ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
+ codeBags.add(new CodeBag(dfg.fleet, program));
+ for(;;) {
+ CodeBag cb = codeBags.get(codeBags.size()-1);
+ boolean last = stride*2 >= length;
+ CodeBag next;
+ if (last) {
+ next = program.getEndProgramCodeBag();
+ } else {
+ next = new CodeBag(dfg.fleet, program);
+ codeBags.add(codeBags.size(), next);
+ }
+
+ for(int i=0; i<arity; i++) {
+ if (pn0[i]!=null) pn0[i].set(cb, stride);
+ if (pn1[i]!=null) pn1[i].set(cb, length + i*stride);
+ if (pn2[i]!=null) pn2[i].set(cb, 2*stride);
+ if (pn3[i]!=null) pn3[i].set(cb, stride);
+ if (pn4[i]!=null) pn4[i].set(cb, stride);
+ if (pn_base1[i]!=null) pn_base1[i].set(cb, base_read);
+ }
+ if (pn5!=null) pn5.set(cb, length);
+ if (pn6!=null) pn6.set(cb, 2*stride+1);
+ if (pn_base2!=null) pn_base2.set(cb, base_write);
+ if (pn_end!=null) pn_end.set(cb, length);
+ cb.sendWord(cb2.getDescriptor(), program.getCBDDestination());
+ cb.sendWord(next.getDescriptor(), next_dest);
+
+ cb.seal();
+ long base = base_read; base_read=base_write; base_write=base;
+ if (last) break;
+ if ((stride*2) < length) stride *= 2;
}
- return ret;
+
+ return codeBags.get(0);
}
- /** demo */
public static void main(String[] s) throws Exception {
- Fleet fleet = new Fpga();
- //Fleet fleet = new Interpreter(false);
-
- Random random = new Random(System.currentTimeMillis());
- long[] vals = new long[256];
- for(int i=0; i<vals.length; i++) {
- vals[i] = Math.abs(random.nextInt());
+ if (s.length != 4) {
+ System.err.println("usage: java " + MergeSort.class.getName() + " <target> <shipname> <base> <length>");
+ System.exit(-1);
}
-
- Ship mem1 = fleet.getShip("Memory", 0);
- Ship mem2 = fleet.getShip("Memory", 1);
- //Ship mem2 = fleet.getShip("DDR2", 0);
-
- FleetProcess fp;
- int stride = 1;
- fp = null;
-
- fp = fleet.run(new Instruction[0]);
- MemoryUtils.writeMem(fp, mem1, 0, vals);
- int vals_length = vals.length;
-
- // Disable readback/writeback inside the loop
- vals = null;
-
- while(stride < vals_length) {
-
- // reset the FleetProcess
- //fp.terminate(); fp = null;
-
- System.out.println("stride " + stride);
-
- // if we reset the FleetProcess, restart it
- if (fp==null) fp = fleet.run(new Instruction[0]);
-
- // do the mergeSort
- vals = MergeSort.mergeSort(fp, fleet, vals, vals_length, stride, mem1, mem2);
-
- // verify the cleanup
- //CleanupUtils.verifyClean(fp);
-
- Ship mem = mem1; mem1=mem2; mem2=mem;
-
- stride = stride * 2;
- System.out.println();
+ Fleet fleet = null;
+ if (s[0].equals("fpga")) fleet = new Fpga();
+ else if (s[0].equals("interpreter")) {
+ fleet = new Interpreter();
+ Log.log = null;
}
-
- BitVector[] bvs = new BitVector[vals_length];
- MemoryUtils.readMem(fp, mem1, 0, bvs);
- System.out.println("results:");
- for(int i=0; i<vals_length; i++)
- System.out.println(bvs[i].toLong());
+ ShipPool pool = new ShipPool(fleet);
+ Ship memory = pool.allocateShip(s[1]);
+ int base = Integer.parseInt(s[2]);
+ int length = Integer.parseInt(s[3]);
+
+ Random random = new Random(System.currentTimeMillis());
+ BitVector[] vals = new BitVector[length];
+ long[] longs = new long[length];
+ for(int i=0; i<vals.length; i++) {
+ vals[i] = new BitVector(fleet.getWordWidth()).set(random.nextLong());
+ for(int j=36; j<vals[i].length(); j++) vals[i].set(j, false);
+ longs[i] = vals[i].toLong();
+ }
+ Arrays.sort(longs);
+
+ Ship codemem = pool.allocateShip("Memory");
+
+ FleetProcess fp = fleet.run(new Instruction[0]);
+ ShipPool pool2 = new ShipPool(pool);
+
+ MemoryUtils.writeMem(fp, pool2, memory, base, vals);
+ pool2.releaseAll();
+
+ Program program = new Program(codemem);
+ CodeBag cb = new MergeSort(fleet, program, pool2, 2, memory, memory).makeInstance(base, length);
+ pool2.releaseAll();
+ long ret = 0;
+ ret = program.run(fp, cb, pool2);
+ pool2.releaseAll();
+
+ MemoryUtils.readMem(fp, pool2, memory, base, vals);
+ pool2.releaseAll();
+
+ System.out.println();
+ int fails = 0;
+ for(int i=0; i<vals.length; i++)
+ if (vals[i].toLong() != longs[i]) {
+ System.out.println("disagreement! on index " +
+ i + "\n expected="+
+ new BitVector(fleet.getWordWidth()).set(longs[i])+
+ "\n got="+vals[i]);
+ fails++;
+ }
+ System.out.println("done! ("+fails+" failures)");
+ if (fails>0) System.exit(-1);
+
}
-
+
}