From e0ddcbf4bfaab8d638eb2cec234d09491243c415 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 2 Dec 2021 17:33:59 -0500 Subject: [PATCH 1/5] Moving the options from TrainTest so it uses HdbscanOptions. --- .../clustering/hdbscan/HdbscanOptions.java | 18 ++++++---- .../tribuo/clustering/hdbscan/TrainTest.java | 33 ++++--------------- spotbugs-exclude.xml | 2 +- 3 files changed, 18 insertions(+), 35 deletions(-) diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java index 69230fb44..244a5a58c 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java @@ -25,32 +25,36 @@ /** * OLCUT {@link Options} for the HDBSCAN* implementation. */ -final public class HdbscanOptions implements Options { +public final class HdbscanOptions implements Options { private static final Logger logger = Logger.getLogger(HdbscanOptions.class.getName()); + @Override + public String getOptionsDescription() { + return "Options for configuring a HdbscanTrainer."; + } /** - * The minimum number of points required to form a cluster. + * The minimum number of points required to form a cluster. Defaults to 5. */ - @Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster. Defaults to 5.") + @Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster.") public int minClusterSize = 5; /** * Distance function in HDBSCAN*. Defaults to EUCLIDEAN. */ - @Option(longName = "distance-function", usage = "Distance function to use for various distance calculations. Defaults to EUCLIDEAN.") + @Option(longName = "distance-function", usage = "Distance function to use for various distance calculations.") public Distance distanceType = Distance.EUCLIDEAN; /** - * The number of nearest-neighbors to use in the initial density approximation. + * The number of nearest-neighbors to use in the initial density approximation. Defaults to 5. */ @Option(longName = "k-nearest-neighbors", usage = "The number of nearest-neighbors to use in the initial density approximation. " + - "The value includes the point itself. Defaults to 5.") + "The value includes the point itself.") public int k = 5; /** * Number of threads to use for training the hdbscan model. Defaults to 2. */ - @Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model. Defaults to 2.") + @Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model.") public int numThreads = 2; /** diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/TrainTest.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/TrainTest.java index 94d100e75..adfb4175e 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/TrainTest.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/TrainTest.java @@ -17,7 +17,6 @@ package org.tribuo.clustering.hdbscan; import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; -import com.oracle.labs.mlrg.olcut.config.Option; import com.oracle.labs.mlrg.olcut.config.Options; import com.oracle.labs.mlrg.olcut.config.UsageException; import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; @@ -35,14 +34,14 @@ /** * Build and run a HDBSCAN* clustering model for a standard dataset. */ -final public class TrainTest { +public final class TrainTest { private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); /** * Options for the HDBSCAN* CLI. */ - public static class HdbscanOptions implements Options { + public static class HdbscanCLIOptions implements Options { @Override public String getOptionsDescription() { return "Trains and evaluates a HDBSCAN* model on the specified dataset."; @@ -54,29 +53,9 @@ public String getOptionsDescription() { public DataOptions general; /** - * The minimum number of points required to form a cluster. + * The HDBSCAN options */ - @Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster. Defaults to 5.") - public int minClusterSize = 5; - - /** - * Distance function in HDBSCAN*. Defaults to EUCLIDEAN. - */ - @Option(longName = "distance-function", usage = "Distance function to use for various distance calculations. Defaults to EUCLIDEAN.") - public HdbscanTrainer.Distance distanceType = HdbscanTrainer.Distance.EUCLIDEAN; - - /** - * The number of nearest-neighbors to use in the initial density approximation. - */ - @Option(longName = "k-nearest-neighbors", usage = "The number of nearest-neighbors to use in the initial density approximation. " + - "The value includes the point itself. Defaults to 5.") - public int k = 5; - - /** - * Number of threads to use for training the hdbscan model. Defaults to 2. - */ - @Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model. Defaults to 2.") - public int numThreads = 2; + public HdbscanOptions hdbscanOptions; } /** @@ -89,7 +68,7 @@ public static void main(String[] args) throws IOException { // Use the labs format logging. LabsLogFormatter.setAllLogFormatters(); - HdbscanOptions o = new HdbscanOptions(); + HdbscanCLIOptions o = new HdbscanCLIOptions(); ConfigurationManager cm; try { cm = new ConfigurationManager(args,o); @@ -108,7 +87,7 @@ public static void main(String[] args) throws IOException { Pair,Dataset> data = o.general.load(factory); Dataset train = data.getA(); - HdbscanTrainer trainer = new HdbscanTrainer(o.minClusterSize, o.distanceType, o.k, o.numThreads); + HdbscanTrainer trainer = o.hdbscanOptions.getTrainer(); Model model = trainer.train(train); logger.info("Finished training model"); ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train); diff --git a/spotbugs-exclude.xml b/spotbugs-exclude.xml index b49d2614d..d8eaf9ede 100644 --- a/spotbugs-exclude.xml +++ b/spotbugs-exclude.xml @@ -40,7 +40,7 @@ positive warnings. - + From 62df72bf1e6813e01f9b93a2800eb9af6828d36e Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 2 Dec 2021 17:34:45 -0500 Subject: [PATCH 2/5] Suppressing the unchecked array warning in ExtendedMinimumSpanningTree. --- .../hdbscan/ExtendedMinimumSpanningTree.java | 11 +++++++---- .../org/tribuo/clustering/hdbscan/HdbscanModel.java | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/ExtendedMinimumSpanningTree.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/ExtendedMinimumSpanningTree.java index 0fe2580d8..91b1a703a 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/ExtendedMinimumSpanningTree.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/ExtendedMinimumSpanningTree.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.List; /** * An Extended Minimum Spanning Tree graph. Includes the functionality to sort the edge weights in ascending order. @@ -27,7 +28,7 @@ final class ExtendedMinimumSpanningTree { private final EMSTTriple[] emstTriples; - private final ArrayList[] edges; + private final List[] edges; /** * Constructs an ExtendedMinimumSpanningTree, including creating an edge list for each vertex from the @@ -40,8 +41,10 @@ final class ExtendedMinimumSpanningTree { */ ExtendedMinimumSpanningTree(int numVertices, int[] firstVertices, int[] secondVertices, double[] edgeWeights) { this.numVertices = numVertices; - - this.edges = new ArrayList[numVertices]; + // Only integer arraylists are inserted into this array, and it's not accessible outside. + @SuppressWarnings("unchecked") + List[] edgeTmp = (List[]) new ArrayList[numVertices]; + this.edges = edgeTmp; for (int i = 0; i < this.edges.length; i++) { this.edges[i] = new ArrayList<>(1 + edgeWeights.length / numVertices); } @@ -82,7 +85,7 @@ public double getEdgeWeightAtIndex(int index) { return this.emstTriples[index].edgeWeight; } - public ArrayList getEdgeListForVertex(int vertex) { + public List getEdgeListForVertex(int vertex) { return this.edges[vertex]; } diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java index 5342bf72c..381db95b1 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java @@ -42,7 +42,7 @@ * current clustering. The model is not updated with the new data. This is a novel prediction technique which * leverages the computed cluster exemplars from the HDBSCAN* algorithm. */ -final public class HdbscanModel extends Model { +public final class HdbscanModel extends Model { private static final long serialVersionUID = 1L; private final List clusterLabels; From ed4c02c130cff94881ab240f5e7d10286278c857 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 2 Dec 2021 17:35:51 -0500 Subject: [PATCH 3/5] Adding setInvocationCount to HdbscanTrainer, and clarifying the semantics of Trainer.train(Dataset, Map, int) by allowing negative -1 as an argument. --- .../clustering/hdbscan/HdbscanTrainer.java | 13 +++++++-- .../clustering/hdbscan/TestHdbscan.java | 29 +++++++++++++++++++ Core/src/main/java/org/tribuo/Trainer.java | 10 +++++-- 3 files changed, 47 insertions(+), 5 deletions(-) diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java index 5ca8e2a06..b04d34d0b 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java @@ -67,7 +67,7 @@ * HDBSCAN* * */ -final public class HdbscanTrainer implements Trainer { +public final class HdbscanTrainer implements Trainer { private static final Logger logger = Logger.getLogger(HdbscanTrainer.class.getName()); static final int OUTLIER_NOISE_CLUSTER_LABEL = 0; @@ -182,7 +182,7 @@ public HdbscanModel train(Dataset examples, Map r ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); - return new HdbscanModel("", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, + return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distanceType); } @@ -196,6 +196,15 @@ public int getInvocationCount() { return trainInvocationCounter; } + @Override + public void setInvocationCount(int newInvocationCount) { + if(newInvocationCount < 0){ + throw new IllegalArgumentException("The supplied invocationCount is less than zero."); + } else { + trainInvocationCounter = newInvocationCount; + } + } + /** * Calculates the core distance for every point in the data set. * diff --git a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java index 391adfbb9..731062fc3 100644 --- a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java +++ b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java @@ -42,6 +42,7 @@ import java.nio.file.Paths; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -69,6 +70,34 @@ public static void setup() { logger.setLevel(Level.WARNING); } + @Test + public void testInvocationCounter() { + ClusteringFactory clusteringFactory = new ClusteringFactory(); + ResponseProcessor emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory); + Map regexMappingProcessors = new HashMap<>(); + regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1")); + regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2")); + regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3")); + RowProcessor rowProcessor = new RowProcessor<>(emptyResponseProcessor,regexMappingProcessors); + CSVDataSource csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"),rowProcessor,false); + Dataset dataset = new MutableDataset<>(csvSource); + + HdbscanTrainer trainer = new HdbscanTrainer(7, Distance.EUCLIDEAN, 7,4); + for (int i = 0; i < 5; i++) { + HdbscanModel model = trainer.train(dataset); + } + + assertEquals(5,trainer.getInvocationCount()); + + trainer.setInvocationCount(0); + + assertEquals(0,trainer.getInvocationCount()); + + Model model = trainer.train(dataset, Collections.emptyMap(), 3); + + assertEquals(4, trainer.getInvocationCount()); + } + @Test public void testEndToEndTrainWithCSVData() { ClusteringFactory clusteringFactory = new ClusteringFactory(); diff --git a/Core/src/main/java/org/tribuo/Trainer.java b/Core/src/main/java/org/tribuo/Trainer.java index dedf50413..8f80fda3f 100644 --- a/Core/src/main/java/org/tribuo/Trainer.java +++ b/Core/src/main/java/org/tribuo/Trainer.java @@ -64,12 +64,16 @@ default public Model train(Dataset examples) { * * @param examples the data set containing the examples. * @param runProvenance Training run specific provenance (e.g., fold number). - * @param invocationCount The state of the RNG the trainer should be set to before training + * @param invocationCount The invocation counter that the trainer should be set to before training, which in most + * cases alters the state of the RNG inside this trainer. If the value is set to + * {@link #INCREMENT_INVOCATION_COUNT} then the invocation count is not changed. * @return a predictive model that can be used to generate predictions for new examples. */ public default Model train(Dataset examples, Map runProvenance, int invocationCount) { - synchronized (this){ - setInvocationCount(invocationCount); + synchronized (this) { + if (invocationCount != INCREMENT_INVOCATION_COUNT) { + setInvocationCount(invocationCount); + } return train(examples, runProvenance); } } From 862c7cf66e89a4d749c06120670c75720ad60324 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 7 Dec 2021 12:16:03 -0500 Subject: [PATCH 4/5] Adding SparseVector support to HdbscanTrainer. --- .../org/tribuo/clustering/hdbscan/HdbscanTrainer.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java index b04d34d0b..1849b738a 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java @@ -27,6 +27,7 @@ import org.tribuo.clustering.ImmutableClusteringInfo; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; +import org.tribuo.math.la.SparseVector; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.provenance.impl.TrainerProvenanceImpl; @@ -147,10 +148,14 @@ public HdbscanModel train(Dataset examples, Map r } ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); - SGDVector[] data = new DenseVector[examples.size()]; + SGDVector[] data = new SGDVector[examples.size()]; int n = 0; for (Example example : examples) { - data[n] = DenseVector.createDenseVector(example, featureMap, false); + if (example.size() == featureMap.size()) { + data[n] = DenseVector.createDenseVector(example, featureMap, false); + } else { + data[n] = SparseVector.createSparseVector(example, featureMap, false); + } n++; } From e0ad08d5d2aa1e74787239608b24052fb6c38e42 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 8 Dec 2021 13:26:16 -0500 Subject: [PATCH 5/5] Adding a sparse data test for HDBSCAN. --- .../tribuo/clustering/hdbscan/TestHdbscan.java | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java index 731062fc3..d3da765f4 100644 --- a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java +++ b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java @@ -241,7 +241,7 @@ public void testTrainTestEvaluation() { runEvaluation(t); } - public void runInvalidExample(HdbscanTrainer trainer) { + public static void runInvalidExample(HdbscanTrainer trainer) { assertThrows(IllegalArgumentException.class, () -> { Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); Model m = trainer.train(p.getA()); @@ -254,7 +254,7 @@ public void testInvalidExample() { runInvalidExample(t); } - public void runEmptyExample(HdbscanTrainer trainer) { + public static void runEmptyExample(HdbscanTrainer trainer) { assertThrows(IllegalArgumentException.class, () -> { Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); Model m = trainer.train(p.getA()); @@ -267,4 +267,16 @@ public void testEmptyExample() { runEmptyExample(t); } + public static void runSparseData(HdbscanTrainer trainer) { + Pair,Dataset> p = ClusteringDataGenerator.sparseTrainTest(); + Model m = trainer.train(p.getA()); + ClusteringEvaluator e = new ClusteringEvaluator(); + e.evaluate(m,p.getB()); + } + + @Test + public void testSparseData() { + runSparseData(t); + } + }