Skip to content

Commit

Permalink
Exposes views on the HDBSCAN cluster exemplars (#229)
Browse files Browse the repository at this point in the history
* Make ClusterExemplar public and add accessors to HdbscanModel.

* Adding a more friendly getClusters method to HdbscanModel.

* Updating ClusterExemplar copy, toString, equals and hashCode.

* Adding more test coverage of HdbscanModel.getClusterExemplars() and getClusters().
  • Loading branch information
Craigacp authored Apr 15, 2022
1 parent db1bf09 commit 38d60e6
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
Expand All @@ -29,6 +30,7 @@
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;

import java.io.IOException;
Expand Down Expand Up @@ -107,6 +109,43 @@ public List<Double> getOutlierScores() {
return outlierScores;
}

/**
* Returns a deep copy of the cluster exemplars.
* @return The cluster exemplars.
*/
public List<HdbscanTrainer.ClusterExemplar> getClusterExemplars() {
List<HdbscanTrainer.ClusterExemplar> list = new ArrayList<>(clusterExemplars.size());
for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
list.add(e.copy());
}
return list;
}

/**
* Returns the features in each cluster exemplar.
* <p>
* In many cases this should be used in preference to {@link #getClusterExemplars()}
* as it performs the mapping from Tribuo's internal feature ids to
* the externally visible feature names.
* @return The cluster exemplars.
*/
public List<Pair<Integer,List<Feature>>> getClusters() {
List<Pair<Integer,List<Feature>>> list = new ArrayList<>(clusterExemplars.size());

for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
List<Feature> features = new ArrayList<>(e.getFeatures().numActiveElements());

for (VectorTuple v : e.getFeatures()) {
Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value);
features.add(f);
}

list.add(new Pair<>(e.getLabel(),features));
}

return list;
}

@Override
public Prediction<ClusterID> predict(Example<ClusterID> example) {
SGDVector vector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.TreeSet;
Expand Down Expand Up @@ -805,7 +806,7 @@ public TrainerProvenance getProvenance() {
/**
* A cluster exemplar, with attributes for the point's label, outlier score and its features.
*/
final static class ClusterExemplar implements Serializable {
public final static class ClusterExemplar implements Serializable {
private static final long serialVersionUID = 1L;

private final Integer label;
Expand All @@ -820,26 +821,72 @@ final static class ClusterExemplar implements Serializable {
this.maxDistToEdge = maxDistToEdge;
}

Integer getLabel() {
/**
* Get the label in this exemplar.
* @return The label.
*/
public Integer getLabel() {
return label;
}

Double getOutlierScore() {
/**
* Get the outlier score in this exemplar.
* @return The outlier score.
*/
public Double getOutlierScore() {
return outlierScore;
}

SGDVector getFeatures() {
/**
* Get the feature vector in this exemplar.
* @return The feature vector.
*/
public SGDVector getFeatures() {
return features;
}

Double getMaxDistToEdge() {
/**
* Get the maximum distance from this exemplar to the edge of the cluster.
* <p>
* For models trained in 4.2 this will return {@link Double#NEGATIVE_INFINITY} as that information is
* not produced by 4.2 models.
* @return The distance to the edge of the cluster.
*/
public Double getMaxDistToEdge() {
if (maxDistToEdge != null) {
return maxDistToEdge;
}
else {
return Double.NEGATIVE_INFINITY;
}
}

/**
* Copies this cluster exemplar.
* @return A deep copy of this cluster exemplar.
*/
public ClusterExemplar copy() {
return new ClusterExemplar(label,outlierScore,features.copy(),maxDistToEdge);
}

@Override
public String toString() {
double dist = maxDistToEdge == null ? Double.NEGATIVE_INFINITY : maxDistToEdge;
return "ClusterExemplar(label="+label+",outlierScore="+outlierScore+",vector="+features+",maxDistToEdge="+dist+")";
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClusterExemplar that = (ClusterExemplar) o;
return label.equals(that.label) && outlierScore.equals(that.outlierScore) && features.equals(that.features) && Objects.equals(maxDistToEdge, that.maxDistToEdge);
}

@Override
public int hashCode() {
return Objects.hash(label, outlierScore, features, maxDistToEdge);
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
import org.junit.jupiter.api.Test;
import org.tribuo.DataSource;
import org.tribuo.Dataset;
import org.tribuo.Feature;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
Expand All @@ -37,7 +38,10 @@
import org.tribuo.data.columnar.processors.response.EmptyResponseProcessor;
import org.tribuo.data.csv.CSVDataSource;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.impl.ArrayExample;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.test.Helpers;

import java.io.FileInputStream;
Expand All @@ -47,8 +51,10 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
Expand All @@ -57,6 +63,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

/**
Expand Down Expand Up @@ -205,8 +212,35 @@ public static void runBasicTrainPredict(HdbscanTrainer trainer) {
Dataset<ClusterID> testData = new MutableDataset<>(splitter.getTest());

HdbscanModel model = trainer.train(trainData);

for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) {
assertTrue(e.getMaxDistToEdge() > 0.0);
}

List<Integer> clusterLabels = model.getClusterLabels();
List<Double> outlierScores = model.getOutlierScores();
List<Pair<Integer,List<Feature>>> exemplarLists = model.getClusters();
List<HdbscanTrainer.ClusterExemplar> exemplars = model.getClusterExemplars();

assertEquals(exemplars.size(), exemplarLists.size());

// Check there's at least one exemplar per label
Set<Integer> exemplarLabels = exemplarLists.stream().map(Pair::getA).collect(Collectors.toSet());
Set<Integer> clusterLabelSet = new HashSet<>(clusterLabels);
// Remove the noise label
clusterLabelSet.remove(Integer.valueOf(0));
assertEquals(exemplarLabels,clusterLabelSet);

for (int i = 0; i < exemplars.size(); i++) {
HdbscanTrainer.ClusterExemplar e = exemplars.get(i);
Pair<Integer, List<Feature>> p = exemplarLists.get(i);
assertEquals(model.getFeatureIDMap().size(), e.getFeatures().size());
assertEquals(p.getB().size(), e.getFeatures().size());
SGDVector otherFeatures = DenseVector.createDenseVector(
new ArrayExample<>(trainData.getOutputFactory().getUnknownOutput(), p.getB()),
model.getFeatureIDMap(), false);
assertEquals(otherFeatures, e.getFeatures());
}

int [] expectedIntClusterLabels = {4,3,4,5,3,5,3,4,3,4,5,5,3,4,4,0,3,4,0,5,5,3,3,4,4,4,4,4,4,4,4,4,4,0,4,5,3,5,3,4,3,4,4,3,0,5,0,4,4,4,4,4,5,4,3,4,4,4,4,4,5,3,4,3,5,3,4,5,3,4,0,5,4,4,4,4,4,5,4,4,4,4,4,5,3,4,4,3,4,3,5,5,0,5,4,4,3,5,5,4,5,5,3,5,4,4,3,5,4,5,5,5,4,4,5,5,3,5,4,4,3,5,5,3,5,4,4,5,5,5,3,5,4,5,3,4,3,5,4,4,3,3,5,4,4,5,5,4,3,4,5,4,5,4,3,3,3,4,5,4,5,5,3,4,3,3,4,5,3,5,5,5,5,5,4,4,3,4,5,5,4,4,3,4,3,4,5,4,4,5,4,3,3,0,3,5,5,3,3,3,4,3,3,5,5,5,5,3,5,5,3,5,3,4,5,3,3,3,4,4,3,3,3,5,3,4,5,3,5,5,5,3,5,3,5,4,5,4,4,5,5,5,3,5,4,5,5,4,4,4,5,4,5,4,3,3,4,5,4,4,3,3,3,4,5,4,4,4,4,5,4,4,4,5,3,5,4,5,3,5,3,5,4,4,0,4,4,5,3,4,5,5,0,5,4,5,3,4,3,5,5,4,5,5,5,5,5,5,3,5,4,3,3,5,3,4,5,4,3,5,4,3,3,3,5,4,5,4,5,5,4,3,5,4,5,4,5,4,3,4,5,4,4,5,5,5,3,4,5,4,0,3,5,3,4,3,3,5,5,5,4,4,3,3,4,3,5,3,3,4,3,5,3,4,5,4,4,3,4,4,3,3,5,4,4,5,3,5,3,3,4,5,3,4,5,5,4,4,4,5,5,5,5,3,3,4,4,4,4,4,3,5,4,3,4,4,5,3,5,3,4,5,4,4,5,3,4,4,4,5,5,4,5,0,4,5,3,4,5,4,4,4,5,4,4,4,0,3,4,5,5,4,4,3,3,4,3,3,4,5,5,4,3,5,4,4,4,4,5,4,4,3,4,5,5,4,3,4,5,4,3,5,5,5,3,4,4,4,4,4,4,5,3,3,3,5,5,4,5,3,5,3,5,4,5,3,4,5,4,3,5,4,4,5,5,0,3,3,5,5,3,0,5,5,5,5,3,4,5,4,3,3,4,5,4,4,0,5,3,4,4,4,4,5,5,5,3,5,4,3,3,5,3,4,3,5,3,3,4,3,5,4,3,4,3,0,4,5,5,5,3,4,3,5,5,4,5,4,4,4,5,4,3,4,3,4,5,3,5,4,5,3,0,4,0,4,3,3,4,3,0,3,3,3,3,4,4,5,3,3,5,4,4,4,5,5,5,3,3,4,4,3,4,5,3,4,4,5,3,4,4,4,3,4,4,4,5,4,4,5,5,5,4,4,4,5,5,5,5,4,3,4,3,3,3,4,4,5,4,5,4,4,4,4,4,5,4,5,5,5,4,3,5,3,5,4,5,4,4,5,0,5,3,4,5,4,4,5,3,4,4,3,5,4,4,4,5,3,3,4,4,5,5,5,3,4,3,4,5,5,4,4,3,3,4,4,5,5,5,3,4,3,4,4,4,5,5,5,0,4,5,5,3,3,4,5,4,3,3,4,3,4,5,4,3,4,5,5,3,3,4,4,3,3,5,4,5,3,4,5,4,3,3,4,5,5,5,3,3,4,4,5,5,5,4,5,5,5,4,4,4,5,4,5,5,3,3,4,4,3,5,5,3,3,4,4,5,3,3,3};
List<Integer> expectedClusterLabels = Arrays.stream(expectedIntClusterLabels).boxed().collect(Collectors.toList());
Expand Down Expand Up @@ -258,6 +292,11 @@ public void deserializeHdbscanModelV42Test() {
fail("There is a problem deserializing the model file " + serializedModelPath);
}

// In v4.2 models this value is unset and defaults to negative infinity.
for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) {
assertEquals(Double.NEGATIVE_INFINITY, e.getMaxDistToEdge());
}

ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
Expand Down

0 comments on commit 38d60e6

Please sign in to comment.