diff --git a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java index 7efdd71da..e10ac708b 100644 --- a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java +++ b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,10 +42,6 @@ /** * A {@link Trainer} which wraps a liblinear-java anomaly detection trainer using a one-class SVM. *

- * Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java. - * This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but - * avoids locking on classes Tribuo does not control. - *

* See: *

  * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
@@ -94,6 +90,8 @@ public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, doubl
 
     /**
      * Creates a trainer for a LibLinear model
+     * 

+ * Uses {@link Trainer#DEFAULT_SEED} as the RNG seed. * @param trainerType Loss function and optimisation method combination. * @param cost Cost penalty for each incorrectly classified training point. * @param maxIterations The maximum number of dataset iterations. @@ -101,7 +99,20 @@ public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, doubl * @param nu The nu parameter in the one-class SVM. */ public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, int maxIterations, double terminationCriterion, double nu) { - super(trainerType,cost,maxIterations,terminationCriterion); + this(trainerType,cost,maxIterations,terminationCriterion,nu,Trainer.DEFAULT_SEED); + } + + /** + * Creates a trainer for a LibLinear model + * @param trainerType Loss function and optimisation method combination. + * @param cost Cost penalty for each incorrectly classified training point. + * @param maxIterations The maximum number of dataset iterations. + * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1). + * @param nu The nu parameter in the one-class SVM. + * @param seed The RNG seed. + */ + public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, int maxIterations, double terminationCriterion, double nu, long seed) { + super(trainerType,cost,maxIterations,terminationCriterion,seed); this.nu = nu; } @@ -119,7 +130,7 @@ public void postConfig() { @Override protected Parameter setupParameters(ImmutableOutputInfo labelIDMap) { libLinearParams.setNu(nu); - return libLinearParams; + return libLinearParams.clone(); } @Override @@ -131,9 +142,6 @@ protected List trainModels(Parameter curParams, int numFeatures, FeatureN data.x = features; data.n = numFeatures; - // Note this isn't sufficient for reproducibility as it doesn't cope with concurrency. - // Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train. - Linear.resetRandom(); return Collections.singletonList(Linear.train(data,curParams)); } diff --git a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java index 0940a1817..1c10e96cd 100644 --- a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java +++ b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationTrainer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,10 +45,6 @@ /** * A {@link Trainer} which wraps a liblinear-java classifier trainer. *

- * Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java. - * This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but - * avoids locking on classes Tribuo does not control. - *

* See: *

  * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
@@ -70,14 +66,16 @@ public class LibLinearClassificationTrainer extends LibLinearTrainer