init
[org.ibex.arenaj.git] / src / edu / berkeley / cs / megacz / Transformer.java
1 package edu.berkeley.cs.megacz;
2 import soot.*;
3 import soot.jimple.*;
4 import soot.util.*;
5 import java.io.*;
6 import java.util.*;
7
8 public class Transformer {    
9     public static void main(String[] args) 
10     {
11         if(args.length == 0)
12         {
13             System.out.println("Syntax: java ashes.examples.countgotos.Main [soot options]");
14             System.exit(0);
15         }            
16         
17         PackManager.v().getPack("jtp").add(new Transform("jtp.instrumenter", GotoInstrumenter.v()));
18
19         // Just in case, resolve the PrintStream SootClass.
20         Scene.v().addBasicClass("java.io.PrintStream",SootClass.SIGNATURES);
21         soot.Main.main(args);
22     }
23
24     static class GotoInstrumenter extends BodyTransformer {
25         private static GotoInstrumenter instance = new GotoInstrumenter();
26         private GotoInstrumenter() {}
27
28         public static GotoInstrumenter v() { return instance; }
29
30         private boolean addedFieldToMainClassAndLoadedPrintStream = false;
31         private SootClass javaIoPrintStream;
32
33         private Local addTmpRef(Body body) {
34             Local tmpRef = Jimple.v().newLocal("tmpRef", RefType.v("java.io.PrintStream"));
35             body.getLocals().add(tmpRef);
36             return tmpRef;
37         }
38      
39         private Local addTmpLong(Body body) {
40             Local tmpLong = Jimple.v().newLocal("tmpLong", LongType.v()); 
41             body.getLocals().add(tmpLong);
42             return tmpLong;
43         }
44
45         private void addStmtsToBefore(Chain units, Stmt s, SootField gotoCounter, Local tmpRef, Local tmpLong)
46         {
47             // insert "tmpRef = java.lang.System.out;" 
48             units.insertBefore(Jimple.v().newAssignStmt( 
49                                                         tmpRef, Jimple.v().newStaticFieldRef( 
50                                                                                              Scene.v().getField("<java.lang.System: java.io.PrintStream out>").makeRef())), s);
51
52             // insert "tmpLong = gotoCounter;" 
53             units.insertBefore(Jimple.v().newAssignStmt(tmpLong, 
54                                                         Jimple.v().newStaticFieldRef(gotoCounter.makeRef())), s);
55             
56             // insert "tmpRef.println(tmpLong);" 
57             SootMethod toCall = javaIoPrintStream.getMethod("void println(long)");                    
58             units.insertBefore(Jimple.v().newInvokeStmt(
59                                                         Jimple.v().newVirtualInvokeExpr(tmpRef, toCall.makeRef(), tmpLong)), s);
60         }
61
62         protected void internalTransform(Body body, String phaseName, Map options) {
63             SootClass sClass = body.getMethod().getDeclaringClass();
64             SootField gotoCounter = null;
65             boolean addedLocals = false;
66             Local tmpRef = null, tmpLong = null;
67             Chain units = body.getUnits();
68
69             System.out.println("sClass is " + sClass);
70         
71             // Add code at the end of the main method to print out the 
72             // gotoCounter (this only works in simple cases, because you may have multiple returns or System.exit()'s )
73             synchronized(this)
74                 {
75                     if (!Scene.v().getMainClass().
76                         declaresMethod("void main(java.lang.String[])"))
77                         throw new RuntimeException("couldn't find main() in mainClass");
78
79                     if (addedFieldToMainClassAndLoadedPrintStream)
80                         gotoCounter = Scene.v().getMainClass().getFieldByName("gotoCount");
81                     else
82                         {
83                             // Add gotoCounter field
84                             gotoCounter = new SootField("gotoCount", LongType.v(), 
85                                                         Modifier.STATIC);
86                             Scene.v().getMainClass().addField(gotoCounter);
87
88                             javaIoPrintStream = Scene.v().getSootClass("java.io.PrintStream");
89
90                             addedFieldToMainClassAndLoadedPrintStream = true;
91                         }
92                 }
93             
94             // Add code to increase goto counter each time a goto is encountered
95             {
96                 boolean isMainMethod = body.getMethod().getSubSignature().equals("void main(java.lang.String[])");
97
98                 Local tmpLocal = Jimple.v().newLocal("tmp", LongType.v());
99                 body.getLocals().add(tmpLocal);
100                 
101                 Iterator stmtIt = units.snapshotIterator();
102             
103                 while(stmtIt.hasNext())
104                     {
105                         Stmt s = (Stmt) stmtIt.next();
106
107                         if(s instanceof GotoStmt)
108                             {
109                                 AssignStmt toAdd1 = Jimple.v().newAssignStmt(tmpLocal, 
110                                                                              Jimple.v().newStaticFieldRef(gotoCounter.makeRef()));
111                                 AssignStmt toAdd2 = Jimple.v().newAssignStmt(tmpLocal,
112                                                                              Jimple.v().newAddExpr(tmpLocal, LongConstant.v(1L)));
113                                 AssignStmt toAdd3 = Jimple.v().newAssignStmt(Jimple.v().newStaticFieldRef(gotoCounter.makeRef()), 
114                                                                              tmpLocal);
115
116                                 // insert "tmpLocal = gotoCounter;"
117                                 units.insertBefore(toAdd1, s);
118                         
119                                 // insert "tmpLocal = tmpLocal + 1L;" 
120                                 units.insertBefore(toAdd2, s);
121
122                                 // insert "gotoCounter = tmpLocal;" 
123                                 units.insertBefore(toAdd3, s);
124                             }
125                         else if (s instanceof InvokeStmt)
126                             {
127                                 InvokeExpr iexpr = (InvokeExpr) ((InvokeStmt)s).getInvokeExpr();
128                                 if (iexpr instanceof StaticInvokeExpr)
129                                     {
130                                         SootMethod target = ((StaticInvokeExpr)iexpr).getMethod();
131                         
132                                         if (target.getSignature().equals("<java.lang.System: void exit(int)>"))
133                                             {
134                                                 if (!addedLocals)
135                                                     {
136                                                         tmpRef = addTmpRef(body); tmpLong = addTmpLong(body);
137                                                         addedLocals = true;
138                                                     }
139                                                 addStmtsToBefore(units, s, gotoCounter, tmpRef, tmpLong);
140                                             }
141                                     }
142                             }
143                         else if (isMainMethod && (s instanceof ReturnStmt || s instanceof ReturnVoidStmt))
144                             {
145                                 if (!addedLocals)
146                                     {
147                                         tmpRef = addTmpRef(body); tmpLong = addTmpLong(body);
148                                         addedLocals = true;
149                                     }
150                                 addStmtsToBefore(units, s, gotoCounter, tmpRef, tmpLong);
151                             }
152                     }
153             }
154         }
155     }
156 }