Skip to content

Commit

Permalink
[breaking] [jvm-packages] Remove rabit check point. (#9599)
Browse files Browse the repository at this point in the history
- Add `numBoostedRound` to jvm packages
- Remove rabit checkpoint version.
- Change the starting version of training continuation in JVM [breaking].
- Redefine the checkpoint version policy in jvm package. [breaking]
- Rename the Python check point callback parameter. [breaking]
- Unifies the checkpoint policy between Python and JVM.
  • Loading branch information
trivialfis authored Sep 26, 2023
1 parent 7901a29 commit c75a3bc
Show file tree
Hide file tree
Showing 15 changed files with 138 additions and 229 deletions.
4 changes: 2 additions & 2 deletions demo/guide-python/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def check(as_pickle):
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model"
directory=tmpdir, interval=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
Expand All @@ -118,7 +118,7 @@ def check(as_pickle):
# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
directory=tmpdir, interval=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
Expand Down
18 changes: 0 additions & 18 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len);

/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);

/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);


/*!
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
* support is experimental, function signature may change in the future without
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
}

private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val (model2, model4) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
(tmpPath, model2, model4)
}

test("test update/load models") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))

manager.updateCheckpoint(model4._booster.booster)
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)

manager.updateCheckpoint(model8._booster)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}

test("test cleanUpHigherVersions") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()

val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())

manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.model").exists())
}

test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}


Expand All @@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -787,35 +787,6 @@ private Map<String, Double> getFeatureImportanceFromModel(
return importanceMap;
}

/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0];
}

public int getVersion() {
return this.version;
}

public void setVersion(int version) {
this.version = version;
}

/**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
Expand All @@ -841,29 +812,6 @@ public byte[] toByteArray(String format) throws XGBoostError {
return bytes[0];
}

/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
version = out[0];
return version;
}

/**
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}

/**
* Get number of model features.
* @return the number of features.
Expand All @@ -874,6 +822,11 @@ public long getNumFeature() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
public int getNumBoostedRound() throws XGBoostError {
int[] numRound = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
return numRound[0];
}

/**
* Internal initialization function.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
Expand All @@ -15,7 +30,7 @@ public class ExternalCheckpointManager {

private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private Path checkpointPath; // directory for checkpoints
private FileSystem fs;

public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
Expand All @@ -35,6 +50,7 @@ private List<Integer> getExistingVersions() throws IOException {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
// Get integer versions from a list of checkpoint files.
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
Expand All @@ -44,19 +60,23 @@ private List<Integer> getExistingVersions() throws IOException {
}
}

private Integer latest(List<Integer> versions) {
return versions.stream()
.max(Comparator.comparing(Integer::valueOf)).get();
}

public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}

public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
int latestVersion = this.latest(versions);
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
Expand All @@ -65,13 +85,16 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {

public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
.map(this::getPath).collect(Collectors.toList());
// checkpointing is done after update, so n_rounds - 1 is the current iteration
// accounting for training continuation.
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
String eventualPath = getPath(iter);
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
logger.info("saving checkpoint with version " + iter);
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
Expand All @@ -83,35 +106,34 @@ public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XG
}

public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}

public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
// Get a list of iterations that need checkpointing.
public List<Integer> getCheckpointRounds(
int firstRound, int checkpointInterval, int numOfRounds)
throws IOException {
int end = firstRound + numOfRounds; // exclusive
int lastRound = end - 1;
if (end - 1 < 0) {
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
}

List<Integer> arr = new ArrayList<>();
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
for (int i = firstRound; i < end; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}

if (!arr.contains(lastRound)) {
arr.add(lastRound);
}
return arr;
}
}
Loading

0 comments on commit c75a3bc

Please sign in to comment.