diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index 7065a5d1b..7c9751d2a 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -34,6 +34,8 @@ import org.tribuo.provenance.impl.TrainerProvenanceImpl; import org.tribuo.util.Util; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.time.OffsetDateTime; import java.util.ArrayList; import java.util.Arrays; @@ -45,7 +47,9 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.IntStream; @@ -63,6 +67,10 @@ * of threads used in the training step. The thread pool is local to an invocation of train, * so there can be multiple concurrent trainings. *

+ * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase + * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a + * {@link java.lang.SecurityManager}. + *

* See: *

  * J. Friedman, T. Hastie, & R. Tibshirani.
@@ -80,6 +88,9 @@
 public class KMeansTrainer implements Trainer {
     private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName());
 
+    // Thread factory for the FJP, to allow use with OpenSearch's SecureSM
+    private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory();
+
     /**
      * Possible distance functions.
      */
@@ -138,8 +149,7 @@ public enum Initialisation {
     /**
      * for olcut.
      */
-    private KMeansTrainer() {
-    }
+    private KMeansTrainer() { }
 
     /**
      * Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
@@ -202,7 +212,17 @@ public KMeansModel train(Dataset examples, Map ru
         }
         ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
 
-        ForkJoinPool fjp = new ForkJoinPool(numThreads);
+        boolean parallel = numThreads > 1;
+        ForkJoinPool fjp;
+        if (parallel) {
+            if (System.getSecurityManager() == null) {
+                fjp = new ForkJoinPool(numThreads);
+            } else {
+                fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false);
+            }
+        } else {
+            fjp = null;
+        }
 
         int[] oldCentre = new int[examples.size()];
         SparseVector[] data = new SparseVector[examples.size()];
@@ -221,7 +241,7 @@ public KMeansModel train(Dataset examples, Map ru
                 centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG);
                 break;
             case PLUSPLUS:
-                centroidVectors = initialisePlusPlusCentroids(centroids, data, featureMap, localRNG, distanceType);
+                centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distanceType);
                 break;
             default:
                 throw new IllegalStateException("Unknown initialisation" + initialisationType);
@@ -229,54 +249,57 @@ public KMeansModel train(Dataset examples, Map ru
 
         Map> clusterAssignments = new HashMap<>();
         for (int i = 0; i < centroids; i++) {
-            clusterAssignments.put(i, Collections.synchronizedList(new ArrayList<>()));
+            clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>());
         }
 
+        AtomicInteger changeCounter = new AtomicInteger(0);
+        Consumer eStepFunc = (IntAndVector e) -> {
+            double minDist = Double.POSITIVE_INFINITY;
+            int clusterID = -1;
+            int id = e.idx;
+            SGDVector vector = e.vector;
+            for (int j = 0; j < centroids; j++) {
+                DenseVector cluster = centroidVectors[j];
+                double distance = getDistance(cluster, vector, distanceType);
+                if (distance < minDist) {
+                    minDist = distance;
+                    clusterID = j;
+                }
+            }
+
+            clusterAssignments.get(clusterID).add(id);
+            if (oldCentre[id] != clusterID) {
+                // Changed the centroid of this vector.
+                oldCentre[id] = clusterID;
+                changeCounter.incrementAndGet();
+            }
+        };
+
         boolean converged = false;
 
         for (int i = 0; (i < iterations) && !converged; i++) {
-            //logger.log(Level.INFO,"Beginning iteration " + i);
-            AtomicInteger changeCounter = new AtomicInteger(0);
+            logger.log(Level.FINE,"Beginning iteration " + i);
+            changeCounter.set(0);
 
             for (Entry> e : clusterAssignments.entrySet()) {
                 e.getValue().clear();
             }
 
             // E step
-            Stream vecStream = Arrays.stream(data);
+            Stream vecStream = Arrays.stream(data);
             Stream intStream = IntStream.range(0, data.length).boxed();
-            Stream eStream;
-            if (numThreads > 1) {
-                eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream, vecStream, IntAndVector::new).parallel());
+            Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
+            if (parallel) {
+                Stream parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel());
+                try {
+                    fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get();
+                } catch (InterruptedException | ExecutionException e) {
+                    throw new RuntimeException("Parallel execution failed", e);
+                }
             } else {
-                eStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new);
-            }
-            try {
-                fjp.submit(() -> eStream.forEach((IntAndVector e) -> {
-                    double minDist = Double.POSITIVE_INFINITY;
-                    int clusterID = -1;
-                    int id = e.idx;
-                    SparseVector vector = e.vector;
-                    for (int j = 0; j < centroids; j++) {
-                        DenseVector cluster = centroidVectors[j];
-                        double distance = getDistance(cluster, vector, distanceType);
-                        if (distance < minDist) {
-                            minDist = distance;
-                            clusterID = j;
-                        }
-                    }
-
-                    clusterAssignments.get(clusterID).add(id);
-                    if (oldCentre[id] != clusterID) {
-                        // Changed the centroid of this vector.
-                        oldCentre[id] = clusterID;
-                        changeCounter.incrementAndGet();
-                    }
-                })).get();
-            } catch (InterruptedException | ExecutionException e) {
-                throw new RuntimeException("Parallel execution failed", e);
+                zipStream.forEach(eStepFunc);
             }
-            //logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated.");
+            logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated.");
 
             mStep(fjp, centroidVectors, clusterAssignments, data, weights);
 
@@ -355,18 +378,15 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF
      *
      * @param centroids The number of centroids to create.
      * @param data The dataset of {@link SparseVector} to use.
-     * @param featureMap The feature map to use for centroid sampling.
      * @param rng The RNG to use.
      * @return A {@link DenseVector} array of centroids.
      */
-    private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data,
-                                                             ImmutableFeatureMap featureMap, SplittableRandom rng,
+    private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, SplittableRandom rng,
                                                              Distance distanceType) {
         if (centroids > data.length) {
             throw new IllegalArgumentException("The number of centroids may not exceed the number of samples.");
         }
 
-        int numFeatures = featureMap.size();
         double[] minDistancePerVector = new double[data.length];
         Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY);
 
@@ -375,7 +395,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
         DenseVector[] centroidVectors = new DenseVector[centroids];
 
         // set first centroid randomly from the data
-        centroidVectors[0] = getRandomCentroidFromData(data, numFeatures, rng);
+        centroidVectors[0] = getRandomCentroidFromData(data, rng);
 
         // Set each uninitialised centroid remaining
         for (int i = 1; i < centroids; i++) {
@@ -384,8 +404,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
             // go through every vector and see if the min distance to the
             // newest centroid is smaller than previous min distance for vec
             for (int j = 0; j < data.length; j++) {
-                SparseVector curVec = data[j];
-                double tempDistance = getDistance(prevCentroid, curVec, distanceType);
+                double tempDistance = getDistance(prevCentroid, data[j], distanceType);
                 minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance);
             }
 
@@ -404,7 +423,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
             // sample from probabilities to get the new centroid from data
             double[] cdf = Util.generateCDF(probabilities);
             int idx = Util.sampleFromCDF(cdf, rng);
-            centroidVectors[i] = sparseToDense(data[idx], numFeatures);
+            centroidVectors[i] = data[idx].densify();
         }
         return centroidVectors;
     }
@@ -413,39 +432,22 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
      * Randomly select a piece of data as the starting centroid.
      *
      * @param data The dataset of {@link SparseVector} to use.
-     * @param numFeatures The number of features.
      * @param rng The RNG to use.
      * @return A {@link DenseVector} representing a centroid.
      */
-    private static DenseVector getRandomCentroidFromData(SparseVector[] data,
-                                                         int numFeatures, SplittableRandom rng) {
-        int rand_idx = rng.nextInt(data.length);
-        return sparseToDense(data[rand_idx], numFeatures);
-    }
-
-    /**
-     * Create a {@link DenseVector} from the data contained in a
-     * {@link SparseVector}.
-     *
-     * @param vec The {@link SparseVector} to be transformed.
-     * @param numFeatures The number of features.
-     * @return A {@link DenseVector} containing the information from vec.
-     */
-    private static DenseVector sparseToDense(SparseVector vec, int numFeatures) {
-        DenseVector dense = new DenseVector(numFeatures);
-        dense.intersectAndAddInPlace(vec);
-        return dense;
+    private static DenseVector getRandomCentroidFromData(SparseVector[] data, SplittableRandom rng) {
+        int randIdx = rng.nextInt(data.length);
+        return data[randIdx].densify();
     }
 
     /**
-     *
+     * Compute the distance between the two vectors.
      * @param cluster A {@link DenseVector} representing a centroid.
      * @param vector A {@link SGDVector} representing an example.
      * @param distanceType The distance metric to employ.
      * @return A double representing the distance from vector to centroid.
      */
-    private static double getDistance(DenseVector cluster, SGDVector vector,
-                                      Distance distanceType) {
+    private static double getDistance(DenseVector cluster, SGDVector vector, Distance distanceType) {
         double distance;
         switch (distanceType) {
             case EUCLIDEAN:
@@ -463,30 +465,41 @@ private static double getDistance(DenseVector cluster, SGDVector vector,
         return distance;
     }
 
+    /**
+     * Runs the mStep, writing to the {@code centroidVectors} array.
+     * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel.
+     *            If the fjp is null then the computation is executed sequentially on the main thread.
+     * @param centroidVectors The centroid vectors to write out.
+     * @param clusterAssignments The current cluster assignments.
+     * @param data The data points.
+     * @param weights The example weights.
+     */
     protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map> clusterAssignments, SparseVector[] data, double[] weights) {
         // M step
-        Stream>> mStream;
-        if (numThreads > 1) {
-            mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel());
+        Consumer>> mStepFunc = (e) -> {
+            DenseVector newCentroid = centroidVectors[e.getKey()];
+            newCentroid.fill(0.0);
+
+            double weightSum = 0.0;
+            for (Integer idx : e.getValue()) {
+                newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
+                weightSum += weights[idx];
+            }
+            if (weightSum != 0.0) {
+                newCentroid.scaleInPlace(1.0 / weightSum);
+            }
+        };
+
+        Stream>> mStream = clusterAssignments.entrySet().stream();
+        if (fjp != null) {
+            Stream>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel());
+            try {
+                fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get();
+            } catch (InterruptedException | ExecutionException e) {
+                throw new RuntimeException("Parallel execution failed", e);
+            }
         } else {
-            mStream = clusterAssignments.entrySet().stream();
-        }
-        try {
-            fjp.submit(() -> mStream.forEach((e) -> {
-                DenseVector newCentroid = centroidVectors[e.getKey()];
-                newCentroid.fill(0.0);
-
-                int counter = 0;
-                for (Integer idx : e.getValue()) {
-                    newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]);
-                    counter++;
-                }
-                if (counter > 0) {
-                    newCentroid.scaleInPlace(1.0 / counter);
-                }
-            })).get();
-        } catch (InterruptedException | ExecutionException e) {
-            throw new RuntimeException("Parallel execution failed", e);
+            mStream.forEach(mStepFunc);
         }
     }
 
@@ -505,11 +518,25 @@ public TrainerProvenance getProvenance() {
      */
     static class IntAndVector {
         final int idx;
-        final SparseVector vector;
+        final SGDVector vector;
 
-        public IntAndVector(int idx, SparseVector vector) {
+        /**
+         * Constructs an index and vector tuple.
+         * @param idx The index.
+         * @param vector The vector.
+         */
+        public IntAndVector(int idx, SGDVector vector) {
             this.idx = idx;
             this.vector = vector;
         }
     }
+
+    /**
+     * Used to allow FJPs to work with OpenSearch's SecureSM.
+     */
+    private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
+        public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
+            return AccessController.doPrivileged((PrivilegedAction) () -> new ForkJoinWorkerThread(pool) {});
+        }
+    }
 }
\ No newline at end of file
diff --git a/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java b/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java
index 8baf52be0..2b97f2fb3 100644
--- a/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java
+++ b/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNModel.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -31,6 +31,8 @@
 import org.tribuo.provenance.ModelProvenance;
 
 import java.io.IOException;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
@@ -42,6 +44,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.ForkJoinWorkerThread;
 import java.util.concurrent.Future;
 import java.util.function.BiFunction;
 import java.util.function.Function;
@@ -52,6 +55,10 @@
 
 /**
  * A k-nearest neighbours model.
+ * 

+ * Note multi-threaded prediction uses a {@link ForkJoinPool} which requires that the Tribuo codebase + * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a + * {@link java.lang.SecurityManager}. */ public class KNNModel> extends Model { @@ -59,6 +66,9 @@ public class KNNModel> extends Model { private static final long serialVersionUID = 1L; + // Thread factory for the FJP, to allow use with OpenSearch's SecureSM + private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(); + /** * The parallel backend for batch predictions. */ @@ -121,7 +131,7 @@ public Prediction predict(Example example) { List> predictions; Stream> stream = Stream.of(vectors); if (numThreads > 1) { - ForkJoinPool fjp = new ForkJoinPool(numThreads); + ForkJoinPool fjp = System.getSecurityManager() == null ? new ForkJoinPool(numThreads) : new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); try { predictions = fjp.submit(()->StreamUtil.boundParallelism(stream.parallel()).map(distanceFunc).sorted().limit(k).map((a) -> new Prediction<>(a.output, input.numActiveElements(), example)).collect(Collectors.toList())).get(); } catch (InterruptedException | ExecutionException e) { @@ -222,7 +232,7 @@ private List> innerPredictMultithreaded(Iterable> examp private List> innerPredictStreams(Iterable> examples) { List> predictions = new ArrayList<>(); List> innerPredictions = null; - ForkJoinPool fjp = new ForkJoinPool(numThreads); + ForkJoinPool fjp = System.getSecurityManager() == null ? new ForkJoinPool(numThreads) : new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); for (Example example : examples) { SparseVector input = SparseVector.createSparseVector(example, featureIDMap, false); @@ -495,4 +505,12 @@ public int compareTo(OutputDoublePair o) { } } + /** + * Used to allow FJPs to work with OpenSearch's SecureSM. + */ + private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { + public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { + return AccessController.doPrivileged((PrivilegedAction) () -> new ForkJoinWorkerThread(pool) {}); + } + } }