Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds setInvocationCount to HdbscanTrainer #198

Merged
merged 5 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* An Extended Minimum Spanning Tree graph. Includes the functionality to sort the edge weights in ascending order.
Expand All @@ -27,7 +28,7 @@ final class ExtendedMinimumSpanningTree {

private final EMSTTriple[] emstTriples;

private final ArrayList<Integer>[] edges;
private final List<Integer>[] edges;

/**
* Constructs an ExtendedMinimumSpanningTree, including creating an edge list for each vertex from the
Expand All @@ -40,8 +41,10 @@ final class ExtendedMinimumSpanningTree {
*/
ExtendedMinimumSpanningTree(int numVertices, int[] firstVertices, int[] secondVertices, double[] edgeWeights) {
this.numVertices = numVertices;

this.edges = new ArrayList[numVertices];
// Only integer arraylists are inserted into this array, and it's not accessible outside.
@SuppressWarnings("unchecked")
List<Integer>[] edgeTmp = (List<Integer>[]) new ArrayList[numVertices];
this.edges = edgeTmp;
for (int i = 0; i < this.edges.length; i++) {
this.edges[i] = new ArrayList<>(1 + edgeWeights.length / numVertices);
}
Expand Down Expand Up @@ -82,7 +85,7 @@ public double getEdgeWeightAtIndex(int index) {
return this.emstTriples[index].edgeWeight;
}

public ArrayList<Integer> getEdgeListForVertex(int vertex) {
public List<Integer> getEdgeListForVertex(int vertex) {
return this.edges[vertex];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
* current clustering. The model is not updated with the new data. This is a novel prediction technique which
* leverages the computed cluster exemplars from the HDBSCAN* algorithm.
*/
final public class HdbscanModel extends Model<ClusterID> {
public final class HdbscanModel extends Model<ClusterID> {
private static final long serialVersionUID = 1L;

private final List<Integer> clusterLabels;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,36 @@
/**
* OLCUT {@link Options} for the HDBSCAN* implementation.
*/
final public class HdbscanOptions implements Options {
public final class HdbscanOptions implements Options {
private static final Logger logger = Logger.getLogger(HdbscanOptions.class.getName());
@Override
public String getOptionsDescription() {
return "Options for configuring a HdbscanTrainer.";
}

/**
* The minimum number of points required to form a cluster.
* The minimum number of points required to form a cluster. Defaults to 5.
*/
@Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster. Defaults to 5.")
@Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster.")
public int minClusterSize = 5;

/**
* Distance function in HDBSCAN*. Defaults to EUCLIDEAN.
*/
@Option(longName = "distance-function", usage = "Distance function to use for various distance calculations. Defaults to EUCLIDEAN.")
@Option(longName = "distance-function", usage = "Distance function to use for various distance calculations.")
public Distance distanceType = Distance.EUCLIDEAN;

/**
* The number of nearest-neighbors to use in the initial density approximation.
* The number of nearest-neighbors to use in the initial density approximation. Defaults to 5.
*/
@Option(longName = "k-nearest-neighbors", usage = "The number of nearest-neighbors to use in the initial density approximation. " +
"The value includes the point itself. Defaults to 5.")
"The value includes the point itself.")
public int k = 5;

/**
* Number of threads to use for training the hdbscan model. Defaults to 2.
*/
@Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model. Defaults to 2.")
@Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model.")
public int numThreads = 2;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
Expand Down Expand Up @@ -67,7 +68,7 @@
* <a href="http://lapad-web.icmc.usp.br/?portfolio_1=a-handful-of-experiments">HDBSCAN*</a>
* </pre>
*/
final public class HdbscanTrainer implements Trainer<ClusterID> {
public final class HdbscanTrainer implements Trainer<ClusterID> {
private static final Logger logger = Logger.getLogger(HdbscanTrainer.class.getName());

static final int OUTLIER_NOISE_CLUSTER_LABEL = 0;
Expand Down Expand Up @@ -147,10 +148,14 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();

SGDVector[] data = new DenseVector[examples.size()];
SGDVector[] data = new SGDVector[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
if (example.size() == featureMap.size()) {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
n++;
}

Expand Down Expand Up @@ -182,7 +187,7 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(),
examples.getProvenance(), trainerProvenance, runProvenance);

return new HdbscanModel("", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector,
return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector,
clusterExemplars, distanceType);
}

Expand All @@ -196,6 +201,15 @@ public int getInvocationCount() {
return trainInvocationCounter;
}

@Override
public void setInvocationCount(int newInvocationCount) {
if(newInvocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
} else {
trainInvocationCounter = newInvocationCount;
}
}

/**
* Calculates the core distance for every point in the data set.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.tribuo.clustering.hdbscan;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
Expand All @@ -35,14 +34,14 @@
/**
* Build and run a HDBSCAN* clustering model for a standard dataset.
*/
final public class TrainTest {
public final class TrainTest {

private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

/**
* Options for the HDBSCAN* CLI.
*/
public static class HdbscanOptions implements Options {
public static class HdbscanCLIOptions implements Options {
@Override
public String getOptionsDescription() {
return "Trains and evaluates a HDBSCAN* model on the specified dataset.";
Expand All @@ -54,29 +53,9 @@ public String getOptionsDescription() {
public DataOptions general;

/**
* The minimum number of points required to form a cluster.
* The HDBSCAN options
*/
@Option(longName = "minimum-cluster-size", usage = "The minimum number of points required to form a cluster. Defaults to 5.")
public int minClusterSize = 5;

/**
* Distance function in HDBSCAN*. Defaults to EUCLIDEAN.
*/
@Option(longName = "distance-function", usage = "Distance function to use for various distance calculations. Defaults to EUCLIDEAN.")
public HdbscanTrainer.Distance distanceType = HdbscanTrainer.Distance.EUCLIDEAN;

/**
* The number of nearest-neighbors to use in the initial density approximation.
*/
@Option(longName = "k-nearest-neighbors", usage = "The number of nearest-neighbors to use in the initial density approximation. " +
"The value includes the point itself. Defaults to 5.")
public int k = 5;

/**
* Number of threads to use for training the hdbscan model. Defaults to 2.
*/
@Option(longName = "hdbscan-num-threads", usage = "Number of threads to use for training the hdbscan model. Defaults to 2.")
public int numThreads = 2;
public HdbscanOptions hdbscanOptions;
}

/**
Expand All @@ -89,7 +68,7 @@ public static void main(String[] args) throws IOException {
// Use the labs format logging.
LabsLogFormatter.setAllLogFormatters();

HdbscanOptions o = new HdbscanOptions();
HdbscanCLIOptions o = new HdbscanCLIOptions();
ConfigurationManager cm;
try {
cm = new ConfigurationManager(args,o);
Expand All @@ -108,7 +87,7 @@ public static void main(String[] args) throws IOException {
Pair<Dataset<ClusterID>,Dataset<ClusterID>> data = o.general.load(factory);
Dataset<ClusterID> train = data.getA();

HdbscanTrainer trainer = new HdbscanTrainer(o.minClusterSize, o.distanceType, o.k, o.numThreads);
HdbscanTrainer trainer = o.hdbscanOptions.getTrainer();
Model<ClusterID> model = trainer.train(train);
logger.info("Finished training model");
ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -69,6 +70,34 @@ public static void setup() {
logger.setLevel(Level.WARNING);
}

@Test
public void testInvocationCounter() {
ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1"));
regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2"));
regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3"));
RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(emptyResponseProcessor,regexMappingProcessors);
CSVDataSource<ClusterID> csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"),rowProcessor,false);
Dataset<ClusterID> dataset = new MutableDataset<>(csvSource);

HdbscanTrainer trainer = new HdbscanTrainer(7, Distance.EUCLIDEAN, 7,4);
for (int i = 0; i < 5; i++) {
HdbscanModel model = trainer.train(dataset);
}

assertEquals(5,trainer.getInvocationCount());

trainer.setInvocationCount(0);

assertEquals(0,trainer.getInvocationCount());

Model<ClusterID> model = trainer.train(dataset, Collections.emptyMap(), 3);

assertEquals(4, trainer.getInvocationCount());
}

@Test
public void testEndToEndTrainWithCSVData() {
ClusteringFactory clusteringFactory = new ClusteringFactory();
Expand Down Expand Up @@ -212,7 +241,7 @@ public void testTrainTestEvaluation() {
runEvaluation(t);
}

public void runInvalidExample(HdbscanTrainer trainer) {
public static void runInvalidExample(HdbscanTrainer trainer) {
assertThrows(IllegalArgumentException.class, () -> {
Pair<Dataset<ClusterID>, Dataset<ClusterID>> p = ClusteringDataGenerator.denseTrainTest();
Model<ClusterID> m = trainer.train(p.getA());
Expand All @@ -225,7 +254,7 @@ public void testInvalidExample() {
runInvalidExample(t);
}

public void runEmptyExample(HdbscanTrainer trainer) {
public static void runEmptyExample(HdbscanTrainer trainer) {
assertThrows(IllegalArgumentException.class, () -> {
Pair<Dataset<ClusterID>, Dataset<ClusterID>> p = ClusteringDataGenerator.denseTrainTest();
Model<ClusterID> m = trainer.train(p.getA());
Expand All @@ -238,4 +267,16 @@ public void testEmptyExample() {
runEmptyExample(t);
}

public static void runSparseData(HdbscanTrainer trainer) {
Pair<Dataset<ClusterID>,Dataset<ClusterID>> p = ClusteringDataGenerator.sparseTrainTest();
Model<ClusterID> m = trainer.train(p.getA());
ClusteringEvaluator e = new ClusteringEvaluator();
e.evaluate(m,p.getB());
}

@Test
public void testSparseData() {
runSparseData(t);
}

}
10 changes: 7 additions & 3 deletions Core/src/main/java/org/tribuo/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,16 @@ default public Model<T> train(Dataset<T> examples) {
*
* @param examples the data set containing the examples.
* @param runProvenance Training run specific provenance (e.g., fold number).
* @param invocationCount The state of the RNG the trainer should be set to before training
* @param invocationCount The invocation counter that the trainer should be set to before training, which in most
* cases alters the state of the RNG inside this trainer. If the value is set to
* {@link #INCREMENT_INVOCATION_COUNT} then the invocation count is not changed.
* @return a predictive model that can be used to generate predictions for new examples.
*/
public default Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
synchronized (this){
setInvocationCount(invocationCount);
synchronized (this) {
if (invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
return train(examples, runProvenance);
}
}
Expand Down
2 changes: 1 addition & 1 deletion spotbugs-exclude.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ positive warnings.
<Class name="org.tribuo.classification.xgboost.TrainTest"/>
<Class name="org.tribuo.classification.xgboost.TrainTest$TrainTestOptions"/>
<Class name="org.tribuo.clustering.hdbscan.TrainTest"/>
<Class name="org.tribuo.clustering.hdbscan.TrainTest$HdbscanOptions"/>
<Class name="org.tribuo.clustering.hdbscan.TrainTest$HdbscanCLIOptions"/>
<Class name="org.tribuo.clustering.kmeans.TrainTest"/>
<Class name="org.tribuo.clustering.kmeans.TrainTest$KMeansOptions"/>
<Class name="org.tribuo.data.ConfigurableTrainTest"/>
Expand Down