diff --git a/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/AnomalyDataGenerator.java b/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/AnomalyDataGenerator.java
index fda55af83..4ba031a5d 100644
--- a/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/AnomalyDataGenerator.java
+++ b/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/AnomalyDataGenerator.java
@@ -43,6 +43,10 @@
* Also has a dataset generator which returns a training dataset
* with no anomalies sampled from a single gaussian, and a test dataset
* sampled from two gaussians where the second is labelled anomalous.
+ *
+ * For most use cases that are not unit tests, it is recommended to use
+ * {@link GaussianAnomalyDataSource} which has similar functionality but
+ * is more flexible and configurable.
*/
public abstract class AnomalyDataGenerator {
diff --git a/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/GaussianAnomalyDataSource.java b/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/GaussianAnomalyDataSource.java
new file mode 100644
index 000000000..b76093a7b
--- /dev/null
+++ b/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/GaussianAnomalyDataSource.java
@@ -0,0 +1,297 @@
+/*
+ * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tribuo.anomaly.example;
+
+import com.oracle.labs.mlrg.olcut.config.Config;
+import com.oracle.labs.mlrg.olcut.config.PropertyException;
+import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
+import com.oracle.labs.mlrg.olcut.provenance.Provenance;
+import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
+import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
+import org.tribuo.ConfigurableDataSource;
+import org.tribuo.Dataset;
+import org.tribuo.Example;
+import org.tribuo.Feature;
+import org.tribuo.MutableDataset;
+import org.tribuo.OutputFactory;
+import org.tribuo.Trainer;
+import org.tribuo.anomaly.AnomalyFactory;
+import org.tribuo.anomaly.Event;
+import org.tribuo.impl.ArrayExample;
+import org.tribuo.provenance.ConfiguredDataSourceProvenance;
+import org.tribuo.provenance.DataSourceProvenance;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.tribuo.anomaly.AnomalyFactory.ANOMALOUS_EVENT;
+import static org.tribuo.anomaly.AnomalyFactory.EXPECTED_EVENT;
+
+/**
+ * Generates an anomaly detection dataset sampling each feature uniformly from a univariate Gaussian.
+ *
+ * Or equivalently sampling all the features from a spherical Gaussian.
+ * Can accept at most 26 features.
+ *
+ * By default the expected means are (1.0, 2.0, 1.0, 2.0, 5.0), with variances
+ * (1.0, 0.5, 0.25, 1.0, 0.1).
+ * The anomalous means are (-2.0, 2.0, -2.0, 2.0, -10.0), with variances (1.0, 0.5, 0.25, 1.0, 0.1)
+ * which are the same as the default expected variances.
+ */
+public final class GaussianAnomalyDataSource implements ConfigurableDataSource {
+
+ private static final AnomalyFactory factory = new AnomalyFactory();
+
+ private static final String[] allFeatureNames = new String[]{
+ "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"
+ };
+
+ @Config(mandatory = true, description = "The number of samples to draw.")
+ private int numSamples;
+
+ @Config(description = "Means of the expected events.")
+ private double[] expectedMeans = new double[]{1.0, 2.0, 1.0, 2.0, 5.0};
+
+ @Config(description = "Variances of the expected events.")
+ private double[] expectedVariances = new double[]{1.0, 0.5, 0.25, 1.0, 0.1};
+
+ @Config(description = "Means of the anomalous events.")
+ private double[] anomalousMeans = new double[]{-2.0, 2.0, -2.0, 2.0, -10.0};
+
+ @Config(description = "Variances of the anomalous events.")
+ private double[] anomalousVariances = new double[]{1.0, 0.5, 0.25, 1.0, 0.1};
+
+ @Config(description = "The RNG seed.")
+ private long seed = Trainer.DEFAULT_SEED;
+
+ @Config(mandatory = true, description = "The fraction of anomalous events.")
+ private float fractionAnomalous = 0.3f;
+
+ private List> examples;
+
+ /**
+ * For OLCUT.
+ */
+ private GaussianAnomalyDataSource() { }
+
+ /**
+ * Generates anomaly detection examples sampling each feature uniformly from a univariate Gaussian.
+ *
+ * Or equivalently sampling all the features from a spherical Gaussian.
+ *
+ * Can accept at most 26 features.
+ *
+ * @param numSamples The size of the output dataset.
+ * @param fractionAnomalous The fraction of anomalies in the generated data.
+ * @param seed The rng seed to use.
+ * @return Examples drawn from a gaussian.
+ */
+ public GaussianAnomalyDataSource(int numSamples, float fractionAnomalous, long seed) {
+ this.numSamples = numSamples;
+ this.fractionAnomalous = fractionAnomalous;
+ this.seed = seed;
+ postConfig();
+ }
+
+ /**
+ * Generates anomaly detection examples sampling each feature uniformly from a univariate Gaussian.
+ *
+ * Or equivalently sampling all the features from a spherical Gaussian.
+ *
+ * Can accept at most 26 features.
+ *
+ * @param numSamples The size of the output dataset.
+ * @param expectedMeans The means of the expected event features.
+ * @param expectedVariances The variances of the expected event features.
+ * @param anomalousMeans The means of the anomalous event features.
+ * @param anomalousVariances The variances of the anomalous event features.
+ * @param fractionAnomalous The fraction of anomalies to generate.
+ * @param seed The rng seed to use.
+ * @return Examples drawn from a gaussian.
+ */
+ public GaussianAnomalyDataSource(int numSamples, double[] expectedMeans, double[] expectedVariances,
+ double[] anomalousMeans, double[] anomalousVariances,
+ float fractionAnomalous, long seed) {
+ this.numSamples = numSamples;
+ this.expectedMeans = expectedMeans;
+ this.expectedVariances = expectedVariances;
+ this.anomalousMeans = anomalousMeans;
+ this.anomalousVariances = anomalousVariances;
+ this.fractionAnomalous = fractionAnomalous;
+ this.seed = seed;
+ postConfig();
+ }
+
+ /**
+ * Used by the OLCUT configuration system, and should not be called by external code.
+ */
+ @Override
+ public void postConfig() {
+ if (numSamples < 1) {
+ throw new PropertyException("", "numSamples", "numSamples must be positive, found " + numSamples);
+ }
+ if ((expectedMeans.length > allFeatureNames.length) || (expectedMeans.length == 0)) {
+ throw new PropertyException("", "expectedMeans", "Must have 1-26 features, found " + expectedMeans.length);
+ }
+ if (expectedMeans.length != expectedVariances.length) {
+ throw new PropertyException("", "expectedMeans", "Must supply the same number of expected means and variances." +
+ " expectedMeans.length = " + expectedMeans.length +
+ " expectedVariances.length = " + expectedVariances.length);
+ }
+ if (anomalousMeans.length != anomalousVariances.length) {
+ throw new PropertyException("", "anomalousMeans", "Must supply the same number of anomalous means and variances." +
+ " anomalousMeans.length = " + anomalousMeans.length +
+ " anomalousVariances.length = " + anomalousVariances.length);
+ }
+ if (fractionAnomalous < 0.0f || fractionAnomalous > 1.0f) {
+ throw new PropertyException("", "fractionAnomalous", "fractionAnomalous must be between 0.0 and 1.0, found " + fractionAnomalous);
+ }
+ if ((fractionAnomalous != 0.0) && (anomalousMeans.length != expectedMeans.length)) {
+ throw new PropertyException("", "anomalousMeans", "When sampling anomalous data there must be the same number " +
+ "of anomalous features as expected features. anomalousMeans.length = " + anomalousMeans.length +
+ ", expectedMeans.length = " + expectedMeans.length);
+
+ }
+ String[] featureNames = Arrays.copyOf(allFeatureNames, expectedMeans.length);
+ // We use java.util.Random here because SplittableRandom doesn't have nextGaussian yet.
+ // Once we adopt Java 17 we may switch to SplittableRandom.
+ Random rng = new Random(seed);
+ List> examples = new ArrayList<>(numSamples);
+ for (int i = 0; i < numSamples; i++) {
+ double draw = rng.nextDouble();
+ if (draw < fractionAnomalous) {
+ List featureList = generateFeatures(rng, featureNames, anomalousMeans, anomalousVariances);
+ examples.add(new ArrayExample<>(ANOMALOUS_EVENT, featureList));
+ } else {
+ List featureList = generateFeatures(rng, featureNames, expectedMeans, expectedVariances);
+ examples.add(new ArrayExample<>(EXPECTED_EVENT, featureList));
+ }
+ }
+ this.examples = Collections.unmodifiableList(examples);
+ }
+
+ @Override
+ public OutputFactory getOutputFactory() {
+ return factory;
+ }
+
+ @Override
+ public DataSourceProvenance getProvenance() {
+ return new GaussianAnomalyDataSourceProvenance(this);
+ }
+
+ @Override
+ public Iterator> iterator() {
+ return examples.iterator();
+ }
+
+ /**
+ * Generates the features based on the RNG, the means and the variances.
+ *
+ * @param rng The RNG to use.
+ * @param names The feature names.
+ * @param means The feature means.
+ * @param variances The feature variances.
+ * @return A sampled feature list.
+ */
+ private static List generateFeatures(Random rng, String[] names, double[] means, double[] variances) {
+ if ((names.length != means.length) || (names.length != variances.length)) {
+ throw new IllegalArgumentException("Names, means and variances must be the same length");
+ }
+
+ List features = new ArrayList<>();
+
+ for (int i = 0; i < names.length; i++) {
+ double value = (rng.nextGaussian() * Math.sqrt(variances[i])) + means[i];
+ features.add(new Feature(names[i], value));
+ }
+
+ return features;
+ }
+
+ /**
+ * Generates an anomaly detection dataset sampling each feature uniformly from a univariate Gaussian.
+ *
+ * Or equivalently sampling all the features from a spherical Gaussian.
+ *
+ * Can accept at most 26 features.
+ *
+ * @param numSamples The size of the output dataset.
+ * @param expectedMeans The means of the expected event features.
+ * @param expectedVariances The variances of the expected event features.
+ * @param anomalousMeans The means of the anomalous event features.
+ * @param anomalousVariances The variances of the anomalous event features.
+ * @param fractionAnomalous The fraction of anomalies to generate.
+ * @param seed The rng seed to use.
+ * @return A dataset drawn from a gaussian.
+ */
+ public static Dataset generateDataset(int numSamples, double[] expectedMeans, double[] expectedVariances,
+ double[] anomalousMeans, double[] anomalousVariances,
+ float fractionAnomalous, long seed) {
+ GaussianAnomalyDataSource source = new GaussianAnomalyDataSource(numSamples, expectedMeans, expectedVariances, anomalousMeans, anomalousVariances, fractionAnomalous, seed);
+ return new MutableDataset<>(source);
+ }
+
+ /**
+ * Provenance for {@link GaussianAnomalyDataSource}.
+ */
+ public static final class GaussianAnomalyDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Constructs a provenance from the host data source.
+ *
+ * @param host The host to read.
+ */
+ GaussianAnomalyDataSourceProvenance(GaussianAnomalyDataSource host) {
+ super(host, "DataSource");
+ }
+
+ /**
+ * Constructs a provenance from the marshalled form.
+ *
+ * @param map The map of field values.
+ */
+ public GaussianAnomalyDataSourceProvenance(Map map) {
+ this(extractProvenanceInfo(map));
+ }
+
+ private GaussianAnomalyDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
+ super(info);
+ }
+
+ /**
+ * Extracts the relevant provenance information fields for this class.
+ *
+ * @param map The map to remove values from.
+ * @return The extracted information.
+ */
+ protected static ExtractedInfo extractProvenanceInfo(Map map) {
+ Map configuredParameters = new HashMap<>(map);
+ String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, GaussianAnomalyDataSourceProvenance.class.getSimpleName()).getValue();
+ String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, GaussianAnomalyDataSourceProvenance.class.getSimpleName()).getValue();
+
+ return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
+ }
+ }
+}
diff --git a/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/package-info.java b/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/package-info.java
index cf8e93eba..ae435cc1b 100644
--- a/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/package-info.java
+++ b/AnomalyDetection/Core/src/main/java/org/tribuo/anomaly/example/package-info.java
@@ -15,6 +15,6 @@
*/
/**
- * Provides a anomaly data generator used for testing implementations.
+ * Provides anomaly data generators used for demos and testing implementations.
*/
package org.tribuo.anomaly.example;
\ No newline at end of file
diff --git a/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java b/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java
index 600715df3..ac05399b8 100644
--- a/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java
+++ b/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java
@@ -16,13 +16,14 @@
package org.tribuo.anomaly.liblinear;
-import com.oracle.labs.mlrg.olcut.util.Pair;
import org.junit.jupiter.api.Test;
+import org.tribuo.DataSource;
import org.tribuo.Dataset;
+import org.tribuo.MutableDataset;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.evaluation.AnomalyEvaluation;
import org.tribuo.anomaly.evaluation.AnomalyEvaluator;
-import org.tribuo.anomaly.example.AnomalyDataGenerator;
+import org.tribuo.anomaly.example.GaussianAnomalyDataSource;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.test.Helpers;
@@ -35,22 +36,26 @@ public class LibLinearAnomalyTrainerTest {
@Test
public void gaussianDataTest() {
- Pair,Dataset> pair = AnomalyDataGenerator.gaussianAnomaly(1000,0.2);
+ DataSource trainSource = new GaussianAnomalyDataSource(1000, 0.0f, 1);
+ DataSource testSource = new GaussianAnomalyDataSource(1000, 0.2f, 1);
+
+ Dataset trainData = new MutableDataset<>(trainSource);
+ Dataset testData = new MutableDataset<>(testSource);
LinearAnomalyType type = new LinearAnomalyType(LinearAnomalyType.LinearType.ONECLASS_SVM);
LibLinearAnomalyTrainer trainer = new LibLinearAnomalyTrainer(type,1.0,1000,0.01,0.05);
- LibLinearModel model = trainer.train(pair.getA());
+ LibLinearModel model = trainer.train(trainData);
AnomalyEvaluator evaluator = new AnomalyEvaluator();
- AnomalyEvaluation evaluation = evaluator.evaluate(model,pair.getB());
+ AnomalyEvaluation evaluation = evaluator.evaluate(model,testData);
- assertEquals(200,evaluation.getTruePositives());
- assertEquals(764,evaluation.getTrueNegatives());
+ assertEquals(196,evaluation.getTruePositives());
+ assertEquals(761,evaluation.getTrueNegatives());
assertEquals(0,evaluation.getFalseNegatives());
- assertEquals(36,evaluation.getFalsePositives());
+ assertEquals(43,evaluation.getFalsePositives());
String confusion = evaluation.confusionString();
String output = evaluation.toString();
diff --git a/tutorials/anomaly-tribuo-v4.ipynb b/tutorials/anomaly-tribuo-v4.ipynb
index e8262528f..24fe2e523 100644
--- a/tutorials/anomaly-tribuo-v4.ipynb
+++ b/tutorials/anomaly-tribuo-v4.ipynb
@@ -32,7 +32,7 @@
"import org.tribuo.util.Util;\n",
"import org.tribuo.anomaly.*;\n",
"import org.tribuo.anomaly.evaluation.*;\n",
- "import org.tribuo.anomaly.example.AnomalyDataGenerator;\n",
+ "import org.tribuo.anomaly.example.GaussianAnomalyDataSource;\n",
"import org.tribuo.anomaly.libsvm.*;\n",
"import org.tribuo.common.libsvm.*;"
]
@@ -51,11 +51,11 @@
"metadata": {},
"source": [
"## Dataset\n",
- "Tribuo's anomaly detection package comes with a simple data generator that emits pairs of datasets containing 5 features. The training data is free from anomalies, and each example is sampled from a 5 dimensional gaussian with fixed mean and diagonal covariance. The test data is sampled from a mixture of two distributions, the first is the same as the training distribution, and the second uses a different mean for the gaussian (keeping the same covariance for simplicity). All the data points sampled from the second distribution are marked `ANOMALOUS`, and the other points are marked `EXPECTED`. These form the two classes for Tribuo's anomaly detection system. You can also use any of the standard data loaders to pull in anomaly detection data.\n",
+ "Tribuo's anomaly detection package comes with a simple data source that samples data points from a mixture of two spherical gaussians. One gaussian is expected, and the other is anomalous. The fraction of each present in any given data source is controllable with the `fractionAnomalous` constructor argument. The means and variances of the expected and anomalous distributions are also controllable on construction or via configuration (see the Configuration tutorial for more details on Tribuo's configuration system). All the data points sampled from the second distribution are marked `ANOMALOUS`, and the other points are marked `EXPECTED`. These form the two classes for Tribuo's anomaly detection system. You can also use any of the standard data loaders to pull in anomaly detection data.\n",
"\n",
- "The libsvm anomaly detection algorithm requires there are no anomalies in the training data, but this is not required in general for Tribuo's anomaly detection infrastructure.\n",
+ "The LibSVM anomaly detection algorithm requires there are no anomalies in the training data, but this is not required in general for Tribuo's anomaly detection infrastructure.\n",
"\n",
- "We'll sample 2000 points for each dataset, and 20% of the test set will be anomalies."
+ "We'll sample 2000 points for each dataset, the training data will be free of anomalies to make LibSVM happy and 20% of the test set will be anomalies."
]
},
{
@@ -64,9 +64,10 @@
"metadata": {},
"outputs": [],
"source": [
- "var pair = AnomalyDataGenerator.gaussianAnomaly(2000,0.2);\n",
- "var data = pair.getA();\n",
- "var test = pair.getB();"
+ "var data = new MutableDataset<>(new GaussianAnomalyDataSource(2000,/* number of examples */\n",
+ " 0.0f,/*fraction anomalous */\n",
+ " 1L/* RNG seed */));\n",
+ "var test = new MutableDataset<>(new GaussianAnomalyDataSource(2000,0.2f,2L));"
]
},
{
@@ -106,11 +107,11 @@
"output_type": "stream",
"text": [
"*\n",
- "optimization finished, #iter = 692\n",
- "obj = 293.8182352369252, rho = 3.201748862633537\n",
- "nSV = 301, nBSV = 120\n",
+ "optimization finished, #iter = 653\n",
+ "obj = 289.5926348816893, rho = 3.144570476807895\n",
+ "nSV = 296, nBSV = 114\n",
"\n",
- "Training took (00:00:00:149)\n"
+ "Training took (00:00:00:179)\n"
]
}
],
@@ -126,7 +127,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Unfortunately the LibSVM implementation is a little chatty and insists on writing to standard out, but after that we can see it took about 120ms to run (on my 2020 16\" Macbook Pro, you may get slightly different runtimes). We can check how many support vectors are used by the SVM, from the training set of 2000:"
+ "Unfortunately the upstream LibSVM implementation is a little chatty and insists on writing to standard out, but after that we can see it took about 150ms to run (on my 2020 16\" Macbook Pro, you may get slightly different runtimes). We can check how many support vectors are used by the SVM, from the training set of 2000:"
]
},
{
@@ -137,7 +138,7 @@
{
"data": {
"text/plain": [
- "301"
+ "296"
]
},
"execution_count": 7,
@@ -153,7 +154,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "So we used 301 datapoints to model the density of the expected data."
+ "So we used 296 datapoints to model the density of the expected data."
]
},
{
@@ -173,10 +174,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "AnomalyEvaluation(tp=421 fp=250 tn=1329 fn=0 precision=0.627422 recall=1.000000 f1=0.771062)\n",
+ "AnomalyEvaluation(tp=405 fp=232 tn=1363 fn=0 precision=0.635793 recall=1.000000 f1=0.777351)\n",
" EXPECTED ANOMALOUS\n",
- "EXPECTED 1,329 250\n",
- "ANOMALOUS 0 421\n",
+ "EXPECTED 1,363 232\n",
+ "ANOMALOUS 0 405\n",
"\n"
]
}
@@ -191,9 +192,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can see that the model has no false negatives, and so perfect recall, but has a precision of 0.62, so approximately 62% of the positive predictions are true anomalies. This can be tuned by changing the width of the gaussian kernel which changes the range of values which are considered to be expected. The confusion matrix presents the same results in a more common form for classification tasks.\n",
+ "We can see that the model has no false negatives, and so perfect recall, but has a precision of 0.63, so approximately 63% of the positive predictions are true anomalies. This can be tuned by changing the width of the gaussian kernel which changes the range of values which are considered to be expected. The confusion matrix presents the same results in a more common form for classification tasks.\n",
"\n",
- "The 4.1 release has support for liblinear's anomaly detection, which is similar to libSVM's anomaly detector using a linear kernel. We expect to add to Tribuo's set of anomaly detection algorithms over time, and we welcome contributions to expand them on our [Github page](https://github.com/oracle/tribuo)."
+ "The 4.1 release added support for liblinear's anomaly detection, which is similar to LibSVM's anomaly detector using a linear kernel. We expect to add to Tribuo's set of anomaly detection algorithms over time, and we welcome contributions to expand them on our [Github page](https://github.com/oracle/tribuo)."
]
}
],
@@ -209,7 +210,7 @@
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
- "version": "17-ea+22-1964"
+ "version": "17+0"
}
},
"nbformat": 4,