diff --git a/Java/benchmark/pom.xml b/Java/benchmark/pom.xml index 2760ae02..a5fdc4ca 100644 --- a/Java/benchmark/pom.xml +++ b/Java/benchmark/pom.xml @@ -6,7 +6,7 @@ software.amazon.randomcutforest randomcutforest-parent - 3.5.1 + 3.6.0 randomcutforest-benchmark diff --git a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java index ca959781..4ee272a0 100644 --- a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java +++ b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java @@ -18,7 +18,6 @@ import java.util.List; import java.util.Random; -import org.github.jamm.MemoryMeter; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Level; @@ -131,10 +130,6 @@ public RandomCutForest scoreAndUpdate(BenchmarkState state, Blackhole blackhole) } blackhole.consume(score); - if (!forest.parallelExecutionEnabled) { - MemoryMeter meter = new MemoryMeter(); - System.out.println(" forest size " + meter.measureDeep(forest)); - } return forest; } diff --git a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java index 76faf4b0..cafec955 100644 --- a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java +++ b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java @@ -18,7 +18,6 @@ import java.util.List; import java.util.Random; -import org.github.jamm.MemoryMeter; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Level; @@ -130,10 +129,6 @@ public RandomCutForest scoreAndUpdate(BenchmarkState state, Blackhole blackhole) } blackhole.consume(score); - if (!forest.parallelExecutionEnabled) { - MemoryMeter meter = new MemoryMeter(); - System.out.println(" forest size " + meter.measureDeep(forest)); - } return forest; } diff --git a/Java/core/pom.xml b/Java/core/pom.xml index d801923d..39e321a7 100644 --- a/Java/core/pom.xml +++ b/Java/core/pom.xml @@ -6,7 +6,7 @@ software.amazon.randomcutforest randomcutforest-parent - 3.5.1 + 3.6.0 randomcutforest-core diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java b/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java index f475d181..6cd6f188 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java @@ -16,6 +16,7 @@ package com.amazon.randomcutforest; import java.util.Objects; +import java.util.function.Supplier; import com.amazon.randomcutforest.tree.IBoundingBoxView; @@ -38,11 +39,19 @@ private CommonUtils() { * @throws IllegalArgumentException if {@code condition} is false. */ public static void checkArgument(boolean condition, String message) { + if (!condition) { throw new IllegalArgumentException(message); } } + // a lazy equivalent of the above, which avoids parameter evaluation + public static void checkArgument(boolean condition, Supplier messageSupplier) { + if (!condition) { + throw new IllegalArgumentException(messageSupplier.get()); + } + } + /** * Throws an {@link IllegalStateException} with the specified message if the * specified input is false. diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java index a28e4a0b..9b36e65a 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java @@ -336,8 +336,9 @@ public RandomCutForest singlePrecisionForest(RandomCutForest.Builder builder, tree = extTrees.get(i); } else if (treeStates != null) { tree = treeMapper.toModel(treeStates.get(i), context, random.nextLong()); - sampler.getSample().forEach(s -> tree.addPoint(s.getValue(), s.getSequenceIndex())); + sampler.getSample().forEach(s -> tree.addPointToPartialTree(s.getValue(), s.getSequenceIndex())); tree.setConfig(Config.BOUNDING_BOX_CACHE_FRACTION, treeStates.get(i).getBoundingBoxCacheFraction()); + tree.validateAndReconstruct(); } else { // using boundingBoxCahce for the new tree tree = new RandomCutTree.Builder().capacity(state.getSampleSize()).randomSeed(random.nextLong()) diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java index e2b4caf6..7d77d26b 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java @@ -59,9 +59,8 @@ public AbstractNodeStore toModel(NodeStoreState state, CompactRandomCutTreeConte } // note boundingBoxCache is not set deliberately return AbstractNodeStore.builder().capacity(capacity).useRoot(root).leftIndex(leftIndex).rightIndex(rightIndex) - .cutDimension(cutDimension).cutValues(cutValue) - .dimensions(compactRandomCutTreeContext.getPointStore().getDimensions()) - .pointStoreView(compactRandomCutTreeContext.getPointStore()).build(); + .cutDimension(cutDimension).cutValues(cutValue).dimension(compactRandomCutTreeContext.getDimension()) + .build(); } @Override diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java index 1306475b..a17ff9aa 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java @@ -23,6 +23,7 @@ @Data public class CompactRandomCutTreeContext { private int maxSize; + private int dimension; private IPointStore pointStore; private Precision precision; } diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java index d79ad247..49caf14b 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java @@ -33,11 +33,12 @@ public class RandomCutTreeMapper @Override public RandomCutTree toModel(CompactRandomCutTreeState state, CompactRandomCutTreeContext context, long seed) { + int dimension = (state.getDimensions() != 0) ? state.getDimensions() : context.getPointStore().getDimensions(); + context.setDimension(dimension); AbstractNodeStoreMapper nodeStoreMapper = new AbstractNodeStoreMapper(); nodeStoreMapper.setRoot(state.getRoot()); AbstractNodeStore nodeStore = nodeStoreMapper.toModel(state.getNodeStoreState(), context); - int dimension = (state.getDimensions() != 0) ? state.getDimensions() : context.getPointStore().getDimensions(); // boundingBoxcache is not set deliberately; // it should be set after the partial tree is complete // likewise all the leaves, including the root, should be set to diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java index 04ddffb1..c87217ba 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java @@ -17,14 +17,8 @@ import static com.amazon.randomcutforest.CommonUtils.checkArgument; -import java.util.Arrays; -import java.util.HashMap; import java.util.Stack; -import java.util.function.Function; -import com.amazon.randomcutforest.MultiVisitor; -import com.amazon.randomcutforest.Visitor; -import com.amazon.randomcutforest.store.IPointStoreView; import com.amazon.randomcutforest.store.IndexIntervalManager; /** @@ -44,8 +38,6 @@ */ public abstract class AbstractNodeStore { - public static double SWITCH_FRACTION = 0.499; - public static int Null = -1; public static boolean DEFAULT_STORE_PARENT = false; @@ -57,74 +49,19 @@ public abstract class AbstractNodeStore { * number_of_leaves + X */ protected final int capacity; - protected final int dimensions; protected final float[] cutValue; - protected double boundingboxCacheFraction; protected IndexIntervalManager freeNodeManager; - protected double[] rangeSumData; - protected float[] boundingBoxData; - protected final IPointStoreView pointStoreView; - protected final HashMap leafMass; - protected boolean centerOfMassEnabled; - protected boolean storeSequenceIndexesEnabled; - protected float[] pointSum; - protected HashMap> sequenceMap; public AbstractNodeStore(AbstractNodeStore.Builder builder) { this.capacity = builder.capacity; - this.dimensions = builder.dimensions; if ((builder.leftIndex == null)) { freeNodeManager = new IndexIntervalManager(capacity); } - this.boundingboxCacheFraction = builder.boundingBoxCacheFraction; cutValue = (builder.cutValues != null) ? builder.cutValues : new float[capacity]; - leafMass = new HashMap<>(); - int cache_limit = (int) Math.floor(boundingboxCacheFraction * capacity); - rangeSumData = new double[cache_limit]; - boundingBoxData = new float[2 * dimensions * cache_limit]; - this.pointStoreView = builder.pointStoreView; - this.centerOfMassEnabled = builder.centerOfMassEnabled; - this.storeSequenceIndexesEnabled = builder.storeSequencesEnabled; - if (this.centerOfMassEnabled) { - pointSum = new float[(capacity) * dimensions]; - } - if (this.storeSequenceIndexesEnabled) { - sequenceMap = new HashMap<>(); - } } protected abstract int addNode(Stack pathToRoot, float[] point, long sendex, int pointIndex, int childIndex, - int cutDimension, float cutValue, BoundingBox box); - - protected int addLeaf(int pointIndex, long sequenceIndex) { - if (storeSequenceIndexesEnabled) { - HashMap leafMap = sequenceMap.remove(pointIndex); - if (leafMap == null) { - leafMap = new HashMap<>(); - } - Integer count = leafMap.remove(sequenceIndex); - if (count != null) { - leafMap.put(sequenceIndex, count + 1); - } else { - leafMap.put(sequenceIndex, 1); - } - sequenceMap.put(pointIndex, leafMap); - } - return pointIndex + capacity + 1; - } - - public void removeLeaf(int leafPointIndex, long sequenceIndex) { - HashMap leafMap = sequenceMap.remove(leafPointIndex); - checkArgument(leafMap != null, " leaf index not found in tree"); - Integer count = leafMap.remove(sequenceIndex); - checkArgument(count != null, " sequence index not found in leaf"); - if (count > 1) { - leafMap.put(sequenceIndex, count - 1); - sequenceMap.put(leafPointIndex, leafMap); - } else if (leafMap.size() > 0) { - sequenceMap.put(leafPointIndex, leafMap); - } - } + int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box); public boolean isLeaf(int index) { return index > capacity; @@ -134,7 +71,7 @@ public boolean isInternal(int index) { return index < capacity && index >= 0; } - public abstract void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex); + public abstract void assignInPartialTree(int savedParent, float[] point, int childReference); public abstract int getLeftIndex(int index); @@ -142,251 +79,14 @@ public boolean isInternal(int index) { public abstract void setRoot(int index); - public float[] getPointSum(int index) { - checkArgument(centerOfMassEnabled, " enable center of mass"); - return (isLeaf(index)) ? pointStoreView.getScaledPoint(getPointIndex(index), getMass(index)) - : Arrays.copyOfRange(pointSum, index * dimensions, (index + 1) * dimensions); - } - - public void invalidatePointSum(int index) { - for (int i = 0; i < dimensions; i++) { - pointSum[index * dimensions + i] = 0; - } - } - - public void recomputePointSum(int index) { - float[] left = getPointSum(getLeftIndex(index)); - float[] right = getPointSum(getRightIndex(index)); - for (int i = 0; i < dimensions; i++) { - pointSum[index * dimensions + i] = left[i] + right[i]; - } - } - - public void increaseLeafMass(int index) { - int y = (index - capacity - 1); - leafMass.merge(y, 1, Integer::sum); - } - - public int decreaseLeafMass(int index) { - int y = (index - capacity - 1); - Integer value = leafMass.remove(y); - if (value != null) { - if (value > 1) { - leafMass.put(y, (value - 1)); - return value; - } else { - return 1; - } - } else { - return 0; - } - } - - public void resizeCache(double fraction) { - if (fraction == 0) { - rangeSumData = null; - boundingBoxData = null; - } else { - int limit = (int) Math.floor(fraction * capacity); - rangeSumData = (rangeSumData == null) ? new double[limit] : Arrays.copyOf(rangeSumData, limit); - boundingBoxData = (boundingBoxData == null) ? new float[limit * 2 * dimensions] - : Arrays.copyOf(boundingBoxData, limit * 2 * dimensions); - } - boundingboxCacheFraction = fraction; - } - - public int translate(int index) { - if (rangeSumData.length <= index) { - return Integer.MAX_VALUE; - } else { - return index; - } - } - - void copyBoxToData(int idx, BoundingBox box) { - int base = 2 * idx * dimensions; - int mid = base + dimensions; - System.arraycopy(box.getMinValues(), 0, boundingBoxData, base, dimensions); - System.arraycopy(box.getMaxValues(), 0, boundingBoxData, mid, dimensions); - rangeSumData[idx] = box.getRangeSum(); - } - - public boolean checkContainsAndAddPoint(int index, float[] point) { - int idx = translate(index); - if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) { - int base = 2 * idx * dimensions; - int mid = base + dimensions; - double rangeSum = 0; - for (int i = 0; i < dimensions; i++) { - boundingBoxData[base + i] = Math.min(boundingBoxData[base + i], point[i]); - } - for (int i = 0; i < dimensions; i++) { - boundingBoxData[mid + i] = Math.max(boundingBoxData[mid + i], point[i]); - } - for (int i = 0; i < dimensions; i++) { - rangeSum += boundingBoxData[mid + i] - boundingBoxData[base + i]; - } - boolean answer = (rangeSumData[idx] == rangeSum); - rangeSumData[idx] = rangeSum; - return answer; - } - return false; - } - - public BoundingBox getBox(int index) { - if (isLeaf(index)) { - float[] point = pointStoreView.get(getPointIndex(index)); - return new BoundingBox(point, point); - } else { - checkArgument(isInternal(index), " incomplete state"); - int idx = translate(index); - if (idx != Integer.MAX_VALUE) { - if (rangeSumData[idx] != 0) { - // return non-trivial boxes - return getBoxFromData(idx); - } else { - BoundingBox box = reconstructBox(index, pointStoreView); - copyBoxToData(idx, box); - return box; - } - } - return reconstructBox(index, pointStoreView); - } - } - - public BoundingBox reconstructBox(int index, IPointStoreView pointStoreView) { - BoundingBox mutatedBoundingBox = getBox(getLeftIndex(index)); - growNodeBox(mutatedBoundingBox, pointStoreView, index, getRightIndex(index)); - return mutatedBoundingBox; - } - - boolean checkStrictlyContains(int index, float[] point) { - int idx = translate(index); - if (idx != Integer.MAX_VALUE) { - int base = 2 * idx * dimensions; - int mid = base + dimensions; - boolean isInside = true; - for (int i = 0; i < dimensions && isInside; i++) { - if (point[i] >= boundingBoxData[mid + i] || boundingBoxData[base + i] >= point[i]) { - isInside = false; - } - } - return isInside; - } - return false; - } - - public boolean checkContainsAndRebuildBox(int index, float[] point, IPointStoreView pointStoreView) { - int idx = translate(index); - if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) { - if (!checkStrictlyContains(index, point)) { - BoundingBox mutatedBoundingBox = reconstructBox(index, pointStoreView); - copyBoxToData(idx, mutatedBoundingBox); - return false; - } - return true; - } - return false; - } - - public BoundingBox getBoxFromData(int idx) { - int base = 2 * idx * dimensions; - int mid = base + dimensions; - - return new BoundingBox(Arrays.copyOfRange(boundingBoxData, base, base + dimensions), - Arrays.copyOfRange(boundingBoxData, mid, mid + dimensions)); - } - - protected void addBox(int index, float[] point, BoundingBox box) { - if (isInternal(index)) { - int idx = translate(index); - if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum - copyBoxToData(idx, box); - checkContainsAndAddPoint(index, point); - } - } - } - - public void growNodeBox(BoundingBox box, IPointStoreView pointStoreView, int node, int sibling) { - if (isLeaf(sibling)) { - float[] point = pointStoreView.get(getPointIndex(sibling)); - box.addPoint(point); - } else { - checkArgument(isInternal(sibling), " incomplete state " + sibling); - int siblingIdx = translate(sibling); - if (siblingIdx != Integer.MAX_VALUE) { - if (rangeSumData[siblingIdx] != 0) { - box.addBox(getBoxFromData(siblingIdx)); - } else { - BoundingBox newBox = getBox(siblingIdx); - copyBoxToData(siblingIdx, newBox); - box.addBox(newBox); - } - return; - } - growNodeBox(box, pointStoreView, sibling, getLeftIndex(sibling)); - growNodeBox(box, pointStoreView, sibling, getRightIndex(sibling)); - return; - } - } - - public double probabilityOfCut(int node, float[] point, IPointStoreView pointStoreView, - BoundingBox otherBox) { - int nodeIdx = translate(node); - if (nodeIdx != Integer.MAX_VALUE && rangeSumData[nodeIdx] != 0) { - int base = 2 * nodeIdx * dimensions; - int mid = base + dimensions; - double minsum = 0; - double maxsum = 0; - for (int i = 0; i < dimensions; i++) { - minsum += Math.max(boundingBoxData[base + i] - point[i], 0); - } - for (int i = 0; i < dimensions; i++) { - maxsum += Math.max(point[i] - boundingBoxData[mid + i], 0); - } - double sum = maxsum + minsum; - - if (sum == 0.0) { - return 0.0; - } - return sum / (rangeSumData[nodeIdx] + sum); - } else if (otherBox != null) { - return otherBox.probabilityOfCut(point); - } else { - BoundingBox box = getBox(node); - return box.probabilityOfCut(point); - } - } - protected abstract void decreaseMassOfInternalNode(int node); protected abstract void increaseMassOfInternalNode(int node); - protected void manageAncestorsAdd(Stack path, float[] point, IPointStoreView pointStoreview) { + protected void manageInternalNodesPartial(Stack path) { while (!path.isEmpty()) { int index = path.pop()[0]; increaseMassOfInternalNode(index); - if (pointSum != null) { - recomputePointSum(index); - } - if (boundingboxCacheFraction > 0.0) { - checkContainsAndRebuildBox(index, point, pointStoreview); - checkContainsAndAddPoint(index, point); - } - } - } - - protected void manageAncestorsDelete(Stack path, float[] point, IPointStoreView pointStoreview) { - boolean resolved = false; - while (!path.isEmpty()) { - int index = path.pop()[0]; - decreaseMassOfInternalNode(index); - if (pointSum != null) { - recomputePointSum(index); - } - if (boundingboxCacheFraction > 0.0 && !resolved) { - resolved = checkContainsAndRebuildBox(index, point, pointStoreview); - } } } @@ -410,22 +110,8 @@ public Stack getPath(int root, float[] point, boolean verbose) { public abstract void deleteInternalNode(int index); - public int getLeafMass(int index) { - int y = (index - capacity - 1); - Integer value = leafMass.get(y); - if (value != null) { - return value + 1; - } else { - return 1; - } - } - public abstract int getMass(int index); - public int getPointIndex(int index) { - return index - capacity - 1; - } - protected boolean leftOf(float cutValue, int cutDimension, float[] point) { return point[cutDimension] <= cutValue; } @@ -447,88 +133,16 @@ public int getSibling(int node, int parent) { public abstract void replaceParentBySibling(int grandParent, int parent, int node); - public HashMap> getSequenceMap() { - return sequenceMap; - } - public abstract int getCutDimension(int index); public double getCutValue(int index) { return cutValue[index]; } - public double getBoundingboxCacheFraction() { - return boundingboxCacheFraction; - } - - protected void traversePathToLeafAndVisitNodes(float[] point, Visitor visitor, int root, - IPointStoreView pointStoreView, Function projectToTree) { - NodeView currentNodeView = new NodeView(this, pointStoreView, root); - traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, root, 0); - } - protected boolean toLeft(float[] point, int currentNodeOffset) { return point[getCutDimension(currentNodeOffset)] <= cutValue[currentNodeOffset]; } - BoundingBox getLeftBox(int index) { - return getBox(getLeftIndex(index)); - } - - BoundingBox getRightBox(int index) { - return getBox(getRightIndex(index)); - } - - protected void traversePathToLeafAndVisitNodes(float[] point, Visitor visitor, NodeView currentNodeView, - int node, int depthOfNode) { - if (isLeaf(node)) { - currentNodeView.setCurrentNode(node, getPointIndex(node), true); - visitor.acceptLeaf(currentNodeView, depthOfNode); - } else { - checkArgument(isInternal(node), " incomplete state " + node + " " + depthOfNode); - if (toLeft(point, node)) { - traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, getLeftIndex(node), depthOfNode + 1); - currentNodeView.updateToParent(node, getRightIndex(node), !visitor.isConverged()); - } else { - traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, getRightIndex(node), depthOfNode + 1); - currentNodeView.updateToParent(node, getLeftIndex(node), !visitor.isConverged()); - } - visitor.accept(currentNodeView, depthOfNode); - } - } - - protected void traverseTreeMulti(float[] point, MultiVisitor visitor, int root, - IPointStoreView pointStoreView, Function liftToTree) { - NodeView currentNodeView = new NodeView(this, pointStoreView, root); - traverseTreeMulti(point, visitor, currentNodeView, root, 0); - } - - protected void traverseTreeMulti(float[] point, MultiVisitor visitor, NodeView currentNodeView, int node, - int depthOfNode) { - if (isLeaf(node)) { - currentNodeView.setCurrentNode(node, getPointIndex(node), false); - visitor.acceptLeaf(currentNodeView, depthOfNode); - } else { - checkArgument(isInternal(node), " incomplete state"); - currentNodeView.setCurrentNodeOnly(node); - if (visitor.trigger(currentNodeView)) { - traverseTreeMulti(point, visitor, currentNodeView, getLeftIndex(node), depthOfNode + 1); - MultiVisitor newVisitor = visitor.newCopy(); - currentNodeView.setCurrentNodeOnly(getRightIndex(node)); - traverseTreeMulti(point, newVisitor, currentNodeView, getRightIndex(node), depthOfNode + 1); - currentNodeView.updateToParent(node, getLeftIndex(node), false); - visitor.combine(newVisitor); - } else if (toLeft(point, node)) { - traverseTreeMulti(point, visitor, currentNodeView, getLeftIndex(node), depthOfNode + 1); - currentNodeView.updateToParent(node, getRightIndex(node), false); - } else { - traverseTreeMulti(point, visitor, currentNodeView, getRightIndex(node), depthOfNode + 1); - currentNodeView.updateToParent(node, getLeftIndex(node), false); - } - visitor.accept(currentNodeView, depthOfNode); - } - } - public abstract int[] getCutDimension(); public abstract int[] getRightIndex(); @@ -552,24 +166,14 @@ public int size() { */ public static class Builder> { - protected int dimensions; protected int capacity; protected int[] leftIndex; protected int[] rightIndex; protected int[] cutDimension; protected float[] cutValues; - protected int root; - protected double boundingBoxCacheFraction; - protected boolean centerOfMassEnabled; - protected boolean storeSequencesEnabled; protected boolean storeParent = DEFAULT_STORE_PARENT; - protected IPointStoreView pointStoreView; - - // dimension of the points being stored - public T dimensions(int dimensions) { - this.dimensions = dimensions; - return (T) this; - } + protected int dimension; + protected int root; // maximum number of points in the store public T capacity(int capacity) { @@ -577,6 +181,11 @@ public T capacity(int capacity) { return (T) this; } + public T dimension(int dimension) { + this.dimension = dimension; + return (T) this; + } + public T useRoot(int root) { this.root = root; return (T) this; @@ -602,33 +211,12 @@ public T cutValues(float[] cutValues) { return (T) this; } - public T pointStoreView(IPointStoreView pointStoreView) { - this.pointStoreView = pointStoreView; - return (T) this; - } - - public T boundingBoxCacheFraction(double boundingBoxCacheFraction) { - this.boundingBoxCacheFraction = boundingBoxCacheFraction; - return (T) this; - } - - public T centerOfMassEnabled(boolean centerOfMassEnabled) { - this.centerOfMassEnabled = centerOfMassEnabled; - return (T) this; - } - public T storeParent(boolean storeParent) { this.storeParent = storeParent; return (T) this; } - public T storeSequencesEnabled(boolean storeSequencesEnabled) { - this.storeSequencesEnabled = storeSequencesEnabled; - return (T) this; - } - public AbstractNodeStore build() { - checkArgument(pointStoreView != null, " a point store view is required "); if (leftIndex == null) { checkArgument(rightIndex == null, " incorrect option of right indices"); checkArgument(cutValues == null, "incorrect option of cut values"); @@ -640,9 +228,9 @@ public AbstractNodeStore build() { } // capacity is numbner of internal nodes - if (capacity < 256 && pointStoreView.getDimensions() <= 256) { + if (capacity < 256 && dimension <= 256) { return new NodeStoreSmall(this); - } else if (capacity < Character.MAX_VALUE && pointStoreView.getDimensions() <= Character.MAX_VALUE) { + } else if (capacity < Character.MAX_VALUE && dimension <= Character.MAX_VALUE) { return new NodeStoreMedium(this); } else { return new NodeStoreLarge(this); diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java index 14b7e277..97266385 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java @@ -51,12 +51,11 @@ public void makeTree(List list, int seed) { int[] cutDimension = new int[numberOfLeaves - 1]; float[] cutValue = new float[numberOfLeaves - 1]; root = makeTreeInt(list, seed, 0, this.gVecBuild, leftIndex, rightIndex, cutDimension, cutValue); - nodeStore = AbstractNodeStore.builder().storeSequencesEnabled(false).pointStoreView(pointStoreView) - .dimensions(dimension).capacity(numberOfLeaves - 1).leftIndex(leftIndex).rightIndex(rightIndex) - .cutDimension(cutDimension).cutValues(cutValue).build(); + nodeStore = AbstractNodeStore.builder().dimension(dimension).capacity(numberOfLeaves - 1) + .leftIndex(leftIndex).rightIndex(rightIndex).cutDimension(cutDimension).cutValues(cutValue).build(); // the cuts are specififed; now build tree for (int i = 0; i < list.size(); i++) { - addPoint(list.get(i), 0L); + addPointToPartialTree(list.get(i), 0L); } } else { root = Null; diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java index 097786e6..2b6d9c0d 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java @@ -34,10 +34,14 @@ public interface ITree extends ITraversable, IDynamicConf double[] liftFromTree(double[] result); - public int[] projectMissingIndices(int[] list); + int[] projectMissingIndices(int[] list); PointReference addPoint(PointReference point, long sequenceIndex); + void addPointToPartialTree(PointReference point, long sequenceIndex); + + void validateAndReconstruct(); + PointReference deletePoint(PointReference point, long sequenceIndex); default long getRandomSeed() { diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java index d6a40a03..2e3c386b 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java @@ -90,7 +90,7 @@ public NodeStoreLarge(AbstractNodeStore.Builder builder) { @Override public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, int pointIndex, int childIndex, - int cutDimension, float cutValue, BoundingBox box) { + int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box) { int index = freeNodeManager.takeIndex(); this.cutValue[index] = cutValue; this.cutDimension[index] = (byte) cutDimension; @@ -101,8 +101,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i this.rightIndex[index] = (pointIndex + capacity + 1); this.leftIndex[index] = childIndex; } - this.mass[index] = (getMass(childIndex) + 1) % (capacity + 1); - addLeaf(pointIndex, sequenceIndex); + this.mass[index] = (((childMassIfLeaf > 0) ? childMassIfLeaf : getMass(childIndex)) + 1) % (capacity + 1); + int parentIndex = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0]; if (this.parentIndex != null) { this.parentIndex[index] = parentIndex; @@ -110,13 +110,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i this.parentIndex[childIndex] = (index); } } - addBox(index, point, box); if (parentIndex != Null) { spliceEdge(parentIndex, childIndex, index); - manageAncestorsAdd(pathToRoot, point, pointStoreView); - } - if (pointSum != null) { - recomputePointSum(index); } return index; } @@ -151,18 +146,20 @@ public void deleteInternalNode(int index) { if (parentIndex != null) { parentIndex[index] = capacity; } - if (pointSum != null) { - invalidatePointSum(index); - } - int idx = translate(index); - if (idx != Integer.MAX_VALUE) { - rangeSumData[idx] = 0.0; - } freeNodeManager.releaseIndex(index); } public int getMass(int index) { - return (isLeaf(index)) ? getLeafMass(index) : mass[index] != 0 ? mass[index] : (capacity + 1); + return mass[index] != 0 ? mass[index] : (capacity + 1); + } + + @Override + public void assignInPartialTree(int node, float[] point, int childReference) { + if (leftOf(node, point)) { + leftIndex[node] = childReference; + } else { + rightIndex[node] = childReference; + } } public void spliceEdge(int parent, int node, int newNode) { @@ -205,15 +202,4 @@ public int[] getRightIndex() { return Arrays.copyOf(rightIndex, rightIndex.length); } - @Override - public void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex) { - int node = pathToRoot.lastElement()[0]; - if (leftOf(node, point)) { - leftIndex[node] = (pointIndex + capacity + 1); - } else { - rightIndex[node] = (pointIndex + capacity + 1); - } - manageAncestorsAdd(pathToRoot, point, pointStoreView); - } - } diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java index b2a4ae40..b2556416 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java @@ -93,7 +93,7 @@ public NodeStoreMedium(AbstractNodeStore.Builder builder) { @Override public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, int pointIndex, int childIndex, - int cutDimension, float cutValue, BoundingBox box) { + int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box) { int index = freeNodeManager.takeIndex(); this.cutValue[index] = cutValue; this.cutDimension[index] = (char) cutDimension; @@ -104,8 +104,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i this.rightIndex[index] = (pointIndex + capacity + 1); this.leftIndex[index] = childIndex; } - this.mass[index] = (char) ((getMass(childIndex) + 1) % (capacity + 1)); - addLeaf(pointIndex, sequenceIndex); + this.mass[index] = (char) ((((childMassIfLeaf > 0) ? childMassIfLeaf : getMass(childIndex)) + 1) + % (capacity + 1)); int parentIndex = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0]; if (this.parentIndex != null) { this.parentIndex[index] = (char) parentIndex; @@ -113,17 +113,21 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i this.parentIndex[childIndex] = (char) (index); } } - addBox(index, point, box); if (parentIndex != Null) { spliceEdge(parentIndex, childIndex, index); - manageAncestorsAdd(pathToRoot, point, pointStoreView); - } - if (pointSum != null) { - recomputePointSum(index); } return index; } + @Override + public void assignInPartialTree(int node, float[] point, int childReference) { + if (leftOf(node, point)) { + leftIndex[node] = childReference; + } else { + rightIndex[node] = childReference; + } + } + public int getLeftIndex(int index) { return leftIndex[index]; } @@ -155,18 +159,11 @@ public void deleteInternalNode(int index) { if (parentIndex != null) { parentIndex[index] = (char) capacity; } - if (pointSum != null) { - invalidatePointSum(index); - } - int idx = translate(index); - if (idx != Integer.MAX_VALUE) { - rangeSumData[idx] = 0.0; - } freeNodeManager.releaseIndex(index); } public int getMass(int index) { - return (isLeaf(index)) ? getLeafMass(index) : mass[index] != 0 ? mass[index] : (capacity + 1); + return mass[index] != 0 ? mass[index] : (capacity + 1); } public void spliceEdge(int parent, int node, int newNode) { @@ -209,14 +206,4 @@ public int[] getRightIndex() { return Arrays.copyOf(rightIndex, rightIndex.length); } - @Override - public void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex) { - int node = pathToRoot.lastElement()[0]; - if (leftOf(node, point)) { - leftIndex[node] = (pointIndex + capacity + 1); - } else { - rightIndex[node] = (pointIndex + capacity + 1); - } - manageAncestorsAdd(pathToRoot, point, pointStoreView); - } } diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java index 39f7ce07..c0f8d5c0 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java @@ -98,7 +98,7 @@ public NodeStoreSmall(AbstractNodeStore.Builder builder) { @Override public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, int pointIndex, int childIndex, - int cutDimension, float cutValue, BoundingBox box) { + int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box) { int index = freeNodeManager.takeIndex(); this.cutValue[index] = cutValue; this.cutDimension[index] = (byte) cutDimension; @@ -109,8 +109,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i this.rightIndex[index] = (char) (pointIndex + capacity + 1); this.leftIndex[index] = (char) childIndex; } - this.mass[index] = (byte) (((byte) getMass(childIndex) + 1) % (capacity + 1)); - addLeaf(pointIndex, sequenceIndex); + this.mass[index] = (byte) ((((childMassIfLeaf > 0) ? childMassIfLeaf : getMass(childIndex)) + 1) + % (capacity + 1)); int parentIndex = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0]; if (this.parentIndex != null) { this.parentIndex[index] = (byte) parentIndex; @@ -118,26 +118,19 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i this.parentIndex[childIndex] = (byte) (index); } } - addBox(index, point, box); if (parentIndex != Null) { spliceEdge(parentIndex, childIndex, index); - manageAncestorsAdd(pathToRoot, point, pointStoreView); - } - if (pointSum != null) { - recomputePointSum(index); } return index; } @Override - public void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex) { - int node = pathToRoot.lastElement()[0]; + public void assignInPartialTree(int node, float[] point, int childReference) { if (leftOf(node, point)) { - leftIndex[node] = (char) (pointIndex + capacity + 1); + leftIndex[node] = (char) childReference; } else { - rightIndex[node] = (char) (pointIndex + capacity + 1); + rightIndex[node] = (char) childReference; } - manageAncestorsAdd(pathToRoot, point, pointStoreView); } public int getLeftIndex(int index) { @@ -171,18 +164,11 @@ public void deleteInternalNode(int index) { if (parentIndex != null) { parentIndex[index] = (byte) capacity; } - if (pointSum != null) { - invalidatePointSum(index); - } - int idx = translate(index); - if (idx != Integer.MAX_VALUE) { - rangeSumData[idx] = 0.0; - } freeNodeManager.releaseIndex(index); } public int getMass(int index) { - return (isLeaf(index)) ? getLeafMass(index) : mass[index] != 0 ? (mass[index] & 0xff) : (capacity + 1); + return mass[index] != 0 ? (mass[index] & 0xff) : (capacity + 1); } public void spliceEdge(int parent, int node, int newNode) { diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java index 7481c7a9..6f9ca01a 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java @@ -22,40 +22,42 @@ import com.amazon.randomcutforest.store.IPointStoreView; public class NodeView implements INodeView { - AbstractNodeStore nodeStore; + + public static double SWITCH_FRACTION = 0.499; + + RandomCutTree tree; int currentNodeOffset; float[] leafPoint; - IPointStoreView pointStoreView; BoundingBox currentBox; - public NodeView(AbstractNodeStore nodeStore, IPointStoreView pointStoreView, int root) { + public NodeView(RandomCutTree tree, IPointStoreView pointStoreView, int root) { this.currentNodeOffset = root; - this.pointStoreView = pointStoreView; - this.nodeStore = nodeStore; + this.tree = tree; } public int getMass() { - return nodeStore.getMass(currentNodeOffset); + return tree.getMass(currentNodeOffset); } public IBoundingBoxView getBoundingBox() { if (currentBox == null) { - return nodeStore.getBox(currentNodeOffset); + return tree.getBox(currentNodeOffset); } return currentBox; } public IBoundingBoxView getSiblingBoundingBox(float[] point) { - return (toLeft(point)) ? nodeStore.getRightBox(currentNodeOffset) : nodeStore.getLeftBox(currentNodeOffset); + return (toLeft(point)) ? tree.getBox(tree.nodeStore.getRightIndex(currentNodeOffset)) + : tree.getBox(tree.nodeStore.getLeftIndex(currentNodeOffset)); } public int getCutDimension() { - return nodeStore.getCutDimension(currentNodeOffset); + return tree.nodeStore.getCutDimension(currentNodeOffset); } @Override public double getCutValue() { - return nodeStore.getCutValue(currentNodeOffset); + return tree.nodeStore.getCutValue(currentNodeOffset); } public float[] getLeafPoint() { @@ -64,8 +66,8 @@ public float[] getLeafPoint() { public HashMap getSequenceIndexes() { checkState(isLeaf(), "can only be invoked for a leaf"); - if (nodeStore.storeSequenceIndexesEnabled) { - return nodeStore.sequenceMap.get(nodeStore.getPointIndex(currentNodeOffset)); + if (tree.storeSequenceIndexesEnabled) { + return tree.getSequenceMap(tree.getPointIndex(currentNodeOffset)); } else { return new HashMap<>(); } @@ -73,23 +75,22 @@ public HashMap getSequenceIndexes() { @Override public double probailityOfSeparation(float[] point) { - return nodeStore.probabilityOfCut(currentNodeOffset, point, pointStoreView, currentBox); + return tree.probabilityOfCut(currentNodeOffset, point, currentBox); } @Override public int getLeafPointIndex() { - checkState(isLeaf(), "cannot invoke 'getLeafPointIndex' from a non-leaf node"); - return nodeStore.getPointIndex(currentNodeOffset); + return tree.getPointIndex(currentNodeOffset); } public boolean isLeaf() { - return nodeStore.isLeaf(currentNodeOffset); + return tree.nodeStore.isLeaf(currentNodeOffset); } protected void setCurrentNode(int newNode, int index, boolean setBox) { currentNodeOffset = newNode; - leafPoint = pointStoreView.get(index); - if (setBox && nodeStore.boundingboxCacheFraction < AbstractNodeStore.SWITCH_FRACTION) { + leafPoint = tree.pointStoreView.get(index); + if (setBox && tree.boundingBoxCacheFraction < SWITCH_FRACTION) { currentBox = new BoundingBox(leafPoint, leafPoint); } } @@ -100,14 +101,15 @@ protected void setCurrentNodeOnly(int newNode) { public void updateToParent(int parent, int currentSibling, boolean updateBox) { currentNodeOffset = parent; - if (updateBox && nodeStore.boundingboxCacheFraction < AbstractNodeStore.SWITCH_FRACTION) { - nodeStore.growNodeBox(currentBox, pointStoreView, parent, currentSibling); + if (updateBox && tree.boundingBoxCacheFraction < SWITCH_FRACTION) { + tree.growNodeBox(currentBox, tree.pointStoreView, parent, currentSibling); } } // this function exists for matching the behavior of RCF2.0 and will be replaced // this function explicitly uses the encoding of the new nodestore protected boolean toLeft(float[] point) { - return point[nodeStore.getCutDimension(currentNodeOffset)] <= nodeStore.getCutValue(currentNodeOffset); + return point[tree.nodeStore.getCutDimension(currentNodeOffset)] <= tree.nodeStore + .getCutValue(currentNodeOffset); } } diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java index dd9789be..540db178 100644 --- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java +++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java @@ -20,7 +20,10 @@ import static com.amazon.randomcutforest.CommonUtils.checkState; import static com.amazon.randomcutforest.tree.AbstractNodeStore.Null; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; import java.util.Optional; import java.util.Random; import java.util.Stack; @@ -67,6 +70,11 @@ public class RandomCutTree implements ITree { protected double boundingBoxCacheFraction; protected int outputAfter; protected int dimension; + protected final HashMap leafMass; + protected double[] rangeSumData; + protected float[] boundingBoxData; + protected float[] pointSum; + protected HashMap> sequenceMap; protected RandomCutTree(Builder builder) { pointStoreView = builder.pointStoreView; @@ -76,23 +84,28 @@ protected RandomCutTree(Builder builder) { outputAfter = builder.outputAfter.orElse(numberOfLeaves / 4); dimension = (builder.dimension != 0) ? builder.dimension : pointStoreView.getDimensions(); nodeStore = (builder.nodeStore != null) ? builder.nodeStore - : AbstractNodeStore.builder().capacity(numberOfLeaves - 1).dimensions(dimension) - .boundingBoxCacheFraction(builder.boundingBoxCacheFraction).pointStoreView(pointStoreView) - .centerOfMassEnabled(builder.centerOfMassEnabled) - .storeSequencesEnabled(builder.storeSequenceIndexesEnabled).build(); - // note the number of internal nodes is one less than sampleSize - // the RCF V2_0 states used this notion + : AbstractNodeStore.builder().capacity(numberOfLeaves - 1).dimension(dimension).build(); this.boundingBoxCacheFraction = builder.boundingBoxCacheFraction; this.storeSequenceIndexesEnabled = builder.storeSequenceIndexesEnabled; this.centerOfMassEnabled = builder.centerOfMassEnabled; this.root = builder.root; + leafMass = new HashMap<>(); + int cache_limit = (int) Math.floor(boundingBoxCacheFraction * (numberOfLeaves - 1)); + rangeSumData = new double[cache_limit]; + boundingBoxData = new float[2 * dimension * cache_limit]; + if (this.centerOfMassEnabled) { + pointSum = new float[(numberOfLeaves - 1) * dimension]; + } + if (this.storeSequenceIndexesEnabled) { + sequenceMap = new HashMap<>(); + } } @Override public void setConfig(String name, T value, Class clazz) { if (Config.BOUNDING_BOX_CACHE_FRACTION.equals(name)) { checkArgument(Double.class.isAssignableFrom(clazz), - String.format("Setting '%s' must be a double value", name)); + () -> String.format("Setting '%s' must be a double value", name)); setBoundingBoxCacheFraction((Double) value); } else { throw new IllegalArgumentException("Unsupported configuration setting: " + name); @@ -104,7 +117,7 @@ public T getConfig(String name, Class clazz) { checkNotNull(clazz, "clazz must not be null"); if (Config.BOUNDING_BOX_CACHE_FRACTION.equals(name)) { checkArgument(clazz.isAssignableFrom(Double.class), - String.format("Setting '%s' must be a double value", name)); + () -> String.format("Setting '%s' must be a double value", name)); return clazz.cast(boundingBoxCacheFraction); } else { throw new IllegalArgumentException("Unsupported configuration setting: " + name); @@ -118,7 +131,7 @@ public T getConfig(String name, Class clazz) { public void setBoundingBoxCacheFraction(double fraction) { checkArgument(0 <= fraction && fraction <= 1, "incorrect parameter"); boundingBoxCacheFraction = fraction; - nodeStore.resizeCache(fraction); + resizeCache(fraction); } /** @@ -134,7 +147,7 @@ public void setBoundingBoxCacheFraction(double fraction) { * @param box A bounding box that we want to find a random cut for. * @return A new Cut corresponding to a random cut in the bounding box. */ - protected static Cut randomCut(double factor, float[] point, BoundingBox box) { + protected Cut randomCut(double factor, float[] point, BoundingBox box) { double range = 0.0; for (int i = 0; i < point.length; i++) { @@ -148,7 +161,7 @@ protected static Cut randomCut(double factor, float[] point, BoundingBox box) { range += maxValue - minValue; } - checkArgument(range > 0, " the union is a single point " + Arrays.toString(point) + checkArgument(range > 0, () -> " the union is a single point " + Arrays.toString(point) + "or the box is inappropriate, box" + box.toString() + "factor =" + factor); double breakPoint = factor * range; @@ -223,40 +236,42 @@ protected static Cut randomCut(double factor, float[] point, BoundingBox box) { } + /** + * the following function adds a point to the tree + * + * @param pointIndex the number corresponding to the point + * @param sequenceIndex sequence index of the point + * @return the value of the point index where the point was added; this is + * pointIndex if there are no duplicates; otherwise it is the value of + * the point being duplicated. + */ public Integer addPoint(Integer pointIndex, long sequenceIndex) { if (root == Null) { - root = nodeStore.addLeaf(pointIndex, sequenceIndex); + root = convertToLeaf(pointIndex); + addLeaf(pointIndex, sequenceIndex); return pointIndex; } else { - float[] point = pointStoreView.get(pointIndex); + float[] point = projectToTree(pointStoreView.get(pointIndex)); + checkArgument(point.length == dimension, () -> " mismatch in dimensions for " + pointIndex); Stack pathToRoot = nodeStore.getPath(root, point, false); int[] first = pathToRoot.pop(); int leafNode = first[0]; int savedParent = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0]; - if (!nodeStore.isLeaf(leafNode)) { - // this corresponds to rebuilding a partial tree - if (savedParent == Null) { - root = pointIndex + numberOfLeaves; // note this capacity is nodestore.capacity + 1 - } else { - nodeStore.addToPartialTree(pathToRoot, point, pointIndex); - nodeStore.manageAncestorsAdd(pathToRoot, point, pointStoreView); - nodeStore.addLeaf(pointIndex, sequenceIndex); - } - return pointIndex; - } int leafSavedSibling = first[1]; int sibling = leafSavedSibling; - int leafPointIndex = nodeStore.getPointIndex(leafNode); - float[] oldPoint = pointStoreView.get(leafPointIndex); + int leafPointIndex = getPointIndex(leafNode); + float[] oldPoint = projectToTree(pointStoreView.get(leafPointIndex)); + checkArgument(oldPoint.length == dimension, () -> " mismatch in dimensions for " + pointIndex); + Stack parentPath = new Stack<>(); if (Arrays.equals(point, oldPoint)) { - nodeStore.increaseLeafMass(leafNode); + increaseLeafMass(leafNode); checkArgument(!nodeStore.freeNodeManager.isEmpty(), "incorrect/impossible state"); - nodeStore.manageAncestorsAdd(pathToRoot, point, pointStoreView); - nodeStore.addLeaf(leafPointIndex, sequenceIndex); + manageAncestorsAdd(pathToRoot, point); + addLeaf(leafPointIndex, sequenceIndex); return leafPointIndex; } else { int node = leafNode; @@ -293,14 +308,11 @@ public Integer addPoint(Integer pointIndex, long sequenceIndex) { parentPath.push(new int[] { node, sibling }); } - if (savedDim == Integer.MAX_VALUE) { - randomCut(factor, point, currentBox); - throw new IllegalStateException(" cut failed "); - } + checkArgument(savedDim != Integer.MAX_VALUE, () -> " cut failed at index " + pointIndex); if (currentBox.contains(point) || parent == Null) { break; } else { - nodeStore.growNodeBox(currentBox, pointStoreView, parent, sibling); + growNodeBox(currentBox, pointStoreView, parent, sibling); int[] next = pathToRoot.pop(); node = next[0]; sibling = next[1]; @@ -318,8 +330,15 @@ public Integer addPoint(Integer pointIndex, long sequenceIndex) { assert (pathToRoot.lastElement()[0] == savedParent); } - int mergedNode = nodeStore.addNode(pathToRoot, point, sequenceIndex, pointIndex, savedNode, savedDim, - savedCutValue, savedBox); + int childMassIfLeaf = isLeaf(savedNode) ? getLeafMass(savedNode) : 0; + int mergedNode = nodeStore.addNode(pathToRoot, point, sequenceIndex, pointIndex, savedNode, + childMassIfLeaf, savedDim, savedCutValue, savedBox); + addLeaf(pointIndex, sequenceIndex); + addBox(mergedNode, point, savedBox); + manageAncestorsAdd(pathToRoot, point); + if (pointSum != null) { + recomputePointSum(mergedNode); + } if (savedParent == Null) { root = mergedNode; } @@ -328,25 +347,77 @@ public Integer addPoint(Integer pointIndex, long sequenceIndex) { } } - public Integer deletePoint(Integer pointIndex, long sequenceIndex) { + protected void manageAncestorsAdd(Stack path, float[] point) { + while (!path.isEmpty()) { + int index = path.pop()[0]; + nodeStore.increaseMassOfInternalNode(index); + if (pointSum != null) { + recomputePointSum(index); + } + if (boundingBoxCacheFraction > 0.0) { + checkContainsAndRebuildBox(index, point, pointStoreView); + checkContainsAndAddPoint(index, point); + } + } + } - if (root == Null) { - throw new IllegalStateException(" deleting from an empty tree"); + /** + * the following is the same as in addPoint() except this function is used to + * rebuild the tree structure. This function does not create auxiliary arrays, + * which should be performed using validateAndReconstruct() + * + * @param pointIndex index of point (in point store) + * @param sequenceIndex sequence index (stored in sampler) + */ + public void addPointToPartialTree(Integer pointIndex, long sequenceIndex) { + + checkArgument(root != Null, " a null root is not a partial tree"); + float[] point = projectToTree(pointStoreView.get(pointIndex)); + checkArgument(point.length == dimension, () -> " incorrect projection at index " + pointIndex); + + Stack pathToRoot = nodeStore.getPath(root, point, false); + int[] first = pathToRoot.pop(); + int leafNode = first[0]; + int savedParent = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0]; + if (!nodeStore.isLeaf(leafNode)) { + if (savedParent == Null) { + root = convertToLeaf(pointIndex); + } else { + nodeStore.assignInPartialTree(savedParent, point, convertToLeaf(pointIndex)); + nodeStore.manageInternalNodesPartial(pathToRoot); + addLeaf(pointIndex, sequenceIndex); + } + return; } - float[] point = pointStoreView.get(pointIndex); + int leafPointIndex = getPointIndex(leafNode); + float[] oldPoint = projectToTree(pointStoreView.get(leafPointIndex)); + + checkArgument(oldPoint.length == dimension && Arrays.equals(point, oldPoint), + () -> "incorrect state on adding " + pointIndex); + increaseLeafMass(leafNode); + checkArgument(!nodeStore.freeNodeManager.isEmpty(), "incorrect/impossible state"); + nodeStore.manageInternalNodesPartial(pathToRoot); + addLeaf(leafPointIndex, sequenceIndex); + return; + } + + public Integer deletePoint(Integer pointIndex, long sequenceIndex) { + + checkArgument(root != Null, " deleting from an empty tree"); + float[] point = projectToTree(pointStoreView.get(pointIndex)); + checkArgument(point.length == dimension, () -> " incorrect projection at index " + pointIndex); Stack pathToRoot = nodeStore.getPath(root, point, false); int[] first = pathToRoot.pop(); int leafSavedSibling = first[1]; int leafNode = first[0]; - int leafPointIndex = nodeStore.getPointIndex(leafNode); + int leafPointIndex = getPointIndex(leafNode); - if (leafPointIndex != pointIndex && !pointStoreView.pointEquals(leafPointIndex, point)) { - throw new IllegalStateException(" deleting wrong node " + leafPointIndex + " instead of " + pointIndex); - } else if (storeSequenceIndexesEnabled) { - nodeStore.removeLeaf(leafPointIndex, sequenceIndex); - } + checkArgument(leafPointIndex == pointIndex, + () -> " deleting wrong node " + leafPointIndex + " instead of " + pointIndex); + + removeLeaf(leafPointIndex, sequenceIndex); - if (nodeStore.decreaseLeafMass(leafNode) == 0) { + if (decreaseLeafMass(leafNode) == 0) { if (pathToRoot.size() == 0) { root = Null; } else { @@ -357,16 +428,406 @@ public Integer deletePoint(Integer pointIndex, long sequenceIndex) { } else { int grandParent = pathToRoot.lastElement()[0]; nodeStore.replaceParentBySibling(grandParent, parent, leafNode); - nodeStore.manageAncestorsDelete(pathToRoot, point, pointStoreView); + manageAncestorsDelete(pathToRoot, point); } nodeStore.deleteInternalNode(parent); + if (pointSum != null) { + invalidatePointSum(parent); + } + int idx = translate(parent); + if (idx != Integer.MAX_VALUE) { + rangeSumData[idx] = 0.0; + } } } else { - nodeStore.manageAncestorsDelete(pathToRoot, point, pointStoreView); + manageAncestorsDelete(pathToRoot, point); } return leafPointIndex; } + protected void manageAncestorsDelete(Stack path, float[] point) { + boolean resolved = false; + while (!path.isEmpty()) { + int index = path.pop()[0]; + nodeStore.decreaseMassOfInternalNode(index); + if (pointSum != null) { + recomputePointSum(index); + } + if (boundingBoxCacheFraction > 0.0 && !resolved) { + resolved = checkContainsAndRebuildBox(index, point, pointStoreView); + } + } + } + + //// leaf, nonleaf representations + + public boolean isLeaf(int index) { + // note that numberOfLeaves - 1 corresponds to an unspefied leaf in partial tree + // 0 .. numberOfLeaves - 2 corresponds to internal nodes + return index >= numberOfLeaves; + } + + public boolean isInternal(int index) { + // note that numberOfLeaves - 1 corresponds to an unspefied leaf in partial tree + // 0 .. numberOfLeaves - 2 corresponds to internal nodes + return index < numberOfLeaves - 1; + } + + public int convertToLeaf(int pointIndex) { + return pointIndex + numberOfLeaves; + } + + public int getPointIndex(int index) { + checkArgument(index >= numberOfLeaves, () -> " does not have a point associated " + index); + return index - numberOfLeaves; + } + + public int getLeftChild(int index) { + checkArgument(isInternal(index), () -> "incorrect call to get left Index " + index); + return nodeStore.getLeftIndex(index); + } + + public int getRightChild(int index) { + checkArgument(isInternal(index), () -> "incorrect call to get right child " + index); + return nodeStore.getRightIndex(index); + } + + public int getCutDimension(int index) { + checkArgument(isInternal(index), () -> "incorrect call to get cut dimension " + index); + return nodeStore.getCutDimension(index); + } + + public double getCutValue(int index) { + checkArgument(isInternal(index), () -> "incorrect call to get cut value " + index); + return nodeStore.getCutValue(index); + } + + ///// mass assignments; separating leafs and internal nodes + + protected int getMass(int index) { + return (isLeaf(index)) ? getLeafMass(index) : nodeStore.getMass(index); + } + + protected int getLeafMass(int index) { + int y = (index - numberOfLeaves); + Integer value = leafMass.get(y); + return (value != null) ? value + 1 : 1; + } + + protected void increaseLeafMass(int index) { + int y = (index - numberOfLeaves); + leafMass.merge(y, 1, Integer::sum); + } + + protected int decreaseLeafMass(int index) { + int y = (index - numberOfLeaves); + Integer value = leafMass.remove(y); + if (value != null) { + if (value > 1) { + leafMass.put(y, (value - 1)); + return value; + } else { + return 1; + } + } else { + return 0; + } + } + + @Override + public int getMass() { + return root == Null ? 0 : isLeaf(root) ? getLeafMass(root) : nodeStore.getMass(root); + } + + /////// Bounding box + + public void resizeCache(double fraction) { + if (fraction == 0) { + rangeSumData = null; + boundingBoxData = null; + } else { + int limit = (int) Math.floor(fraction * (numberOfLeaves - 1)); + rangeSumData = (rangeSumData == null) ? new double[limit] : Arrays.copyOf(rangeSumData, limit); + boundingBoxData = (boundingBoxData == null) ? new float[limit * 2 * dimension] + : Arrays.copyOf(boundingBoxData, limit * 2 * dimension); + } + boundingBoxCacheFraction = fraction; + } + + protected int translate(int index) { + if (rangeSumData == null || rangeSumData.length <= index) { + return Integer.MAX_VALUE; + } else { + return index; + } + } + + void copyBoxToData(int idx, BoundingBox box) { + int base = 2 * idx * dimension; + int mid = base + dimension; + System.arraycopy(box.getMinValues(), 0, boundingBoxData, base, dimension); + System.arraycopy(box.getMaxValues(), 0, boundingBoxData, mid, dimension); + rangeSumData[idx] = box.getRangeSum(); + } + + boolean checkContainsAndAddPoint(int index, float[] point) { + int idx = translate(index); + if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) { + int base = 2 * idx * dimension; + int mid = base + dimension; + double rangeSum = 0; + for (int i = 0; i < dimension; i++) { + boundingBoxData[base + i] = Math.min(boundingBoxData[base + i], point[i]); + } + for (int i = 0; i < dimension; i++) { + boundingBoxData[mid + i] = Math.max(boundingBoxData[mid + i], point[i]); + } + for (int i = 0; i < dimension; i++) { + rangeSum += boundingBoxData[mid + i] - boundingBoxData[base + i]; + } + boolean answer = (rangeSumData[idx] == rangeSum); + rangeSumData[idx] = rangeSum; + return answer; + } + return false; + } + + public BoundingBox getBox(int index) { + if (isLeaf(index)) { + float[] point = projectToTree(pointStoreView.get(getPointIndex(index))); + checkArgument(point.length == dimension, () -> "failure in projection at index " + index); + return new BoundingBox(point, point); + } else { + checkArgument(isInternal(index), " incomplete state"); + int idx = translate(index); + if (idx != Integer.MAX_VALUE) { + if (rangeSumData[idx] != 0) { + // return non-trivial boxes + return getBoxFromData(idx); + } else { + BoundingBox box = reconstructBox(index, pointStoreView); + copyBoxToData(idx, box); + return box; + } + } + return reconstructBox(index, pointStoreView); + } + } + + BoundingBox reconstructBox(int index, IPointStoreView pointStoreView) { + BoundingBox mutatedBoundingBox = getBox(nodeStore.getLeftIndex(index)); + growNodeBox(mutatedBoundingBox, pointStoreView, index, nodeStore.getRightIndex(index)); + return mutatedBoundingBox; + } + + boolean checkStrictlyContains(int index, float[] point) { + int idx = translate(index); + if (idx != Integer.MAX_VALUE) { + int base = 2 * idx * dimension; + int mid = base + dimension; + boolean isInside = true; + for (int i = 0; i < dimension && isInside; i++) { + if (point[i] >= boundingBoxData[mid + i] || boundingBoxData[base + i] >= point[i]) { + isInside = false; + } + } + return isInside; + } + return false; + } + + boolean checkContainsAndRebuildBox(int index, float[] point, IPointStoreView pointStoreView) { + int idx = translate(index); + if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) { + if (!checkStrictlyContains(index, point)) { + BoundingBox mutatedBoundingBox = reconstructBox(index, pointStoreView); + copyBoxToData(idx, mutatedBoundingBox); + return false; + } + return true; + } + return false; + } + + BoundingBox getBoxFromData(int idx) { + int base = 2 * idx * dimension; + int mid = base + dimension; + + return new BoundingBox(Arrays.copyOfRange(boundingBoxData, base, base + dimension), + Arrays.copyOfRange(boundingBoxData, mid, mid + dimension)); + } + + void addBox(int index, float[] point, BoundingBox box) { + if (isInternal(index)) { + int idx = translate(index); + if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum + copyBoxToData(idx, box); + checkContainsAndAddPoint(index, point); + } + } + } + + void growNodeBox(BoundingBox box, IPointStoreView pointStoreView, int node, int sibling) { + if (isLeaf(sibling)) { + float[] point = projectToTree(pointStoreView.get(getPointIndex(sibling))); + checkArgument(point.length == dimension, () -> " incorrect projection at index " + sibling); + box.addPoint(point); + } else { + if (!isInternal(sibling)) { + throw new IllegalStateException(" incomplete state " + sibling); + } + int siblingIdx = translate(sibling); + if (siblingIdx != Integer.MAX_VALUE) { + if (rangeSumData[siblingIdx] != 0) { + box.addBox(getBoxFromData(siblingIdx)); + } else { + BoundingBox newBox = getBox(siblingIdx); + copyBoxToData(siblingIdx, newBox); + box.addBox(newBox); + } + return; + } + growNodeBox(box, pointStoreView, sibling, nodeStore.getLeftIndex(sibling)); + growNodeBox(box, pointStoreView, sibling, nodeStore.getRightIndex(sibling)); + return; + } + } + + public double probabilityOfCut(int node, float[] point, BoundingBox otherBox) { + int nodeIdx = translate(node); + if (nodeIdx != Integer.MAX_VALUE && rangeSumData[nodeIdx] != 0) { + int base = 2 * nodeIdx * dimension; + int mid = base + dimension; + double minsum = 0; + double maxsum = 0; + for (int i = 0; i < dimension; i++) { + minsum += Math.max(boundingBoxData[base + i] - point[i], 0); + } + for (int i = 0; i < dimension; i++) { + maxsum += Math.max(point[i] - boundingBoxData[mid + i], 0); + } + double sum = maxsum + minsum; + + if (sum == 0.0) { + return 0.0; + } + return sum / (rangeSumData[nodeIdx] + sum); + } else if (otherBox != null) { + return otherBox.probabilityOfCut(point); + } else { + BoundingBox box = getBox(node); + return box.probabilityOfCut(point); + } + } + + /// additional information at nodes + + public float[] getPointSum(int index) { + checkArgument(centerOfMassEnabled, " enable center of mass"); + if (isLeaf(index)) { + float[] point = projectToTree(pointStoreView.get(getPointIndex(index))); + checkArgument(point.length == dimension, () -> " incorrect projection"); + int mass = getMass(index); + for (int i = 0; i < point.length; i++) { + point[i] *= mass; + } + return point; + } else { + return Arrays.copyOfRange(pointSum, index * dimension, (index + 1) * dimension); + } + } + + public void invalidatePointSum(int index) { + for (int i = 0; i < dimension; i++) { + pointSum[index * dimension + i] = 0; + } + } + + public void recomputePointSum(int index) { + float[] left = getPointSum(nodeStore.getLeftIndex(index)); + float[] right = getPointSum(nodeStore.getRightIndex(index)); + for (int i = 0; i < dimension; i++) { + pointSum[index * dimension + i] = left[i] + right[i]; + } + } + + public HashMap getSequenceMap(int index) { + HashMap hashMap = new HashMap<>(); + List list = getSequenceList(index); + for (Long e : list) { + hashMap.merge(e, 1, Integer::sum); + } + return hashMap; + } + + public List getSequenceList(int index) { + return sequenceMap.get(index); + } + + protected void addLeaf(int pointIndex, long sequenceIndex) { + if (storeSequenceIndexesEnabled) { + List leafList = sequenceMap.remove(pointIndex); + if (leafList == null) { + leafList = new ArrayList<>(1); + } + leafList.add(sequenceIndex); + sequenceMap.put(pointIndex, leafList); + } + } + + public void removeLeaf(int leafPointIndex, long sequenceIndex) { + if (storeSequenceIndexesEnabled) { + List leafList = sequenceMap.remove(leafPointIndex); + checkArgument(leafList != null, " leaf index not found in tree"); + checkArgument(leafList.remove(sequenceIndex), " sequence index not found in leaf"); + if (!leafList.isEmpty()) { + sequenceMap.put(leafPointIndex, leafList); + } + } + } + + //// validations + + public void validateAndReconstruct() { + if (root != Null) { + validateAndReconstruct(root); + } + } + + /** + * This function is supposed to validate the integrity of the tree and rebuild + * internal data structures. At this moment the only internal structure is the + * pointsum. + * + * @param index the node of a tree + * @return a bounding box of the points + */ + public BoundingBox validateAndReconstruct(int index) { + if (isLeaf(index)) { + return getBox(index); + } else { + BoundingBox leftBox = validateAndReconstruct(getLeftChild(index)); + BoundingBox rightBox = validateAndReconstruct(getRightChild(index)); + if (leftBox.maxValues[getCutDimension(index)] > getCutValue(index) + || rightBox.minValues[getCutDimension(index)] <= getCutValue(index)) { + throw new IllegalStateException(" incorrect bounding state at index " + index + " cut value " + + getCutValue(index) + "cut dimension " + getCutDimension(index) + " left Box " + + leftBox.toString() + " right box " + rightBox.toString()); + } + if (centerOfMassEnabled) { + recomputePointSum(index); + } + rightBox.addBox(leftBox); + int idx = translate(index); + if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum + copyBoxToData(idx, rightBox); + } + return rightBox; + } + } + + //// traversals + /** * Starting from the root, traverse the canonical path to a leaf node and visit * the nodes along the path. The canonical path is determined by the input @@ -390,11 +851,31 @@ public Integer deletePoint(Integer pointIndex, long sequenceIndex) { public R traverse(float[] point, IVisitorFactory visitorFactory) { checkState(root != Null, "this tree doesn't contain any nodes"); Visitor visitor = visitorFactory.newVisitor(this, point); - nodeStore.traversePathToLeafAndVisitNodes(projectToTree(point), visitor, root, pointStoreView, - this::liftFromTree); + NodeView currentNodeView = new NodeView(this, pointStoreView, root); + traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, root, 0); return visitorFactory.liftResult(this, visitor.getResult()); } + protected void traversePathToLeafAndVisitNodes(float[] point, Visitor visitor, NodeView currentNodeView, + int node, int depthOfNode) { + if (isLeaf(node)) { + currentNodeView.setCurrentNode(node, getPointIndex(node), true); + visitor.acceptLeaf(currentNodeView, depthOfNode); + } else { + checkArgument(isInternal(node), () -> " incomplete state " + node + " " + depthOfNode); + if (nodeStore.toLeft(point, node)) { + traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, nodeStore.getLeftIndex(node), + depthOfNode + 1); + currentNodeView.updateToParent(node, nodeStore.getRightIndex(node), !visitor.isConverged()); + } else { + traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, nodeStore.getRightIndex(node), + depthOfNode + 1); + currentNodeView.updateToParent(node, nodeStore.getLeftIndex(node), !visitor.isConverged()); + } + visitor.accept(currentNodeView, depthOfNode); + } + } + /** * This is a traversal method which follows the standard traversal path (defined * in {@link #traverse(float[], IVisitorFactory)}) but at Node in checks to see @@ -416,17 +897,35 @@ public R traverseMulti(float[] point, IMultiVisitorFactory visitorFactory checkNotNull(visitorFactory, "visitor must not be null"); checkState(root != Null, "this tree doesn't contain any nodes"); MultiVisitor visitor = visitorFactory.newVisitor(this, point); - nodeStore.traverseTreeMulti(projectToTree(point), visitor, root, pointStoreView, this::liftFromTree); + NodeView currentNodeView = new NodeView(this, pointStoreView, root); + traverseTreeMulti(point, visitor, currentNodeView, root, 0); return visitorFactory.liftResult(this, visitor.getResult()); } - /** - * - * @return the mass of the tree - */ - @Override - public int getMass() { - return root == Null ? 0 : nodeStore.getMass(root); + protected void traverseTreeMulti(float[] point, MultiVisitor visitor, NodeView currentNodeView, int node, + int depthOfNode) { + if (nodeStore.isLeaf(node)) { + currentNodeView.setCurrentNode(node, getPointIndex(node), false); + visitor.acceptLeaf(currentNodeView, depthOfNode); + } else { + checkArgument(nodeStore.isInternal(node), " incomplete state"); + currentNodeView.setCurrentNodeOnly(node); + if (visitor.trigger(currentNodeView)) { + traverseTreeMulti(point, visitor, currentNodeView, nodeStore.getLeftIndex(node), depthOfNode + 1); + MultiVisitor newVisitor = visitor.newCopy(); + currentNodeView.setCurrentNodeOnly(nodeStore.getRightIndex(node)); + traverseTreeMulti(point, newVisitor, currentNodeView, nodeStore.getRightIndex(node), depthOfNode + 1); + currentNodeView.updateToParent(node, nodeStore.getLeftIndex(node), false); + visitor.combine(newVisitor); + } else if (nodeStore.toLeft(point, node)) { + traverseTreeMulti(point, visitor, currentNodeView, nodeStore.getLeftIndex(node), depthOfNode + 1); + currentNodeView.updateToParent(node, nodeStore.getRightIndex(node), false); + } else { + traverseTreeMulti(point, visitor, currentNodeView, nodeStore.getRightIndex(node), depthOfNode + 1); + currentNodeView.updateToParent(node, nodeStore.getLeftIndex(node), false); + } + visitor.accept(currentNodeView, depthOfNode); + } } public int getNumberOfLeaves() { diff --git a/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java b/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java index dd814892..6a1b7915 100644 --- a/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java +++ b/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java @@ -152,4 +152,30 @@ public void testRoundTripForSingleNodeForest() { } } + private static float[] generate(int input) { + return new float[] { (float) (20 * Math.sin(input / 10.0)), (float) (20 * Math.cos(input / 10.0)) }; + } + + @Test + void benchmarkMappers() { + long seed = new Random().nextLong(); + System.out.println(" Seed " + seed); + Random random = new Random(seed); + + RandomCutForest rcf = RandomCutForest.builder().dimensions(2 * 10).shingleSize(10).sampleSize(628) + .internalShinglingEnabled(true).randomSeed(random.nextLong()).build(); + for (int i = 0; i < 10000; i++) { + rcf.update(generate(i)); + } + RandomCutForestMapper mapper = new RandomCutForestMapper(); + mapper.setSaveExecutorContextEnabled(true); + mapper.setSaveTreeStateEnabled(true); + for (int j = 0; j < 1000; j++) { + RandomCutForest newRCF = mapper.toModel(mapper.toState(rcf)); + float[] test = generate(10000 + j); + assertEquals(newRCF.getAnomalyScore(test), rcf.getAnomalyScore(test), 1e-6); + rcf.update(test); + } + } + } diff --git a/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java b/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java index 90bb2564..505da148 100644 --- a/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java +++ b/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java @@ -94,7 +94,7 @@ public void setUp() { assertEquals(pointStoreFloat.add(new float[] { 0, 1 }, 5), 4); assertEquals(pointStoreFloat.add(new float[] { 0, 0 }, 6), 5); - assertThrows(IllegalStateException.class, () -> tree.deletePoint(0, 1)); + assertThrows(IllegalArgumentException.class, () -> tree.deletePoint(0, 1)); tree.addPoint(0, 1); when(rng.nextDouble()).thenReturn(0.625); @@ -118,57 +118,48 @@ public void testInitialTreeState() { int node = tree.getRoot(); // the second double[] is intentional IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 1, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(1)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(1)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); assertThat(tree.getMass(), is(5)); - assertArrayEquals(new double[] { -1, 2 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, -1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1); - - node = tree.nodeStore.getRightIndex(node); + assertArrayEquals(new double[] { -1, 2 }, toDoubleArray(tree.getPointSum(node)), EPSILON); + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1); + + node = tree.getRightChild(node); expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new BoundingBox(new float[] { 1, 1 })); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(0)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(0.5, EPSILON)); - assertThat(tree.nodeStore.getMass(node), is(4)); - assertArrayEquals(new double[] { 0.0, 3.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))), - is(new float[] { 1, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(2L), 1); - - node = tree.nodeStore.getLeftIndex(node); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(0)); + assertThat(tree.getCutValue(node), closeTo(0.5, EPSILON)); + assertThat(tree.getMass(node), is(4)); + assertArrayEquals(new double[] { 0.0, 3.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); + + assertThat(tree.isLeaf(tree.getRightChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 1, 1 })); + assertThat(tree.getMass(tree.getRightChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(2L), 1); + + node = tree.getLeftChild(node); expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 0, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - - assertThat(tree.nodeStore.getCutDimension(node), is(0)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); - assertThat(tree.nodeStore.getMass(node), is(3)); - assertArrayEquals(new double[] { -1.0, 2.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, 0 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(3L), 1); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))), - is(new float[] { 0, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(2)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(4L), 1); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(5L), 1); - assertThrows(IllegalStateException.class, () -> tree.deletePoint(5, 6)); + assertThat(tree.getBox(node), is(expectedBox)); + + assertThat(tree.getCutDimension(node), is(0)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); + assertThat(tree.getMass(node), is(3)); + assertArrayEquals(new double[] { -1.0, 2.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, 0 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(3L), 1); + + assertThat(tree.isLeaf(tree.getRightChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 0, 1 })); + assertThat(tree.getMass(tree.getRightChild(node)), is(2)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(4L), 1); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(5L), 1); + assertThrows(IllegalArgumentException.class, () -> tree.deletePoint(5, 6)); } @Test @@ -180,44 +171,37 @@ public void testDeletePointWithLeafSibling() { int node = tree.getRoot(); IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 1, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(1)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(1)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); assertThat(tree.getMass(), is(4)); - assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); + assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, -1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1); + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1); // sibling node moves up and bounding box recomputed - node = tree.nodeStore.getRightIndex(node); + node = tree.getRightChild(node); expectedBox = new BoundingBox(new float[] { 0, 1 }).getMergedBox(new float[] { 1, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(0)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(0.5, EPSILON)); - assertThat(tree.nodeStore.getMass(node), is(3)); - assertArrayEquals(new double[] { 1.0, 3.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { 0, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(2)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(4L), 1); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(5L), 1); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))), - is(new float[] { 1, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(2L), 1); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(0)); + assertThat(tree.getCutValue(node), closeTo(0.5, EPSILON)); + assertThat(tree.getMass(node), is(3)); + assertArrayEquals(new double[] { 1.0, 3.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); + + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { 0, 1 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(2)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(4L), 1); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(5L), 1); + + assertThat(tree.isLeaf(tree.getRightChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 1, 1 })); + assertThat(tree.getMass(tree.getRightChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(2L), 1); } @Test @@ -228,41 +212,34 @@ public void testDeletePointWithNonLeafSibling() { int node = tree.getRoot(); IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 0, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(1)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(1)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); assertThat(tree.getMass(), is(4)); - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, -1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1); + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1); // sibling node moves up and bounding box stays the same - node = tree.nodeStore.getRightIndex(node); + node = tree.getRightChild(node); expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 0, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(0)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, 0 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(3L), 1); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))), - is(new float[] { 0, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(2)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(4L), 1); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(5L), 1); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(0)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); + + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, 0 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(3L), 1); + + assertThat(tree.isLeaf(tree.getRightChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 0, 1 })); + assertThat(tree.getMass(tree.getRightChild(node)), is(2)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(4L), 1); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(5L), 1); } @Test @@ -273,65 +250,55 @@ public void testDeletePointWithMassGreaterThan1() { int node = tree.getRoot(); IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 1, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(1)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(1)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); assertThat(tree.getMass(), is(4)); - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, -1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1); - assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, -1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1); - - node = tree.nodeStore.getRightIndex(node); + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1); + assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); + + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1); + + node = tree.getRightChild(node); expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 1, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertThat(tree.nodeStore.getCutDimension(node), is(0)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(0.5, EPSILON)); - assertThat(tree.nodeStore.getMass(node), is(3)); + assertThat(tree.getBox(node), is(expectedBox)); + assertThat(tree.getCutDimension(node), is(0)); + assertThat(tree.getCutValue(node), closeTo(0.5, EPSILON)); + assertThat(tree.getMass(node), is(3)); - assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); + assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))), - is(new float[] { 1, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(2L), 1); + assertThat(tree.isLeaf(tree.getRightChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 1, 1 })); + assertThat(tree.getMass(tree.getRightChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(2L), 1); - node = tree.nodeStore.getLeftIndex(node); + node = tree.getLeftChild(node); expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 0, 1 }); - assertThat(tree.nodeStore.getBox(node), is(expectedBox)); - assertEquals(expectedBox.toString(), tree.nodeStore.getBox(node).toString()); - assertThat(tree.nodeStore.getCutDimension(node), is(0)); - assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON)); + assertThat(tree.getBox(node), is(expectedBox)); + assertEquals(expectedBox.toString(), tree.getBox(node).toString()); + assertThat(tree.getCutDimension(node), is(0)); + assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON)); assertThat(tree.getMass(), is(4)); - assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))), - is(new float[] { -1, 0 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(3L), 1); - - assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true)); - assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))), - is(new float[] { 0, 1 })); - assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1)); - assertEquals(tree.nodeStore.getSequenceMap() - .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(5L), 1); + assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON); + + assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, 0 })); + assertThat(tree.getMass(tree.getLeftChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(3L), 1); + + assertThat(tree.isLeaf(tree.getRightChild(node)), is(true)); + assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 0, 1 })); + assertThat(tree.getMass(tree.getRightChild(node)), is(1)); + assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(5L), 1); } @Test @@ -392,6 +359,7 @@ public void testfloat() { System.out.println("rangesum " + box.getRangeSum()); double factor = 1.0 - 1e-16; System.out.println(factor); - Cut cut = RandomCutTree.randomCut(factor, possible, box); + RandomCutTree tree = RandomCutTree.builder().dimension(trials).build(); + Cut cut = tree.randomCut(factor, possible, box); } } diff --git a/Java/examples/pom.xml b/Java/examples/pom.xml index 94058604..96766042 100644 --- a/Java/examples/pom.xml +++ b/Java/examples/pom.xml @@ -7,7 +7,7 @@ software.amazon.randomcutforest randomcutforest-parent - 3.5.1 + 3.6.0 randomcutforest-examples diff --git a/Java/parkservices/pom.xml b/Java/parkservices/pom.xml index 85f590b7..6977c6ef 100644 --- a/Java/parkservices/pom.xml +++ b/Java/parkservices/pom.xml @@ -6,7 +6,7 @@ software.amazon.randomcutforest randomcutforest-parent - 3.5.1 + 3.6.0 randomcutforest-parkservices diff --git a/Java/pom.xml b/Java/pom.xml index bad56e22..5540e139 100644 --- a/Java/pom.xml +++ b/Java/pom.xml @@ -4,7 +4,7 @@ software.amazon.randomcutforest randomcutforest-parent - 3.5.1 + 3.6.0 pom software.amazon.randomcutforest:randomcutforest diff --git a/Java/serialization/pom.xml b/Java/serialization/pom.xml index 20c37c99..1f2a33db 100644 --- a/Java/serialization/pom.xml +++ b/Java/serialization/pom.xml @@ -7,7 +7,7 @@ software.amazon.randomcutforest randomcutforest-parent - 3.5.1 + 3.6.0 randomcutforest-serialization diff --git a/Java/testutils/pom.xml b/Java/testutils/pom.xml index 467f925d..9d464baf 100644 --- a/Java/testutils/pom.xml +++ b/Java/testutils/pom.xml @@ -4,7 +4,7 @@ randomcutforest-parent software.amazon.randomcutforest - 3.5.1 + 3.6.0 randomcutforest-testutils