MergeSort.java: reformatting
[fleet.git] / src / edu / berkeley / fleet / dataflow / MergeSort.java
1 package edu.berkeley.fleet.dataflow;
2 import java.util.*;
3 import java.io.*;
4 import edu.berkeley.fleet.loops.*;
5 import edu.berkeley.fleet.api.*;
6 import edu.berkeley.fleet.interpreter.*;
7 import edu.berkeley.fleet.fpga.*;
8 import org.ibex.graphics.*;
9
10 public class MergeSort {
11
12     private final Fleet fleet;
13     private final Ship mem1;
14     private final Ship mem2;
15     private final int arity;
16     private final ShipPool pool;
17     private final Program program;
18     private final DataFlowGraph dfg;
19
20     private ParameterNode[] pn0;
21     private ParameterNode[] pn1;
22     private ParameterNode[] pn2;
23     private ParameterNode[] pn3;
24     private ParameterNode[] pn4;
25     private ParameterNode[] pn_base1;
26     private ParameterNode   pn_base2;
27     private ParameterNode   pn5;
28     private ParameterNode   pn6;
29     private ParameterNode   pn_end;
30
31     CodeBag cb2 = null;
32     Destination next_dest = null;
33
34     public MergeSort(Fleet fleet, Program program, ShipPool pool, int arity, Ship mem1, Ship mem2) {
35         this.fleet = fleet;
36         this.mem1 = mem1;
37         this.mem2 = mem2;
38         this.arity = arity;
39         this.pool = pool;
40         this.program = program;
41         this.dfg = new DataFlowGraph(fleet, pool);
42         next_dest = makeDfgCodeBag(dfg, program);
43         cb2 = dfg.build(new CodeBag(dfg.fleet, program));
44         cb2.seal();
45     }
46
47     public Destination makeDfgCodeBag(DataFlowGraph dfg, Program program) {
48
49         MemoryNode mem_read  = new MemoryNode(dfg, mem1);
50         MemoryNode mem_write = (mem1==mem2) ? mem_read : new MemoryNode(dfg, mem2);
51
52         pn0 = new ParameterNode[arity];
53         pn1 = new ParameterNode[arity];
54         pn2 = new ParameterNode[arity];
55         pn3 = new ParameterNode[arity];
56         pn4 = new ParameterNode[arity];
57         pn_base1 = new ParameterNode[arity];
58         pn_base2 = new ParameterNode(dfg, true);
59         pn_end = new ParameterNode(dfg);
60         pn5 = new ParameterNode(dfg);
61         pn6 = new ParameterNode(dfg, true);
62
63         AluNode sm = new AluNode(dfg, "MAXMERGE");
64         // So far: we have four spare Counter ships; one can be used for resetting
65         for(int i=0; i<arity; i++) {
66
67             DownCounterNode c0 = new DownCounterNode(dfg);
68             c0.start.connect((pn0[i] = new ParameterNode(dfg, true)).out);
69             c0.incr.connect(new ForeverNode(dfg, 1).out);
70
71             DownCounterNode c1 = new DownCounterNode(dfg);
72             c1.start.connect((pn1[i] = new ParameterNode(dfg)).out);
73             c1.incr.connect((pn2[i] = new ParameterNode(dfg)).out);
74
75             RepeatNode r1 = new RepeatNode(dfg);
76             r1.val.connect(c1.out);
77             r1.count.connect((pn3[i] = new ParameterNode(dfg, true)).out);
78
79             AluNode alu1 = new AluNode(dfg, "ADD");
80             alu1.in1.connect(r1.out);
81             alu1.in2.connect(c0.out);
82
83             AluNode alu2 = new AluNode(dfg, "ADD");
84             alu2.in1.connect((pn_base1[i] = new ParameterNode(dfg, true)).out);
85             alu2.in2.connect(alu1.out);
86             alu2.out.connect(i==0 ? mem_read.inAddrRead1 : mem_read.inAddrRead2);
87
88             PunctuatorNode punc = new PunctuatorNode(dfg, -1);
89             punc.count.connect((pn4[i] = new ParameterNode(dfg, true)).out);
90             punc.val.connect(i==0 ? mem_read.outRead1 : mem_read.outRead2);
91             punc.out.connect(i==0 ? sm.in1 : sm.in2);
92         }
93
94         UnPunctuatorNode unpunc = new UnPunctuatorNode(dfg);
95         unpunc.val.connect(sm.out);
96         unpunc.count.connect(pn6.out);
97
98         DownCounterNode cw = new DownCounterNode(dfg);
99         cw.start.connect(pn5.out);
100         cw.incr.connect(new OnceNode(dfg, 1).out);
101
102         AluNode alu = new AluNode(dfg, "ADD");
103         alu.in1.connect(pn_base2.out);
104         cw.out.connect(alu.in2);
105
106         mem_write.inAddrWrite.connect(alu.out);
107         mem_write.inDataWrite.connect(unpunc.out);
108
109         UnPunctuatorNode discard = new UnPunctuatorNode(dfg, true);
110         discard.count.connect(pn_end.out);
111         mem_write.outWrite.connect(discard.val);
112         DoneNode done = new DoneNode(dfg, program);
113         discard.out.connect(done.in);
114
115         return done.getDestinationToSendNextCodeBagDescriptorTo();
116     }
117
118     public CodeBag makeInstance(int offset, int length) throws Exception {
119
120         long base_read = offset;
121         long base_write = length + offset;
122         int stride = 1;
123         ArrayList<CodeBag> codeBags = new ArrayList<CodeBag>();
124         codeBags.add(new CodeBag(dfg.fleet, program));
125         for(;;) {
126             CodeBag cb = codeBags.get(codeBags.size()-1);
127             boolean last = stride*2 >= length;
128             CodeBag next;
129             if (last) {
130                 next = program.getEndProgramCodeBag();
131             } else {
132                 next = new CodeBag(dfg.fleet, program);
133                 codeBags.add(codeBags.size(), next);
134             }
135
136             for(int i=0; i<arity; i++) {
137                 if (pn0[i]!=null) pn0[i].set(cb, stride);
138                 if (pn1[i]!=null) pn1[i].set(cb, length + i*stride);
139                 if (pn2[i]!=null) pn2[i].set(cb, 2*stride);
140                 if (pn3[i]!=null) pn3[i].set(cb, stride);
141                 if (pn4[i]!=null) pn4[i].set(cb, stride);
142                 if (pn_base1[i]!=null) pn_base1[i].set(cb, base_read);
143             }
144             if (pn5!=null) pn5.set(cb, length);
145             if (pn6!=null) pn6.set(cb, 2*stride+1);
146             if (pn_base2!=null) pn_base2.set(cb, base_write);
147             if (pn_end!=null) pn_end.set(cb, length);
148             cb.sendWord(cb2.getDescriptor(), program.getCBDDestination());
149             cb.sendWord(next.getDescriptor(), next_dest);
150
151             cb.seal();
152             long base = base_read; base_read=base_write; base_write=base;
153             if (last) break;
154             if ((stride*2) < length) stride *= 2;
155         }
156
157         return codeBags.get(0);
158     }
159
160     public static void main(String[] s) throws Exception {
161         if (s.length != 4) {
162             System.err.println("usage: java " + MergeSort.class.getName() + " <target> <shipname> <base> <length>");
163             System.exit(-1);
164         }
165         Fleet fleet = null;
166         if (s[0].equals("fpga"))             fleet = new Fpga();
167         else if (s[0].equals("interpreter")) {
168             fleet = new Interpreter();
169             Log.log = null;
170         }
171         ShipPool pool = new ShipPool(fleet);
172         Ship memory = pool.allocateShip(s[1]);
173         int base = Integer.parseInt(s[2]);
174         int length = Integer.parseInt(s[3]);
175
176         Random random = new Random(System.currentTimeMillis());        
177         BitVector[] vals  = new BitVector[length];
178         long[] longs = new long[length];
179         for(int i=0; i<vals.length; i++) {
180             vals[i] = new BitVector(fleet.getWordWidth()).set(random.nextLong());
181             for(int j=36; j<vals[i].length(); j++) vals[i].set(j, false);
182             longs[i] = vals[i].toLong();
183         }
184         Arrays.sort(longs);
185
186         Ship codemem = pool.allocateShip("Memory");
187
188         FleetProcess fp = fleet.run(new Instruction[0]);
189         ShipPool pool2 = new ShipPool(pool);
190
191         MemoryUtils.writeMem(fp, pool2, memory, base, vals);
192         pool2.releaseAll();
193
194         Program program = new Program(codemem);
195         CodeBag cb = new MergeSort(fleet, program, pool2, 2, memory, memory).makeInstance(base, length);
196         pool2.releaseAll();
197         long ret = 0;
198         ret = program.run(fp, cb, pool2);
199         pool2.releaseAll();
200
201         MemoryUtils.readMem(fp, pool2, memory, base, vals);
202         pool2.releaseAll();
203
204         System.out.println();
205         int fails = 0;
206         for(int i=0; i<vals.length; i++)
207             if (vals[i].toLong() != longs[i]) {
208                 System.out.println("disagreement!  on index " +
209                                    i + "\n  expected="+
210                                    new BitVector(fleet.getWordWidth()).set(longs[i])+
211                                    "\n       got="+vals[i]);
212                 fails++;
213             }
214         System.out.println("done! ("+fails+" failures)");
215         if (fails>0) System.exit(-1);
216         
217     }
218     
219 }