diff --git a/Classification/Core/src/main/java/org/tribuo/classification/example/CheckerboardDataSource.java b/Classification/Core/src/main/java/org/tribuo/classification/example/CheckerboardDataSource.java new file mode 100644 index 000000000..5c9378808 --- /dev/null +++ b/Classification/Core/src/main/java/org/tribuo/classification/example/CheckerboardDataSource.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.example; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import org.tribuo.Example; +import org.tribuo.classification.Label; +import org.tribuo.impl.ArrayExample; + +import java.util.ArrayList; +import java.util.List; + +/** + * Creates a data source using a 2d checkerboard of alternating classes. + */ +public final class CheckerboardDataSource extends DemoLabelDataSource { + + @Config(description = "The number of squares on each side.") + private int numSquares = 5; + + @Config(description = "The minimum feature value.") + private double min = 0.0; + + @Config(description = "The maximum feature value.") + private double max = 10.0; + + private double range; + + private double tileWidth; + + /** + * For OLCUT. + */ + private CheckerboardDataSource() { + super(); + } + + /** + * Creates a checkboard with the required number of squares per dimension, where each feature value lies between min and max. + * + * @param numSamples The number of samples to generate. + * @param seed The RNG seed. + * @param numSquares The number of squares. + * @param min The minimum feature value. + * @param max The maximum feature value. + */ + public CheckerboardDataSource(int numSamples, long seed, int numSquares, double min, double max) { + super(numSamples, seed); + this.numSquares = numSquares; + this.min = min; + this.max = max; + postConfig(); + } + + /** + * Used by the OLCUT configuration system, and should not be called by external code. + */ + @Override + public void postConfig() { + if (max <= min) { + throw new PropertyException("", "min", "min must be strictly less than max, min = " + min + ", max = " + max); + } + if (numSquares < 2) { + throw new PropertyException("", "numSquares", "numSquares must be 2 or greater, found " + numSquares); + } + range = Math.abs(max - min); + tileWidth = range / numSquares; + super.postConfig(); + } + + @Override + protected List> generate() { + List> list = new ArrayList<>(); + + for (int i = 0; i < numSamples; i++) { + double[] values = new double[2]; + values[0] = (rng.nextDouble() * range); + values[1] = (rng.nextDouble() * range); + + int modX1 = ((int) Math.floor(values[0] / tileWidth)) % 2; + int modX2 = ((int) Math.floor(values[1] / tileWidth)) % 2; + + Label label; + if (modX1 == modX2) { + label = FIRST_CLASS; + } else { + label = SECOND_CLASS; + } + + // Update the minimums after computing the label so we don't have to + // deal with tricky negative issues interacting with Math.floor(). + values[0] += min; + values[1] += min; + + list.add(new ArrayExample<>(label, FEATURE_NAMES, values)); + } + + return list; + } + + @Override + public String toString() { + return "Checkerboard(numSamples=" + numSamples + ",seed=" + seed + ",numSquares=" + numSquares + ",min=" + min + ",max=" + max + ')'; + } +} diff --git a/Classification/Core/src/main/java/org/tribuo/classification/example/ConcentricCirclesDataSource.java b/Classification/Core/src/main/java/org/tribuo/classification/example/ConcentricCirclesDataSource.java new file mode 100644 index 000000000..ab58fa554 --- /dev/null +++ b/Classification/Core/src/main/java/org/tribuo/classification/example/ConcentricCirclesDataSource.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.example; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import org.tribuo.Example; +import org.tribuo.classification.Label; +import org.tribuo.impl.ArrayExample; + +import java.util.ArrayList; +import java.util.List; + +/** + * A data source for two concentric circles, one per class. + */ +public final class ConcentricCirclesDataSource extends DemoLabelDataSource { + + @Config(description = "The radius of the outer circle.") + private double radius = 2; + + @Config(description = "The proportion of the circle radius that forms class one.") + private double classProportion = 0.5; + + /** + * For OLCUT. + */ + private ConcentricCirclesDataSource() { + super(); + } + + /** + * Constructs a data source for two concentric circles, one per class. + * + * @param numSamples The number of samples to generate. + * @param seed The RNG seed. + * @param radius The radius of the outer circle. + * @param classProportion The proportion of the circle area that forms class 1. + */ + public ConcentricCirclesDataSource(int numSamples, long seed, double radius, double classProportion) { + super(numSamples, seed); + this.radius = radius; + this.classProportion = classProportion; + postConfig(); + } + + /** + * Used by the OLCUT configuration system, and should not be called by external code. + */ + @Override + public void postConfig() { + if ((classProportion <= 0.0) || (classProportion >= 1.0)) { + throw new PropertyException("", "classProportion", "Class proportion must be between zero and one, found " + classProportion); + } + if (radius <= 0) { + throw new PropertyException("", "radius", "Radius must be positive, found " + radius); + } + super.postConfig(); + } + + @Override + protected List> generate() { + List> list = new ArrayList<>(); + + for (int i = 0; i < numSamples; i++) { + double rotation = rng.nextDouble() * 2 * Math.PI; + double distance = Math.sqrt(rng.nextDouble()) * radius; + double[] values = new double[2]; + values[0] = distance * Math.cos(rotation); + values[1] = distance * Math.sin(rotation); + + double labelDistance = (values[0] * values[0]) + (values[1] * values[1]); + Label label; + if (labelDistance < classProportion * radius * radius) { + label = FIRST_CLASS; + } else { + label = SECOND_CLASS; + } + + list.add(new ArrayExample<>(label, FEATURE_NAMES, values)); + } + + return list; + } + + @Override + public String toString() { + return "ConcentricCircles(numSamples=" + numSamples + ",seed=" + seed + ",radius=" + radius + ",classProportion=" + classProportion + ")"; + } +} diff --git a/Classification/Core/src/main/java/org/tribuo/classification/example/DemoLabelDataSource.java b/Classification/Core/src/main/java/org/tribuo/classification/example/DemoLabelDataSource.java new file mode 100644 index 000000000..eddb0f4d7 --- /dev/null +++ b/Classification/Core/src/main/java/org/tribuo/classification/example/DemoLabelDataSource.java @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.example; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; +import com.oracle.labs.mlrg.olcut.provenance.Provenance; +import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; +import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; +import org.tribuo.ConfigurableDataSource; +import org.tribuo.Example; +import org.tribuo.classification.Label; +import org.tribuo.classification.LabelFactory; +import org.tribuo.provenance.ConfiguredDataSourceProvenance; +import org.tribuo.provenance.DataSourceProvenance; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * The base class for the 2d binary classification data sources in {@link org.tribuo.classification.example}. + *

+ * The feature names are {@link #X1} and {@link #X2} and the labels are {@link #FIRST_CLASS} and {@link #SECOND_CLASS}. + *

+ * Likely to be sealed to the classes in this package when we adopt Java 17. + */ +public abstract class DemoLabelDataSource implements ConfigurableDataSource

+ * Note does not call {@link #postConfig} to generate the examples, + * this must be called by the subclass's constructor. + * @param numSamples The number of samples to generate. + * @param seed The RNG seed. + */ + DemoLabelDataSource(int numSamples, long seed) { + this.numSamples = numSamples; + this.seed = seed; + } + + /** + * Configures the class. Should be called in sub-classes' postConfigs + * after they've validated their parameters. + */ + @Override + public void postConfig() { + if (numSamples < 1) { + throw new PropertyException("","numSamples","Number of samples must be positive, found " + numSamples); + } + this.rng = new Random(seed); + this.examples = Collections.unmodifiableList(generate()); + } + + /** + * Generates the examples using the configured fields. + *

+ * Is called internally by {@link #postConfig}. + * @return The generated examples. + */ + protected abstract List> generate(); + + @Override + public LabelFactory getOutputFactory() { + return factory; + } + + @Override + public DataSourceProvenance getProvenance() { + return new DemoLabelDataSourceProvenance(this); + } + + @Override + public Iterator> iterator() { + return examples.iterator(); + } + + /** + * Provenance for {@link DemoLabelDataSource}. + */ + public static final class DemoLabelDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { + private static final long serialVersionUID = 1L; + + /** + * Constructs a provenance from the host data source. + * + * @param host The host to read. + */ + DemoLabelDataSourceProvenance(DemoLabelDataSource host) { + super(host, "DataSource"); + } + + /** + * Constructs a provenance from the marshalled form. + * + * @param map The map of field values. + */ + public DemoLabelDataSourceProvenance(Map map) { + this(extractProvenanceInfo(map)); + } + + private DemoLabelDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) { + super(info); + } + + /** + * Extracts the relevant provenance information fields for this class. + * + * @param map The map to remove values from. + * @return The extracted information. + */ + static ExtractedInfo extractProvenanceInfo(Map map) { + Map configuredParameters = new HashMap<>(map); + String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, DemoLabelDataSourceProvenance.class.getSimpleName()).getValue(); + String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, DemoLabelDataSourceProvenance.class.getSimpleName()).getValue(); + + return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap()); + } + } +} diff --git a/Classification/Core/src/main/java/org/tribuo/classification/example/GaussianLabelDataSource.java b/Classification/Core/src/main/java/org/tribuo/classification/example/GaussianLabelDataSource.java new file mode 100644 index 000000000..f020226de --- /dev/null +++ b/Classification/Core/src/main/java/org/tribuo/classification/example/GaussianLabelDataSource.java @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.example; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import org.tribuo.Example; +import org.tribuo.classification.Label; +import org.tribuo.impl.ArrayExample; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +/** + * A data source for two classes generated from separate Gaussians. + */ +public final class GaussianLabelDataSource extends DemoLabelDataSource { + + @Config(mandatory = true, description = "2d mean of the first Gaussian.") + public double[] firstMean; + + @Config(mandatory = true, description = "4 element covariance matrix of the first Gaussian.") + public double[] firstCovarianceMatrix; + + @Config(mandatory = true, description = "2d mean of the second Gaussian.") + private double[] secondMean; + + @Config(mandatory = true, description = "4 element covariance matrix of the second Gaussian.") + private double[] secondCovarianceMatrix; + + private double[] firstCholesky; + + private double[] secondCholesky; + + /** + * For OLCUT. + */ + private GaussianLabelDataSource() { + super(); + } + + /** + * Constructs a data source which contains two classes where each class is sampled from a 2d Gaussian with + * the specified parameters. + * + * @param numSamples The number of samples to draw. + * @param seed The RNG seed. + * @param firstMean The mean of class one's Gaussian. + * @param firstCovarianceMatrix The covariance matrix of class one's Gaussian. + * @param secondMean The mean of class two's Gaussian. + * @param secondCovarianceMatrix The covariance matrix of class two's Gaussian. + */ + public GaussianLabelDataSource(int numSamples, long seed, double[] firstMean, double[] firstCovarianceMatrix, double[] secondMean, double[] secondCovarianceMatrix) { + super(numSamples, seed); + this.firstMean = firstMean; + this.firstCovarianceMatrix = firstCovarianceMatrix; + this.secondMean = secondMean; + this.secondCovarianceMatrix = secondCovarianceMatrix; + postConfig(); + } + + /** + * Used by the OLCUT configuration system, and should not be called by external code. + */ + @Override + public void postConfig() { + if (firstMean.length != 2) { + throw new PropertyException("", "firstMean", "firstMean is not the right length"); + } + if (secondMean.length != 2) { + throw new PropertyException("", "secondMean", "secondMean is not the right length"); + } + if (firstCovarianceMatrix.length != 4) { + throw new PropertyException("", "firstCovarianceMatrix", "firstCovarianceMatrix is not the right length"); + } + if (secondCovarianceMatrix.length != 4) { + throw new PropertyException("", "secondCovarianceMatrix", "secondCovarianceMatrix is not the right length"); + } + + for (int i = 0; i < firstCovarianceMatrix.length; i++) { + if (firstCovarianceMatrix[i] < 0) { + throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not positive semi-definite"); + } + if (secondCovarianceMatrix[i] < 0) { + throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not positive semi-definite"); + } + } + + if (firstCovarianceMatrix[1] != firstCovarianceMatrix[2]) { + throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not a covariance matrix"); + } + + if (secondCovarianceMatrix[1] != secondCovarianceMatrix[2]) { + throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not a covariance matrix"); + } + + firstCholesky = new double[3]; + firstCholesky[0] = Math.sqrt(firstCovarianceMatrix[0]); + firstCholesky[1] = firstCovarianceMatrix[1] / Math.sqrt(firstCovarianceMatrix[0]); + firstCholesky[2] = Math.sqrt(firstCovarianceMatrix[3] * firstCovarianceMatrix[0] - firstCovarianceMatrix[1] * firstCovarianceMatrix[1]) / Math.sqrt(firstCovarianceMatrix[0]); + + secondCholesky = new double[3]; + secondCholesky[0] = Math.sqrt(secondCovarianceMatrix[0]); + secondCholesky[1] = secondCovarianceMatrix[1] / Math.sqrt(secondCovarianceMatrix[0]); + secondCholesky[2] = Math.sqrt(secondCovarianceMatrix[3] * secondCovarianceMatrix[0] - secondCovarianceMatrix[1] * secondCovarianceMatrix[1]) / Math.sqrt(secondCovarianceMatrix[0]); + super.postConfig(); + } + + @Override + protected List> generate() { + List> list = new ArrayList<>(); + + for (int i = 0; i < numSamples / 2; i++) { + double[] sample = sampleGaussian(rng, firstMean, firstCholesky); + ArrayExample

+ * Also known as the two moons dataset. + */ +public final class InterlockingCrescentsDataSource extends DemoLabelDataSource { + + /** + * For OLCUT. + */ + private InterlockingCrescentsDataSource() { + super(); + } + + /** + * Constructs an interlocking crescents data source. + * + * @param numSamples The number of samples to generate. + */ + public InterlockingCrescentsDataSource(int numSamples) { + super(numSamples, Trainer.DEFAULT_SEED); + postConfig(); + } + + @Override + protected List> generate() { + List> list = new ArrayList<>(); + + for (int i = 0; i < numSamples / 2; i++) { + double[] values = new double[2]; + values[0] = Math.cos(Math.PI * ((double) i) / ((numSamples / 2) - 1)); + values[1] = Math.sin(Math.PI * ((double) i) / ((numSamples / 2) - 1)); + list.add(new ArrayExample<>(FIRST_CLASS, FEATURE_NAMES, values)); + } + + for (int i = numSamples / 2; i < numSamples; i++) { + int j = i - numSamples / 2; + double[] values = new double[2]; + + values[0] = 1 - Math.cos(Math.PI * ((double) j) / ((numSamples / 2) - 1)); + values[1] = 0.5 - Math.sin(Math.PI * ((double) j) / ((numSamples / 2) - 1)); + list.add(new ArrayExample<>(SECOND_CLASS, FEATURE_NAMES, values)); + } + + return list; + } + + @Override + public String toString() { + return "InterlockingCrescents(numSamples=" + numSamples + ")"; + } +} diff --git a/Classification/Core/src/main/java/org/tribuo/classification/example/NoisyInterlockingCrescentsDataSource.java b/Classification/Core/src/main/java/org/tribuo/classification/example/NoisyInterlockingCrescentsDataSource.java new file mode 100644 index 000000000..a1b4cac69 --- /dev/null +++ b/Classification/Core/src/main/java/org/tribuo/classification/example/NoisyInterlockingCrescentsDataSource.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.example; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import org.tribuo.Example; +import org.tribuo.classification.Label; +import org.tribuo.impl.ArrayExample; + +import java.util.ArrayList; +import java.util.List; + +/** + * A data source of two interleaved half circles with some zero mean Gaussian noise applied to each point. + */ +public final class NoisyInterlockingCrescentsDataSource extends DemoLabelDataSource { + + @Config(description = "Variance of the Gaussian noise") + private double variance = 0.1; + + /** + * For OLCUT. + */ + private NoisyInterlockingCrescentsDataSource() { + super(); + } + + /** + * Constructs a noisy interlocking crescents data source. + *

+ * It's the same as {@link InterlockingCrescentsDataSource} but each point has Gaussian + * noise with zero mean and the specified variance added to it. + * + * @param numSamples The number of samples to generate. + * @param seed The RNG seed. + * @param variance The variance of the Gaussian noise. + */ + public NoisyInterlockingCrescentsDataSource(int numSamples, long seed, double variance) { + super(numSamples, seed); + this.variance = variance; + postConfig(); + } + + /** + * Used by the OLCUT configuration system, and should not be called by external code. + */ + @Override + public void postConfig() { + if (variance <= 0.0) { + throw new PropertyException("", "variance", "Variance must be positive, found " + variance); + } + super.postConfig(); + } + + @Override + protected List> generate() { + List> list = new ArrayList<>(); + + for (int i = 0; i < numSamples / 2; i++) { + double[] values = new double[2]; + double u = rng.nextDouble(); + values[0] = Math.cos(Math.PI * u) + rng.nextGaussian() * variance; + values[1] = Math.sin(Math.PI * u) + rng.nextGaussian() * variance; + list.add(new ArrayExample<>(FIRST_CLASS, FEATURE_NAMES, values)); + } + + for (int i = numSamples / 2; i < numSamples; i++) { + double[] values = new double[2]; + double u = rng.nextDouble(); + values[0] = (1 - Math.cos(Math.PI * u)) + rng.nextGaussian() * variance; + values[1] = (0.5 - Math.sin(Math.PI * u)) + rng.nextGaussian() * variance; + list.add(new ArrayExample<>(SECOND_CLASS, FEATURE_NAMES, values)); + } + + return list; + } + + @Override + public String toString() { + return "NoisyInterlockingCrescents(numSamples=" + numSamples + ",seed=" + seed + ",variance=" + variance + ')'; + } +} diff --git a/Classification/Core/src/main/java/org/tribuo/classification/example/package-info.java b/Classification/Core/src/main/java/org/tribuo/classification/example/package-info.java index fb494c39a..4cb7c5a1b 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/example/package-info.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/example/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ /** - * Provides a multiclass data generator used for testing implementations. Not designed for synthesising data for real uses. + * Provides a multiclass data generator used for testing implementations, along with several synthetic data generators + * for 2d binary classification problems to be used in demos or tutorials. */ package org.tribuo.classification.example; \ No newline at end of file diff --git a/Classification/Core/src/test/java/org/tribuo/classification/example/DemoLabelDataSourceTest.java b/Classification/Core/src/test/java/org/tribuo/classification/example/DemoLabelDataSourceTest.java new file mode 100644 index 000000000..6e422bfa2 --- /dev/null +++ b/Classification/Core/src/test/java/org/tribuo/classification/example/DemoLabelDataSourceTest.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.example; + +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import org.junit.jupiter.api.Test; +import org.tribuo.Dataset; +import org.tribuo.MutableDataset; +import org.tribuo.classification.Label; +import org.tribuo.test.Helpers; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DemoLabelDataSourceTest { + + @Test + public void testCheckerboard() { + // Check zero samples throws + assertThrows(PropertyException.class, () -> new CheckerboardDataSource(0, 1, 10, 0.0, 1.0)); + // Check invalid numSquares throws + assertThrows(PropertyException.class, () -> new CheckerboardDataSource(200, 1, 1, 0.0, 1.0)); + // Check invalid min & max + assertThrows(PropertyException.class, () -> new CheckerboardDataSource(200, 1, 10, 0.0, 0.0)); + assertThrows(PropertyException.class, () -> new CheckerboardDataSource(200, 1, 10, 0.0, -1.0)); + + // Check valid parameters work + CheckerboardDataSource source = new CheckerboardDataSource(2000, 1, 10, -1.0, 1.0); + assertEquals(2000, source.examples.size()); + Dataset