diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index 7065a5d1b..7c9751d2a 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -34,6 +34,8 @@ import org.tribuo.provenance.impl.TrainerProvenanceImpl; import org.tribuo.util.Util; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.time.OffsetDateTime; import java.util.ArrayList; import java.util.Arrays; @@ -45,7 +47,9 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.IntStream; @@ -63,6 +67,10 @@ * of threads used in the training step. The thread pool is local to an invocation of train, * so there can be multiple concurrent trainings. *
+ * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase + * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a + * {@link java.lang.SecurityManager}. + *
* See: *
* J. Friedman, T. Hastie, & R. Tibshirani. @@ -80,6 +88,9 @@ public class KMeansTrainer implements Trainer{ private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName()); + // Thread factory for the FJP, to allow use with OpenSearch's SecureSM + private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(); + /** * Possible distance functions. */ @@ -138,8 +149,7 @@ public enum Initialisation { /** * for olcut. */ - private KMeansTrainer() { - } + private KMeansTrainer() { } /** * Constructs a K-Means trainer using the supplied parameters and the default random initialisation. @@ -202,7 +212,17 @@ public KMeansModel train(Dataset examples, Map ru } ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); - ForkJoinPool fjp = new ForkJoinPool(numThreads); + boolean parallel = numThreads > 1; + ForkJoinPool fjp; + if (parallel) { + if (System.getSecurityManager() == null) { + fjp = new ForkJoinPool(numThreads); + } else { + fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); + } + } else { + fjp = null; + } int[] oldCentre = new int[examples.size()]; SparseVector[] data = new SparseVector[examples.size()]; @@ -221,7 +241,7 @@ public KMeansModel train(Dataset examples, Map ru centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG); break; case PLUSPLUS: - centroidVectors = initialisePlusPlusCentroids(centroids, data, featureMap, localRNG, distanceType); + centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distanceType); break; default: throw new IllegalStateException("Unknown initialisation" + initialisationType); @@ -229,54 +249,57 @@ public KMeansModel train(Dataset examples, Map ru Map > clusterAssignments = new HashMap<>(); for (int i = 0; i < centroids; i++) { - clusterAssignments.put(i, Collections.synchronizedList(new ArrayList<>())); + clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>()); } + AtomicInteger changeCounter = new AtomicInteger(0); + Consumer eStepFunc = (IntAndVector e) -> { + double minDist = Double.POSITIVE_INFINITY; + int clusterID = -1; + int id = e.idx; + SGDVector vector = e.vector; + for (int j = 0; j < centroids; j++) { + DenseVector cluster = centroidVectors[j]; + double distance = getDistance(cluster, vector, distanceType); + if (distance < minDist) { + minDist = distance; + clusterID = j; + } + } + + clusterAssignments.get(clusterID).add(id); + if (oldCentre[id] != clusterID) { + // Changed the centroid of this vector. + oldCentre[id] = clusterID; + changeCounter.incrementAndGet(); + } + }; + boolean converged = false; for (int i = 0; (i < iterations) && !converged; i++) { - //logger.log(Level.INFO,"Beginning iteration " + i); - AtomicInteger changeCounter = new AtomicInteger(0); + logger.log(Level.FINE,"Beginning iteration " + i); + changeCounter.set(0); for (Entry > e : clusterAssignments.entrySet()) { e.getValue().clear(); } // E step - Stream vecStream = Arrays.stream(data); + Stream vecStream = Arrays.stream(data); Stream intStream = IntStream.range(0, data.length).boxed(); - Stream eStream; - if (numThreads > 1) { - eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream, vecStream, IntAndVector::new).parallel()); + Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); + if (parallel) { + Stream parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel()); + try { + fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Parallel execution failed", e); + } } else { - eStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); - } - try { - fjp.submit(() -> eStream.forEach((IntAndVector e) -> { - double minDist = Double.POSITIVE_INFINITY; - int clusterID = -1; - int id = e.idx; - SparseVector vector = e.vector; - for (int j = 0; j < centroids; j++) { - DenseVector cluster = centroidVectors[j]; - double distance = getDistance(cluster, vector, distanceType); - if (distance < minDist) { - minDist = distance; - clusterID = j; - } - } - - clusterAssignments.get(clusterID).add(id); - if (oldCentre[id] != clusterID) { - // Changed the centroid of this vector. - oldCentre[id] = clusterID; - changeCounter.incrementAndGet(); - } - })).get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException("Parallel execution failed", e); + zipStream.forEach(eStepFunc); } - //logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated."); + logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated."); mStep(fjp, centroidVectors, clusterAssignments, data, weights); @@ -355,18 +378,15 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF * * @param centroids The number of centroids to create. * @param data The dataset of {@link SparseVector} to use. - * @param featureMap The feature map to use for centroid sampling. * @param rng The RNG to use. * @return A {@link DenseVector} array of centroids. */ - private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, - ImmutableFeatureMap featureMap, SplittableRandom rng, + private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, SplittableRandom rng, Distance distanceType) { if (centroids > data.length) { throw new IllegalArgumentException("The number of centroids may not exceed the number of samples."); } - int numFeatures = featureMap.size(); double[] minDistancePerVector = new double[data.length]; Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY); @@ -375,7 +395,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe DenseVector[] centroidVectors = new DenseVector[centroids]; // set first centroid randomly from the data - centroidVectors[0] = getRandomCentroidFromData(data, numFeatures, rng); + centroidVectors[0] = getRandomCentroidFromData(data, rng); // Set each uninitialised centroid remaining for (int i = 1; i < centroids; i++) { @@ -384,8 +404,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe // go through every vector and see if the min distance to the // newest centroid is smaller than previous min distance for vec for (int j = 0; j < data.length; j++) { - SparseVector curVec = data[j]; - double tempDistance = getDistance(prevCentroid, curVec, distanceType); + double tempDistance = getDistance(prevCentroid, data[j], distanceType); minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); } @@ -404,7 +423,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe // sample from probabilities to get the new centroid from data double[] cdf = Util.generateCDF(probabilities); int idx = Util.sampleFromCDF(cdf, rng); - centroidVectors[i] = sparseToDense(data[idx], numFeatures); + centroidVectors[i] = data[idx].densify(); } return centroidVectors; } @@ -413,39 +432,22 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe * Randomly select a piece of data as the starting centroid. * * @param data The dataset of {@link SparseVector} to use. - * @param numFeatures The number of features. * @param rng The RNG to use. * @return A {@link DenseVector} representing a centroid. */ - private static DenseVector getRandomCentroidFromData(SparseVector[] data, - int numFeatures, SplittableRandom rng) { - int rand_idx = rng.nextInt(data.length); - return sparseToDense(data[rand_idx], numFeatures); - } - - /** - * Create a {@link DenseVector} from the data contained in a - * {@link SparseVector}. - * - * @param vec The {@link SparseVector} to be transformed. - * @param numFeatures The number of features. - * @return A {@link DenseVector} containing the information from vec. - */ - private static DenseVector sparseToDense(SparseVector vec, int numFeatures) { - DenseVector dense = new DenseVector(numFeatures); - dense.intersectAndAddInPlace(vec); - return dense; + private static DenseVector getRandomCentroidFromData(SparseVector[] data, SplittableRandom rng) { + int randIdx = rng.nextInt(data.length); + return data[randIdx].densify(); } /** - * + * Compute the distance between the two vectors. * @param cluster A {@link DenseVector} representing a centroid. * @param vector A {@link SGDVector} representing an example. * @param distanceType The distance metric to employ. * @return A double representing the distance from vector to centroid. */ - private static double getDistance(DenseVector cluster, SGDVector vector, - Distance distanceType) { + private static double getDistance(DenseVector cluster, SGDVector vector, Distance distanceType) { double distance; switch (distanceType) { case EUCLIDEAN: @@ -463,30 +465,41 @@ private static double getDistance(DenseVector cluster, SGDVector vector, return distance; } + /** + * Runs the mStep, writing to the {@code centroidVectors} array. + * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel. + * If the fjp is null then the computation is executed sequentially on the main thread. + * @param centroidVectors The centroid vectors to write out. + * @param clusterAssignments The current cluster assignments. + * @param data The data points. + * @param weights The example weights. + */ protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map > clusterAssignments, SparseVector[] data, double[] weights) { // M step - Stream >> mStream; - if (numThreads > 1) { - mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel()); + Consumer >> mStepFunc = (e) -> { + DenseVector newCentroid = centroidVectors[e.getKey()]; + newCentroid.fill(0.0); + + double weightSum = 0.0; + for (Integer idx : e.getValue()) { + newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]); + weightSum += weights[idx]; + } + if (weightSum != 0.0) { + newCentroid.scaleInPlace(1.0 / weightSum); + } + }; + + Stream >> mStream = clusterAssignments.entrySet().stream(); + if (fjp != null) { + Stream >> parallelMStream = StreamUtil.boundParallelism(mStream.parallel()); + try { + fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Parallel execution failed", e); + } } else { - mStream = clusterAssignments.entrySet().stream(); - } - try { - fjp.submit(() -> mStream.forEach((e) -> { - DenseVector newCentroid = centroidVectors[e.getKey()]; - newCentroid.fill(0.0); - - int counter = 0; - for (Integer idx : e.getValue()) { - newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]); - counter++; - } - if (counter > 0) { - newCentroid.scaleInPlace(1.0 / counter); - } - })).get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException("Parallel execution failed", e); + mStream.forEach(mStepFunc); } } @@ -505,11 +518,25 @@ public TrainerProvenance getProvenance() { */ static class IntAndVector { final int idx; - final SparseVector vector; + final SGDVector vector; - public IntAndVector(int idx, SparseVector vector) { + /** + * Constructs an index and vector tuple. + * @param idx The index. + * @param vector The vector. + */ + public IntAndVector(int idx, SGDVector vector) { this.idx = idx; this.vector = vector; } } + + /** + * Used to allow FJPs to work with OpenSearch's SecureSM. + */ + private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { + public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { + return AccessController.doPrivileged((PrivilegedAction ) () -> new ForkJoinWorkerThread(pool) {}); + } + } } \ No newline at end of file diff --git a/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java b/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java index 8baf52be0..2b97f2fb3 100644 --- a/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java +++ b/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,8 @@ import org.tribuo.provenance.ModelProvenance; import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -42,6 +44,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; import java.util.concurrent.Future; import java.util.function.BiFunction; import java.util.function.Function; @@ -52,6 +55,10 @@ /** * A k-nearest neighbours model. + * + * Note multi-threaded prediction uses a {@link ForkJoinPool} which requires that the Tribuo codebase + * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a + * {@link java.lang.SecurityManager}. */ public class KNNModel
> extends Model { @@ -59,6 +66,9 @@ public class KNNModel > extends Model { private static final long serialVersionUID = 1L; + // Thread factory for the FJP, to allow use with OpenSearch's SecureSM + private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(); + /** * The parallel backend for batch predictions. */ @@ -121,7 +131,7 @@ public Prediction predict(Example example) { List > predictions; Stream > stream = Stream.of(vectors); if (numThreads > 1) { - ForkJoinPool fjp = new ForkJoinPool(numThreads); + ForkJoinPool fjp = System.getSecurityManager() == null ? new ForkJoinPool(numThreads) : new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); try { predictions = fjp.submit(()->StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get(); } catch (InterruptedException | ExecutionException e) { @@ -222,7 +232,7 @@ private List > innerPredictMultithreaded(Iterable > examp private List > innerPredictStreams(Iterable > examples) { List > predictions = new ArrayList<>(); List > innerPredictions = null; - ForkJoinPool fjp = new ForkJoinPool(numThreads); + ForkJoinPool fjp = System.getSecurityManager() == null ? new ForkJoinPool(numThreads) : new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); for (Example example : examples) { SparseVector input = SparseVector.createSparseVector(example, featureIDMap, false); @@ -495,4 +505,12 @@ public int compareTo(OutputDoublePair o) { } } + /** + * Used to allow FJPs to work with OpenSearch's SecureSM. + */ + private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { + public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { + return AccessController.doPrivileged((PrivilegedAction ) () -> new ForkJoinWorkerThread(pool) {}); + } + } }