diff --git a/Core/src/main/java/org/tribuo/util/Util.java b/Core/src/main/java/org/tribuo/util/Util.java index 2db850193..06db14fd3 100644 --- a/Core/src/main/java/org/tribuo/util/Util.java +++ b/Core/src/main/java/org/tribuo/util/Util.java @@ -429,6 +429,25 @@ public static double[] generateCDF(double[] pmf) { return cumulativeSum(pmf); } + /** + * Validates that the supplied double array is a probability mass function. + *

+ * That is, each element is bounded 0,1 and all elements sum to 1. + * @param pmf The PMF to check. + * @return True if it's a valid pmf. + */ + public static boolean validatePMF(double[] pmf) { + double total = 0.0; + for (double v : pmf) { + if ((v < 0) || (v > 1.0)) { + return false; + } else { + total += v; + } + } + return !(Math.abs(total - 1.0) > 1e-10); + } + /** * Produces a cumulative sum array. * @param input The input to sum. diff --git a/Math/src/main/java/org/tribuo/math/distributions/Distribution.java b/Math/src/main/java/org/tribuo/math/distributions/Distribution.java new file mode 100644 index 000000000..944e0d5f4 --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/distributions/Distribution.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.distributions; + +import org.tribuo.math.la.DenseVector; + +import java.util.random.RandomGenerator; + +/** + * Interface for probability distributions which can be sampled from. + *

+ * The vector sampled represents a single sample from that (possibly multivariate) + * distribution rather than a sequence of samples. + */ +public interface Distribution { + + /** + * Sample a single vector from this probability distribution. + * @return A vector sampled from the distribution. + */ + DenseVector sampleVector(); + + /** + * Sample a single vector from this probability distribution using the supplied RNG. + * @param otherRNG The RNG to use. + * @return A vector sampled from this distribution. + */ + DenseVector sampleVector(RandomGenerator otherRNG); + + /** + * Sample a vector from this probability distribution and return it as an array. + * @return An array sampled from this distribution. + */ + default double[] sampleArray() { + return sampleVector().toArray(); + } +} diff --git a/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java new file mode 100644 index 000000000..7e1f7f70a --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.distributions; + +import org.tribuo.math.la.DenseVector; +import org.tribuo.util.Util; + +import java.util.Arrays; +import java.util.List; +import java.util.SplittableRandom; +import java.util.random.RandomGenerator; + +/** + * A mixture distribution which samples from a set of internal distributions mixed by some probability distribution. + * @param The inner distribution type. + */ +public final class MixtureDistribution implements Distribution { + + private final List dists; + + private final double[] mixingDistribution; + + private final double[] cdf; + + private final RandomGenerator rng; + + private final long seed; + + /** + * Construct a mixture distribution over the supplied components. + * @param distributions The distribution components. + * @param mixingDistribution The mixing distribution, must be a valid PMF. + * @param seed The RNG seed. + */ + public MixtureDistribution(List distributions, DenseVector mixingDistribution, long seed) { + this(distributions, mixingDistribution.toArray(), seed); + } + + /** + * Construct a mixture distribution over the supplied components. + * @param distributions The distribution components. + * @param mixingDistribution The mixing distribution, must be a valid PMF. + * @param seed The RNG seed. + */ + public MixtureDistribution(List distributions, double[] mixingDistribution, long seed) { + this.dists = List.copyOf(distributions); + this.mixingDistribution = Arrays.copyOf(mixingDistribution, mixingDistribution.length); + this.seed = seed; + this.rng = new SplittableRandom(seed); + if (dists.size() != this.mixingDistribution.length) { + throw new IllegalArgumentException("Invalid distribution, expected the same number of components as probabilities, found " + dists.size() + " components, and " + this.mixingDistribution.length + " probabilities"); + } + if (!Util.validatePMF(this.mixingDistribution)) { + throw new IllegalArgumentException("Invalid mixing distribution, was not a valid PMF, found " + Arrays.toString(this.mixingDistribution)); + } + this.cdf = Util.generateCDF(this.mixingDistribution); + } + + /** + * Returns the number of distributions. + * @return The number of distributions. + */ + public int getNumComponents() { + return dists.size(); + } + + /** + * Return a mixture component. + * @param i The index of the mixture component. + * @return The ith component. + */ + public T getComponent(int i) { + return dists.get(i); + } + + /** + * Returns a copy of the mixing distribution. + * @return A copy of the mixing distribution. + */ + public double[] getMixingDistribution() { + return Arrays.copyOf(mixingDistribution, mixingDistribution.length); + } + + @Override + public DenseVector sampleVector() { + return sampleVector(rng); + } + + @Override + public DenseVector sampleVector(RandomGenerator otherRNG) { + int idx = Util.sampleFromCDF(cdf, otherRNG); + return dists.get(idx).sampleVector(); + } + + @Override + public String toString() { + return "Mixture(seed="+seed+",mixingDistribution="+ Arrays.toString(mixingDistribution) +",components="+dists+")"; + } +} diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index b05e8c892..bd4e87fd3 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -30,7 +30,7 @@ /** * A class for sampling from multivariate normal distributions. */ -public final class MultivariateNormalDistribution { +public final class MultivariateNormalDistribution implements Distribution { private final long seed; private final Random rng; @@ -231,6 +231,7 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova * Sample a vector from this multivariate normal distribution. * @return A sample from this distribution. */ + @Override public DenseVector sampleVector() { return sampleVector(rng); } @@ -239,6 +240,7 @@ public DenseVector sampleVector() { * Sample a vector from this multivariate normal distribution. * @return A sample from this distribution. */ + @Override public DenseVector sampleVector(RandomGenerator otherRNG) { DenseVector sampled = new DenseVector(means.size()); for (int i = 0; i < means.size(); i++) { @@ -256,14 +258,6 @@ public DenseVector sampleVector(RandomGenerator otherRNG) { return sampled; } - /** - * Sample a vector from this multivariate normal distribution. - * @return A sample from this distribution. - */ - public double[] sampleArray() { - return sampleVector().toArray(); - } - /** * Gets a copy of the mean vector. * @return A copy of the mean vector.