MergeSort.java: reformatting
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
index 36cd08f..459bab5 100644 (file)
 package edu.berkeley.fleet.dataflow;
+import java.util.*;
+import java.io.*;
 import edu.berkeley.fleet.loops.*;
 import edu.berkeley.fleet.api.*;
+import edu.berkeley.fleet.interpreter.*;
 import edu.berkeley.fleet.fpga.*;
-import java.util.*;
+import org.ibex.graphics.*;
 
 public class MergeSort {
-    public static long[] mergeSort(FleetProcess fp, Fleet fleet,
-                                   long[] vals, int vals_length, int stride_length,
-                                   Ship memoryShip1, Ship memoryShip2) throws Exception {
-
-        if (vals != null) {
-            BitVector[] mem = new BitVector[vals_length];
-            for(int i=0; i<mem.length; i++) mem[i] = new BitVector(fleet.getWordWidth()).set(vals[i]);
-            MemoryUtils.writeMem(fp, memoryShip1, 0, mem);
-        }
 
-        //////////////////////////////////////////////////////////////////////////////
+    private final Fleet fleet;
+    private final Ship mem1;
+    private final Ship mem2;
+    private final int arity;
+    private final ShipPool pool;
+    private final Program program;
+    private final DataFlowGraph dfg;
+
+    private ParameterNode[] pn0;
+    private ParameterNode[] pn1;
+    private ParameterNode[] pn2;
+    private ParameterNode[] pn3;
+    private ParameterNode[] pn4;
+    private ParameterNode[] pn_base1;
+    private ParameterNode   pn_base2;
+    private ParameterNode   pn5;
+    private ParameterNode   pn6;
+    private ParameterNode   pn_end;
+
+    CodeBag cb2 = null;
+    Destination next_dest = null;
+
+    public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
+        this.fleet = fleet;
+        this.mem1 = mem1;
+        this.mem2 = mem2;
+        this.arity = arity;
+        this.pool = pool;
+        this.program = program;
+        this.dfg = new DataFlowGraph(fleet, pool);
+        next_dest = makeDfgCodeBag(dfg, program);
+        cb2 = dfg.build(new CodeBag(dfg.fleet, program));
+        cb2.seal();
+    }
 
-        DataFlowGraph proc = new DataFlowGraph(fleet);
-        DebugNode dm = new DebugNode(proc);
+    public Destination makeDfgCodeBag(DataFlowGraph dfg, Program program) {
 
-        int end_of_data = vals_length;
-        int num_strides = end_of_data / (stride_length * 2);
+        MemoryNode mem_read  = new MemoryNode(dfg, mem1);
+        MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
 
-        MemoryNode mm  = new MemoryNode(proc, memoryShip1);
-        SortedMergeNode sm = new SortedMergeNode(proc);
+        pn0 = new ParameterNode[arity];
+        pn1 = new ParameterNode[arity];
+        pn2 = new ParameterNode[arity];
+        pn3 = new ParameterNode[arity];
+        pn4 = new ParameterNode[arity];
+        pn_base1 = new ParameterNode[arity];
+        pn_base2 = new ParameterNode(dfg, true);
+        pn_end = new ParameterNode(dfg);
+        pn5 = new ParameterNode(dfg);
+        pn6 = new ParameterNode(dfg, true);
 
+        AluNode sm = new AluNode(dfg, "MAXMERGE");
         // So far: we have four spare Counter ships; one can be used for resetting
-        for(int i=0; i<2; i++) {
+        for(int i=0; i<arity; i++) {
 
-            DownCounterNode c0 = new DownCounterNode(proc);
-            DownCounterNode c1 = new DownCounterNode(proc);
+            DownCounterNode c0 = new DownCounterNode(dfg);
+            c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
+            c0.incr.connect(new ForeverNode(dfg, 1).out);
 
-            c0.start.connect(new ForeverNode(proc, stride_length).out);
-            c0.incr.connect(new ForeverNode(proc, 1).out);
+            DownCounterNode c1 = new DownCounterNode(dfg);
+            c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
+            c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
 
-            c1.start.connect(new OnceNode(proc, end_of_data + i*stride_length).out);
-            c1.incr.connect(new OnceNode(proc, stride_length*2).out);
-
-            RepeatNode r1 = new RepeatNode(proc);
+            RepeatNode r1 = new RepeatNode(dfg);
             r1.val.connect(c1.out);
-            r1.count.connect(new ForeverNode(proc, stride_length).out);
+            r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
+
+            AluNode alu1 = new AluNode(dfg, "ADD");
+            alu1.in1.connect(r1.out);
+            alu1.in2.connect(c0.out);
 
-            AluNode alu = new AluNode(proc);
-            alu.in1.connect(r1.out);
-            alu.in2.connect(c0.out);
-            alu.inOp.connect(new ForeverNode(proc, ((Node.DockInPort)alu.inOp).getConstant("ADD")).out);
-            alu.out.connect(i==0 ? mm.inAddrRead1 : mm.inAddrRead2);
+            AluNode alu2 = new AluNode(dfg, "ADD");
+            alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
+            alu2.in2.connect(alu1.out);
+            alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
 
-            PunctuatorNode punc = new PunctuatorNode(proc, -1);
-            punc.count.connect(new ForeverNode(proc, stride_length).out);
-            punc.val.connect(i==0 ? mm.outRead1 : mm.outRead2);
+            PunctuatorNode punc = new PunctuatorNode(dfg, -1);
+            punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
+            punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
             punc.out.connect(i==0 ? sm.in1 : sm.in2);
         }
 
-        UnPunctuatorNode unpunc = new UnPunctuatorNode(proc);
+        UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
         unpunc.val.connect(sm.out);
-        unpunc.count.connect(new ForeverNode(proc, 2*stride_length).out);
-
-        DownCounterNode cw = new DownCounterNode(proc);
-        cw.start.connect(new OnceNode(proc, end_of_data).out);
-        cw.incr.connect(new OnceNode(proc, 1).out);
-
-        MemoryNode mm2 = new MemoryNode(proc, memoryShip2);
-        mm2.inAddrWrite.connect(cw.out);
-        mm2.inDataWrite.connect(unpunc.out);
-        mm2.outWrite.connect(dm.in);
+        unpunc.count.connect(pn6.out);
 
-        //////////////////////////////////////////////////////////////////////////////
+        DownCounterNode cw = new DownCounterNode(dfg);
+        cw.start.connect(pn5.out);
+        cw.incr.connect(new OnceNode(dfg, 1).out);
 
-        Context ctx = new Context(fp.getFleet());
-        ctx.setAutoflush(true);
+        AluNode alu = new AluNode(dfg, "ADD");
+        alu.in1.connect(pn_base2.out);
+        cw.out.connect(alu.in2);
 
-        ArrayList<Instruction> ai = new ArrayList<Instruction>();
-        proc.build(ctx);
-        ctx.emit(ai);
-        for(Instruction ins : ai) {
-            //System.out.println(ins);
-            fp.sendInstruction(ins);
-        }
-        fp.flush();
+        mem_write.inAddrWrite.connect(alu.out);
+        mem_write.inDataWrite.connect(unpunc.out);
 
-        for(int i=0; i<vals_length; i++) {
-            System.out.print("\rreading back... " + i+"/"+vals_length+"  ");
-            BitVector rec = fp.recvWord();
-            System.out.print(" (prev result: " + rec + " = " + rec.toLong() + ")");
-        }
-        System.out.println("\rdone.                                                                    ");
-
-        //if (true) return ret;
-
-        Context ctx2 = new Context(fp.getFleet());
-        Dock debugIn = fleet.getShip("Debug",0).getDock("in");
-        Dock fred = debugIn;
-        fp.sendToken(debugIn.getInstructionDestination());
-        fp.flush();
-
-        LoopFactory lf = new LoopFactory(ctx2, debugIn, 0);
-        lf.literal(0);
-        lf.abortLoopIfTorpedoPresent();
-        lf.recvToken();
-        lf.deliver();
-
-        ctx2.dispatch(fp);
-        fp.flush();
-
-        int count = 0;
-
-        Ship counter = proc.pool.allocateShip("Counter");
-
-        for(int phase=0; phase<=3; phase++) {
-            System.out.println("== phase "+phase+" ==================================================================");
-            ctx2 = new Context(fp.getFleet());
-
-            Destination ackDestination = counter.getDock("in2").getDataDestination();
-            proc.reset(ctx2, phase, ackDestination);
-
-            Context ctx3 = new Context(fp.getFleet());
-            lf = new LoopFactory(ctx3, counter.getDock("inOp"), 1);
-            lf.literal("DROP_C1_V2");
-            lf.deliver();
-            lf.literal(5);
-            lf.deliver();
-            lf = new LoopFactory(ctx3, counter.getDock("in1"), 1);
-            lf.literal(DataFlowGraph.reset_count-1);
-            lf.deliver();
-            lf.literal(1);
-            lf.deliver();
-            lf = new LoopFactory(ctx3, counter.getDock("in2"), 0);
-            lf.abortLoopIfTorpedoPresent();
-            lf.recvWord();
-            lf.deliver();
-            lf = new LoopFactory(ctx3, counter.getDock("out"), 1);
-            lf.collectWord();
-            lf.sendToken(counter.getDock("in2").getInstructionDestination());  // HACK: we don't check to make sure this hits
-            lf.sendToken(debugIn.getDataDestination());
-            ctx3.dispatch(fp);  // HACK: we don't check to make sure that this is "firmly in place"
-
-            for(Dock dock : DataFlowGraph.torpedoes) fp.sendToken(dock.getInstructionDestination());
-            ctx2.dispatch(fp);
-            fp.flush();
-            System.out.println("flushed");
-
-            fp.recvWord();
-            System.out.println("phase done");
-
-            System.out.println();
-        }
+        UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
+        discard.count.connect(pn_end.out);
+        mem_write.outWrite.connect(discard.val);
+        DoneNode done = new DoneNode(dfg, program);
+        discard.out.connect(done.in);
 
-        fp.sendToken(debugIn.getInstructionDestination());
-        fp.flush();
+        return done.getDestinationToSendNextCodeBagDescriptorTo();
+    }
 
-        //System.out.println("verifying cleanup:");
-        //CleanupUtils.verifyClean(fp);
+    public CodeBag makeInstance(int offset, int length) throws Exception {
 
-        System.out.println("reading back:");
-        long[] ret = null;
-        if (vals != null) {
-            ret = new long[vals_length];
-            BitVector[] mem = new BitVector[vals_length];
-            MemoryUtils.readMem(fp, memoryShip2, 0, mem);
-            for(int i=0; i<ret.length; i++) ret[i] = mem[i].toLong();
+        long base_read = offset;
+        long base_write = length + offset;
+        int stride = 1;
+        ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
+        codeBags.add(new CodeBag(dfg.fleet, program));
+        for(;;) {
+            CodeBag cb = codeBags.get(codeBags.size()-1);
+            boolean last = stride*2 >= length;
+            CodeBag next;
+            if (last) {
+                next = program.getEndProgramCodeBag();
+            } else {
+                next = new CodeBag(dfg.fleet, program);
+                codeBags.add(codeBags.size(), next);
+            }
+
+            for(int i=0; i<arity; i++) {
+                if (pn0[i]!=null) pn0[i].set(cb, stride);
+                if (pn1[i]!=null) pn1[i].set(cb, length + i*stride);
+                if (pn2[i]!=null) pn2[i].set(cb, 2*stride);
+                if (pn3[i]!=null) pn3[i].set(cb, stride);
+                if (pn4[i]!=null) pn4[i].set(cb, stride);
+                if (pn_base1[i]!=null) pn_base1[i].set(cb, base_read);
+            }
+            if (pn5!=null) pn5.set(cb, length);
+            if (pn6!=null) pn6.set(cb, 2*stride+1);
+            if (pn_base2!=null) pn_base2.set(cb, base_write);
+            if (pn_end!=null) pn_end.set(cb, length);
+            cb.sendWord(cb2.getDescriptor(), program.getCBDDestination());
+            cb.sendWord(next.getDescriptor(), next_dest);
+
+            cb.seal();
+            long base = base_read; base_read=base_write; base_write=base;
+            if (last) break;
+            if ((stride*2) < length) stride *= 2;
         }
-        return ret;
+
+        return codeBags.get(0);
     }
 
-    /** demo */
     public static void main(String[] s) throws Exception {
-        Fleet fleet = new Fpga();
-        //Fleet fleet = new Interpreter(false);
-
-        Random random = new Random(System.currentTimeMillis());
-        long[] vals = new long[256];
-        for(int i=0; i<vals.length; i++) {
-            vals[i] = Math.abs(random.nextInt());
+        if (s.length != 4) {
+            System.err.println("usage: java " + MergeSort.class.getName() + " <target> <shipname> <base> <length>");
+            System.exit(-1);
         }
-
-        Ship mem1 = fleet.getShip("Memory", 0);
-        Ship mem2 = fleet.getShip("Memory", 1);
-        //Ship mem2 = fleet.getShip("DDR2", 0);
-
-        FleetProcess fp;
-        int stride = 1;
-        fp = null;
-
-        fp = fleet.run(new Instruction[0]);
-        MemoryUtils.writeMem(fp, mem1, 0, vals);
-        int vals_length = vals.length;
-
-        // Disable readback/writeback inside the loop
-        vals = null;
-
-        while(stride < vals_length) {
-            
-            // reset the FleetProcess
-            //fp.terminate(); fp = null;
-
-            System.out.println("stride " + stride);
-
-            // if we reset the FleetProcess, restart it
-            if (fp==null) fp = fleet.run(new Instruction[0]);
-
-            // do the mergeSort
-            vals = MergeSort.mergeSort(fp, fleet, vals, vals_length, stride, mem1, mem2);
-
-            // verify the cleanup
-            //CleanupUtils.verifyClean(fp);
-
-            Ship mem = mem1; mem1=mem2; mem2=mem;
-
-            stride = stride * 2;
-            System.out.println();
+        Fleet fleet = null;
+        if (s[0].equals("fpga"))             fleet = new Fpga();
+        else if (s[0].equals("interpreter")) {
+            fleet = new Interpreter();
+            Log.log = null;
         }
-
-        BitVector[] bvs = new BitVector[vals_length];
-        MemoryUtils.readMem(fp, mem1, 0, bvs);
-        System.out.println("results:");
-        for(int i=0; i<vals_length; i++)
-            System.out.println(bvs[i].toLong());
+        ShipPool pool = new ShipPool(fleet);
+        Ship memory = pool.allocateShip(s[1]);
+        int base = Integer.parseInt(s[2]);
+        int length = Integer.parseInt(s[3]);
+
+        Random random = new Random(System.currentTimeMillis());        
+        BitVector[] vals  = new BitVector[length];
+        long[] longs = new long[length];
+        for(int i=0; i<vals.length; i++) {
+            vals[i] = new BitVector(fleet.getWordWidth()).set(random.nextLong());
+            for(int j=36; j<vals[i].length(); j++) vals[i].set(j, false);
+            longs[i] = vals[i].toLong();
+        }
+        Arrays.sort(longs);
+
+        Ship codemem = pool.allocateShip("Memory");
+
+        FleetProcess fp = fleet.run(new Instruction[0]);
+        ShipPool pool2 = new ShipPool(pool);
+
+        MemoryUtils.writeMem(fp, pool2, memory, base, vals);
+        pool2.releaseAll();
+
+        Program program = new Program(codemem);
+        CodeBag cb = new MergeSort(fleet, program, pool2, 2, memory, memory).makeInstance(base, length);
+        pool2.releaseAll();
+        long ret = 0;
+        ret = program.run(fp, cb, pool2);
+        pool2.releaseAll();
+
+        MemoryUtils.readMem(fp, pool2, memory, base, vals);
+        pool2.releaseAll();
+
+        System.out.println();
+        int fails = 0;
+        for(int i=0; i<vals.length; i++)
+            if (vals[i].toLong() != longs[i]) {
+                System.out.println("disagreement!  on index " +
+                                   i + "\n  expected="+
+                                   new BitVector(fleet.getWordWidth()).set(longs[i])+
+                                   "\n       got="+vals[i]);
+                fails++;
+            }
+        System.out.println("done! ("+fails+" failures)");
+        if (fails>0) System.exit(-1);
+        
     }
-
+    
 }