-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a clustering data source for use in demos and tests.
- Loading branch information
Showing
4 changed files
with
380 additions
and
23 deletions.
There are no files selected for viewing
350 changes: 350 additions & 0 deletions
350
Clustering/Core/src/main/java/org/tribuo/clustering/example/GaussianClusterDataSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,350 @@ | ||
/* | ||
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.tribuo.clustering.example; | ||
|
||
import com.oracle.labs.mlrg.olcut.config.Config; | ||
import com.oracle.labs.mlrg.olcut.config.PropertyException; | ||
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; | ||
import com.oracle.labs.mlrg.olcut.provenance.Provenance; | ||
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; | ||
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; | ||
import org.apache.commons.math3.distribution.MultivariateNormalDistribution; | ||
import org.apache.commons.math3.random.JDKRandomGenerator; | ||
import org.tribuo.ConfigurableDataSource; | ||
import org.tribuo.Dataset; | ||
import org.tribuo.Example; | ||
import org.tribuo.MutableDataset; | ||
import org.tribuo.OutputFactory; | ||
import org.tribuo.Trainer; | ||
import org.tribuo.clustering.ClusterID; | ||
import org.tribuo.clustering.ClusteringFactory; | ||
import org.tribuo.impl.ArrayExample; | ||
import org.tribuo.provenance.ConfiguredDataSourceProvenance; | ||
import org.tribuo.provenance.DataSourceProvenance; | ||
import org.tribuo.util.Util; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.Collections; | ||
import java.util.HashMap; | ||
import java.util.Iterator; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Random; | ||
|
||
/** | ||
* Generates a clustering dataset drawn from a mixture of 5 Gaussians. | ||
* <p> | ||
* The Gaussians can be at most 4 dimensional, resulting in 4 features. | ||
* <p> | ||
* By default the Gaussians are 2-dimensional with the following means and variances: | ||
* <ul> | ||
* <li>Mean = (), variance = ()</li> | ||
* </ul> | ||
*/ | ||
public final class GaussianClusterDataSource implements ConfigurableDataSource<ClusterID> { | ||
|
||
private static final ClusteringFactory factory = new ClusteringFactory(); | ||
|
||
private static final String[] allFeatureNames = new String[]{ | ||
"A", "B", "C", "D", | ||
}; | ||
|
||
@Config(mandatory = true, description = "The number of samples to draw.") | ||
private int numSamples; | ||
|
||
@Config(description = "The probability of sampling from each Gaussian, must sum to 1.0.") | ||
private double[] mixingPMF = new double[]{0.1, 0.35, 0.05, 0.25, 0.25}; | ||
|
||
@Config(description = "The mean of the first Gaussian.") | ||
private double[] firstMean = new double[]{0.0, 0.0}; | ||
|
||
@Config(description = "A vector representing the first Gaussian's covariance matrix.") | ||
private double[] firstVariance = new double[]{1.0, 0.0, 0.0, 1.0}; | ||
|
||
@Config(description = "The mean of the second Gaussian.") | ||
private double[] secondMean = new double[]{5.0, 5.0}; | ||
|
||
@Config(description = "A vector representing the second Gaussian's covariance matrix.") | ||
private double[] secondVariance = new double[]{1.0, 0.0, 0.0, 1.0}; | ||
|
||
@Config(description = "The mean of the third Gaussian.") | ||
private double[] thirdMean = new double[]{2.5, 2.5}; | ||
|
||
@Config(description = "A vector representing the third Gaussian's covariance matrix.") | ||
private double[] thirdVariance = new double[]{1.0, 0.5, 0.5, 1.0}; | ||
|
||
@Config(description = "The mean of the fourth Gaussian.") | ||
private double[] fourthMean = new double[]{10.0, 0.0}; | ||
|
||
@Config(description = "A vector representing the fourth Gaussian's covariance matrix.") | ||
private double[] fourthVariance = new double[]{0.1, 0.0, 0.0, 0.1}; | ||
|
||
@Config(description = "The mean of the fifth Gaussian.") | ||
private double[] fifthMean = new double[]{-1.0, 0.0}; | ||
|
||
@Config(description = "A vector representing the fifth Gaussian's covariance matrix.") | ||
private double[] fifthVariance = new double[]{1.0, 0.0, 0.0, 0.1}; | ||
|
||
@Config(description = "The RNG seed.") | ||
private long seed = Trainer.DEFAULT_SEED; | ||
|
||
private List<Example<ClusterID>> examples; | ||
|
||
/** | ||
* For OLCUT. | ||
*/ | ||
private GaussianClusterDataSource() { | ||
} | ||
|
||
/** | ||
* Generates a clustering dataset drawn from a mixture of 5 Gaussians. | ||
* | ||
* @param numSamples The size of the output dataset. | ||
* @param seed The rng seed to use. | ||
* @return Examples drawn from a mixture of Gaussians. | ||
*/ | ||
public GaussianClusterDataSource(int numSamples, long seed) { | ||
this.numSamples = numSamples; | ||
this.seed = seed; | ||
postConfig(); | ||
} | ||
|
||
/** | ||
* Generates a clustering dataset drawn from a mixture of 5 Gaussians. | ||
* <p> | ||
* The Gaussians can be at most 4 dimensional, resulting in 4 features. | ||
* | ||
* @param numSamples The size of the output dataset. | ||
* @param mixingPMF The probability of each cluster. | ||
* @param firstMean The mean of the first Gaussian. | ||
* @param firstVariance The variance of the first Gaussian, linearised from a row-major matrix. | ||
* @param secondMean The mean of the second Gaussian. | ||
* @param secondVariance The variance of the second Gaussian, linearised from a row-major matrix. | ||
* @param thirdMean The mean of the third Gaussian. | ||
* @param thirdVariance The variance of the third Gaussian, linearised from a row-major matrix. | ||
* @param fourthMean The mean of the fourth Gaussian. | ||
* @param fourthVariance The variance of the fourth Gaussian, linearised from a row-major matrix. | ||
* @param fifthMean The mean of the fifth Gaussian. | ||
* @param fifthVariance The variance of the fifth Gaussian, linearised from a row-major matrix. | ||
* @param seed The rng seed to use. | ||
*/ | ||
public GaussianClusterDataSource(int numSamples, double[] mixingPMF, | ||
double[] firstMean, double[] firstVariance, | ||
double[] secondMean, double[] secondVariance, | ||
double[] thirdMean, double[] thirdVariance, | ||
double[] fourthMean, double[] fourthVariance, | ||
double[] fifthMean, double[] fifthVariance, | ||
long seed) { | ||
this.numSamples = numSamples; | ||
this.mixingPMF = mixingPMF; | ||
this.firstMean = firstMean; | ||
this.firstVariance = firstVariance; | ||
this.secondMean = secondMean; | ||
this.secondVariance = secondVariance; | ||
this.thirdMean = thirdMean; | ||
this.thirdVariance = thirdVariance; | ||
this.fourthMean = fourthMean; | ||
this.fourthVariance = fourthVariance; | ||
this.fifthMean = fifthMean; | ||
this.fifthVariance = fifthVariance; | ||
this.seed = seed; | ||
postConfig(); | ||
} | ||
|
||
/** | ||
* Used by the OLCUT configuration system, and should not be called by external code. | ||
*/ | ||
@Override | ||
public void postConfig() { | ||
if (numSamples < 1) { | ||
throw new PropertyException("", "numSamples", "numSamples must be positive, found " + numSamples); | ||
} | ||
if (mixingPMF.length != 5) { | ||
throw new PropertyException("", "mixingPMF", "mixingPMF must have 5 elements, found " + mixingPMF.length); | ||
} | ||
if (Math.abs(Util.sum(mixingPMF) - 1.0) > 1e-10) { | ||
throw new PropertyException("", "mixingPMF", "mixingPMF must sum to 1.0, found " + Util.sum(mixingPMF)); | ||
} | ||
if ((firstMean.length > allFeatureNames.length) || (firstMean.length == 0)) { | ||
throw new PropertyException("", "firstMean", "Must have 1-4 features, found " + firstMean.length); | ||
} | ||
int covarianceSize = firstMean.length * firstMean.length; | ||
if (firstVariance.length != (covarianceSize)) { | ||
throw new PropertyException("", "firstVariance", "Invalid first covariance matrix, expected " + covarianceSize + " elements, found " + firstVariance.length); | ||
} | ||
if (secondMean.length != firstMean.length) { | ||
throw new PropertyException("", "secondMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + secondMean.length); | ||
} | ||
if (secondVariance.length != firstVariance.length) { | ||
throw new PropertyException("", "secondVariance", "secondVariance is invalid, expected " + covarianceSize + ", found " + secondVariance.length); | ||
} | ||
if (thirdMean.length != firstMean.length) { | ||
throw new PropertyException("", "thirdMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + thirdMean.length); | ||
} | ||
if (thirdVariance.length != firstVariance.length) { | ||
throw new PropertyException("", "thirdVariance", "thirdVariance is invalid, expected " + covarianceSize + ", found " + thirdVariance.length); | ||
} | ||
if (fourthMean.length != firstMean.length) { | ||
throw new PropertyException("", "fourthMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + fourthMean.length); | ||
} | ||
if (fourthVariance.length != firstVariance.length) { | ||
throw new PropertyException("", "fourthVariance", "fourthVariance is invalid, expected " + covarianceSize + ", found " + fourthVariance.length); | ||
} | ||
if (fifthMean.length != firstMean.length) { | ||
throw new PropertyException("", "fifthMean", "All Gaussians must have the same number of dimensions, expected " + firstMean.length + ", found " + fifthMean.length); | ||
} | ||
if (fifthVariance.length != firstVariance.length) { | ||
throw new PropertyException("", "fifthVariance", "fifthVariance is invalid, expected " + covarianceSize + ", found " + fifthVariance.length); | ||
} | ||
double[] mixingCDF = Util.generateCDF(mixingPMF); | ||
String[] featureNames = Arrays.copyOf(allFeatureNames, firstMean.length); | ||
Random rng = new Random(seed); | ||
MultivariateNormalDistribution first = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), | ||
firstMean, reshape(firstVariance) | ||
); | ||
MultivariateNormalDistribution second = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), | ||
secondMean, reshape(secondVariance) | ||
); | ||
MultivariateNormalDistribution third = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), | ||
thirdMean, reshape(thirdVariance) | ||
); | ||
MultivariateNormalDistribution fourth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), | ||
fourthMean, reshape(fourthVariance) | ||
); | ||
MultivariateNormalDistribution fifth = new MultivariateNormalDistribution(new JDKRandomGenerator(rng.nextInt()), | ||
fifthMean, reshape(fifthVariance) | ||
); | ||
MultivariateNormalDistribution[] Gaussians = new MultivariateNormalDistribution[]{first, second, third, fourth, fifth}; | ||
List<Example<ClusterID>> examples = new ArrayList<>(numSamples); | ||
for (int i = 0; i < numSamples; i++) { | ||
int centroid = Util.sampleFromCDF(mixingCDF, rng); | ||
double[] sample = Gaussians[centroid].sample(); | ||
examples.add(new ArrayExample<>(new ClusterID(centroid), featureNames, sample)); | ||
} | ||
this.examples = Collections.unmodifiableList(examples); | ||
} | ||
|
||
@Override | ||
public OutputFactory<ClusterID> getOutputFactory() { | ||
return factory; | ||
} | ||
|
||
@Override | ||
public DataSourceProvenance getProvenance() { | ||
return new GaussianClusterDataSourceProvenance(this); | ||
} | ||
|
||
@Override | ||
public Iterator<Example<ClusterID>> iterator() { | ||
return examples.iterator(); | ||
} | ||
|
||
/** | ||
* Reshapes the vector into a matrix. | ||
* | ||
* @param vector The vector. | ||
* @return The matrix assuming the vector is linearised in row-major order. | ||
*/ | ||
private static double[][] reshape(double[] vector) { | ||
int length = (int) Math.sqrt(vector.length); | ||
if (length * length != vector.length) { | ||
throw new IllegalArgumentException("The vector does not represent a square matrix, found " + vector.length + " elements, which is not square."); | ||
} | ||
double[][] matrix = new double[length][length]; | ||
for (int i = 0; i < vector.length; i++) { | ||
matrix[i / length][i % length] = vector[i]; | ||
} | ||
return matrix; | ||
} | ||
|
||
/** | ||
* Generates a clustering dataset drawn from a mixture of 5 Gaussians. | ||
* <p> | ||
* The Gaussians can be at most 4 dimensional, resulting in 4 features. | ||
* | ||
* @param numSamples The size of the output dataset. | ||
* @param mixingPMF The probability of each cluster. | ||
* @param firstMean The mean of the first Gaussian. | ||
* @param firstVariance The variance of the first Gaussian, linearised from a row-major matrix. | ||
* @param secondMean The mean of the second Gaussian. | ||
* @param secondVariance The variance of the second Gaussian, linearised from a row-major matrix. | ||
* @param thirdMean The mean of the third Gaussian. | ||
* @param thirdVariance The variance of the third Gaussian, linearised from a row-major matrix. | ||
* @param fourthMean The mean of the fourth Gaussian. | ||
* @param fourthVariance The variance of the fourth Gaussian, linearised from a row-major matrix. | ||
* @param fifthMean The mean of the fifth Gaussian. | ||
* @param fifthVariance The variance of the fifth Gaussian, linearised from a row-major matrix. | ||
* @param seed The rng seed to use. | ||
* @return A dataset drawn from a mixture of Gaussians. | ||
*/ | ||
public static Dataset<ClusterID> generateDataset(int numSamples, double[] mixingPMF, | ||
double[] firstMean, double[] firstVariance, | ||
double[] secondMean, double[] secondVariance, | ||
double[] thirdMean, double[] thirdVariance, | ||
double[] fourthMean, double[] fourthVariance, | ||
double[] fifthMean, double[] fifthVariance, | ||
long seed) { | ||
GaussianClusterDataSource source = new GaussianClusterDataSource(numSamples, mixingPMF, | ||
firstMean, firstVariance, secondMean, secondVariance, thirdMean, thirdVariance, fourthMean, fourthVariance, | ||
fifthMean, fifthVariance, seed); | ||
return new MutableDataset<>(source); | ||
} | ||
|
||
/** | ||
* Provenance for {@link GaussianClusterDataSource}. | ||
*/ | ||
public static final class GaussianClusterDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { | ||
private static final long serialVersionUID = 1L; | ||
|
||
/** | ||
* Constructs a provenance from the host data source. | ||
* | ||
* @param host The host to read. | ||
*/ | ||
GaussianClusterDataSourceProvenance(GaussianClusterDataSource host) { | ||
super(host, "DataSource"); | ||
} | ||
|
||
/** | ||
* Constructs a provenance from the marshalled form. | ||
* | ||
* @param map The map of field values. | ||
*/ | ||
public GaussianClusterDataSourceProvenance(Map<String, Provenance> map) { | ||
this(extractProvenanceInfo(map)); | ||
} | ||
|
||
private GaussianClusterDataSourceProvenance(ExtractedInfo info) { | ||
super(info); | ||
} | ||
|
||
/** | ||
* Extracts the relevant provenance information fields for this class. | ||
* | ||
* @param map The map to remove values from. | ||
* @return The extracted information. | ||
*/ | ||
protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) { | ||
Map<String, Provenance> configuredParameters = new HashMap<>(map); | ||
String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, GaussianClusterDataSourceProvenance.class.getSimpleName()).getValue(); | ||
String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, GaussianClusterDataSourceProvenance.class.getSimpleName()).getValue(); | ||
|
||
return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap()); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.