diff --git a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java
index 5caf5a719..7efdd71da 100644
--- a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java
+++ b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java
@@ -41,7 +41,11 @@
/**
* A {@link Trainer} which wraps a liblinear-java anomaly detection trainer using a one-class SVM.
- *
+ *
+ * Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
+ * This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
+ * avoids locking on classes Tribuo does not control.
+ *
* See:
*
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
@@ -127,6 +131,9 @@ protected List trainModels(Parameter curParams, int numFeatures, FeatureN
data.x = features;
data.n = numFeatures;
+ // Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
+ // Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
+ Linear.resetRandom();
return Collections.singletonList(Linear.train(data,curParams));
}
diff --git a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java
index 17227476b..7e0982716 100644
--- a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java
+++ b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java
@@ -22,6 +22,7 @@
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.Trainer;
import org.tribuo.anomaly.Event;
import org.tribuo.anomaly.Event.EventType;
import org.tribuo.common.libsvm.LibSVMModel;
@@ -38,11 +39,16 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.SplittableRandom;
import java.util.logging.Logger;
/**
* A trainer for anomaly models that uses LibSVM.
*
+ * Note the train method is synchronized on {@code LibSVMTrainer.class} due to a global RNG in LibSVM.
+ * This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but
+ * avoids locking on classes Tribuo does not control.
+ *
* See:
*
* Chang CC, Lin CJ.
@@ -66,11 +72,20 @@ public class LibSVMAnomalyTrainer extends LibSVMTrainer {
protected LibSVMAnomalyTrainer() {}
/**
- * Creates a one-class LibSVM trainer using the supplied parameters.
- * @param parameters The training parameters.
+ * Creates a one-class LibSVM trainer using the supplied parameters and {@link Trainer#DEFAULT_SEED}.
+ * @param parameters The SVM training parameters.
*/
public LibSVMAnomalyTrainer(SVMParameters parameters) {
- super(parameters);
+ this(parameters, Trainer.DEFAULT_SEED);
+ }
+
+ /**
+ * Creates a one-class LibSVM trainer using the supplied parameters and RNG seed.
+ * @param parameters The SVM parameters.
+ * @param seed The RNG seed for LibSVM's internal RNG.
+ */
+ public LibSVMAnomalyTrainer(SVMParameters parameters, long seed) {
+ super(parameters,seed);
}
/**
@@ -100,7 +115,7 @@ protected LibSVMModel createModel(ModelProvenance provenance, ImmutableFe
}
@Override
- protected List trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
+ protected List trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs, SplittableRandom localRNG) {
svm_problem problem = new svm_problem();
problem.l = outputs[0].length;
problem.x = features;
@@ -112,6 +127,9 @@ protected List trainModels(svm_parameter curParams, int numFeatures,
if(checkString != null) {
throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
}
+ // This is safe because we synchronize on LibSVMTrainer.class in the train method to
+ // ensure there is no concurrent use of the rng.
+ svm.rand.setSeed(localRNG.nextLong());
return Collections.singletonList(svm.svm_train(problem, curParams));
}
diff --git a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java
index cbf45e5a0..e80261470 100644
--- a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java
+++ b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java
@@ -64,7 +64,7 @@ public SVMAnomalyType(SVMMode type) {
@Override
public boolean isClassification() {
- return true;
+ return false;
}
@Override
@@ -74,7 +74,7 @@ public boolean isRegression() {
@Override
public boolean isAnomaly() {
- return false;
+ return true;
}
@Override
diff --git a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java
index d3583e9f8..0940a1817 100644
--- a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java
+++ b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java
@@ -44,7 +44,11 @@
/**
* A {@link Trainer} which wraps a liblinear-java classifier trainer.
- *
+ *
+ * Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
+ * This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
+ * avoids locking on classes Tribuo does not control.
+ *
* See:
*
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
@@ -114,6 +118,9 @@ protected List trainModels(Parameter curParams, int numFeatures, FeatureN
data.n = numFeatures;
data.bias = 1.0;
+ // Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
+ // Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
+ Linear.resetRandom();
return Collections.singletonList(Linear.train(data,curParams));
}
diff --git a/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java b/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java
index 5c9277b3a..9272585df 100644
--- a/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java
+++ b/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java
@@ -112,7 +112,6 @@ public void testMulticlass() throws IOException, ClassNotFoundException {
for (Example