more cleanup
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import edu.berkeley.fleet.loops.*;
3 import edu.berkeley.fleet.api.*;
4 import edu.berkeley.fleet.fpga.*;
5 import java.util.*;
6
7 public class MergeSort {
8     public static long[] mergeSort(FleetProcess fp, Fleet fleet,
9                                    long[] vals, int vals_length, int stride_length,
10                                    Ship memoryShip1, Ship memoryShip2) throws Exception {
11
12         if (vals != null) {
13             BitVector[] mem = new BitVector[vals_length];
14             for(int i=0; i<mem.length; i++) mem[i] = new BitVector(fleet.getWordWidth()).set(vals[i]);
15             MemoryUtils.writeMem(fp, memoryShip1, 0, mem);
16         }
17
18         //////////////////////////////////////////////////////////////////////////////
19
20         DataFlowGraph proc = new DataFlowGraph(fleet);
21         DebugNode dm = new DebugNode(proc);
22
23         int end_of_data = vals_length;
24         int num_strides = end_of_data / (stride_length * 2);
25
26         MemoryNode mm  = new MemoryNode(proc, memoryShip1);
27         SortedMergeNode sm = new SortedMergeNode(proc);
28
29         // So far: we have four spare Counter ships; one can be used for resetting
30         for(int i=0; i<2; i++) {
31
32             DownCounterNode c0 = new DownCounterNode(proc);
33             DownCounterNode c1 = new DownCounterNode(proc);
34
35             c0.start.connect(new ForeverNode(proc, stride_length).out);
36             c0.incr.connect(new ForeverNode(proc, 1).out);
37
38             c1.start.connect(new OnceNode(proc, end_of_data + i*stride_length).out);
39             c1.incr.connect(new OnceNode(proc, stride_length*2).out);
40
41             RepeatNode r1 = new RepeatNode(proc);
42             r1.val.connect(c1.out);
43             r1.count.connect(new ForeverNode(proc, stride_length).out);
44
45             AluNode alu = new AluNode(proc);
46             alu.in1.connect(r1.out);
47             alu.in2.connect(c0.out);
48             alu.inOp.connect(new ForeverNode(proc, ((Node.DockInPort)alu.inOp).getConstant("ADD")).out);
49             alu.out.connect(i==0 ? mm.inAddrRead1 : mm.inAddrRead2);
50
51             PunctuatorNode punc = new PunctuatorNode(proc, -1);
52             punc.count.connect(new ForeverNode(proc, stride_length).out);
53             punc.val.connect(i==0 ? mm.outRead1 : mm.outRead2);
54             punc.out.connect(i==0 ? sm.in1 : sm.in2);
55         }
56
57         UnPunctuatorNode unpunc = new UnPunctuatorNode(proc);
58         unpunc.val.connect(sm.out);
59         unpunc.count.connect(new ForeverNode(proc, 2*stride_length).out);
60
61         DownCounterNode cw = new DownCounterNode(proc);
62         cw.start.connect(new OnceNode(proc, end_of_data).out);
63         cw.incr.connect(new OnceNode(proc, 1).out);
64
65         MemoryNode mm2 = new MemoryNode(proc, memoryShip2);
66         mm2.inAddrWrite.connect(cw.out);
67         mm2.inDataWrite.connect(unpunc.out);
68         mm2.outWrite.connect(dm.in);
69
70         //////////////////////////////////////////////////////////////////////////////
71
72         Context ctx = new Context(fp.getFleet());
73         ctx.setAutoflush(true);
74
75         ArrayList<Instruction> ai = new ArrayList<Instruction>();
76         proc.build(ctx);
77         ctx.emit(ai);
78         for(Instruction ins : ai) {
79             //System.out.println(ins);
80             fp.sendInstruction(ins);
81         }
82         fp.flush();
83
84         for(int i=0; i<vals_length; i++) {
85             System.out.print("\rreading back... " + i+"/"+vals_length+"  ");
86             BitVector rec = fp.recvWord();
87             System.out.print(" (prev result: " + rec + " = " + rec.toLong() + ")");
88         }
89         System.out.println("\rdone.                                                                    ");
90
91         //if (true) return ret;
92
93         Context ctx2 = new Context(fp.getFleet());
94         Dock debugIn = fleet.getShip("Debug",0).getDock("in");
95         Dock fred = debugIn;
96         fp.sendToken(debugIn.getInstructionDestination());
97         fp.flush();
98
99         LoopFactory lf = new LoopFactory(ctx2, debugIn, 0);
100         lf.literal(0);
101         lf.abortLoopIfTorpedoPresent();
102         lf.recvToken();
103         lf.deliver();
104
105         ctx2.dispatch(fp);
106         fp.flush();
107
108         int count = 0;
109
110         Ship counter = proc.pool.allocateShip("Counter");
111
112         for(int phase=0; phase<=3; phase++) {
113             System.out.println("== phase "+phase+" ==================================================================");
114             ctx2 = new Context(fp.getFleet());
115
116             Destination ackDestination = counter.getDock("in2").getDataDestination();
117             int expected_tokens = proc.reset(ctx2, phase, ackDestination);
118
119             Context ctx3 = new Context(fp.getFleet());
120             lf = new LoopFactory(ctx3, counter.getDock("inOp"), 1);
121             lf.literal("DROP_C1_V2");
122             lf.deliver();
123             lf.literal(5);
124             lf.deliver();
125             lf = new LoopFactory(ctx3, counter.getDock("in1"), 1);
126             lf.literal(expected_tokens-1);
127             lf.deliver();
128             lf.literal(1);
129             lf.deliver();
130             lf = new LoopFactory(ctx3, counter.getDock("in2"), 0);
131             lf.abortLoopIfTorpedoPresent();
132             lf.recvWord();
133             lf.deliver();
134             lf = new LoopFactory(ctx3, counter.getDock("out"), 1);
135             lf.collectWord();
136             lf.sendToken(counter.getDock("in2").getInstructionDestination());  // HACK: we don't check to make sure this hits
137             lf.sendToken(debugIn.getDataDestination());
138             ctx3.dispatch(fp);  // HACK: we don't check to make sure that this is "firmly in place"
139
140             for(Dock dock : DataFlowGraph.torpedoes) fp.sendToken(dock.getInstructionDestination());
141             ctx2.dispatch(fp);
142             fp.flush();
143             System.out.println("flushed");
144
145             fp.recvWord();
146             System.out.println("phase done");
147
148             System.out.println();
149         }
150
151         fp.sendToken(debugIn.getInstructionDestination());
152         fp.flush();
153
154         //System.out.println("verifying cleanup:");
155         //CleanupUtils.verifyClean(fp);
156
157         System.out.println("reading back:");
158         long[] ret = null;
159         if (vals != null) {
160             ret = new long[vals_length];
161             BitVector[] mem = new BitVector[vals_length];
162             MemoryUtils.readMem(fp, memoryShip2, 0, mem);
163             for(int i=0; i<ret.length; i++) ret[i] = mem[i].toLong();
164         }
165         return ret;
166     }
167
168     /** demo */
169     public static void main(String[] s) throws Exception {
170         Fleet fleet = new Fpga();
171         //Fleet fleet = new Interpreter(false);
172
173         Random random = new Random(System.currentTimeMillis());
174         long[] vals = new long[256];
175         for(int i=0; i<vals.length; i++) {
176             vals[i] = Math.abs(random.nextInt());
177         }
178
179         Ship mem1 = fleet.getShip("Memory", 0);
180         Ship mem2 = fleet.getShip("Memory", 1);
181         //Ship mem2 = fleet.getShip("DDR2", 0);
182
183         FleetProcess fp;
184         int stride = 1;
185         fp = null;
186
187         fp = fleet.run(new Instruction[0]);
188         MemoryUtils.writeMem(fp, mem1, 0, vals);
189         int vals_length = vals.length;
190
191         // Disable readback/writeback inside the loop
192         vals = null;
193
194         while(stride < vals_length) {
195             
196             // reset the FleetProcess
197             //fp.terminate(); fp = null;
198
199             System.out.println("stride " + stride);
200
201             // if we reset the FleetProcess, restart it
202             if (fp==null) fp = fleet.run(new Instruction[0]);
203
204             // do the mergeSort
205             vals = MergeSort.mergeSort(fp, fleet, vals, vals_length, stride, mem1, mem2);
206
207             // verify the cleanup
208             //CleanupUtils.verifyClean(fp);
209
210             Ship mem = mem1; mem1=mem2; mem2=mem;
211
212             stride = stride * 2;
213             System.out.println();
214         }
215
216         BitVector[] bvs = new BitVector[vals_length];
217         MemoryUtils.readMem(fp, mem1, 0, bvs);
218         System.out.println("results:");
219         for(int i=0; i<vals_length; i++)
220             System.out.println(bvs[i].toLong());
221     }
222
223 }