diff --git a/demo/guide-python/callbacks.py b/demo/guide-python/callbacks.py index be03b1693d99..9c12f70de03c 100644 --- a/demo/guide-python/callbacks.py +++ b/demo/guide-python/callbacks.py @@ -104,7 +104,7 @@ def check(as_pickle): # Use callback class from xgboost.callback # Feel free to subclass/customize it to suit your need. check_point = xgb.callback.TrainingCheckPoint( - directory=tmpdir, iterations=rounds, name="model" + directory=tmpdir, interval=rounds, name="model" ) xgb.train( {"objective": "binary:logistic"}, @@ -118,7 +118,7 @@ def check(as_pickle): # This version of checkpoint saves everything including parameters and # model. See: doc/tutorials/saving_model.rst check_point = xgb.callback.TrainingCheckPoint( - directory=tmpdir, iterations=rounds, as_pickle=True, name="model" + directory=tmpdir, interval=rounds, as_pickle=True, name="model" ) xgb.train( {"objective": "binary:logistic"}, diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 9bce616efb84..5df62df55017 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len, XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, const void *buf, bst_ulong len); -/*! - * \brief Initialize the booster from rabit checkpoint. - * This is used in distributed training API. - * \param handle handle - * \param version The output version of the model. - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, - int* version); - -/*! - * \brief Save the current checkpoint to rabit. - * \param handle handle - * \return 0 when success, -1 when failure happens - */ -XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle); - - /*! * \brief Save XGBoost's internal configuration into a JSON document. Currently the * support is experimental, function signature may change in the future without diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala index adc9c10687be..e6835158d4b7 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite } private def createNewModels(): - (String, XGBoostClassificationModel, XGBoostClassificationModel) = { + (String, XGBoostClassificationModel, XGBoostClassificationModel) = { val tmpPath = createTmpFolder("test").toAbsolutePath.toString - val (model4, model8) = { + val (model2, model4) = { val training = buildDataFrame(Classification.train) val paramMap = produceParamMap(tmpPath, 2) (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) } - (tmpPath, model4, model8) + (tmpPath, model2, model4) } test("test update/load models") { - val (tmpPath, model4, model8) = createNewModels() + val (tmpPath, model2, model4) = createNewModels() val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) - manager.updateCheckpoint(model4._booster.booster) + manager.updateCheckpoint(model2._booster.booster) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) - assert(files.head.getPath.getName == "4.model") - assert(manager.loadCheckpointAsScalaBooster().getVersion == 4) + assert(files.head.getPath.getName == "1.model") + assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2) - manager.updateCheckpoint(model8._booster) + manager.updateCheckpoint(model4._booster) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) - assert(files.head.getPath.getName == "8.model") - assert(manager.loadCheckpointAsScalaBooster().getVersion == 8) + assert(files.head.getPath.getName == "3.model") + assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4) } test("test cleanUpHigherVersions") { - val (tmpPath, model4, model8) = createNewModels() + val (tmpPath, model2, model4) = createNewModels() val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) - manager.updateCheckpoint(model8._booster) - manager.cleanUpHigherVersions(8) - assert(new File(s"$tmpPath/8.model").exists()) + manager.updateCheckpoint(model4._booster) + manager.cleanUpHigherVersions(3) + assert(new File(s"$tmpPath/3.model").exists()) - manager.cleanUpHigherVersions(4) - assert(!new File(s"$tmpPath/8.model").exists()) + manager.cleanUpHigherVersions(2) + assert(!new File(s"$tmpPath/3.model").exists()) } test("test checkpoint rounds") { import scala.collection.JavaConverters._ - val (tmpPath, model4, model8) = createNewModels() + val (tmpPath, model2, model4) = createNewModels() val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) - assertResult(Seq(7))( - manager.getCheckpointRounds(0, 7).asScala) - assertResult(Seq(2, 4, 6, 7))( - manager.getCheckpointRounds(2, 7).asScala) - manager.updateCheckpoint(model4._booster) - assertResult(Seq(4, 6, 7))( - manager.getCheckpointRounds(2, 7).asScala) + assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala) + assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala) + assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala) } @@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite // Check only one model is kept after training val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) - assert(files.head.getPath.getName == "8.model") - val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") + assert(files.head.getPath.getName == "4.model") + val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model") // Train next model based on prev model val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) assert(error(tmpModel) >= error(prevModel._booster)) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 7ed12c704a9f..51959ce0cfb1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -787,35 +787,6 @@ private Map getFeatureImportanceFromModel( return importanceMap; } - /** - * Save the model as byte array representation. - * Write these bytes to a file will give compatible format with other xgboost bindings. - * - * If java natively support HDFS file API, use toByteArray and write the ByteArray - * - * @param withStats Controls whether the split statistics are output. - * @return dumped model information - * @throws XGBoostError native error - */ - private String[] getDumpInfo(boolean withStats) throws XGBoostError { - int statsFlag = 0; - if (withStats) { - statsFlag = 1; - } - String[][] modelInfos = new String[1][]; - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text", - modelInfos)); - return modelInfos[0]; - } - - public int getVersion() { - return this.version; - } - - public void setVersion(int version) { - this.version = version; - } - /** * Save model into raw byte array. Currently it's using the deprecated format as * default, which will be changed into `ubj` in future releases. @@ -841,29 +812,6 @@ public byte[] toByteArray(String format) throws XGBoostError { return bytes[0]; } - /** - * Load the booster model from thread-local rabit checkpoint. - * This is only used in distributed training. - * @return the stored version number of the checkpoint. - * @throws XGBoostError - */ - int loadRabitCheckpoint() throws XGBoostError { - int[] out = new int[1]; - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); - version = out[0]; - return version; - } - - /** - * Save the booster model into thread-local rabit checkpoint and increment the version. - * This is only used in distributed training. - * @throws XGBoostError - */ - void saveRabitCheckpoint() throws XGBoostError { - XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); - version += 1; - } - /** * Get number of model features. * @return the number of features. @@ -874,6 +822,11 @@ public long getNumFeature() throws XGBoostError { XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature)); return numFeature[0]; } + public int getNumBoostedRound() throws XGBoostError { + int[] numRound = new int[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound)); + return numRound[0]; + } /** * Internal initialization function. diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java index 655b99020313..3d794756daa5 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java @@ -1,3 +1,18 @@ +/* + Copyright (c) 2014-2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ package ml.dmlc.xgboost4j.java; import java.io.IOException; @@ -15,7 +30,7 @@ public class ExternalCheckpointManager { private Log logger = LogFactory.getLog("ExternalCheckpointManager"); private String modelSuffix = ".model"; - private Path checkpointPath; + private Path checkpointPath; // directory for checkpoints private FileSystem fs; public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError { @@ -35,6 +50,7 @@ private List getExistingVersions() throws IOException { if (!fs.exists(checkpointPath)) { return new ArrayList<>(); } else { + // Get integer versions from a list of checkpoint files. return Arrays.stream(fs.listStatus(checkpointPath)) .map(path -> path.getPath().getName()) .filter(fileName -> fileName.endsWith(modelSuffix)) @@ -44,6 +60,11 @@ private List getExistingVersions() throws IOException { } } + private Integer latest(List versions) { + return versions.stream() + .max(Comparator.comparing(Integer::valueOf)).get(); + } + public void cleanPath() throws IOException { fs.delete(checkpointPath, true); } @@ -51,12 +72,11 @@ public void cleanPath() throws IOException { public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { List versions = getExistingVersions(); if (versions.size() > 0) { - int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get(); + int latestVersion = this.latest(versions); String checkpointPath = getPath(latestVersion); InputStream in = fs.open(new Path(checkpointPath)); logger.info("loaded checkpoint from " + checkpointPath); Booster booster = XGBoost.loadModel(in); - booster.setVersion(latestVersion); return booster; } else { return null; @@ -65,13 +85,16 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError { List prevModelPaths = getExistingVersions().stream() - .map(this::getPath).collect(Collectors.toList()); - String eventualPath = getPath(boosterToCheckpoint.getVersion()); + .map(this::getPath).collect(Collectors.toList()); + // checkpointing is done after update, so n_rounds - 1 is the current iteration + // accounting for training continuation. + Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1; + String eventualPath = getPath(iter); String tempPath = eventualPath + "-" + UUID.randomUUID(); try (OutputStream out = fs.create(new Path(tempPath), true)) { boosterToCheckpoint.saveModel(out); fs.rename(new Path(tempPath), new Path(eventualPath)); - logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion()); + logger.info("saving checkpoint with version " + iter); prevModelPaths.stream().forEach(path -> { try { fs.delete(new Path(path), true); @@ -83,7 +106,7 @@ public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XG } public void cleanUpHigherVersions(int currentRound) throws IOException { - getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> { + getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> { try { fs.delete(new Path(getPath(v)), true); } catch (IOException e) { @@ -91,27 +114,26 @@ public void cleanUpHigherVersions(int currentRound) throws IOException { } }); } - - public List getCheckpointRounds(int checkpointInterval, int numOfRounds) + // Get a list of iterations that need checkpointing. + public List getCheckpointRounds( + int firstRound, int checkpointInterval, int numOfRounds) throws IOException { + int end = firstRound + numOfRounds; // exclusive + int lastRound = end - 1; + if (end - 1 < 0) { + throw new IllegalArgumentException("Inavlid `numOfRounds`."); + } + + List arr = new ArrayList<>(); if (checkpointInterval > 0) { - List prevRounds = - getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList()); - prevRounds.add(0); - int firstCheckpointRound = prevRounds.stream() - .max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval; - List arr = new ArrayList<>(); - for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) { + for (int i = firstRound; i < end; i += checkpointInterval) { arr.add(i); } - arr.add(numOfRounds); - return arr; - } else if (checkpointInterval <= 0) { - List l = new ArrayList(); - l.add(numOfRounds); - return l; - } else { - throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set."); } + + if (!arr.contains(lastRound)) { + arr.add(lastRound); + } + return arr; } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index bcd0b1b11d2f..2be62a3437d6 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014,2021 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -133,7 +133,7 @@ public static Booster train( int earlyStoppingRound) throws XGBoostError { return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); } - + // save checkpoint if iter is in checkpointIterations private static void saveCheckpoint( Booster booster, int iter, @@ -169,7 +169,6 @@ public static Booster trainAndSaveCheckpoint( int bestIteration; List names = new ArrayList(); List mats = new ArrayList(); - Set checkpointIterations = new HashSet<>(); ExternalCheckpointManager ecm = null; if (checkpointPath != null) { ecm = new ExternalCheckpointManager(checkpointPath, fs); @@ -203,32 +202,30 @@ public static Booster trainAndSaveCheckpoint( booster = new Booster(params, allMats); booster.setFeatureNames(dtrain.getFeatureNames()); booster.setFeatureTypes(dtrain.getFeatureTypes()); - booster.loadRabitCheckpoint(); } else { // Start training on an existing booster booster.setParams(params); } + Set checkpointIterations = new HashSet<>(); if (ecm != null) { - checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds)); + checkpointIterations = new HashSet<>( + ecm.getCheckpointRounds(booster.getNumBoostedRound(), checkpointInterval, numRounds)); } boolean initial_best_score_flag = false; boolean max_direction = false; // begin to train - for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { - if (booster.getVersion() % 2 == 0) { - if (obj != null) { - booster.update(dtrain, obj); - } else { - booster.update(dtrain, iter); - } - saveCheckpoint(booster, iter, checkpointIterations, ecm); - booster.saveRabitCheckpoint(); + for (int iter = 0; iter < numRounds; iter++) { + if (obj != null) { + booster.update(dtrain, iter, obj); + } else { + booster.update(dtrain, iter); } + saveCheckpoint(booster, iter, checkpointIterations, ecm); - //evaluation + // evaluation if (evalMats.length > 0) { float[] metricsOut = new float[evalMats.length]; String evalInfo; @@ -285,7 +282,6 @@ public static Booster trainAndSaveCheckpoint( Communicator.communicatorPrint(evalInfo + '\n'); } } - booster.saveRabitCheckpoint(); } return booster; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index eabbf29ba945..236d53e900a9 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -140,10 +140,11 @@ public final static native int XGBoosterDumpModelExWithFeatures( public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings); public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterSetAttr(long handle, String key, String value); - public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); - public final static native int XGBoosterSaveRabitCheckpoint(long handle); + public final static native int XGBoosterGetNumFeature(long handle, long[] feature); + public final static native int XGBoosterGetNumBoostedRound(long handle, int[] rounds); + // communicator functions public final static native int CommunicatorInit(String[] args); public final static native int CommunicatorFinalize(); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 31be86898e5a..c288bfab19fb 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -326,7 +326,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) @throws(classOf[XGBoostError]) def getNumFeature: Long = booster.getNumFeature - def getVersion: Int = booster.getVersion + def getNumBoostedRound: Long = booster.getNumBoostedRound /** * Save model into a raw byte array. Available options are "json", "ubj" and "deprecated". diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 821b1ebff054..332b1a12774b 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -984,33 +984,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr return ret; } -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterLoadRabitCheckpoint - * Signature: (J[I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint - (JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - int version; - int ret = XGBoosterLoadRabitCheckpoint(handle, &version); - JVM_CHECK_CALL(ret); - jint jversion = version; - jenv->SetIntArrayRegion(jout, 0, 1, &jversion); - return ret; -} - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterSaveRabitCheckpoint - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint - (JNIEnv *jenv, jclass jcls, jlong jhandle) { - BoosterHandle handle = (BoosterHandle) jhandle; - return XGBoosterSaveRabitCheckpoint(handle); -} - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetNumFeature @@ -1027,6 +1000,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound( + JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) { + BoosterHandle handle = (BoosterHandle)jhandle; + std::int32_t n_rounds{0}; + auto ret = XGBoosterBoostedRounds(handle, &n_rounds); + JVM_CHECK_CALL(ret); + jint jn_rounds = n_rounds; + jenv->SetIntArrayRegion(jout, 0, 1, &jn_rounds); + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 87ff6d30db6a..cc4ad53d4e4c 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -287,22 +287,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr (JNIEnv *, jclass, jlong, jstring, jstring); -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterLoadRabitCheckpoint - * Signature: (J[I)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabitCheckpoint - (JNIEnv *, jclass, jlong, jintArray); - -/* - * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterSaveRabitCheckpoint - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabitCheckpoint - (JNIEnv *, jclass, jlong); - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetNumFeature @@ -311,6 +295,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature (JNIEnv *, jclass, jlong, jlongArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterGetNumBoostedRound + * Signature: (J[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound + (JNIEnv *, jclass, jlong, jintArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: CommunicatorInit diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index c7508b20d8ea..b686ddbed858 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package ml.dmlc.xgboost4j.java; import junit.framework.TestCase; -import org.junit.Assert; import org.junit.Test; import java.io.ByteArrayInputStream; @@ -31,7 +30,7 @@ /** * test cases for Booster Inplace Predict - * + * * @author hzx and Sovrn */ public class BoosterImplTest { @@ -845,14 +844,12 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException { float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat); // Save tempBooster to bytestream and load back - int prevVersion = tempBooster.getVersion(); ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray()); tempBooster = XGBoost.loadModel(in); in.close(); - tempBooster.setVersion(prevVersion); // Continue training using tempBooster - round = 4; + round = 2; Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster); float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat); TestCase.assertTrue(booster1error == booster2error); diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 6077aa1e3188..29d880539df6 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -540,7 +540,10 @@ def after_training(self, model: _Model) -> _Model: class TrainingCheckPoint(TrainingCallback): - """Checkpointing operation. + """Checkpointing operation. Users are encouraged to create their own callbacks for + checkpoint as XGBoost doesn't handle distributed file systems. When checkpointing on + distributed systems, be sure to know the rank of the worker to avoid multiple + workers checkpointing to the same place. .. versionadded:: 1.3.0 @@ -553,9 +556,9 @@ class TrainingCheckPoint(TrainingCallback): pattern of output model file. Models will be saved as name_0.json, name_1.json, name_2.json .... as_pickle : - When set to True, all training parameters will be saved in pickle format, instead - of saving only the model. - iterations : + When set to True, all training parameters will be saved in pickle format, + instead of saving only the model. + interval : Interval of checkpointing. Checkpointing is slow so setting a larger number can reduce performance hit. @@ -566,15 +569,20 @@ def __init__( directory: Union[str, os.PathLike], name: str = "model", as_pickle: bool = False, - iterations: int = 100, + interval: int = 100, ) -> None: self._path = os.fspath(directory) self._name = name self._as_pickle = as_pickle - self._iterations = iterations - self._epoch = 0 + self._iterations = interval + self._epoch = 0 # counter for iterval + self._start = 0 # beginning iteration super().__init__() + def before_training(self, model: _Model) -> _Model: + self._start = model.num_boosted_rounds() + return model + def after_iteration( self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog ) -> bool: @@ -583,11 +591,12 @@ def after_iteration( self._path, self._name + "_" - + str(epoch) + + (str(epoch + self._start)) + (".pkl" if self._as_pickle else ".json"), ) - self._epoch = 0 + self._epoch = 0 # reset counter if collective.get_rank() == 0: + # checkpoint using the first worker if self._as_pickle: with open(path, "wb") as fd: pickle.dump(model, fd) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 2b0862d4945a..f6ab8d4dfe32 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1430,36 +1430,13 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle, API_END(); } -XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, - int* version) { - API_BEGIN(); - CHECK_HANDLE(); - auto* bst = static_cast(handle); - xgboost_CHECK_C_ARG_PTR(version); - *version = rabit::LoadCheckPoint(); - if (*version != 0) { - bst->Configure(); - } - API_END(); -} - -XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) { - API_BEGIN(); - CHECK_HANDLE(); - auto *learner = static_cast(handle); - learner->Configure(); - rabit::CheckPoint(); - API_END(); -} - -XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, - int end_layer, int step, +XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step, BoosterHandle *out) { API_BEGIN(); CHECK_HANDLE(); xgboost_CHECK_C_ARG_PTR(out); - auto* learner = static_cast(handle); + auto *learner = static_cast(handle); bool out_of_bound = false; auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound); if (out_of_bound) { diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index 56c9fdabdde3..262c09c99503 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -443,7 +443,7 @@ def test_check_point(self): m = xgb.DMatrix(X, y) with tempfile.TemporaryDirectory() as tmpdir: check_point = xgb.callback.TrainingCheckPoint( - directory=tmpdir, iterations=1, name="model" + directory=tmpdir, interval=1, name="model" ) xgb.train( {"objective": "binary:logistic"}, @@ -456,7 +456,7 @@ def test_check_point(self): assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json")) check_point = xgb.callback.TrainingCheckPoint( - directory=tmpdir, iterations=1, as_pickle=True, name="model" + directory=tmpdir, interval=1, as_pickle=True, name="model" ) xgb.train( {"objective": "binary:logistic"}, diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index ae8d24139579..3510dff7bf8a 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -2238,7 +2238,7 @@ def test_callback(self, client: "Client") -> None: y, callbacks=[ xgb.callback.TrainingCheckPoint( - directory=Path(tmpdir), iterations=1, name="model" + directory=Path(tmpdir), interval=1, name="model" ) ], )