updates to get some of the shutdown code to execute via Program
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import edu.berkeley.fleet.loops.*;
4 import edu.berkeley.fleet.api.*;
5 import edu.berkeley.fleet.fpga.*;
6 import java.util.*;
7
8 public class MergeSort {
9     public static long[] mergeSort(FleetProcess fp, Fleet fleet, ShipPool pool,
10                                    long[] vals, int vals_length, int stride_length,
11                                    Ship memoryShip1, Ship memoryShip2) throws Exception {
12
13         if (vals != null) {
14             BitVector[] mem = new BitVector[vals_length];
15             for(int i=0; i<mem.length; i++) mem[i] = new BitVector(fleet.getWordWidth()).set(vals[i]);
16             MemoryUtils.writeMem(fp, memoryShip1, 0, mem);
17         }
18
19         //////////////////////////////////////////////////////////////////////////////
20
21         DataFlowGraph proc = new DataFlowGraph(fleet, pool);
22         DebugNode dm = new DebugNode(proc);
23
24         int end_of_data = vals_length;
25         int num_strides = end_of_data / (stride_length * 2);
26
27         MemoryNode mm  = new MemoryNode(proc, memoryShip1);
28         SortedMergeNode sm = new SortedMergeNode(proc);
29
30         // So far: we have four spare Counter ships; one can be used for resetting
31         for(int i=0; i<2; i++) {
32
33             DownCounterNode c0 = new DownCounterNode(proc);
34             DownCounterNode c1 = new DownCounterNode(proc);
35
36             c0.start.connect(new ForeverNode(proc, stride_length).out);
37             c0.incr.connect(new ForeverNode(proc, 1).out);
38
39             c1.start.connect(new OnceNode(proc, end_of_data + i*stride_length).out);
40             c1.incr.connect(new OnceNode(proc, stride_length*2).out);
41
42             RepeatNode r1 = new RepeatNode(proc);
43             r1.val.connect(c1.out);
44             r1.count.connect(new ForeverNode(proc, stride_length).out);
45
46             AluNode alu = new AluNode(proc);
47             alu.in1.connect(r1.out);
48             alu.in2.connect(c0.out);
49             alu.inOp.connect(new ForeverNode(proc, ((Node.DockInPort)alu.inOp).getConstant("ADD")).out);
50             alu.out.connect(i==0 ? mm.inAddrRead1 : mm.inAddrRead2);
51
52             PunctuatorNode punc = new PunctuatorNode(proc, -1);
53             punc.count.connect(new ForeverNode(proc, stride_length).out);
54             punc.val.connect(i==0 ? mm.outRead1 : mm.outRead2);
55             punc.out.connect(i==0 ? sm.in1 : sm.in2);
56         }
57
58         UnPunctuatorNode unpunc = new UnPunctuatorNode(proc);
59         unpunc.val.connect(sm.out);
60         unpunc.count.connect(new ForeverNode(proc, 2*stride_length).out);
61
62         DownCounterNode cw = new DownCounterNode(proc);
63         cw.start.connect(new OnceNode(proc, end_of_data).out);
64         cw.incr.connect(new OnceNode(proc, 1).out);
65
66         MemoryNode mm2 = new MemoryNode(proc, memoryShip2);
67         mm2.inAddrWrite.connect(cw.out);
68         mm2.inDataWrite.connect(unpunc.out);
69         mm2.outWrite.connect(dm.in);
70
71         //////////////////////////////////////////////////////////////////////////////
72
73         Ship codeMemoryShip = proc.pool.allocateShip("Memory");
74
75         Context ctx = new Context(fp.getFleet());
76         ctx.setAutoflush(true);
77         proc.build(ctx);
78         ctx.dispatch(fp, true);
79
80         for(int i=0; i<vals_length; i++) {
81             System.out.print("\rreading back... " + i+"/"+vals_length+"  ");
82             BitVector rec = fp.recvWord();
83             System.out.print(" (prev result: " + rec + " = " + rec.toLong() + ")");
84         }
85         System.out.println("\rdone.                                                                    ");
86
87         Dock debugIn = fleet.getShip("Debug",0).getDock("in");
88         fp.sendToken(debugIn.getInstructionDestination());
89         fp.flush();
90
91         int count = 0;
92         Ship counter = proc.pool.allocateShip("Counter");
93
94         for(int phase=0; phase<=3; phase++) {
95             System.out.println("== phase "+phase+" ==================================================================");
96
97             LoopFactory lf;
98
99             Destination ackDestination = counter.getDock("in2").getDataDestination();
100             HashSet<Dock> sendTorpedoesTo = new HashSet<Dock>();
101             Context ctx_reset = new Context(fp.getFleet());
102             int expected_tokens = proc.reset(ctx_reset, phase, ackDestination, sendTorpedoesTo);
103
104             Context ctx_debug = new Context(fp.getFleet());
105             lf = new LoopFactory(ctx_debug, debugIn, 0);
106             lf.literal(0);
107             lf.abortLoopIfTorpedoPresent();
108             lf.recvToken();
109             lf.deliver();
110
111             Context ctx_count = new Context(fp.getFleet());
112             lf = new LoopFactory(ctx_count, counter.getDock("inOp"), 1);
113             lf.literal("DROP_C1_V2");
114             lf.deliver();
115             lf.literal(5);
116             lf.deliver();
117             lf = new LoopFactory(ctx_count, counter.getDock("in1"), 1);
118             lf.literal(expected_tokens-1);
119             lf.deliver();
120             lf.literal(1);
121             lf.deliver();
122             lf = new LoopFactory(ctx_count, counter.getDock("in2"), 0);
123             lf.abortLoopIfTorpedoPresent();
124             lf.recvWord();
125             lf.deliver();
126             lf = new LoopFactory(ctx_count, counter.getDock("out"), 1);
127             lf.collectWord();
128             lf.sendToken(counter.getDock("in2").getInstructionDestination());  // HACK: we don't check to make sure this hits
129             lf.sendToken(debugIn.getDataDestination());
130
131             Program program = new Program(codeMemoryShip);
132             CodeBag cb = program.makeCodeBag(ctx_count);
133             program.install(fp);
134             MemoryUtils.putMemoryShipInDispatchMode(fp, codeMemoryShip);
135             program.run(fp, cb);
136             fp.flush();
137
138             ctx_debug.dispatch(fp, true);
139             //ctx_count.dispatch(fp, true);  // HACK: we don't check to make sure that this is "firmly in place"
140             for(Dock dock : sendTorpedoesTo) fp.sendToken(dock.getInstructionDestination());
141             ctx_reset.dispatch(fp, true);
142             System.out.println("flushed");
143
144             fp.recvWord();
145             fp.sendToken(debugIn.getInstructionDestination());
146             System.out.println("phase done");
147
148             MemoryUtils.removeMemoryShipFromDispatchMode(fp, codeMemoryShip);
149             System.out.println();
150         }
151
152         fp.flush();
153
154         //System.out.println("verifying cleanup:");
155         //CleanupUtils.verifyClean(fp);
156
157         System.out.println("reading back:");
158         long[] ret = null;
159         if (vals != null) {
160             ret = new long[vals_length];
161             BitVector[] mem = new BitVector[vals_length];
162             MemoryUtils.readMem(fp, memoryShip2, 0, mem);
163             for(int i=0; i<ret.length; i++) ret[i] = mem[i].toLong();
164         }
165         return ret;
166     }
167
168     /** demo */
169     public static void main(String[] s) throws Exception {
170         Fleet fleet = new Fpga();
171         //Fleet fleet = new Interpreter(false);
172
173         Random random = new Random(System.currentTimeMillis());
174         long[] vals = new long[256];
175         for(int i=0; i<vals.length; i++) {
176             vals[i] = Math.abs(random.nextInt());
177         }
178
179         Ship mem1 = fleet.getShip("Memory", 0);
180         Ship mem2 = fleet.getShip("Memory", 1);
181         //Ship mem2 = fleet.getShip("DDR2", 0);
182
183         FleetProcess fp;
184         int stride = 1;
185         fp = null;
186
187         fp = fleet.run(new Instruction[0]);
188         MemoryUtils.writeMem(fp, mem1, 0, vals);
189         int vals_length = vals.length;
190
191         // Disable readback/writeback inside the loop
192         vals = null;
193
194         while(stride < vals_length) {
195             
196             // reset the FleetProcess
197             //fp.terminate(); fp = null;
198
199             System.out.println("stride " + stride);
200
201             // if we reset the FleetProcess, restart it
202             if (fp==null) fp = fleet.run(new Instruction[0]);
203
204             ShipPool pool = new ShipPool(fleet);
205             pool.allocateShip(mem1);
206             pool.allocateShip(mem2);
207
208             // do the mergeSort
209             vals = MergeSort.mergeSort(fp, fleet, pool, vals, vals_length, stride, mem1, mem2);
210
211             // verify the cleanup
212             //CleanupUtils.verifyClean(fp);
213
214             Ship mem = mem1; mem1=mem2; mem2=mem;
215
216             stride = stride * 2;
217             System.out.println();
218         }
219
220         BitVector[] bvs = new BitVector[vals_length];
221         MemoryUtils.readMem(fp, mem1, 0, bvs);
222         System.out.println("results:");
223         for(int i=0; i<vals_length; i++)
224             System.out.println(bvs[i].toLong());
225     }
226
227 }