unrolling forests without recursion
[sbp.git] / src / edu / berkeley / sbp / Forest.java
index 9dabf60..937a229 100644 (file)
@@ -16,6 +16,42 @@ public abstract class Forest<T> /*extends PrintableTree<Forest.MyBody<T>>*/
     private final int idx = master_idx++;
     public int toInt() { return idx; }
 
+    public abstract void expand(TaskList tl, HashSet<Tree<T>> ht);
+    public abstract void gather(TaskList tl, HashSet<Tree<T>>[] ht, HashSet<Tree<T>> target);
+
+    public static class TaskList extends ArrayList<TaskList.Task> {
+        public interface Task {
+            public void perform();
+        }
+        public class ExpandTask implements Task {
+            private Forest f;
+            private HashSet hs;
+            public ExpandTask(Forest f, HashSet hs) { this.f = f; this.hs = hs; }
+            public void perform() { f.expand(TaskList.this, hs); }
+        }
+        public class GatherTask implements Task {
+            private Forest f;
+            private HashSet[] ht;
+            private HashSet hs;
+            public GatherTask(Forest f, HashSet<Tree<?>>[] ht, HashSet<Tree<?>> hs) { this.f = f; this.hs = hs; this.ht = ht;}
+            public void perform() { f.gather(TaskList.this, ht, hs); }
+        }
+        public void expand(Forest f, HashSet hs) {
+            add(new ExpandTask(f, hs));
+        }
+        public void gather(Forest f, HashSet[] ht, HashSet hs) {
+            add(new GatherTask(f, ht, hs));
+        }
+        public void run() {
+            while(true) {
+                if (isEmpty()) return;
+                Task task = get(size()-1);
+                remove(size()-1);
+                task.perform();
+            }
+        }
+    }
+
     /** assume that this forest contains exactly one tree and return it; otherwise throw an exception */
     public final Tree<T> expand1() throws Ambiguous, ParseFailed {
         try {
@@ -27,8 +63,15 @@ public abstract class Forest<T> /*extends PrintableTree<Forest.MyBody<T>>*/
 
     /** expand this forest into a set of trees */
     public HashSet<Tree<T>> expand(boolean toss) {
+        /*
         final HashSetTreeConsumer<T> ret = new HashSetTreeConsumer<T>();
         visit(new TreeMaker2<T>(toss, ret), null, null);
+        */
+        TaskList tl = new TaskList();
+        HashSet<Tree<T>> ret = new HashSet<Tree<T>>();
+        tl.expand(this, ret);
+        tl.run();
+
         if (toss && ret.size() > 1) throw new InnerAmbiguous(this);
         return ret;
     }
@@ -71,7 +114,7 @@ public abstract class Forest<T> /*extends PrintableTree<Forest.MyBody<T>>*/
         public GraphViz.Node toGraphViz(GraphViz gv) {
             if (gv.hasNode(this)) return gv.createNode(this);
             GraphViz.Node n = gv.createNode(this);
-            n.label = StringUtil.escapify(headToString()==null?"":headToString(), "\r\n");
+            n.label = headToString()==null?"":headToString();
             n.directed = true;
             n.comment = reduction==null?null:reduction+"";
             edges(n);
@@ -113,6 +156,41 @@ public abstract class Forest<T> /*extends PrintableTree<Forest.MyBody<T>>*/
             this.reduction = reduction;
             this.labels = labels;
         }
+        public void gather(TaskList tl, HashSet<Tree<T>>[] ht, HashSet<Tree<T>> target) {
+            gather(tl, ht, target, new Tree[ht.length], 0);
+        }
+        private void gather(TaskList tl, HashSet<Tree<T>>[] ht, HashSet<Tree<T>> target, Tree[] trees, int i) {
+            if (i==ht.length) {
+                target.add(new Tree<T>(null, tag, trees));
+                return;
+            }
+            for(Tree<T> tree : ht[i]) {
+                if (unwrap && i==trees.length-1) {
+                    // I think this is wrong
+                    Tree[] trees2 = new Tree[trees.length - 1 + tree.numChildren()];
+                    System.arraycopy(trees, 0, trees2, 0, trees.length-1);
+                    for(int j=0; j<tree.numChildren(); j++)
+                        trees2[trees.length-1+j] = tree.child(j);
+                    target.add(new Tree<T>(null, tag, trees2));
+                } else {
+                    trees[i] = tree;
+                    gather(tl, ht, target, trees, i+1);
+                    trees[i] = null;
+                }
+            }
+        }
+        public void expand(TaskList tl, HashSet<Tree<T>> ht) {
+            if (singleton) {
+                tokens[0].expand(tl, ht);
+                return;
+            }
+            HashSet<Tree<T>>[] children = new HashSet[tokens.length];
+            tl.gather(this, children, ht);
+            for(int i=0; i<children.length; i++) {
+                children[i] = new HashSet<Tree<T>>();
+                tl.expand(tokens[i], children[i]);
+            }
+        }
 
         public void expand(final int i, final TreeMaker<T> h) {
             if (singleton) {
@@ -159,6 +237,16 @@ public abstract class Forest<T> /*extends PrintableTree<Forest.MyBody<T>>*/
      *  viewed, it becomes immutable
      */
     static class Ref<T> extends Forest<T> {
+        public void expand(TaskList tl, HashSet<Tree<T>> ht) {
+            for (Forest<T> f : hp) f.expand(tl, ht);
+        }
+        public void gather(TaskList tl, HashSet<Tree<T>>[] ht, HashSet<Tree<T>> target) {
+            throw new Error();
+        }
+        public HashSet<GSS.Phase.Node> parents = new HashSet<GSS.Phase.Node>();
+        public boolean contains(Forest f) {
+            return hp.contains(f);
+        }
         public boolean ambiguous() {
             if (hp.size()==0) return false;
             if (hp.size()==1) return hp.iterator().next().ambiguous();