break out dataflow nodes into separate classes
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import edu.berkeley.fleet.loops.*;
3 import edu.berkeley.fleet.api.*;
4 import java.util.*;
5
6 public class MergeSort {
7     public static long[] mergeSort(FleetProcess fp, Fleet fleet,
8                                    long[] vals, int vals_length, int stride_length,
9                                    Ship memoryShip1, Ship memoryShip2) throws Exception {
10
11         if (vals != null) {
12             BitVector[] mem = new BitVector[vals_length];
13             for(int i=0; i<mem.length; i++) mem[i] = new BitVector(fleet.getWordWidth()).set(vals[i]);
14             MemoryUtils.writeMem(fp, memoryShip1, 0, mem);
15         }
16
17         //////////////////////////////////////////////////////////////////////////////
18
19         DataFlowGraph proc = new DataFlowGraph(fleet);
20         DebugNode dm = new DebugNode(proc);
21
22         int end_of_data = vals_length;
23         int num_strides = end_of_data / (stride_length * 2);
24
25         MemoryNode mm  = new MemoryNode(proc, memoryShip1);
26         SortedMergeNode sm = new SortedMergeNode(proc);
27
28         // So far: we have four spare Counter ships; one can be used for resetting
29         for(int i=0; i<2; i++) {
30
31             DownCounterNode c0 = new DownCounterNode(proc);
32             DownCounterNode c1 = new DownCounterNode(proc);
33
34             c0.start.connect(new ForeverNode(proc, stride_length).out);
35             c0.incr.connect(new ForeverNode(proc, 1).out);
36
37             c1.start.connect(new OnceNode(proc, end_of_data + i*stride_length).out);
38             c1.incr.connect(new OnceNode(proc, stride_length*2).out);
39
40             RepeatNode r1 = new RepeatNode(proc);
41             r1.val.connect(c1.out);
42             r1.count.connect(new ForeverNode(proc, stride_length).out);
43
44             AluNode alu = new AluNode(proc);
45             alu.in1.connect(r1.out);
46             alu.in2.connect(c0.out);
47             alu.inOp.connect(new ForeverNode(proc, ((Node.DockInPort)alu.inOp).getConstant("ADD")).out);
48             alu.out.connect(i==0 ? mm.inAddrRead1 : mm.inAddrRead2);
49
50             PunctuatorNode punc = new PunctuatorNode(proc, -1);
51             punc.count.connect(new ForeverNode(proc, stride_length).out);
52             punc.val.connect(i==0 ? mm.outRead1 : mm.outRead2);
53             punc.out.connect(i==0 ? sm.in1 : sm.in2);
54         }
55
56         UnPunctuatorNode unpunc = new UnPunctuatorNode(proc);
57         unpunc.val.connect(sm.out);
58         unpunc.count.connect(new ForeverNode(proc, 2*stride_length).out);
59
60         DownCounterNode cw = new DownCounterNode(proc);
61         cw.start.connect(new OnceNode(proc, end_of_data).out);
62         cw.incr.connect(new OnceNode(proc, 1).out);
63
64         MemoryNode mm2 = new MemoryNode(proc, memoryShip2);
65         mm2.inAddrWrite.connect(cw.out);
66         mm2.inDataWrite.connect(unpunc.out);
67         mm2.outWrite.connect(dm.in);
68
69         //////////////////////////////////////////////////////////////////////////////
70
71         Context ctx = new Context(fp.getFleet());
72         ctx.setAutoflush(true);
73
74         ArrayList<Instruction> ai = new ArrayList<Instruction>();
75         proc.build(ctx);
76         ctx.emit(ai);
77         for(Instruction ins : ai) {
78             //System.out.println(ins);
79             fp.sendInstruction(ins);
80         }
81         fp.flush();
82
83         for(int i=0; i<vals_length; i++) {
84             System.out.print("\rreading back... " + i+"/"+vals_length+"  ");
85             BitVector rec = fp.recvWord();
86             System.out.print(" (prev result: " + rec + " = " + rec.toLong() + ")");
87         }
88         System.out.println("\rdone.                                                                    ");
89
90         //if (true) return ret;
91
92         Context ctx2 = new Context(fp.getFleet());
93         Dock debugIn = fleet.getShip("Debug",0).getDock("in");
94         Dock fred = debugIn;
95         fp.sendToken(debugIn.getInstructionDestination());
96         fp.flush();
97
98         LoopFactory lf = new LoopFactory(ctx2, debugIn, 0);
99         lf.literal(0);
100         lf.abortLoopIfTorpedoPresent();
101         lf.recvToken();
102         lf.deliver();
103
104         ctx2.dispatch(fp);
105         fp.flush();
106
107         int count = 0;
108
109         Ship counter = proc.pool.allocateShip("Counter");
110
111         for(int phase=0; phase<=3; phase++) {
112             System.out.println("== phase "+phase+" ==================================================================");
113             ctx2 = new Context(fp.getFleet());
114
115             Destination ackDestination = counter.getDock("in2").getDataDestination();
116             proc.reset(ctx2, phase, ackDestination);
117
118             Context ctx3 = new Context(fp.getFleet());
119             lf = new LoopFactory(ctx3, counter.getDock("inOp"), 1);
120             lf.literal("DROP_C1_V2");
121             lf.deliver();
122             lf.literal(5);
123             lf.deliver();
124             lf = new LoopFactory(ctx3, counter.getDock("in1"), 1);
125             lf.literal(DataFlowGraph.reset_count-1);
126             lf.deliver();
127             lf.literal(1);
128             lf.deliver();
129             lf = new LoopFactory(ctx3, counter.getDock("in2"), 0);
130             lf.abortLoopIfTorpedoPresent();
131             lf.recvWord();
132             lf.deliver();
133             lf = new LoopFactory(ctx3, counter.getDock("out"), 1);
134             lf.collectWord();
135             lf.sendToken(counter.getDock("in2").getInstructionDestination());  // HACK: we don't check to make sure this hits
136             lf.sendToken(debugIn.getDataDestination());
137             ctx3.dispatch(fp);  // HACK: we don't check to make sure that this is "firmly in place"
138
139             for(Dock dock : DataFlowGraph.torpedoes) fp.sendToken(dock.getInstructionDestination());
140             ctx2.dispatch(fp);
141             fp.flush();
142             System.out.println("flushed");
143
144             fp.recvWord();
145             System.out.println("phase done");
146
147             System.out.println();
148         }
149
150         fp.sendToken(debugIn.getInstructionDestination());
151         fp.flush();
152
153         //System.out.println("verifying cleanup:");
154         //CleanupUtils.verifyClean(fp);
155
156         System.out.println("reading back:");
157         long[] ret = null;
158         if (vals != null) {
159             ret = new long[vals_length];
160             BitVector[] mem = new BitVector[vals_length];
161             MemoryUtils.readMem(fp, memoryShip2, 0, mem);
162             for(int i=0; i<ret.length; i++) ret[i] = mem[i].toLong();
163         }
164         return ret;
165     }
166 }