Skip to content

Commit

Permalink
Upgrades XGBoost to 1.3.1 (#107)
Browse files Browse the repository at this point in the history
* Initial upgrade of XGBoost to 1.3.1.

* Bumping XGBoost to v1.3.2

* Finishing plumbing through the extra functionality.

* Reverting back to 1.3.1 as that contains a macOS binary.

* Fixing options in line with the review comments.

* Fixing infinity in the options usage.
  • Loading branch information
Craigacp authored Mar 23, 2021
1 parent 4001680 commit 892afa1
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,32 @@ public XGBoostClassificationTrainer(int numTrees, double eta, double gamma, int
postConfig();
}

/**
* Create an XGBoost trainer.
*
* @param boosterType The base learning algorithm.
* @param treeMethod The tree building algorithm if using a tree booster.
* @param numTrees Number of trees to boost.
* @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
* @param gamma Minimum loss reduction to make a split (default 0, range
* [0,inf]).
* @param maxDepth Maximum tree depth (default 6, range [1,inf]).
* @param minChildWeight Minimum sum of instance weights needed in a leaf
* (default 1, range [0, inf]).
* @param subsample Subsample size for each tree (default 1, range (0,1]).
* @param featureSubsample Subsample features for each tree (default 1,
* range (0,1]).
* @param lambda L2 regularization term on weights (default 1).
* @param alpha L1 regularization term on weights (default 0).
* @param nThread Number of threads to use (default 4).
* @param verbosity Set the logging verbosity of the native library.
* @param seed RNG seed.
*/
public XGBoostClassificationTrainer(BoosterType boosterType, TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, LoggingVerbosity verbosity, long seed) {
super(boosterType,treeMethod,numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,verbosity,seed);
postConfig();
}

/**
* This gives direct access to the XGBoost parameter map.
* <p>
Expand All @@ -128,7 +154,7 @@ protected XGBoostClassificationTrainer() { }
public void postConfig() {
super.postConfig();
parameters.put("objective", "multi:softprob");
if(!evalMetric.isEmpty()) {
if (!evalMetric.isEmpty()) {
parameters.put("eval_metric", evalMetric);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,45 @@
import com.oracle.labs.mlrg.olcut.config.Option;
import org.tribuo.Trainer;
import org.tribuo.classification.ClassificationOptions;
import org.tribuo.common.xgboost.XGBoostTrainer;

import java.util.logging.Level;
import java.util.logging.Logger;

/**
* CLI options for training an XGBoost classifier.
*/
public class XGBoostOptions implements ClassificationOptions<XGBoostClassificationTrainer> {
private static final Logger logger = Logger.getLogger(XGBoostOptions.class.getName());

@Option(longName = "xgb-booster-type", usage = "Weak learning algorithm.")
public XGBoostTrainer.BoosterType xgbBoosterType = XGBoostTrainer.BoosterType.GBTREE;
@Option(longName = "xgb-tree-method", usage = "Tree building algorithm.")
public XGBoostTrainer.TreeMethod xgbTreeMethod = XGBoostTrainer.TreeMethod.AUTO;
@Option(longName = "xgb-ensemble-size", usage = "Number of trees in the ensemble.")
public int xgbEnsembleSize = -1;
@Option(longName = "xgb-alpha", usage = "L1 regularization term for weights (default 0).")
@Option(longName = "xgb-alpha", usage = "L1 regularization term for weights.")
public float xbgAlpha = 0.0f;
@Option(longName = "xgb-min-weight", usage = "Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).")
@Option(longName = "xgb-min-weight", usage = "Minimum sum of instance weights needed in a leaf (range [0,Infinity]).")
public float xgbMinWeight = 1;
@Option(longName = "xgb-max-depth", usage = "Max tree depth (default 6, range (0,inf]).")
@Option(longName = "xgb-max-depth", usage = "Max tree depth (range (0,Integer.MAX_VALUE]).")
public int xgbMaxDepth = 6;
@Option(longName = "xgb-eta", usage = "Step size shrinkage parameter (default 0.3, range [0,1]).")
@Option(longName = "xgb-eta", usage = "Step size shrinkage parameter (range [0,1]).")
public float xgbEta = 0.3f;
@Option(longName = "xgb-subsample-features", usage = "Subsample features for each tree (default 1, range (0,1]).")
public float xgbSubsampleFeatures;
@Option(longName = "xgb-gamma", usage = "Minimum loss reduction to make a split (default 0, range [0,inf]).")
@Option(longName = "xgb-subsample-features", usage = "Subsample features for each tree (range (0,1]).")
public float xgbSubsampleFeatures = 0.0f;
@Option(longName = "xgb-gamma", usage = "Minimum loss reduction to make a split (range [0,Infinity]).")
public float xgbGamma = 0.0f;
@Option(longName = "xgb-lambda", usage = "L2 regularization term for weights (default 1).")
@Option(longName = "xgb-lambda", usage = "L2 regularization term for weights.")
public float xgbLambda = 1.0f;
@Option(longName = "xgb-quiet", usage = "Make the XGBoost training procedure quiet.")
@Option(longName = "xgb-quiet", usage = "Deprecated, use xgb-loglevel.")
public boolean xgbQuiet;
@Option(longName = "xgb-subsample", usage = "Subsample size for each tree (default 1, range (0,1]).")
@Option(longName = "xgb-loglevel", usage = "Make the XGBoost training procedure quiet.")
public XGBoostTrainer.LoggingVerbosity xgbLogLevel = XGBoostTrainer.LoggingVerbosity.WARNING;
@Option(longName = "xgb-subsample", usage = "Subsample size for each tree (range (0,1]).")
public float xgbSubsample = 1.0f;
@Option(longName = "xgb-num-threads", usage = "Number of threads to use (default 4, range (1, num hw threads)).")
public int xgbNumThreads;
@Option(longName = "xgb-num-threads", usage = "Number of threads to use (range (1, num hw threads)). The default of 0 means use all hw threads.")
public int xgbNumThreads = 0;
@Option(longName = "xgb-seed", usage = "Sets the random seed for XGBoost.")
private long xgbSeed = Trainer.DEFAULT_SEED;

Expand All @@ -54,6 +66,10 @@ public XGBoostClassificationTrainer getTrainer() {
if (xgbEnsembleSize == -1) {
throw new IllegalArgumentException("Please supply the number of trees.");
}
return new XGBoostClassificationTrainer(xgbEnsembleSize, xgbEta, xgbGamma, xgbMaxDepth, xgbMinWeight, xgbSubsample, xgbSubsampleFeatures, xgbLambda, xbgAlpha, xgbNumThreads, xgbQuiet, xgbSeed);
if (xgbQuiet) {
logger.log(Level.WARNING,"Silencing XGBoost, overriding logging verbosity. Please switch to the 'xgb-loglevel' argument.");
xgbLogLevel = XGBoostTrainer.LoggingVerbosity.SILENT;
}
return new XGBoostClassificationTrainer(xgbBoosterType, xgbTreeMethod, xgbEnsembleSize, xgbEta, xgbGamma, xgbMaxDepth, xgbMinWeight, xgbSubsample, xgbSubsampleFeatures, xgbLambda, xbgAlpha, xgbNumThreads, xgbLogLevel, xgbSeed);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.common.xgboost.XGBoostFeatureImportance;
import org.tribuo.common.xgboost.XGBoostModel;
import org.tribuo.common.xgboost.XGBoostTrainer;
import org.tribuo.data.text.TextDataSource;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.data.text.impl.BasicPipeline;
Expand Down Expand Up @@ -61,6 +62,15 @@ public class TestXGBoost {

private static final XGBoostClassificationTrainer t = new XGBoostClassificationTrainer(50);

private static final XGBoostClassificationTrainer dart = new XGBoostClassificationTrainer(
XGBoostTrainer.BoosterType.DART,XGBoostTrainer.TreeMethod.AUTO,50,0.3,0,6,1,1,1,1,0,1, XGBoostTrainer.LoggingVerbosity.SILENT,42);

private static final XGBoostClassificationTrainer linear = new XGBoostClassificationTrainer(
XGBoostTrainer.BoosterType.LINEAR,XGBoostTrainer.TreeMethod.AUTO,50,0.3,0,6,1,1,1,1,0,1, XGBoostTrainer.LoggingVerbosity.SILENT,42);

private static final XGBoostClassificationTrainer gbtree = new XGBoostClassificationTrainer(
XGBoostTrainer.BoosterType.GBTREE,XGBoostTrainer.TreeMethod.HIST,50,0.3,0,6,1,1,1,1,0,1, XGBoostTrainer.LoggingVerbosity.SILENT,42);

private static final int[] NUM_TREES = new int[]{1,5,10,50};

//on Windows, this resolves to some nonsense like this: /C:/workspace/Classification/XGBoost/target/test-classes/test_input.tribuo
Expand Down Expand Up @@ -168,8 +178,8 @@ private void checkPrediction(String msgPrefix, XGBoostModel<Label> model, Predic
}
}

public Model<Label> testXGBoost(Pair<Dataset<Label>,Dataset<Label>> p) {
Model<Label> m = t.train(p.getA());
public static Model<Label> testXGBoost(XGBoostClassificationTrainer trainer, Pair<Dataset<Label>,Dataset<Label>> p) {
Model<Label> m = trainer.train(p.getA());
LabelEvaluator e = new LabelEvaluator();
LabelEvaluation evaluation = e.evaluate(m,p.getB());
Map<String, List<Pair<String,Double>>> features = m.getTopFeatures(3);
Expand Down Expand Up @@ -205,20 +215,23 @@ public void testFeatureImportanceSmokeTest() {
@Test
public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Model<Label> model = testXGBoost(p);
Model<Label> model = testXGBoost(t,p);
Helpers.testModelSerialization(model,Label.class);
testXGBoost(dart,p);
testXGBoost(linear,p);
testXGBoost(gbtree,p);
}

@Test
public void testSparseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.sparseTrainTest();
testXGBoost(p);
testXGBoost(t,p);
}

@Test
public void testSparseBinaryData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.binarySparseTrainTest();
testXGBoost(p);
testXGBoost(t,p);
}

@Test
Expand Down
5 changes: 5 additions & 0 deletions Common/XGBoost/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>

<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
Expand Down Expand Up @@ -68,6 +69,10 @@
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_2.12</artifactId>
</exclusion>
<exclusion>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_2.12</artifactId>
</exclusion>
<exclusion>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-java8-compat_2.12</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,56 @@ public abstract class XGBoostTrainer<T extends Output<T>> implements Trainer<T>,

private static final Logger logger = Logger.getLogger(XGBoostTrainer.class.getName());

protected final Map<String, Object> parameters = new HashMap<>();
/**
* The tree building algorithm.
*/
public enum TreeMethod {
/**
* XGBoost chooses between {@link TreeMethod#EXACT} and {@link TreeMethod#APPROX}
* depending on dataset size.
*/
AUTO("auto"),
/**
* Exact greedy algorithm, enumerates all split candidates.
*/
EXACT("exact"),
/**
* Approximate greedy algorithm, using a quantile sketch of the data and a gradient histogram.
*/
APPROX("approx"),
/**
* Faster histogram optimized approximate algorithm.
*/
HIST("hist"),
/**
* GPU implementation of the {@link TreeMethod#HIST} algorithm.
* <p>
* Note: GPU computation may not be supported on all platforms, and Tribuo is not tested with XGBoost GPU support.
*/
GPU_HIST("gpu_hist");

public final String paramName;

TreeMethod(String paramName) {
this.paramName = paramName;
}
}

/**
* The logging verbosity of the native library.
*/
public enum LoggingVerbosity {
SILENT(0),
WARNING(1),
INFO(2),
DEBUG(3);

public final int value;

LoggingVerbosity(int value) {
this.value = value;
}
}

/**
* The type of XGBoost model.
Expand All @@ -104,6 +153,8 @@ public enum BoosterType {
}
}

protected final Map<String, Object> parameters = new HashMap<>();

@Config(mandatory = true,description="The number of trees to build.")
protected int numTrees;

Expand Down Expand Up @@ -134,12 +185,22 @@ public enum BoosterType {
@Config(description="The number of threads to use at training time.")
private int nThread = 4;

@Config(description="Quiesce all the logging output from the XGBoost C library.")
/**
* Deprecated by XGBoost in favour of the verbosity field.
*/
@Deprecated
@Config(description="Quiesce all the logging output from the XGBoost C library. Deprecated in favour of 'verbosity'.")
private int silent = 1;

@Config(description="Logging verbosity, 0 is silent, 3 is debug.")
private LoggingVerbosity verbosity = LoggingVerbosity.SILENT;

@Config(description="Type of the weak learner.")
private BoosterType booster = BoosterType.GBTREE;

@Config(description="The tree building algorithm to use.")
private TreeMethod treeMethod = TreeMethod.AUTO;

@Config(description="The RNG seed.")
private long seed = Trainer.DEFAULT_SEED;

Expand All @@ -155,6 +216,8 @@ protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) {

/**
* Create an XGBoost trainer.
* <p>
* Sets the boosting algorithm to {@link BoosterType#GBTREE} and the tree building algorithm to {@link TreeMethod#AUTO}.
*
* @param numTrees Number of trees to boost.
* @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
Expand All @@ -173,9 +236,36 @@ protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) {
* @param seed RNG seed.
*/
protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) {
this(BoosterType.GBTREE,TreeMethod.AUTO,numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,silent ? LoggingVerbosity.SILENT : LoggingVerbosity.INFO,seed);
}

/**
* Create an XGBoost trainer.
*
* @param boosterType The base learning algorithm.
* @param treeMethod The tree building algorithm if using a tree booster.
* @param numTrees Number of trees to boost.
* @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
* @param gamma Minimum loss reduction to make a split (default 0, range
* [0,inf]).
* @param maxDepth Maximum tree depth (default 6, range [1,inf]).
* @param minChildWeight Minimum sum of instance weights needed in a leaf
* (default 1, range [0, inf]).
* @param subsample Subsample size for each tree (default 1, range (0,1]).
* @param featureSubsample Subsample features for each tree (default 1,
* range (0,1]).
* @param lambda L2 regularization term on weights (default 1).
* @param alpha L1 regularization term on weights (default 0).
* @param nThread Number of threads to use (default 4).
* @param verbosity Set the logging verbosity of the native library.
* @param seed RNG seed.
*/
protected XGBoostTrainer(BoosterType boosterType, TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, LoggingVerbosity verbosity, long seed) {
if (numTrees < 1) {
throw new IllegalArgumentException("Must supply a positive number of trees. Received " + numTrees);
}
this.booster = boosterType;
this.treeMethod = treeMethod;
this.numTrees = numTrees;
this.eta = eta;
this.gamma = gamma;
Expand All @@ -186,7 +276,8 @@ protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, d
this.lambda = lambda;
this.alpha = alpha;
this.nThread = nThread;
this.silent = silent ? 1 : 0;
this.verbosity = verbosity;
this.silent = 0; // silent is deprecated
this.seed = seed;
}

Expand Down Expand Up @@ -227,8 +318,13 @@ public void postConfig() {
parameters.put("alpha", alpha);
parameters.put("nthread", nThread);
parameters.put("seed", seed);
parameters.put("silent", silent);
if (silent == 1) {
parameters.put("verbosity", 0);
} else {
parameters.put("verbosity", verbosity.value);
}
parameters.put("booster", booster.paramName);
parameters.put("tree_method", treeMethod.paramName);
}

@Override
Expand Down
Loading

0 comments on commit 892afa1

Please sign in to comment.