Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bumping to liblinear-java 2.44 #228

Merged
merged 1 commit into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,10 +42,6 @@
/**
* A {@link Trainer} which wraps a liblinear-java anomaly detection trainer using a one-class SVM.
* <p>
* Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
* This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
Expand Down Expand Up @@ -94,14 +90,29 @@ public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, doubl

/**
* Creates a trainer for a LibLinear model
* <p>
* Uses {@link Trainer#DEFAULT_SEED} as the RNG seed.
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
* @param nu The nu parameter in the one-class SVM.
*/
public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, int maxIterations, double terminationCriterion, double nu) {
super(trainerType,cost,maxIterations,terminationCriterion);
this(trainerType,cost,maxIterations,terminationCriterion,nu,Trainer.DEFAULT_SEED);
}

/**
* Creates a trainer for a LibLinear model
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
* @param nu The nu parameter in the one-class SVM.
* @param seed The RNG seed.
*/
public LibLinearAnomalyTrainer(LinearAnomalyType trainerType, double cost, int maxIterations, double terminationCriterion, double nu, long seed) {
super(trainerType,cost,maxIterations,terminationCriterion,seed);
this.nu = nu;
}

Expand All @@ -119,7 +130,7 @@ public void postConfig() {
@Override
protected Parameter setupParameters(ImmutableOutputInfo<Event> labelIDMap) {
libLinearParams.setNu(nu);
return libLinearParams;
return libLinearParams.clone();
}

@Override
Expand All @@ -131,9 +142,6 @@ protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureN
data.x = features;
data.n = numFeatures;

// Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
// Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
Linear.resetRandom();
return Collections.singletonList(Linear.train(data,curParams));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,10 +45,6 @@
/**
* A {@link Trainer} which wraps a liblinear-java classifier trainer.
* <p>
* Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
* This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
Expand All @@ -70,14 +66,16 @@ public class LibLinearClassificationTrainer extends LibLinearTrainer<Label> impl
private Map<String,Float> labelWeights = Collections.emptyMap();

/**
* Creates a trainer using the default values (L2R_L2LOSS_SVC_DUAL, 1, 0.1).
* Creates a trainer using the default values ({@link LinearType#L2R_L2LOSS_SVC_DUAL}, 1, 0.1, {@link Trainer#DEFAULT_SEED}).
*/
public LibLinearClassificationTrainer() {
this(new LinearClassificationType(LinearType.L2R_L2LOSS_SVC_DUAL),1,1000,0.1);
}

/**
* Creates a trainer for a LibLinearClassificationModel. Sets maxIterations to 1000.
* Creates a trainer for a LibLinearClassificationModel.
* <p>
* Uses {@link Trainer#DEFAULT_SEED} as the RNG seed. Sets maxIterations to 1000.
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
Expand All @@ -88,13 +86,27 @@ public LibLinearClassificationTrainer(LinearClassificationType trainerType, doub

/**
* Creates a trainer for a LibLinear model
* <p>
* Uses {@link Trainer#DEFAULT_SEED} as the RNG seed.
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
*/
public LibLinearClassificationTrainer(LinearClassificationType trainerType, double cost, int maxIterations, double terminationCriterion) {
super(trainerType,cost,maxIterations,terminationCriterion);
this(trainerType,cost,maxIterations,terminationCriterion,Trainer.DEFAULT_SEED);
}

/**
* Creates a trainer for a LibLinear model
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
* @param seed The RNG seed.
*/
public LibLinearClassificationTrainer(LinearClassificationType trainerType, double cost, int maxIterations, double terminationCriterion, long seed) {
super(trainerType,cost,maxIterations,terminationCriterion,seed);
}

/**
Expand All @@ -118,9 +130,6 @@ protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureN
data.n = numFeatures;
data.bias = 1.0;

// Note this isn't sufficient for reproducibility as it doesn't cope with concurrency.
// Concurrency safety is handled by the global lock on LibLinearTrainer.class in LibLinearTrainer.train.
Linear.resetRandom();
return Collections.singletonList(Linear.train(data,curParams));
}

Expand Down Expand Up @@ -148,9 +157,8 @@ protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Label> data, Imm

@Override
protected Parameter setupParameters(ImmutableOutputInfo<Label> labelIDMap) {
Parameter curParams;
Parameter curParams = libLinearParams.clone();
if (!labelWeights.isEmpty()) {
curParams = new Parameter(libLinearParams.getSolverType(),libLinearParams.getC(),libLinearParams.getEps());
double[] weights = new double[labelIDMap.size()];
int[] indices = new int[labelIDMap.size()];
int i = 0;
Expand All @@ -167,8 +175,6 @@ protected Parameter setupParameters(ImmutableOutputInfo<Label> labelIDMap) {
}
curParams.setWeights(weights,indices);
//logger.info("Weights = " + Arrays.toString(weights) + ", labels = " + Arrays.toString(indices) + ", outputIDInfo = " + outputIDInfo);
} else {
curParams = libLinearParams;
}
return curParams;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,6 +64,7 @@

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -187,15 +188,21 @@ public Model<Label> testLibLinear(Pair<Dataset<Label>,Dataset<Label>> p) {

@Test
public void testReproducible() {
// Note this test will need to change if LibLinearTrainer grows a per Problem RNG.
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
t.setInvocationCount(0);
Model<Label> m = t.train(p.getA());
Map<String, List<Pair<String,Double>>> mFeatures = m.getTopFeatures(-1);

t.setInvocationCount(0);
Model<Label> mTwo = t.train(p.getA());
Map<String, List<Pair<String,Double>>> mTwoFeatures = mTwo.getTopFeatures(-1);

assertEquals(mFeatures,mTwoFeatures);

Model<Label> mThree = t.train(p.getA());
Map<String, List<Pair<String,Double>>> mThreeFeatures = mThree.getTopFeatures(-1);

assertNotEquals(mFeatures, mThreeFeatures);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,15 +39,13 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.logging.Logger;

/**
* A {@link Trainer} which wraps a liblinear-java trainer.
* <p>
* Note the train method is synchronized on {@code LibLinearTrainer.class} due to a global RNG in liblinear-java.
* This is insufficient to ensure reproducibility if liblinear-java is used directly in the same JVM as Tribuo, but
* avoids locking on classes Tribuo does not control.
* <p>
* See:
* <pre>
* Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
Expand Down Expand Up @@ -82,12 +80,22 @@ public abstract class LibLinearTrainer<T extends Output<T>> implements Trainer<T
@Config(description="Epsilon insensitivity in the regression cost function.")
protected double epsilon = 0.1;

@Config(description="RNG seed.")
protected long seed = Trainer.DEFAULT_SEED;

private SplittableRandom rng;

private int trainInvocationCount = 0;

/**
* For OLCUT
*/
protected LibLinearTrainer() {}

/**
* Creates a trainer for a LibLinear model
* <p>
* Uses {@link Trainer#DEFAULT_SEED} as the RNG seed, and 0.1 as epsilon.
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
Expand All @@ -103,14 +111,41 @@ protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIte
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
*/
protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, long seed) {
this(trainerType,cost,maxIterations,terminationCriterion,0.1, seed);
}

/**
* Creates a trainer for a LibLinear model
* <p>
* Uses {@link Trainer#DEFAULT_SEED} as the RNG seed.
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
* @param epsilon The insensitivity of the regression loss to small differences.
*/
protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon) {
this(trainerType,cost,maxIterations,terminationCriterion,epsilon,Trainer.DEFAULT_SEED);
}

/**
* Creates a trainer for a LibLinear model
* @param trainerType Loss function and optimisation method combination.
* @param cost Cost penalty for each incorrectly classified training point.
* @param maxIterations The maximum number of dataset iterations.
* @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
* @param epsilon The insensitivity of the regression loss to small differences.
* @param seed The RNG seed.
*/
protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon, long seed) {
this.trainerType = trainerType;
this.cost = cost;
this.maxIterations = maxIterations;
this.terminationCriterion = terminationCriterion;
this.epsilon = epsilon;
this.seed = seed;
postConfig();
}

Expand All @@ -120,6 +155,7 @@ protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIte
@Override
public void postConfig() {
libLinearParams = new Parameter(trainerType.getSolverType(),cost,terminationCriterion,maxIterations,epsilon);
rng = new SplittableRandom(seed);
Linear.disableDebugOutput();
}

Expand All @@ -134,27 +170,35 @@ public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runP
}

@Override
public synchronized LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}

// Creates a new RNG, adds one to the invocation count.
TrainerProvenance trainerProvenance;
SplittableRandom localRNG;
synchronized(this) {
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCount++;
}

ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo();
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
TrainerProvenance trainerProvenance = getProvenance();
ModelProvenance provenance = new ModelProvenance(LibLinearModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
trainInvocationCount++;

// Setup parameters and RNG
Parameter curParams = setupParameters(outputIDInfo);
curParams.setRandom(new Random(localRNG.nextLong()));

ModelProvenance provenance = new ModelProvenance(LibLinearModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);

Pair<FeatureNode[][],double[][]> data = extractData(examples,outputIDInfo,featureIDMap);

List<de.bwaldvogel.liblinear.Model> models;
synchronized (LibLinearTrainer.class) {
models = trainModels(curParams, featureIDMap.size() + 1, data.getA(), data.getB());
}
List<de.bwaldvogel.liblinear.Model> models = trainModels(curParams, featureIDMap.size() + 1, data.getA(), data.getB());

return createModel(provenance,featureIDMap,outputIDInfo,models);
}
Expand All @@ -170,7 +214,11 @@ public synchronized void setInvocationCount(int invocationCount) {
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

this.trainInvocationCount = invocationCount;
rng = new SplittableRandom(seed);

for (trainInvocationCount = 0; trainInvocationCount < invocationCount; trainInvocationCount++){
SplittableRandom localRNG = rng.split();
}
}

@Override
Expand All @@ -188,6 +236,8 @@ public String toString() {
buffer.append(libLinearParams.getMaxIters());
buffer.append(",regression-epsilon=");
buffer.append(libLinearParams.getP());
buffer.append(",seed=");
buffer.append(seed);
buffer.append(')');

return buffer.toString();
Expand Down Expand Up @@ -223,13 +273,13 @@ public String toString() {
protected abstract Pair<FeatureNode[][],double[][]> extractData(Dataset<T> data, ImmutableOutputInfo<T> outputInfo, ImmutableFeatureMap featureMap);

/**
* Constructs the parameters. Most of the time this is a no-op, but
* Constructs the parameters. Most of the time this just clones the existing ones, but
* classification overrides it to incorporate label weights if they exist.
* @param info The output info.
* @return The Parameters to use for training.
*/
protected Parameter setupParameters(ImmutableOutputInfo<T> info) {
return libLinearParams;
return libLinearParams.clone();
}

/**
Expand Down
Loading