Skip to content

Commit

Permalink
Adds setInvocationCount to HdbscanTrainer (#198)
Browse files Browse the repository at this point in the history
* Moving the options from TrainTest so it uses HdbscanOptions.

* Suppressing the unchecked array warning in ExtendedMinimumSpanningTree.

* Adding setInvocationCount to HdbscanTrainer, and clarifying the semantics of Trainer.train(Dataset<T>, Map<String,Provenance>, int) by allowing negative -1 as an argument.

* Adding SparseVector support to HdbscanTrainer.

* Adding a sparse data test for HDBSCAN.
  • Loading branch information
Craigacp authored Dec 9, 2021
1 parent 3eb8d35 commit d30814c
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* An Extended Minimum Spanning Tree graph. Includes the functionality to sort the edge weights in ascending order.
Expand All @@ -27,7 +28,7 @@ final class ExtendedMinimumSpanningTree {

private final EMSTTriple[] emstTriples;

private final ArrayList<Integer>[] edges;
private final List<Integer>[] edges;

/**
* Constructs an ExtendedMinimumSpanningTree, including creating an edge list for each vertex from the
Expand All @@ -40,8 +41,10 @@ final class ExtendedMinimumSpanningTree {
*/
ExtendedMinimumSpanningTree(int numVertices, int[] firstVertices, int[] secondVertices, double[] edgeWeights) {
this.numVertices = numVertices;

this.edges = new ArrayList[numVertices];
// Only integer arraylists are inserted into this array, and it's not accessible outside.
@SuppressWarnings("unchecked")
List<Integer>[] edgeTmp = (List<Integer>[]) new ArrayList[numVertices];
this.edges = edgeTmp;
for (int i = 0; i < this.edges.length; i++) {
this.edges[i] = new ArrayList<>(1 + edgeWeights.length / numVertices);
}
Expand Down Expand Up @@ -82,7 +85,7 @@ public double getEdgeWeightAtIndex(int index) {
return this.emstTriples[index].edgeWeight;
}

public ArrayList<Integer> getEdgeListForVertex(int vertex) {
public List<Integer> getEdgeListForVertex(int vertex) {
return this.edges[vertex];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
* current clustering. The model is not updated with the new data. This is a novel prediction technique which
* leverages the computed cluster exemplars from the HDBSCAN* algorithm.
*/
final public class HdbscanModel extends Model<ClusterID> {
public final class HdbscanModel extends Model<ClusterID> {
private static final long serialVersionUID = 1L;

private final List<Integer> clusterLabels;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,36 @@
/**
* OLCUT {@link Options} for the HDBSCAN* implementation.
*/
final public class HdbscanOptions implements Options {
public final class HdbscanOptions implements Options {
private static final Logger logger = Logger.getLogger(HdbscanOptions.class.getName());
@Override
public String getOptionsDescription() {
return "Options for configuring a HdbscanTrainer.";
}

/**
* The minimum number of points required to form a cluster.
* The minimum number of points required to form a cluster. Defaults to 5.
*/
@Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster. Defaults to 5.")
@Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster.")
public int minClusterSize = 5;

/**
* Distance function in HDBSCAN*. Defaults to EUCLIDEAN.
*/
@Option(longName = "distance-function", usage = "Distance function to use for various distance calculations. Defaults to EUCLIDEAN.")
@Option(longName = "distance-function", usage = "Distance function to use for various distance calculations.")
public Distance distanceType = Distance.EUCLIDEAN;

/**
* The number of nearest-neighbors to use in the initial density approximation.
* The number of nearest-neighbors to use in the initial density approximation. Defaults to 5.
*/
@Option(longName = "k-nearest-neighbors", usage = "The number of nearest-neighbors to use in the initial density approximation. " +
"The value includes the point itself. Defaults to 5.")
"The value includes the point itself.")
public int k = 5;

/**
* Number of threads to use for training the hdbscan model. Defaults to 2.
*/
@Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model. Defaults to 2.")
@Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model.")
public int numThreads = 2;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
Expand Down Expand Up @@ -67,7 +68,7 @@
* <a href="http://lapad-web.icmc.usp.br/?portfolio_1=a-handful-of-experiments">HDBSCAN*</a>
* </pre>
*/
final public class HdbscanTrainer implements Trainer<ClusterID> {
public final class HdbscanTrainer implements Trainer<ClusterID> {
private static final Logger logger = Logger.getLogger(HdbscanTrainer.class.getName());

static final int OUTLIER_NOISE_CLUSTER_LABEL = 0;
Expand Down Expand Up @@ -147,10 +148,14 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();

SGDVector[] data = new DenseVector[examples.size()];
SGDVector[] data = new SGDVector[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
if (example.size() == featureMap.size()) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
n++;
}

Expand Down Expand Up @@ -182,7 +187,7 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(),
examples.getProvenance(), trainerProvenance, runProvenance);

return new HdbscanModel("", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector,
return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector,
clusterExemplars, distanceType);
}

Expand All @@ -196,6 +201,15 @@ public int getInvocationCount() {
return trainInvocationCounter;
}

@Override
public void setInvocationCount(int newInvocationCount) {
if(newInvocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
} else {
trainInvocationCounter = newInvocationCount;
}
}

/**
* Calculates the core distance for every point in the data set.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.tribuo.clustering.hdbscan;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
Expand All @@ -35,14 +34,14 @@
/**
* Build and run a HDBSCAN* clustering model for a standard dataset.
*/
final public class TrainTest {
public final class TrainTest {

private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

/**
* Options for the HDBSCAN* CLI.
*/
public static class HdbscanOptions implements Options {
public static class HdbscanCLIOptions implements Options {
@Override
public String getOptionsDescription() {
return "Trains and evaluates a HDBSCAN* model on the specified dataset.";
Expand All @@ -54,29 +53,9 @@ public String getOptionsDescription() {
public DataOptions general;

/**
* The minimum number of points required to form a cluster.
* The HDBSCAN options
*/
@Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster. Defaults to 5.")
public int minClusterSize = 5;

/**
* Distance function in HDBSCAN*. Defaults to EUCLIDEAN.
*/
@Option(longName = "distance-function", usage = "Distance function to use for various distance calculations. Defaults to EUCLIDEAN.")
public HdbscanTrainer.Distance distanceType = HdbscanTrainer.Distance.EUCLIDEAN;

/**
* The number of nearest-neighbors to use in the initial density approximation.
*/
@Option(longName = "k-nearest-neighbors", usage = "The number of nearest-neighbors to use in the initial density approximation. " +
"The value includes the point itself. Defaults to 5.")
public int k = 5;

/**
* Number of threads to use for training the hdbscan model. Defaults to 2.
*/
@Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model. Defaults to 2.")
public int numThreads = 2;
public HdbscanOptions hdbscanOptions;
}

/**
Expand All @@ -89,7 +68,7 @@ public static void main(String[] args) throws IOException {
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();

HdbscanOptions o = new HdbscanOptions();
HdbscanCLIOptions o = new HdbscanCLIOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args,o);
Expand All @@ -108,7 +87,7 @@ public static void main(String[] args) throws IOException {
Pair<Dataset<ClusterID>,Dataset<ClusterID>> data = o.general.load(factory);
Dataset<ClusterID> train = data.getA();

HdbscanTrainer trainer = new HdbscanTrainer(o.minClusterSize, o.distanceType, o.k, o.numThreads);
HdbscanTrainer trainer = o.hdbscanOptions.getTrainer();
Model<ClusterID> model = trainer.train(train);
logger.info("Finished training model");
ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -69,6 +70,34 @@ public static void setup() {
logger.setLevel(Level.WARNING);
}

@Test
public void testInvocationCounter() {
ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1"));
regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2"));
regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3"));
RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(emptyResponseProcessor,regexMappingProcessors);
CSVDataSource<ClusterID> csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"),rowProcessor,false);
Dataset<ClusterID> dataset = new MutableDataset<>(csvSource);

HdbscanTrainer trainer = new HdbscanTrainer(7, Distance.EUCLIDEAN, 7,4);
for (int i = 0; i < 5; i++) {
HdbscanModel model = trainer.train(dataset);
}

assertEquals(5,trainer.getInvocationCount());

trainer.setInvocationCount(0);

assertEquals(0,trainer.getInvocationCount());

Model<ClusterID> model = trainer.train(dataset, Collections.emptyMap(), 3);

assertEquals(4, trainer.getInvocationCount());
}

@Test
public void testEndToEndTrainWithCSVData() {
ClusteringFactory clusteringFactory = new ClusteringFactory();
Expand Down Expand Up @@ -212,7 +241,7 @@ public void testTrainTestEvaluation() {
runEvaluation(t);
}

public void runInvalidExample(HdbscanTrainer trainer) {
public static void runInvalidExample(HdbscanTrainer trainer) {
assertThrows(IllegalArgumentException.class, () -> {
Pair<Dataset<ClusterID>, Dataset<ClusterID>> p = ClusteringDataGenerator.denseTrainTest();
Model<ClusterID> m = trainer.train(p.getA());
Expand All @@ -225,7 +254,7 @@ public void testInvalidExample() {
runInvalidExample(t);
}

public void runEmptyExample(HdbscanTrainer trainer) {
public static void runEmptyExample(HdbscanTrainer trainer) {
assertThrows(IllegalArgumentException.class, () -> {
Pair<Dataset<ClusterID>, Dataset<ClusterID>> p = ClusteringDataGenerator.denseTrainTest();
Model<ClusterID> m = trainer.train(p.getA());
Expand All @@ -238,4 +267,16 @@ public void testEmptyExample() {
runEmptyExample(t);
}

public static void runSparseData(HdbscanTrainer trainer) {
Pair<Dataset<ClusterID>,Dataset<ClusterID>> p = ClusteringDataGenerator.sparseTrainTest();
Model<ClusterID> m = trainer.train(p.getA());
ClusteringEvaluator e = new ClusteringEvaluator();
e.evaluate(m,p.getB());
}

@Test
public void testSparseData() {
runSparseData(t);
}

}
10 changes: 7 additions & 3 deletions Core/src/main/java/org/tribuo/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,16 @@ default public Model<T> train(Dataset<T> examples) {
*
* @param examples the data set containing the examples.
* @param runProvenance Training run specific provenance (e.g., fold number).
* @param invocationCount The state of the RNG the trainer should be set to before training
* @param invocationCount The invocation counter that the trainer should be set to before training, which in most
* cases alters the state of the RNG inside this trainer. If the value is set to
* {@link #INCREMENT_INVOCATION_COUNT} then the invocation count is not changed.
* @return a predictive model that can be used to generate predictions for new examples.
*/
public default Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
synchronized (this){
setInvocationCount(invocationCount);
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
return train(examples, runProvenance);
}
}
Expand Down
2 changes: 1 addition & 1 deletion spotbugs-exclude.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ positive warnings.
<Class name="org.tribuo.classification.xgboost.TrainTest"/>
<Class name="org.tribuo.classification.xgboost.TrainTest$TrainTestOptions"/>
<Class name="org.tribuo.clustering.hdbscan.TrainTest"/>
<Class name="org.tribuo.clustering.hdbscan.TrainTest$HdbscanOptions"/>
<Class name="org.tribuo.clustering.hdbscan.TrainTest$HdbscanCLIOptions"/>
<Class name="org.tribuo.clustering.kmeans.TrainTest"/>
<Class name="org.tribuo.clustering.kmeans.TrainTest$KMeansOptions"/>
<Class name="org.tribuo.data.ConfigurableTrainTest"/>
Expand Down

0 comments on commit d30814c

Please sign in to comment.