Skip to content

Commit

Permalink
KMeans DenseVector support (#201)
Browse files Browse the repository at this point in the history
* Adding DenseVector support to KMeansTraner.

* Adding docs to KMeansTrainer.mStep.

* Letting KMeansModel pick dynamically between dense and sparse vectors at predict time.

Description
Relaxed the method signatures in `KMeansTrainer` to allow the use of `DenseVector` or `SparseVector`, and changed the train & predict methods to dynamically pick between sparse and dense vectors. This is similar to the change for #112.

This speeds up k-means in dense spaces as dense vectors are faster than a sparse vector operating in a dense space.

This changes the signature of the protected `mStep` method, relaxing one of the argument types from `SparseVector[]` to `SGDVector[]`. This change means the train method will no longer call subclasses which override the `mStep` method, and so it's a breaking change for subclasses of `KMeansTrainer`. If users have tagged their override with `@Override` then the compiler will warn them and it should be a one line change.

Motivation
Improves the speed of K-means when working in dense spaces.
  • Loading branch information
Craigacp authored Dec 17, 2021
1 parent 0d491ff commit de2e5d4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;
Expand Down Expand Up @@ -114,7 +115,12 @@ public List<List<Feature>> getCentroids() {

@Override
public Prediction<ClusterID> predict(Example<ClusterID> example) {
SparseVector vector = SparseVector.createSparseVector(example,featureIDMap,false);
SGDVector vector;
if (example.size() == featureIDMap.size()) {
vector = DenseVector.createDenseVector(example, featureIDMap, false);
} else {
vector = SparseVector.createSparseVector(example, featureIDMap, false);
}
if (vector.numActiveElements() == 0) {
throw new IllegalArgumentException("No features found in Example " + example.toString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
* of threads used in the training step. The thread pool is local to an invocation of train,
* so there can be multiple concurrent trainings.
* <p>
* The train method will instantiate dense examples as dense vectors, speeding up the computation.
* <p>
* Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase
* is given the "modifyThread" and "modifyThreadGroup" privileges when running under a
* {@link java.lang.SecurityManager}.
Expand Down Expand Up @@ -225,12 +227,16 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
}

int[] oldCentre = new int[examples.size()];
SparseVector[] data = new SparseVector[examples.size()];
SGDVector[] data = new SGDVector[examples.size()];
double[] weights = new double[examples.size()];
int n = 0;
for (Example<ClusterID> example : examples) {
weights[n] = example.getWeight();
data[n] = SparseVector.createSparseVector(example, featureMap, false);
if (example.size() == featureMap.size()) {
data[n] = DenseVector.createDenseVector(example, featureMap, false);
} else {
data[n] = SparseVector.createSparseVector(example, featureMap, false);
}
oldCentre[n] = -1;
n++;
}
Expand Down Expand Up @@ -321,7 +327,7 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(),
examples.getProvenance(), trainerProvenance, runProvenance);

return new KMeansModel("", provenance, featureMap, outputMap, centroidVectors, distanceType);
return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distanceType);
}

@Override
Expand Down Expand Up @@ -377,11 +383,11 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF
* Initialisation method called at the start of each train call when using kmeans++ centroid initialisation.
*
* @param centroids The number of centroids to create.
* @param data The dataset of {@link SparseVector} to use.
* @param data The dataset of {@link SGDVector} to use.
* @param rng The RNG to use.
* @return A {@link DenseVector} array of centroids.
*/
private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, SplittableRandom rng,
private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVector[] data, SplittableRandom rng,
Distance distanceType) {
if (centroids > data.length) {
throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
Expand Down Expand Up @@ -423,7 +429,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
// sample from probabilities to get the new centroid from data
double[] cdf = Util.generateCDF(probabilities);
int idx = Util.sampleFromCDF(cdf, rng);
centroidVectors[i] = data[idx].densify();
centroidVectors[i] = DenseVector.createDenseVector(data[idx].toArray());
}
return centroidVectors;
}
Expand All @@ -435,9 +441,9 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
* @param rng The RNG to use.
* @return A {@link DenseVector} representing a centroid.
*/
private static DenseVector getRandomCentroidFromData(SparseVector[] data, SplittableRandom rng) {
private static DenseVector getRandomCentroidFromData(SGDVector[] data, SplittableRandom rng) {
int randIdx = rng.nextInt(data.length);
return data[randIdx].densify();
return DenseVector.createDenseVector(data[randIdx].toArray());
}

/**
Expand Down Expand Up @@ -467,14 +473,17 @@ private static double getDistance(DenseVector cluster, SGDVector vector, Distanc

/**
* Runs the mStep, writing to the {@code centroidVectors} array.
* <p>
* Note in 4.2 this method changed signature slightly, and overrides of the old
* version will not match.
* @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel.
* If the fjp is null then the computation is executed sequentially on the main thread.
* @param centroidVectors The centroid vectors to write out.
* @param clusterAssignments The current cluster assignments.
* @param data The data points.
* @param weights The example weights.
*/
protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SparseVector[] data, double[] weights) {
protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SGDVector[] data, double[] weights) {
// M step
Consumer<Entry<Integer, List<Integer>>> mStepFunc = (e) -> {
DenseVector newCentroid = centroidVectors[e.getKey()];
Expand Down

0 comments on commit de2e5d4

Please sign in to comment.