diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index 9571b4348..abcc6fafe 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -51,9 +51,11 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.BinaryOperator; import java.util.function.Function; +import java.util.function.Supplier; import java.util.function.ToDoubleFunction; import java.util.logging.Level; import java.util.logging.Logger; @@ -342,24 +344,24 @@ public GaussianMixtureModel train(Dataset examples, Map dataMStream = Arrays.stream(data); Stream resMStream = Arrays.stream(responsibilities); Stream zipMStream = StreamUtil.zip(dataMStream, resMStream, Vectors::new); - Tensor[] zeroTensorArr = switch (covarianceType) { - case FULL -> { + Supplier zeroTensor = switch (covarianceType) { + case FULL -> () -> { Tensor[] output = new Tensor[numGaussians]; for (int j = 0; j < numGaussians; j++) { output[j] = new DenseMatrix(numFeatures, numFeatures); } - yield output; - } - case DIAGONAL, SPHERICAL -> { + return output; + }; + case DIAGONAL, SPHERICAL -> () -> { Tensor[] output = new Tensor[numGaussians]; for (int j = 0; j < numGaussians; j++) { output[j] = new DenseVector(numFeatures); } - yield output; - } + return output; + }; }; // Fix parallel behaviour - BiFunction mStep = switch (covarianceType) { + BiConsumer mStep = switch (covarianceType) { case FULL -> (Tensor[] input, Vectors v) -> { for (int j = 0; j < numGaussians; j++) { // Compute covariance contribution from current input @@ -369,7 +371,6 @@ public GaussianMixtureModel train(Dataset examples, Map (Tensor[] input, Vectors v) -> { for (int j = 0; j < numGaussians; j++) { @@ -380,7 +381,6 @@ public GaussianMixtureModel train(Dataset examples, Map (Tensor[] input, Vectors v) -> { for (int j = 0; j < numGaussians; j++) { @@ -393,33 +393,27 @@ public GaussianMixtureModel train(Dataset examples, Map combineTensor = (Tensor[] a, Tensor[] b) -> { - Tensor[] output = new Tensor[a.length]; + BiConsumer combineTensor = (Tensor[] a, Tensor[] b) -> { for (int j = 0; j < a.length; j++) { if (a[j] instanceof DenseMatrix aMat && b[j] instanceof DenseMatrix bMat) { - output[j] = aMat.add(bMat); + aMat.intersectAndAddInPlace(bMat); } else if (a[j] instanceof DenseVector aVec && b[j] instanceof DenseVector bVec) { - output[j] = aVec.add(bVec); + aVec.intersectAndAddInPlace(bVec); } else { throw new IllegalStateException("Invalid types in reduce, expected both DenseMatrix or DenseVector, found " + a[j].getClass() + " and " + b[j].getClass()); } } - return output; }; if (parallel) { - throw new RuntimeException("Parallel mstep not implemented"); - /* try { - covariances = fjp.submit(() -> zipMStream.parallel().reduce(zeroTensorArr, mStep, combineTensor)).get(); + covariances = fjp.submit(() -> zipMStream.parallel().collect(zeroTensor, mStep, combineTensor)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } - */ } else { - covariances = zipMStream.reduce(zeroTensorArr, mStep, combineTensor); + covariances = zipMStream.collect(zeroTensor, mStep, combineTensor); } // renormalize mixing distribution diff --git a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java index 2938ecd8a..a2fc3c13c 100644 --- a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java +++ b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java @@ -46,6 +46,9 @@ public class TestGMM { private static final GMMTrainer diagonal = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.DIAGONAL, GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); + private static final GMMTrainer fullParallel = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL, + GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 4, 1); + private static final GMMTrainer plusPlusFull = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL, GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1); @@ -75,6 +78,11 @@ public void testPlusPlusFullEvaluation() { runEvaluation(plusPlusFull); } + @Test + public void testParallelEvaluation() { + runEvaluation(fullParallel); + } + public static void runEvaluation(GMMTrainer trainer) { Dataset data = new MutableDataset<>(new GaussianClusterDataSource(500, 1L)); Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); @@ -150,7 +158,6 @@ public void testPlusPlusInvalidExample() { runInvalidExample(plusPlusFull); } - public void runEmptyExample(GMMTrainer trainer) { assertThrows(IllegalArgumentException.class, () -> { Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); @@ -186,7 +193,7 @@ public void testSetInvocationCount() { // The number of times to call train before final training. // Original trainer will be trained numOfInvocations + 1 times - // New trainer will have it's invocation count set to numOfInvocations then trained once + // New trainer will have its invocation count set to numOfInvocations then trained once int numOfInvocations = 2; // Create the first model and train it numOfInvocations + 1 times