added GraphViz support
[sbp.git] / src / edu / berkeley / sbp / util / GraphViz.java
diff --git a/src/edu/berkeley/sbp/util/GraphViz.java b/src/edu/berkeley/sbp/util/GraphViz.java
new file mode 100644 (file)
index 0000000..a92e298
--- /dev/null
@@ -0,0 +1,123 @@
+package edu.berkeley.sbp.util;
+import edu.berkeley.sbp.util.*;
+import edu.berkeley.sbp.*;
+import java.io.*;
+import java.util.*;
+import java.lang.reflect.*;
+import java.lang.ref.*;
+
+public class GraphViz {
+
+    IdentityHashMap<ToGraphViz,Node> ihm = new IdentityHashMap<ToGraphViz,Node>();
+    HashMap<Node,Group> groups = new HashMap<Node,Group>();
+
+    public class Group {
+        public Group() { }
+        public void add(Node n) { groups.put(n, this); }
+    }
+
+    private static int master_idx=0;
+    public class Node {
+        private final int idx = master_idx++;
+        public String label;
+        public boolean directed = false;
+        public String color="black";
+        public ArrayList<Node> edges = new ArrayList<Node>();
+        public ArrayList<Node> inbound = new ArrayList<Node>();
+        public void edge(ToGraphViz o) {
+            Node n = o.toGraphViz(GraphViz.this);
+            if (n==null) return;
+            edges.add(n);
+            n.inbound.add(this);
+        }
+        public String name() {
+            if (inbound.size()==1 && inbound.get(0).simple())
+                return inbound.get(0).name()+":node_"+idx;
+            return "node_"+idx;
+        }
+        public void edges(PrintWriter pw) {
+            if (simple()) return;
+            for(Node n : edges)
+                pw.println("    "+name()+" -> " + n.name());
+        }
+        public int numEdges() { return edges.size(); }
+        public boolean simple() {
+            boolean simple = true;
+            if (label!=null && !label.equals("")) simple = false;
+            if (simple)
+                for(Node n : edges)
+                    //if (n.numEdges()>0) { simple = false; break; }
+                    if (n.inbound.size() > 1) { simple = false; break; }
+            return simple;
+        }
+        public void dump(PrintWriter pw) {
+            if (inbound.size() > 0) {
+                boolean good = false;
+                for(Node n : inbound)
+                    if (!n.simple())
+                        { good = true; break; }
+                if (!good) return;
+            }
+            pw.print("    "+name());
+            pw.print(" [");
+            if (directed) pw.print("ordering=out");
+            if (simple()) {
+                pw.print(" shape=record ");
+                pw.print(" label=\"{");
+                boolean first = true;
+                for(Node n : edges) {
+                    if (!first) pw.print("|");
+                    first = false;
+                    pw.print("<"+n.name()+">");
+                    pw.print(StringUtil.escapify(n.label,"\\\""));
+                }
+                pw.print("}\"");
+            } else {
+                pw.print(" label=\"");
+                pw.print(StringUtil.escapify(label,"\\\""));
+                pw.print("\"");
+            }
+            pw.print("color="+color);
+            pw.print("];\n");
+        }
+    }
+
+    public boolean hasNode(ToGraphViz o) {
+        return ihm.get(o)!=null;
+    }
+
+    public Node createNode(ToGraphViz o) {
+        Node n = ihm.get(o);
+        if (n!=null) return n;
+        n = new Node();
+        ihm.put(o, n);
+        return n;
+    }
+
+    public static interface ToGraphViz {
+        public Node    toGraphViz(GraphViz gv);
+        public boolean isTransparent();
+        public boolean isHidden();
+    }
+
+    public void dump(PrintWriter pw) {
+        IdentityHashMap<Node,Node> done = new IdentityHashMap<Node,Node>();
+        pw.println("digraph G {\n");
+        for(Group g : groups.values()) {
+            pw.println("  { rank=same;\n");
+            for(Node n : groups.keySet())
+                if (groups.get(n)==g) {
+                    done.put(n,n);
+                    n.dump(pw);
+                }
+            pw.println("  }\n");
+        }
+        for(Node n : ihm.values()) {
+            if (done.get(n)!=null) continue;
+            n.dump(pw);
+        }
+        for(Node n : ihm.values()) n.edges(pw);
+        pw.println("}\n");
+    }
+
+}