diff --git a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java index 5caf5a719..7efdd71da 100644 --- a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java +++ b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java @@ -41,7 +41,11 @@ /** * A {@link Trainer} which wraps a liblinear-java anomaly detection trainer using a one-class SVM. - * + *

+ * Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java. + * This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but + * avoids locking on classes Tribuo does not control. + *

* See: *

  * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
@@ -127,6 +131,9 @@ protected List trainModels(Parameter curParams, int numFeatures, FeatureN
         data.x = features;
         data.n = numFeatures;
 
+        // Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
+        // Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
+        Linear.resetRandom();
         return Collections.singletonList(Linear.train(data,curParams));
     }
 
diff --git a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java
index 17227476b..7e0982716 100644
--- a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java
+++ b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/LibSVMAnomalyTrainer.java
@@ -22,6 +22,7 @@
 import org.tribuo.Example;
 import org.tribuo.ImmutableFeatureMap;
 import org.tribuo.ImmutableOutputInfo;
+import org.tribuo.Trainer;
 import org.tribuo.anomaly.Event;
 import org.tribuo.anomaly.Event.EventType;
 import org.tribuo.common.libsvm.LibSVMModel;
@@ -38,11 +39,16 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.SplittableRandom;
 import java.util.logging.Logger;
 
 /**
  * A trainer for anomaly models that uses LibSVM.
  * 

+ * Note the train method is synchronized on {@code LibSVMTrainer.class} due to a global RNG in LibSVM. + * This is insufficient to ensure reproducibility if LibSVM is used directly in the same JVM as Tribuo, but + * avoids locking on classes Tribuo does not control. + *

* See: *

  * Chang CC, Lin CJ.
@@ -66,11 +72,20 @@ public class LibSVMAnomalyTrainer extends LibSVMTrainer {
     protected LibSVMAnomalyTrainer() {}
 
     /**
-     * Creates a one-class LibSVM trainer using the supplied parameters.
-     * @param parameters The training parameters.
+     * Creates a one-class LibSVM trainer using the supplied parameters and {@link Trainer#DEFAULT_SEED}.
+     * @param parameters The SVM training parameters.
      */
     public LibSVMAnomalyTrainer(SVMParameters parameters) {
-        super(parameters);
+        this(parameters, Trainer.DEFAULT_SEED);
+    }
+
+    /**
+     * Creates a one-class LibSVM trainer using the supplied parameters and RNG seed.
+     * @param parameters The SVM parameters.
+     * @param seed The RNG seed for LibSVM's internal RNG.
+     */
+    public LibSVMAnomalyTrainer(SVMParameters parameters, long seed) {
+        super(parameters,seed);
     }
 
     /**
@@ -100,7 +115,7 @@ protected LibSVMModel createModel(ModelProvenance provenance, ImmutableFe
     }
 
     @Override
-    protected List trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
+    protected List trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs, SplittableRandom localRNG) {
         svm_problem problem = new svm_problem();
         problem.l = outputs[0].length;
         problem.x = features;
@@ -112,6 +127,9 @@ protected List trainModels(svm_parameter curParams, int numFeatures,
         if(checkString != null) {
             throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
         }
+        // This is safe because we synchronize on LibSVMTrainer.class in the train method to
+        // ensure there is no concurrent use of the rng.
+        svm.rand.setSeed(localRNG.nextLong());
         return Collections.singletonList(svm.svm_train(problem, curParams));
     }
 
diff --git a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java
index cbf45e5a0..e80261470 100644
--- a/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java
+++ b/AnomalyDetection/LibSVM/src/main/java/org/tribuo/anomaly/libsvm/SVMAnomalyType.java
@@ -64,7 +64,7 @@ public SVMAnomalyType(SVMMode type) {
 
     @Override
     public boolean isClassification() {
-        return true;
+        return false;
     }
 
     @Override
@@ -74,7 +74,7 @@ public boolean isRegression() {
 
     @Override
     public boolean isAnomaly() {
-        return false;
+        return true;
     }
 
     @Override
diff --git a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java
index d3583e9f8..0940a1817 100644
--- a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java
+++ b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java
@@ -44,7 +44,11 @@
 
 /**
  * A {@link Trainer} which wraps a liblinear-java classifier trainer.
- *
+ * 

+ * Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java. + * This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but + * avoids locking on classes Tribuo does not control. + *

* See: *

  * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
@@ -114,6 +118,9 @@ protected List trainModels(Parameter curParams, int numFeatures, FeatureN
         data.n = numFeatures;
         data.bias = 1.0;
 
+        // Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
+        // Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
+        Linear.resetRandom();
         return Collections.singletonList(Linear.train(data,curParams));
     }
 
diff --git a/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java b/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java
index 5c9277b3a..9272585df 100644
--- a/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java
+++ b/Classification/LibLinear/src/test/java/org/tribuo/classification/liblinear/TestLibLinearModel.java
@@ -112,7 +112,6 @@ public void testMulticlass() throws IOException, ClassNotFoundException {
         for (Example