get rid of ugly globals in DataFlowGraph
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import edu.berkeley.fleet.loops.*;
4 import edu.berkeley.fleet.api.*;
5 import edu.berkeley.fleet.fpga.*;
6 import java.util.*;
7
8 public class MergeSort {
9     public static long[] mergeSort(FleetProcess fp, Fleet fleet,
10                                    long[] vals, int vals_length, int stride_length,
11                                    Ship memoryShip1, Ship memoryShip2) throws Exception {
12
13         if (vals != null) {
14             BitVector[] mem = new BitVector[vals_length];
15             for(int i=0; i<mem.length; i++) mem[i] = new BitVector(fleet.getWordWidth()).set(vals[i]);
16             MemoryUtils.writeMem(fp, memoryShip1, 0, mem);
17         }
18
19         //////////////////////////////////////////////////////////////////////////////
20
21         DataFlowGraph proc = new DataFlowGraph(fleet);
22         DebugNode dm = new DebugNode(proc);
23
24         int end_of_data = vals_length;
25         int num_strides = end_of_data / (stride_length * 2);
26
27         MemoryNode mm  = new MemoryNode(proc, memoryShip1);
28         SortedMergeNode sm = new SortedMergeNode(proc);
29
30         // So far: we have four spare Counter ships; one can be used for resetting
31         for(int i=0; i<2; i++) {
32
33             DownCounterNode c0 = new DownCounterNode(proc);
34             DownCounterNode c1 = new DownCounterNode(proc);
35
36             c0.start.connect(new ForeverNode(proc, stride_length).out);
37             c0.incr.connect(new ForeverNode(proc, 1).out);
38
39             c1.start.connect(new OnceNode(proc, end_of_data + i*stride_length).out);
40             c1.incr.connect(new OnceNode(proc, stride_length*2).out);
41
42             RepeatNode r1 = new RepeatNode(proc);
43             r1.val.connect(c1.out);
44             r1.count.connect(new ForeverNode(proc, stride_length).out);
45
46             AluNode alu = new AluNode(proc);
47             alu.in1.connect(r1.out);
48             alu.in2.connect(c0.out);
49             alu.inOp.connect(new ForeverNode(proc, ((Node.DockInPort)alu.inOp).getConstant("ADD")).out);
50             alu.out.connect(i==0 ? mm.inAddrRead1 : mm.inAddrRead2);
51
52             PunctuatorNode punc = new PunctuatorNode(proc, -1);
53             punc.count.connect(new ForeverNode(proc, stride_length).out);
54             punc.val.connect(i==0 ? mm.outRead1 : mm.outRead2);
55             punc.out.connect(i==0 ? sm.in1 : sm.in2);
56         }
57
58         UnPunctuatorNode unpunc = new UnPunctuatorNode(proc);
59         unpunc.val.connect(sm.out);
60         unpunc.count.connect(new ForeverNode(proc, 2*stride_length).out);
61
62         DownCounterNode cw = new DownCounterNode(proc);
63         cw.start.connect(new OnceNode(proc, end_of_data).out);
64         cw.incr.connect(new OnceNode(proc, 1).out);
65
66         MemoryNode mm2 = new MemoryNode(proc, memoryShip2);
67         mm2.inAddrWrite.connect(cw.out);
68         mm2.inDataWrite.connect(unpunc.out);
69         mm2.outWrite.connect(dm.in);
70
71         //////////////////////////////////////////////////////////////////////////////
72
73         Context ctx = new Context(fp.getFleet());
74         ctx.setAutoflush(true);
75
76         ArrayList<Instruction> ai = new ArrayList<Instruction>();
77         proc.build(ctx);
78         ctx.emit(ai);
79         for(Instruction ins : ai) {
80             //System.out.println(ins);
81             fp.sendInstruction(ins);
82         }
83         fp.flush();
84
85         for(int i=0; i<vals_length; i++) {
86             System.out.print("\rreading back... " + i+"/"+vals_length+"  ");
87             BitVector rec = fp.recvWord();
88             System.out.print(" (prev result: " + rec + " = " + rec.toLong() + ")");
89         }
90         System.out.println("\rdone.                                                                    ");
91
92         //if (true) return ret;
93
94         Context ctx2 = new Context(fp.getFleet());
95         Dock debugIn = fleet.getShip("Debug",0).getDock("in");
96         Dock fred = debugIn;
97         fp.sendToken(debugIn.getInstructionDestination());
98         fp.flush();
99
100         LoopFactory lf = new LoopFactory(ctx2, debugIn, 0);
101         lf.literal(0);
102         lf.abortLoopIfTorpedoPresent();
103         lf.recvToken();
104         lf.deliver();
105
106         ctx2.dispatch(fp);
107         fp.flush();
108
109         int count = 0;
110
111         Ship counter = proc.pool.allocateShip("Counter");
112
113         for(int phase=0; phase<=3; phase++) {
114             System.out.println("== phase "+phase+" ==================================================================");
115             ctx2 = new Context(fp.getFleet());
116
117             Destination ackDestination = counter.getDock("in2").getDataDestination();
118             HashSet<Dock> sendTorpedoesTo = new HashSet<Dock>();
119             int expected_tokens = proc.reset(ctx2, phase, ackDestination, sendTorpedoesTo);
120
121             Context ctx3 = new Context(fp.getFleet());
122             lf = new LoopFactory(ctx3, counter.getDock("inOp"), 1);
123             lf.literal("DROP_C1_V2");
124             lf.deliver();
125             lf.literal(5);
126             lf.deliver();
127             lf = new LoopFactory(ctx3, counter.getDock("in1"), 1);
128             lf.literal(expected_tokens-1);
129             lf.deliver();
130             lf.literal(1);
131             lf.deliver();
132             lf = new LoopFactory(ctx3, counter.getDock("in2"), 0);
133             lf.abortLoopIfTorpedoPresent();
134             lf.recvWord();
135             lf.deliver();
136             lf = new LoopFactory(ctx3, counter.getDock("out"), 1);
137             lf.collectWord();
138             lf.sendToken(counter.getDock("in2").getInstructionDestination());  // HACK: we don't check to make sure this hits
139             lf.sendToken(debugIn.getDataDestination());
140             ctx3.dispatch(fp);  // HACK: we don't check to make sure that this is "firmly in place"
141
142             for(Dock dock : sendTorpedoesTo) fp.sendToken(dock.getInstructionDestination());
143             ctx2.dispatch(fp);
144             fp.flush();
145             System.out.println("flushed");
146
147             fp.recvWord();
148             System.out.println("phase done");
149
150             System.out.println();
151         }
152
153         fp.sendToken(debugIn.getInstructionDestination());
154         fp.flush();
155
156         //System.out.println("verifying cleanup:");
157         //CleanupUtils.verifyClean(fp);
158
159         System.out.println("reading back:");
160         long[] ret = null;
161         if (vals != null) {
162             ret = new long[vals_length];
163             BitVector[] mem = new BitVector[vals_length];
164             MemoryUtils.readMem(fp, memoryShip2, 0, mem);
165             for(int i=0; i<ret.length; i++) ret[i] = mem[i].toLong();
166         }
167         return ret;
168     }
169
170     /** demo */
171     public static void main(String[] s) throws Exception {
172         Fleet fleet = new Fpga();
173         //Fleet fleet = new Interpreter(false);
174
175         Random random = new Random(System.currentTimeMillis());
176         long[] vals = new long[256];
177         for(int i=0; i<vals.length; i++) {
178             vals[i] = Math.abs(random.nextInt());
179         }
180
181         Ship mem1 = fleet.getShip("Memory", 0);
182         Ship mem2 = fleet.getShip("Memory", 1);
183         //Ship mem2 = fleet.getShip("DDR2", 0);
184
185         FleetProcess fp;
186         int stride = 1;
187         fp = null;
188
189         fp = fleet.run(new Instruction[0]);
190         MemoryUtils.writeMem(fp, mem1, 0, vals);
191         int vals_length = vals.length;
192
193         // Disable readback/writeback inside the loop
194         vals = null;
195
196         while(stride < vals_length) {
197             
198             // reset the FleetProcess
199             //fp.terminate(); fp = null;
200
201             System.out.println("stride " + stride);
202
203             // if we reset the FleetProcess, restart it
204             if (fp==null) fp = fleet.run(new Instruction[0]);
205
206             // do the mergeSort
207             vals = MergeSort.mergeSort(fp, fleet, vals, vals_length, stride, mem1, mem2);
208
209             // verify the cleanup
210             //CleanupUtils.verifyClean(fp);
211
212             Ship mem = mem1; mem1=mem2; mem2=mem;
213
214             stride = stride * 2;
215             System.out.println();
216         }
217
218         BitVector[] bvs = new BitVector[vals_length];
219         MemoryUtils.readMem(fp, mem1, 0, bvs);
220         System.out.println("results:");
221         for(int i=0; i<vals_length; i++)
222             System.out.println(bvs[i].toLong());
223     }
224
225 }