Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrates DistanceType to a Distance interface #285

Merged
merged 2 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto;
import org.tribuo.clustering.hdbscan.protos.HdbscanModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

Expand Down Expand Up @@ -78,20 +78,20 @@ public final class HdbscanModel extends Model<ClusterID> {

// This is not final to support deserialization of older models. It will be final in a future version which doesn't
// maintain serialization compatibility with 4.X.
private DistanceType distType;
private org.tribuo.math.distance.Distance dist;

private final List<HdbscanTrainer.ClusterExemplar> clusterExemplars;

private final double noisePointsOutlierScore;

HdbscanModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap,
ImmutableOutputInfo<ClusterID> outputIDInfo, List<Integer> clusterLabels, DenseVector outlierScoresVector,
List<HdbscanTrainer.ClusterExemplar> clusterExemplars, DistanceType distType, double noisePointsOutlierScore) {
List<HdbscanTrainer.ClusterExemplar> clusterExemplars, org.tribuo.math.distance.Distance dist, double noisePointsOutlierScore) {
super(name,description,featureIDMap,outputIDInfo,false);
this.clusterLabels = Collections.unmodifiableList(clusterLabels);
this.outlierScoresVector = outlierScoresVector;
this.clusterExemplars = Collections.unmodifiableList(clusterExemplars);
this.distType = distType;
this.dist = dist;
this.noisePointsOutlierScore = noisePointsOutlierScore;
}

Expand Down Expand Up @@ -135,10 +135,10 @@ public static HdbscanModel deserializeFromProto(int version, String className, A
exemplars.add(HdbscanTrainer.ClusterExemplar.deserialize(p));
}

DistanceType distType = DistanceType.valueOf(proto.getDistType());
org.tribuo.math.distance.Distance dist = ProtoUtil.deserialize(proto.getDistance());

return new HdbscanModel(carrier.name(), carrier.provenance(), carrier.featureDomain(),
outputDomain, clusterLabels, outlierScoresVector, exemplars, distType, proto.getNoisePointsOutlierScore());
outputDomain, clusterLabels, outlierScoresVector, exemplars, dist, proto.getNoisePointsOutlierScore());
}

/**
Expand Down Expand Up @@ -222,7 +222,7 @@ public Prediction<ClusterID> predict(Example<ClusterID> example) {
if (Double.compare(noisePointsOutlierScore, 0) > 0) { // This will be true from models > 4.2
boolean isNoisePoint = true;
for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) {
double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType);
double distance = dist.computeDistance(clusterExemplar.getFeatures(), vector);
if (isNoisePoint && distance <= clusterExemplar.getMaxDistToEdge()) {
isNoisePoint = false;
}
Expand All @@ -239,7 +239,7 @@ public Prediction<ClusterID> predict(Example<ClusterID> example) {
}
else {
for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) {
double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType);
double distance = dist.computeDistance(clusterExemplar.getFeatures(), vector);
if (distance < minDistance) {
minDistance = distance;
clusterLabel = clusterExemplar.getLabel();
Expand Down Expand Up @@ -268,7 +268,7 @@ public ModelProto serialize() {
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.addAllClusterLabels(clusterLabels);
modelBuilder.setOutlierScoresVector(outlierScoresVector.serialize());
modelBuilder.setDistType(distType.name());
modelBuilder.setDistance(dist.serialize());
for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
modelBuilder.addClusterExemplars(e.serialize());
}
Expand All @@ -288,13 +288,13 @@ protected HdbscanModel copy(String newName, ModelProvenance newProvenance) {
List<Integer> copyClusterLabels = new ArrayList<>(clusterLabels);
List<HdbscanTrainer.ClusterExemplar> copyExemplars = new ArrayList<>(clusterExemplars);
return new HdbscanModel(newName, newProvenance, featureIDMap, outputIDInfo, copyClusterLabels,
copyOutlierScoresVector, copyExemplars, distType, noisePointsOutlierScore);
copyOutlierScoresVector, copyExemplars, dist, noisePointsOutlierScore);
}

private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
if (distType == null) {
distType = distanceType.getDistanceType();
if (dist == null) {
dist = distanceType.getDistanceType().getDistance();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ public String getOptionsDescription() {
*/
public HdbscanTrainer getTrainer() {
logger.info("Configuring Hdbscan Trainer");
return new HdbscanTrainer(minClusterSize, distType, k, numThreads, nqFactoryType);
return new HdbscanTrainer(minClusterSize, distType.getDistance(), k, numThreads, nqFactoryType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public DistanceType getDistanceType() {
private Distance distanceType;

@Config(description = "The distance function to use.")
private DistanceType distType;
private org.tribuo.math.distance.Distance dist;

@Config(mandatory = true, description = "The number of nearest-neighbors to use in the initial density approximation. " +
"This includes the point itself.")
Expand All @@ -154,19 +154,18 @@ public DistanceType getDistanceType() {
/**
* for olcut.
*/
private HdbscanTrainer() {
}
private HdbscanTrainer() {}

/**
* Constructs an HDBSCAN* trainer with only the minClusterSize parameter.
*
* @param minClusterSize The minimum number of points required to form a cluster.
* {@link #distType} defaults to {@link DistanceType#L2}, {@link #k} defaults to {@link #minClusterSize},
* {@link #dist} defaults to {@link DistanceType#L2}, {@link #k} defaults to {@link #minClusterSize},
* {@link #numThreads} defaults to 1 and {@link #neighboursQueryFactory} defaults to
* {@link NeighboursBruteForceFactory}.
*/
public HdbscanTrainer(int minClusterSize) {
this(minClusterSize, DistanceType.L2, minClusterSize, 1, NeighboursQueryFactoryType.BRUTE_FORCE);
this(minClusterSize, DistanceType.L2.getDistance(), minClusterSize, 1, NeighboursQueryFactoryType.BRUTE_FORCE);
}

/**
Expand All @@ -182,24 +181,24 @@ public HdbscanTrainer(int minClusterSize) {
*/
@Deprecated
public HdbscanTrainer(int minClusterSize, Distance distanceType, int k, int numThreads) {
this(minClusterSize, distanceType.getDistanceType(), k, numThreads, NeighboursQueryFactoryType.BRUTE_FORCE);
this(minClusterSize, distanceType.getDistanceType().getDistance(), k, numThreads, NeighboursQueryFactoryType.BRUTE_FORCE);
}

/**
* Constructs an HDBSCAN* trainer using the supplied parameters.
*
* @param minClusterSize The minimum number of points required to form a cluster.
* @param distType The distance function.
* @param dist The distance function.
* @param k The number of nearest-neighbors to use in the initial density approximation.
* @param numThreads The number of threads.
* @param nqFactoryType The nearest neighbour query implementation factory to use.
*/
public HdbscanTrainer(int minClusterSize, DistanceType distType, int k, int numThreads, NeighboursQueryFactoryType nqFactoryType) {
public HdbscanTrainer(int minClusterSize, org.tribuo.math.distance.Distance dist, int k, int numThreads, NeighboursQueryFactoryType nqFactoryType) {
this.minClusterSize = minClusterSize;
this.distType = distType;
this.dist = dist;
this.k = k;
this.numThreads = numThreads;
this.neighboursQueryFactory = NeighboursQueryFactoryType.getNeighboursQueryFactory(nqFactoryType, distType, numThreads);
this.neighboursQueryFactory = NeighboursQueryFactoryType.getNeighboursQueryFactory(nqFactoryType, dist, numThreads);
}

/**
Expand All @@ -211,7 +210,7 @@ public HdbscanTrainer(int minClusterSize, DistanceType distType, int k, int numT
*/
public HdbscanTrainer(int minClusterSize, int k, NeighboursQueryFactory neighboursQueryFactory) {
this.minClusterSize = minClusterSize;
this.distType = neighboursQueryFactory.getDistanceType();
this.dist = neighboursQueryFactory.getDistance();
this.k = k;
this.neighboursQueryFactory = neighboursQueryFactory;
}
Expand All @@ -222,19 +221,19 @@ public HdbscanTrainer(int minClusterSize, int k, NeighboursQueryFactory neighbou
@Override
public synchronized void postConfig() {
if (this.distanceType != null) {
if (this.distType != null) {
if (this.dist != null) {
throw new PropertyException("distType", "Both distType and distanceType must not both be set.");
} else {
this.distType = this.distanceType.getDistanceType();
this.dist = this.distanceType.getDistanceType().getDistance();
this.distanceType = null;
}
}

if (neighboursQueryFactory == null) {
int numberThreads = (this.numThreads <= 0) ? 1 : this.numThreads;
this.neighboursQueryFactory = new NeighboursBruteForceFactory(distType, numberThreads);
this.neighboursQueryFactory = new NeighboursBruteForceFactory(dist, numberThreads);
} else {
if (!this.distType.equals(neighboursQueryFactory.getDistanceType())) {
if (!this.dist.equals(neighboursQueryFactory.getDistance())) {
throw new PropertyException("neighboursQueryFactory", "distType and its field on the " +
"NeighboursQueryFactory must be equal.");
}
Expand Down Expand Up @@ -264,7 +263,7 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
}

DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, dist);

double[] pointNoiseLevels = new double[data.length]; // The levels at which each point becomes noise
int[] pointLastClusters = new int[data.length]; // The last label of each point before becoming noise
Expand All @@ -284,7 +283,7 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);

// Compute the cluster exemplars.
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, dist);

// Get the outlier score value for points that are predicted as noise points.
double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
Expand All @@ -295,7 +294,7 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
examples.getProvenance(), trainerProvenance, runProvenance);

return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector,
clusterExemplars, distType, noisePointsOutlierScore);
clusterExemplars, dist, noisePointsOutlierScore);
}

@Override
Expand Down Expand Up @@ -347,12 +346,12 @@ private static DenseVector calculateCoreDistances(SGDVector[] data, int k, Neigh
* core distances for each point.
* @param data An array of {@link DenseVector} containing the data.
* @param coreDistances A {@link DenseVector} containing the core distances for every point.
* @param distType The distance metric to employ.
* @param dist The distance metric to employ.
* @return An {@link ExtendedMinimumSpanningTree} representation of the data using the mutual reachability distances,
* and the graph is sorted by edge weight in ascending order.
*/
private static ExtendedMinimumSpanningTree constructEMST(SGDVector[] data, DenseVector coreDistances,
DistanceType distType) {
org.tribuo.math.distance.Distance dist) {
// One bit is set (true) for each attached point, and unset (false) for unattached points:
BitSet attachedPoints = new BitSet(data.length);

Expand Down Expand Up @@ -380,7 +379,7 @@ private static ExtendedMinimumSpanningTree constructEMST(SGDVector[] data, Dense
continue;
}

double mutualReachabilityDistance = DistanceType.getDistance(data[currentPoint], data[neighbor], distType);
double mutualReachabilityDistance = dist.computeDistance(data[currentPoint], data[neighbor]);
if (coreDistances.get(currentPoint) > mutualReachabilityDistance) {
mutualReachabilityDistance = coreDistances.get(currentPoint);
}
Expand Down Expand Up @@ -754,11 +753,11 @@ private static Map<Integer, List<Pair<Double, Integer>>> generateClusterAssignme
*
* @param data An array of {@link DenseVector} containing the data.
* @param clusterAssignments A map of the cluster labels, and the points assigned to them.
* @param distType The distance metric to employ.
* @param dist The distance metric to employ.
* @return A list of {@link ClusterExemplar}s which are used for predictions.
*/
private static List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Integer, List<Pair<Double, Integer>>> clusterAssignments,
DistanceType distType) {
org.tribuo.math.distance.Distance dist) {
List<ClusterExemplar> clusterExemplars = new ArrayList<>();
// The formula to calculate the exemplar number. This calculates the number of exemplars to be used for this
// configuration. The appropriate number of exemplars is important for prediction. At the time, this
Expand Down Expand Up @@ -797,7 +796,7 @@ else if (numExemplarsThisCluster > outlierScoreIndexTree.size()) {
SGDVector features = data[partialClusterExemplar.getValue()];
double maxInnerDist = Double.NEGATIVE_INFINITY;
for (Entry<Double, Integer> entry : outlierScoreIndexTree.entrySet()) {
double distance = DistanceType.getDistance(features, data[entry.getValue()], distType);
double distance = dist.computeDistance(features, data[entry.getValue()]);
if (distance > maxInnerDist){
maxInnerDist = distance;
}
Expand Down Expand Up @@ -834,7 +833,7 @@ private static double getNoisePointsOutlierScore(Map<Integer, List<Pair<Double,

@Override
public String toString() {
return "HdbscanTrainer(minClusterSize=" + minClusterSize + ",distanceType=" + distType + ",k=" + k + ",numThreads=" + numThreads + ")";
return "HdbscanTrainer(minClusterSize=" + minClusterSize + ",distanceType=" + dist + ",k=" + k + ",numThreads=" + numThreads + ")";
}

@Override
Expand Down
Loading