diff --git a/Core/src/main/java/org/tribuo/util/Util.java b/Core/src/main/java/org/tribuo/util/Util.java
index 2db850193..06db14fd3 100644
--- a/Core/src/main/java/org/tribuo/util/Util.java
+++ b/Core/src/main/java/org/tribuo/util/Util.java
@@ -429,6 +429,25 @@ public static double[] generateCDF(double[] pmf) {
return cumulativeSum(pmf);
}
+ /**
+ * Validates that the supplied double array is a probability mass function.
+ *
+ * That is, each element is bounded 0,1 and all elements sum to 1.
+ * @param pmf The PMF to check.
+ * @return True if it's a valid pmf.
+ */
+ public static boolean validatePMF(double[] pmf) {
+ double total = 0.0;
+ for (double v : pmf) {
+ if ((v < 0) || (v > 1.0)) {
+ return false;
+ } else {
+ total += v;
+ }
+ }
+ return !(Math.abs(total - 1.0) > 1e-10);
+ }
+
/**
* Produces a cumulative sum array.
* @param input The input to sum.
diff --git a/Math/src/main/java/org/tribuo/math/distributions/Distribution.java b/Math/src/main/java/org/tribuo/math/distributions/Distribution.java
new file mode 100644
index 000000000..944e0d5f4
--- /dev/null
+++ b/Math/src/main/java/org/tribuo/math/distributions/Distribution.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tribuo.math.distributions;
+
+import org.tribuo.math.la.DenseVector;
+
+import java.util.random.RandomGenerator;
+
+/**
+ * Interface for probability distributions which can be sampled from.
+ *
+ * The vector sampled represents a single sample from that (possibly multivariate)
+ * distribution rather than a sequence of samples.
+ */
+public interface Distribution {
+
+ /**
+ * Sample a single vector from this probability distribution.
+ * @return A vector sampled from the distribution.
+ */
+ DenseVector sampleVector();
+
+ /**
+ * Sample a single vector from this probability distribution using the supplied RNG.
+ * @param otherRNG The RNG to use.
+ * @return A vector sampled from this distribution.
+ */
+ DenseVector sampleVector(RandomGenerator otherRNG);
+
+ /**
+ * Sample a vector from this probability distribution and return it as an array.
+ * @return An array sampled from this distribution.
+ */
+ default double[] sampleArray() {
+ return sampleVector().toArray();
+ }
+}
diff --git a/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java
new file mode 100644
index 000000000..7e1f7f70a
--- /dev/null
+++ b/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tribuo.math.distributions;
+
+import org.tribuo.math.la.DenseVector;
+import org.tribuo.util.Util;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.SplittableRandom;
+import java.util.random.RandomGenerator;
+
+/**
+ * A mixture distribution which samples from a set of internal distributions mixed by some probability distribution.
+ * @param The inner distribution type.
+ */
+public final class MixtureDistribution implements Distribution {
+
+ private final List dists;
+
+ private final double[] mixingDistribution;
+
+ private final double[] cdf;
+
+ private final RandomGenerator rng;
+
+ private final long seed;
+
+ /**
+ * Construct a mixture distribution over the supplied components.
+ * @param distributions The distribution components.
+ * @param mixingDistribution The mixing distribution, must be a valid PMF.
+ * @param seed The RNG seed.
+ */
+ public MixtureDistribution(List distributions, DenseVector mixingDistribution, long seed) {
+ this(distributions, mixingDistribution.toArray(), seed);
+ }
+
+ /**
+ * Construct a mixture distribution over the supplied components.
+ * @param distributions The distribution components.
+ * @param mixingDistribution The mixing distribution, must be a valid PMF.
+ * @param seed The RNG seed.
+ */
+ public MixtureDistribution(List distributions, double[] mixingDistribution, long seed) {
+ this.dists = List.copyOf(distributions);
+ this.mixingDistribution = Arrays.copyOf(mixingDistribution, mixingDistribution.length);
+ this.seed = seed;
+ this.rng = new SplittableRandom(seed);
+ if (dists.size() != this.mixingDistribution.length) {
+ throw new IllegalArgumentException("Invalid distribution, expected the same number of components as probabilities, found " + dists.size() + " components, and " + this.mixingDistribution.length + " probabilities");
+ }
+ if (!Util.validatePMF(this.mixingDistribution)) {
+ throw new IllegalArgumentException("Invalid mixing distribution, was not a valid PMF, found " + Arrays.toString(this.mixingDistribution));
+ }
+ this.cdf = Util.generateCDF(this.mixingDistribution);
+ }
+
+ /**
+ * Returns the number of distributions.
+ * @return The number of distributions.
+ */
+ public int getNumComponents() {
+ return dists.size();
+ }
+
+ /**
+ * Return a mixture component.
+ * @param i The index of the mixture component.
+ * @return The ith component.
+ */
+ public T getComponent(int i) {
+ return dists.get(i);
+ }
+
+ /**
+ * Returns a copy of the mixing distribution.
+ * @return A copy of the mixing distribution.
+ */
+ public double[] getMixingDistribution() {
+ return Arrays.copyOf(mixingDistribution, mixingDistribution.length);
+ }
+
+ @Override
+ public DenseVector sampleVector() {
+ return sampleVector(rng);
+ }
+
+ @Override
+ public DenseVector sampleVector(RandomGenerator otherRNG) {
+ int idx = Util.sampleFromCDF(cdf, otherRNG);
+ return dists.get(idx).sampleVector();
+ }
+
+ @Override
+ public String toString() {
+ return "Mixture(seed="+seed+",mixingDistribution="+ Arrays.toString(mixingDistribution) +",components="+dists+")";
+ }
+}
diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java
index b05e8c892..bd4e87fd3 100644
--- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java
+++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java
@@ -30,7 +30,7 @@
/**
* A class for sampling from multivariate normal distributions.
*/
-public final class MultivariateNormalDistribution {
+public final class MultivariateNormalDistribution implements Distribution {
private final long seed;
private final Random rng;
@@ -231,6 +231,7 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova
* Sample a vector from this multivariate normal distribution.
* @return A sample from this distribution.
*/
+ @Override
public DenseVector sampleVector() {
return sampleVector(rng);
}
@@ -239,6 +240,7 @@ public DenseVector sampleVector() {
* Sample a vector from this multivariate normal distribution.
* @return A sample from this distribution.
*/
+ @Override
public DenseVector sampleVector(RandomGenerator otherRNG) {
DenseVector sampled = new DenseVector(means.size());
for (int i = 0; i < means.size(); i++) {
@@ -256,14 +258,6 @@ public DenseVector sampleVector(RandomGenerator otherRNG) {
return sampled;
}
- /**
- * Sample a vector from this multivariate normal distribution.
- * @return A sample from this distribution.
- */
- public double[] sampleArray() {
- return sampleVector().toArray();
- }
-
/**
* Gets a copy of the mean vector.
* @return A copy of the mean vector.