Skip to content

Commit

Permalink
Single threaded K-Means training no longer uses a ForkJoinPool (#197)
Browse files Browse the repository at this point in the history
* Stop using a FJP for single threaded k-means training.

* Fixing a bug in weighted k-means calculations.

* Adding a note about FJP requiring the modifyThread permission to KMeansTrainer's javadoc.

* Adding a custom ForkJoinWorkerThreadFactory to KMeansTrainer and KNNModel to make them work with custom security managers.

* Tightening the check so the custom thread factory is only used when there is a security manager.
  • Loading branch information
Craigacp authored Dec 9, 2021
1 parent ec40222 commit 3eb8d35
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -45,7 +47,9 @@
import java.util.SplittableRandom;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
Expand All @@ -63,6 +67,10 @@
* of threads used in the training step. The thread pool is local to an invocation of train,
* so there can be multiple concurrent trainings.
* <p>
* Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase
* is given the "modifyThread" and "modifyThreadGroup" privileges when running under a
* {@link java.lang.SecurityManager}.
* <p>
* See:
* <pre>
* J. Friedman, T. Hastie, &amp; R. Tibshirani.
Expand All @@ -80,6 +88,9 @@
public class KMeansTrainer implements Trainer<ClusterID> {
private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());

// Thread factory for the FJP, to allow use with OpenSearch's SecureSM
private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();

/**
* Possible distance functions.
*/
Expand Down Expand Up @@ -138,8 +149,7 @@ public enum Initialisation {
/**
* for olcut.
*/
private KMeansTrainer() {
}
private KMeansTrainer() { }

/**
* Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
Expand Down Expand Up @@ -202,7 +212,17 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();

ForkJoinPool fjp = new ForkJoinPool(numThreads);
boolean parallel = numThreads > 1;
ForkJoinPool fjp;
if (parallel) {
if (System.getSecurityManager() == null) {
fjp = new ForkJoinPool(numThreads);
} else {
fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
}
} else {
fjp = null;
}

int[] oldCentre = new int[examples.size()];
SparseVector[] data = new SparseVector[examples.size()];
Expand All @@ -221,62 +241,65 @@ public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> ru
centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
break;
case PLUSPLUS:
centroidVectors = initialisePlusPlusCentroids(centroids, data, featureMap, localRNG, distanceType);
centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distanceType);
break;
default:
throw new IllegalStateException("Unknown initialisation" + initialisationType);
}

Map<Integer, List<Integer>> clusterAssignments = new HashMap<>();
for (int i = 0; i < centroids; i++) {
clusterAssignments.put(i, Collections.synchronizedList(new ArrayList<>()));
clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
}

AtomicInteger changeCounter = new AtomicInteger(0);
Consumer<IntAndVector> eStepFunc = (IntAndVector e) -> {
double minDist = Double.POSITIVE_INFINITY;
int clusterID = -1;
int id = e.idx;
SGDVector vector = e.vector;
for (int j = 0; j < centroids; j++) {
DenseVector cluster = centroidVectors[j];
double distance = getDistance(cluster, vector, distanceType);
if (distance < minDist) {
minDist = distance;
clusterID = j;
}
}

clusterAssignments.get(clusterID).add(id);
if (oldCentre[id] != clusterID) {
// Changed the centroid of this vector.
oldCentre[id] = clusterID;
changeCounter.incrementAndGet();
}
};

boolean converged = false;

for (int i = 0; (i < iterations) && !converged; i++) {
//logger.log(Level.INFO,"Beginning iteration " + i);
AtomicInteger changeCounter = new AtomicInteger(0);
logger.log(Level.FINE,"Beginning iteration " + i);
changeCounter.set(0);

for (Entry<Integer, List<Integer>> e : clusterAssignments.entrySet()) {
e.getValue().clear();
}

// E step
Stream<SparseVector> vecStream = Arrays.stream(data);
Stream<SGDVector> vecStream = Arrays.stream(data);
Stream<Integer> intStream = IntStream.range(0, data.length).boxed();
Stream<IntAndVector> eStream;
if (numThreads > 1) {
eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream, vecStream, IntAndVector::new).parallel());
Stream<IntAndVector> zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
if (parallel) {
Stream<IntAndVector> parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
try {
fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
}
} else {
eStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
}
try {
fjp.submit(() -> eStream.forEach((IntAndVector e) -> {
double minDist = Double.POSITIVE_INFINITY;
int clusterID = -1;
int id = e.idx;
SparseVector vector = e.vector;
for (int j = 0; j < centroids; j++) {
DenseVector cluster = centroidVectors[j];
double distance = getDistance(cluster, vector, distanceType);
if (distance < minDist) {
minDist = distance;
clusterID = j;
}
}

clusterAssignments.get(clusterID).add(id);
if (oldCentre[id] != clusterID) {
// Changed the centroid of this vector.
oldCentre[id] = clusterID;
changeCounter.incrementAndGet();
}
})).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
zipStream.forEach(eStepFunc);
}
//logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated.");
logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");

mStep(fjp, centroidVectors, clusterAssignments, data, weights);

Expand Down Expand Up @@ -355,18 +378,15 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF
*
* @param centroids The number of centroids to create.
* @param data The dataset of {@link SparseVector} to use.
* @param featureMap The feature map to use for centroid sampling.
* @param rng The RNG to use.
* @return A {@link DenseVector} array of centroids.
*/
private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data,
ImmutableFeatureMap featureMap, SplittableRandom rng,
private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, SplittableRandom rng,
Distance distanceType) {
if (centroids > data.length) {
throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
}

int numFeatures = featureMap.size();
double[] minDistancePerVector = new double[data.length];
Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY);

Expand All @@ -375,7 +395,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
DenseVector[] centroidVectors = new DenseVector[centroids];

// set first centroid randomly from the data
centroidVectors[0] = getRandomCentroidFromData(data, numFeatures, rng);
centroidVectors[0] = getRandomCentroidFromData(data, rng);

// Set each uninitialised centroid remaining
for (int i = 1; i < centroids; i++) {
Expand All @@ -384,8 +404,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
// go through every vector and see if the min distance to the
// newest centroid is smaller than previous min distance for vec
for (int j = 0; j < data.length; j++) {
SparseVector curVec = data[j];
double tempDistance = getDistance(prevCentroid, curVec, distanceType);
double tempDistance = getDistance(prevCentroid, data[j], distanceType);
minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance);
}

Expand All @@ -404,7 +423,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
// sample from probabilities to get the new centroid from data
double[] cdf = Util.generateCDF(probabilities);
int idx = Util.sampleFromCDF(cdf, rng);
centroidVectors[i] = sparseToDense(data[idx], numFeatures);
centroidVectors[i] = data[idx].densify();
}
return centroidVectors;
}
Expand All @@ -413,39 +432,22 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
* Randomly select a piece of data as the starting centroid.
*
* @param data The dataset of {@link SparseVector} to use.
* @param numFeatures The number of features.
* @param rng The RNG to use.
* @return A {@link DenseVector} representing a centroid.
*/
private static DenseVector getRandomCentroidFromData(SparseVector[] data,
int numFeatures, SplittableRandom rng) {
int rand_idx = rng.nextInt(data.length);
return sparseToDense(data[rand_idx], numFeatures);
}

/**
* Create a {@link DenseVector} from the data contained in a
* {@link SparseVector}.
*
* @param vec The {@link SparseVector} to be transformed.
* @param numFeatures The number of features.
* @return A {@link DenseVector} containing the information from vec.
*/
private static DenseVector sparseToDense(SparseVector vec, int numFeatures) {
DenseVector dense = new DenseVector(numFeatures);
dense.intersectAndAddInPlace(vec);
return dense;
private static DenseVector getRandomCentroidFromData(SparseVector[] data, SplittableRandom rng) {
int randIdx = rng.nextInt(data.length);
return data[randIdx].densify();
}

/**
*
* Compute the distance between the two vectors.
* @param cluster A {@link DenseVector} representing a centroid.
* @param vector A {@link SGDVector} representing an example.
* @param distanceType The distance metric to employ.
* @return A double representing the distance from vector to centroid.
*/
private static double getDistance(DenseVector cluster, SGDVector vector,
Distance distanceType) {
private static double getDistance(DenseVector cluster, SGDVector vector, Distance distanceType) {
double distance;
switch (distanceType) {
case EUCLIDEAN:
Expand All @@ -463,30 +465,41 @@ private static double getDistance(DenseVector cluster, SGDVector vector,
return distance;
}

/**
* Runs the mStep, writing to the {@code centroidVectors} array.
* @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel.
* If the fjp is null then the computation is executed sequentially on the main thread.
* @param centroidVectors The centroid vectors to write out.
* @param clusterAssignments The current cluster assignments.
* @param data The data points.
* @param weights The example weights.
*/
protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map<Integer, List<Integer>> clusterAssignments, SparseVector[] data, double[] weights) {
// M step
Stream<Entry<Integer, List<Integer>>> mStream;
if (numThreads > 1) {
mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel());
Consumer<Entry<Integer, List<Integer>>> mStepFunc = (e) -> {
DenseVector newCentroid = centroidVectors[e.getKey()];
newCentroid.fill(0.0);

double weightSum = 0.0;
for (Integer idx : e.getValue()) {
newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
weightSum += weights[idx];
}
if (weightSum != 0.0) {
newCentroid.scaleInPlace(1.0 / weightSum);
}
};

Stream<Entry<Integer, List<Integer>>> mStream = clusterAssignments.entrySet().stream();
if (fjp != null) {
Stream<Entry<Integer, List<Integer>>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel());
try {
fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
}
} else {
mStream = clusterAssignments.entrySet().stream();
}
try {
fjp.submit(() -> mStream.forEach((e) -> {
DenseVector newCentroid = centroidVectors[e.getKey()];
newCentroid.fill(0.0);

int counter = 0;
for (Integer idx : e.getValue()) {
newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
counter++;
}
if (counter > 0) {
newCentroid.scaleInPlace(1.0 / counter);
}
})).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException("Parallel execution failed", e);
mStream.forEach(mStepFunc);
}
}

Expand All @@ -505,11 +518,25 @@ public TrainerProvenance getProvenance() {
*/
static class IntAndVector {
final int idx;
final SparseVector vector;
final SGDVector vector;

public IntAndVector(int idx, SparseVector vector) {
/**
* Constructs an index and vector tuple.
* @param idx The index.
* @param vector The vector.
*/
public IntAndVector(int idx, SGDVector vector) {
this.idx = idx;
this.vector = vector;
}
}

/**
* Used to allow FJPs to work with OpenSearch's SecureSM.
*/
private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return AccessController.doPrivileged((PrivilegedAction<ForkJoinWorkerThread>) () -> new ForkJoinWorkerThread(pool) {});
}
}
}
Loading

0 comments on commit 3eb8d35

Please sign in to comment.