diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java index 03558ffd3..c3336da3a 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java @@ -19,6 +19,7 @@ import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.Excuse; +import org.tribuo.Feature; import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; @@ -29,6 +30,7 @@ import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.SparseVector; +import org.tribuo.math.la.VectorTuple; import org.tribuo.provenance.ModelProvenance; import java.io.IOException; @@ -107,6 +109,43 @@ public List getOutlierScores() { return outlierScores; } + /** + * Returns a deep copy of the cluster exemplars. + * @return The cluster exemplars. + */ + public List getClusterExemplars() { + List list = new ArrayList<>(clusterExemplars.size()); + for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) { + list.add(e.copy()); + } + return list; + } + + /** + * Returns the features in each cluster exemplar. + *

+ * In many cases this should be used in preference to {@link #getClusterExemplars()} + * as it performs the mapping from Tribuo's internal feature ids to + * the externally visible feature names. + * @return The cluster exemplars. + */ + public List>> getClusters() { + List>> list = new ArrayList<>(clusterExemplars.size()); + + for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) { + List features = new ArrayList<>(e.getFeatures().numActiveElements()); + + for (VectorTuple v : e.getFeatures()) { + Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value); + features.add(f); + } + + list.add(new Pair<>(e.getLabel(),features)); + } + + return list; + } + @Override public Prediction predict(Example example) { SGDVector vector; diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java index 6c0992c84..2e8e9153e 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java @@ -50,6 +50,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.PriorityQueue; import java.util.TreeMap; import java.util.TreeSet; @@ -805,7 +806,7 @@ public TrainerProvenance getProvenance() { /** * A cluster exemplar, with attributes for the point's label, outlier score and its features. */ - final static class ClusterExemplar implements Serializable { + public final static class ClusterExemplar implements Serializable { private static final long serialVersionUID = 1L; private final Integer label; @@ -820,19 +821,38 @@ final static class ClusterExemplar implements Serializable { this.maxDistToEdge = maxDistToEdge; } - Integer getLabel() { + /** + * Get the label in this exemplar. + * @return The label. + */ + public Integer getLabel() { return label; } - Double getOutlierScore() { + /** + * Get the outlier score in this exemplar. + * @return The outlier score. + */ + public Double getOutlierScore() { return outlierScore; } - SGDVector getFeatures() { + /** + * Get the feature vector in this exemplar. + * @return The feature vector. + */ + public SGDVector getFeatures() { return features; } - Double getMaxDistToEdge() { + /** + * Get the maximum distance from this exemplar to the edge of the cluster. + *

+ * For models trained in 4.2 this will return {@link Double#NEGATIVE_INFINITY} as that information is + * not produced by 4.2 models. + * @return The distance to the edge of the cluster. + */ + public Double getMaxDistToEdge() { if (maxDistToEdge != null) { return maxDistToEdge; } @@ -840,6 +860,33 @@ Double getMaxDistToEdge() { return Double.NEGATIVE_INFINITY; } } + + /** + * Copies this cluster exemplar. + * @return A deep copy of this cluster exemplar. + */ + public ClusterExemplar copy() { + return new ClusterExemplar(label,outlierScore,features.copy(),maxDistToEdge); + } + + @Override + public String toString() { + double dist = maxDistToEdge == null ? Double.NEGATIVE_INFINITY : maxDistToEdge; + return "ClusterExemplar(label="+label+",outlierScore="+outlierScore+",vector="+features+",maxDistToEdge="+dist+")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClusterExemplar that = (ClusterExemplar) o; + return label.equals(that.label) && outlierScore.equals(that.outlierScore) && features.equals(that.features) && Objects.equals(maxDistToEdge, that.maxDistToEdge); + } + + @Override + public int hashCode() { + return Objects.hash(label, outlierScore, features, maxDistToEdge); + } } } diff --git a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java index 8474e548f..925012a1b 100644 --- a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java +++ b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.tribuo.DataSource; import org.tribuo.Dataset; +import org.tribuo.Feature; import org.tribuo.Model; import org.tribuo.MutableDataset; import org.tribuo.Prediction; @@ -37,7 +38,10 @@ import org.tribuo.data.columnar.processors.response.EmptyResponseProcessor; import org.tribuo.data.csv.CSVDataSource; import org.tribuo.evaluation.TrainTestSplitter; +import org.tribuo.impl.ArrayExample; import org.tribuo.math.distance.DistanceType; +import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.SGDVector; import org.tribuo.test.Helpers; import java.io.FileInputStream; @@ -47,8 +51,10 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -57,6 +63,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; /** @@ -205,8 +212,35 @@ public static void runBasicTrainPredict(HdbscanTrainer trainer) { Dataset testData = new MutableDataset<>(splitter.getTest()); HdbscanModel model = trainer.train(trainData); + + for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) { + assertTrue(e.getMaxDistToEdge() > 0.0); + } + List clusterLabels = model.getClusterLabels(); List outlierScores = model.getOutlierScores(); + List>> exemplarLists = model.getClusters(); + List exemplars = model.getClusterExemplars(); + + assertEquals(exemplars.size(), exemplarLists.size()); + + // Check there's at least one exemplar per label + Set exemplarLabels = exemplarLists.stream().map(Pair::getA).collect(Collectors.toSet()); + Set clusterLabelSet = new HashSet<>(clusterLabels); + // Remove the noise label + clusterLabelSet.remove(Integer.valueOf(0)); + assertEquals(exemplarLabels,clusterLabelSet); + + for (int i = 0; i < exemplars.size(); i++) { + HdbscanTrainer.ClusterExemplar e = exemplars.get(i); + Pair> p = exemplarLists.get(i); + assertEquals(model.getFeatureIDMap().size(), e.getFeatures().size()); + assertEquals(p.getB().size(), e.getFeatures().size()); + SGDVector otherFeatures = DenseVector.createDenseVector( + new ArrayExample<>(trainData.getOutputFactory().getUnknownOutput(), p.getB()), + model.getFeatureIDMap(), false); + assertEquals(otherFeatures, e.getFeatures()); + } int [] expectedIntClusterLabels = {4,3,4,5,3,5,3,4,3,4,5,5,3,4,4,0,3,4,0,5,5,3,3,4,4,4,4,4,4,4,4,4,4,0,4,5,3,5,3,4,3,4,4,3,0,5,0,4,4,4,4,4,5,4,3,4,4,4,4,4,5,3,4,3,5,3,4,5,3,4,0,5,4,4,4,4,4,5,4,4,4,4,4,5,3,4,4,3,4,3,5,5,0,5,4,4,3,5,5,4,5,5,3,5,4,4,3,5,4,5,5,5,4,4,5,5,3,5,4,4,3,5,5,3,5,4,4,5,5,5,3,5,4,5,3,4,3,5,4,4,3,3,5,4,4,5,5,4,3,4,5,4,5,4,3,3,3,4,5,4,5,5,3,4,3,3,4,5,3,5,5,5,5,5,4,4,3,4,5,5,4,4,3,4,3,4,5,4,4,5,4,3,3,0,3,5,5,3,3,3,4,3,3,5,5,5,5,3,5,5,3,5,3,4,5,3,3,3,4,4,3,3,3,5,3,4,5,3,5,5,5,3,5,3,5,4,5,4,4,5,5,5,3,5,4,5,5,4,4,4,5,4,5,4,3,3,4,5,4,4,3,3,3,4,5,4,4,4,4,5,4,4,4,5,3,5,4,5,3,5,3,5,4,4,0,4,4,5,3,4,5,5,0,5,4,5,3,4,3,5,5,4,5,5,5,5,5,5,3,5,4,3,3,5,3,4,5,4,3,5,4,3,3,3,5,4,5,4,5,5,4,3,5,4,5,4,5,4,3,4,5,4,4,5,5,5,3,4,5,4,0,3,5,3,4,3,3,5,5,5,4,4,3,3,4,3,5,3,3,4,3,5,3,4,5,4,4,3,4,4,3,3,5,4,4,5,3,5,3,3,4,5,3,4,5,5,4,4,4,5,5,5,5,3,3,4,4,4,4,4,3,5,4,3,4,4,5,3,5,3,4,5,4,4,5,3,4,4,4,5,5,4,5,0,4,5,3,4,5,4,4,4,5,4,4,4,0,3,4,5,5,4,4,3,3,4,3,3,4,5,5,4,3,5,4,4,4,4,5,4,4,3,4,5,5,4,3,4,5,4,3,5,5,5,3,4,4,4,4,4,4,5,3,3,3,5,5,4,5,3,5,3,5,4,5,3,4,5,4,3,5,4,4,5,5,0,3,3,5,5,3,0,5,5,5,5,3,4,5,4,3,3,4,5,4,4,0,5,3,4,4,4,4,5,5,5,3,5,4,3,3,5,3,4,3,5,3,3,4,3,5,4,3,4,3,0,4,5,5,5,3,4,3,5,5,4,5,4,4,4,5,4,3,4,3,4,5,3,5,4,5,3,0,4,0,4,3,3,4,3,0,3,3,3,3,4,4,5,3,3,5,4,4,4,5,5,5,3,3,4,4,3,4,5,3,4,4,5,3,4,4,4,3,4,4,4,5,4,4,5,5,5,4,4,4,5,5,5,5,4,3,4,3,3,3,4,4,5,4,5,4,4,4,4,4,5,4,5,5,5,4,3,5,3,5,4,5,4,4,5,0,5,3,4,5,4,4,5,3,4,4,3,5,4,4,4,5,3,3,4,4,5,5,5,3,4,3,4,5,5,4,4,3,3,4,4,5,5,5,3,4,3,4,4,4,5,5,5,0,4,5,5,3,3,4,5,4,3,3,4,3,4,5,4,3,4,5,5,3,3,4,4,3,3,5,4,5,3,4,5,4,3,3,4,5,5,5,3,3,4,4,5,5,5,4,5,5,5,4,4,4,5,4,5,5,3,3,4,4,3,5,5,3,3,4,4,5,3,3,3}; List expectedClusterLabels = Arrays.stream(expectedIntClusterLabels).boxed().collect(Collectors.toList()); @@ -258,6 +292,11 @@ public void deserializeHdbscanModelV42Test() { fail("There is a problem deserializing the model file " + serializedModelPath); } + // In v4.2 models this value is unset and defaults to negative infinity. + for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) { + assertEquals(Double.NEGATIVE_INFINITY, e.getMaxDistToEdge()); + } + ClusteringFactory clusteringFactory = new ClusteringFactory(); ResponseProcessor emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory); Map regexMappingProcessors = new HashMap<>();