diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 461c07153edd..47a1b611a199 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -37,6 +37,7 @@
2.4.3
2.12.8
2.12
+ 2.7.3
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala
deleted file mode 100644
index e4705f181b8a..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import ml.dmlc.xgboost4j.scala.Booster
-import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
-import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.SparkContext
-
-/**
- * A class which allows user to save checkpoints every a few rounds. If a previous job fails,
- * the job can restart training from a saved checkpoints instead of from scratch. This class
- * provides interface and helper methods for the checkpoint functionality.
- *
- * NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level
- * checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS
- * for every a few iterations.
- *
- * @param sc the sparkContext object
- * @param checkpointPath the hdfs path to store checkpoints
- */
-private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) {
- private val logger = LogFactory.getLog("XGBoostSpark")
- private val modelSuffix = ".model"
-
- private def getPath(version: Int) = {
- s"$checkpointPath/$version$modelSuffix"
- }
-
- private def getExistingVersions: Seq[Int] = {
- val fs = FileSystem.get(sc.hadoopConfiguration)
- if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) {
- Seq()
- } else {
- fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect {
- case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt
- }
- }
- }
-
- def cleanPath(): Unit = {
- if (checkpointPath != "") {
- FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
- }
- }
-
- /**
- * Load existing checkpoint with the highest version as a Booster object
- *
- * @return the booster with the highest version, null if no checkpoints available.
- */
- private[spark] def loadCheckpointAsBooster: Booster = {
- val versions = getExistingVersions
- if (versions.nonEmpty) {
- val version = versions.max
- val fullPath = getPath(version)
- val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath))
- logger.info(s"Start training from previous booster at $fullPath")
- val booster = SXGBoost.loadModel(inputStream)
- booster.booster.setVersion(version)
- booster
- } else {
- null
- }
- }
-
- /**
- * Clean up all previous checkpoints and save a new checkpoint
- *
- * @param checkpoint the checkpoint to save as an XGBoostModel
- */
- private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
- val fs = FileSystem.get(sc.hadoopConfiguration)
- val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
- val fullPath = getPath(checkpoint.getVersion)
- val outputStream = fs.create(new Path(fullPath), true)
- logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
- checkpoint.saveModel(outputStream)
- prevModelPaths.foreach(path => fs.delete(path, true))
- }
-
- /**
- * Clean up checkpoint boosters with version higher than or equal to the round.
- *
- * @param round the number of rounds in the current training job
- */
- private[spark] def cleanUpHigherVersions(round: Int): Unit = {
- val higherVersions = getExistingVersions.filter(_ / 2 >= round)
- higherVersions.foreach { version =>
- val fs = FileSystem.get(sc.hadoopConfiguration)
- fs.delete(new Path(getPath(version)), true)
- }
- }
-
- /**
- * Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval
- * and total number of rounds for the training. Concretely, the checkpoint rounds start with
- * prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it
- * reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled
- * and the method returns Seq(round)
- *
- * @param checkpointInterval Period (in iterations) between checkpoints.
- * @param round the total number of rounds for the training
- * @return a seq of integers, each represent the index of round to save the checkpoints
- */
- private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = {
- if (checkpointPath.nonEmpty && checkpointInterval > 0) {
- val prevRounds = getExistingVersions.map(_ / 2)
- val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval
- (firstCheckpointRound until round by checkpointInterval) :+ round
- } else if (checkpointInterval <= 0) {
- Seq(round)
- } else {
- throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.")
- }
- }
-}
-
-object CheckpointManager {
-
- case class CheckpointParam(
- checkpointPath: String,
- checkpointInterval: Int,
- skipCleanCheckpoint: Boolean)
-
- private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
- val checkpointPath: String = params.get("checkpoint_path") match {
- case None => ""
- case Some(path: String) => path
- case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
- " an instance of String.")
- }
-
- val checkpointInterval: Int = params.get("checkpoint_interval") match {
- case None => 0
- case Some(freq: Int) => freq
- case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
- " an instance of Int.")
- }
-
- val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
- case None => false
- case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
- case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
- " an instance of Boolean")
- }
- CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index a63a3df03e76..c0354866e24b 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -25,12 +25,13 @@ import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
-import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam
import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams
+import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.fs.FileSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
@@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see
private[this] case class XGBoostExecutionParams(
numWorkers: Int,
- round: Int,
+ numRounds: Int,
useExternalMemory: Boolean,
obj: ObjectiveTrait,
eval: EvalTrait,
@@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams(
allowNonZeroForMissing: Boolean,
trackerConf: TrackerConf,
timeoutRequestWorkers: Long,
- checkpointParam: CheckpointParam,
+ checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean) {
@@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean]
validateSparkSslConf
-
if (overridedParams.contains("tree_method")) {
require(overridedParams("tree_method") == "hist" ||
overridedParams("tree_method") == "approx" ||
@@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
" an instance of Long.")
}
val checkpointParam =
- CheckpointManager.extractParams(overridedParams)
+ ExternalCheckpointParams.extractParams(overridedParams)
val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
.asInstanceOf[Double]
@@ -339,11 +339,9 @@ object XGBoost extends Serializable {
watches: Watches,
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
- round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
-
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
@@ -357,14 +355,23 @@ object XGBoost extends Serializable {
rabitEnv.put("DMLC_TASK_ID", taskId)
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
-
+ val numRounds = xgbExecutionParam.numRounds
+ val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try {
Rabit.init(rabitEnv)
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
- val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round))
- val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round,
- watches.toMap, metrics, obj, eval,
- earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
+ val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
+ val externalCheckpointParams = xgbExecutionParam.checkpointParam
+ val booster = if (makeCheckpoint) {
+ SXGBoost.trainAndSaveCheckpoint(
+ watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
+ watches.toMap, metrics, obj, eval,
+ earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
+ } else {
+ SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
+ watches.toMap, metrics, obj, eval,
+ earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
+ }
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch {
case xgbException: XGBoostError =>
@@ -437,7 +444,6 @@ object XGBoost extends Serializable {
trainingData: RDD[XGBLabeledPoint],
xgbExecutionParams: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
- checkpointRound: Int,
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) {
@@ -446,8 +452,8 @@ object XGBoost extends Serializable {
processMissingValues(labeledPoints, xgbExecutionParams.missing,
xgbExecutionParams.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParams.useExternalMemory))
- buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
- xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
+ buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
+ xgbExecutionParams.eval, prevBooster)
}).cache()
} else {
coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers).
@@ -459,8 +465,8 @@ object XGBoost extends Serializable {
xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParams.useExternalMemory))
- buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound,
- xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster)
+ buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj,
+ xgbExecutionParams.eval, prevBooster)
}.cache()
}
}
@@ -469,7 +475,6 @@ object XGBoost extends Serializable {
trainingData: RDD[Array[XGBLabeledPoint]],
xgbExecutionParam: XGBoostExecutionParams,
rabitEnv: java.util.Map[String, String],
- checkpointRound: Int,
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
if (evalSetsMap.isEmpty) {
@@ -478,7 +483,7 @@ object XGBoost extends Serializable {
processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing,
xgbExecutionParam.allowNonZeroForMissing),
getCacheDirName(xgbExecutionParam.useExternalMemory))
- buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
+ buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster)
}).cache()
} else {
@@ -490,7 +495,7 @@ object XGBoost extends Serializable {
xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing))
},
getCacheDirName(xgbExecutionParam.useExternalMemory))
- buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound,
+ buildDistributedBooster(watches, xgbExecutionParam, rabitEnv,
xgbExecutionParam.obj,
xgbExecutionParam.eval,
prevBooster)
@@ -529,60 +534,58 @@ object XGBoost extends Serializable {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext)
val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams
- val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext
- val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam.
- checkpointPath)
- checkpointManager.cleanUpHigherVersions(xgbExecParams.round)
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers)
- var prevBooster = checkpointManager.loadCheckpointAsBooster
+ val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam =>
+ val checkpointManager = new ExternalCheckpointManager(
+ checkpointParam.checkpointPath,
+ FileSystem.get(sc.hadoopConfiguration))
+ checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
+ checkpointManager.loadCheckpointAsScalaBooster()
+ }.orNull
try {
// Train for every ${savingRound} rounds and save the partially completed booster
- val producedBooster = checkpointManager.getCheckpointRounds(
- xgbExecParams.checkpointParam.checkpointInterval,
- xgbExecParams.round).map {
- checkpointRound: Int =>
- val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
- try {
- val parallelismTracker = new SparkParallelismTracker(sc,
- xgbExecParams.timeoutRequestWorkers,
- xgbExecParams.numWorkers)
-
- tracker.getWorkerEnvs().putAll(xgbRabitParams)
- val boostersAndMetrics = if (hasGroup) {
- trainForRanking(transformedTrainingData.left.get, xgbExecParams,
- tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
- } else {
- trainForNonRanking(transformedTrainingData.right.get, xgbExecParams,
- tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap)
- }
- val sparkJobThread = new Thread() {
- override def run() {
- // force the job
- boostersAndMetrics.foreachPartition(() => _)
- }
- }
- sparkJobThread.setUncaughtExceptionHandler(tracker)
- sparkJobThread.start()
- val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
- logger.info(s"Rabit returns with exit code $trackerReturnVal")
- val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
- boostersAndMetrics, sparkJobThread)
- if (checkpointRound < xgbExecParams.round) {
- prevBooster = booster
- checkpointManager.updateCheckpoint(prevBooster)
- }
- (booster, metrics)
- } finally {
- tracker.stop()
+ val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf)
+ val (booster, metrics) = try {
+ val parallelismTracker = new SparkParallelismTracker(sc,
+ xgbExecParams.timeoutRequestWorkers,
+ xgbExecParams.numWorkers)
+ val rabitEnv = tracker.getWorkerEnvs
+ val boostersAndMetrics = if (hasGroup) {
+ trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster,
+ evalSetsMap)
+ } else {
+ trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv,
+ prevBooster, evalSetsMap)
+ }
+ val sparkJobThread = new Thread() {
+ override def run() {
+ // force the job
+ boostersAndMetrics.foreachPartition(() => _)
}
- }.last
+ }
+ sparkJobThread.setUncaughtExceptionHandler(tracker)
+ sparkJobThread.start()
+ val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L))
+ logger.info(s"Rabit returns with exit code $trackerReturnVal")
+ val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal,
+ boostersAndMetrics, sparkJobThread)
+ (booster, metrics)
+ } finally {
+ tracker.stop()
+ }
// we should delete the checkpoint directory after a successful training
- if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
- checkpointManager.cleanPath()
+ xgbExecParams.checkpointParam.foreach {
+ cpParam =>
+ if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
+ val checkpointManager = new ExternalCheckpointManager(
+ cpParam.checkpointPath,
+ FileSystem.get(sc.hadoopConfiguration))
+ checkpointManager.cleanPath()
+ }
}
- producedBooster
+ (booster, metrics)
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala
index a69097c805b4..1c8ddfb08967 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala
@@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables {
def needDeterministicRepartitioning: Boolean = {
- getCheckpointPath.nonEmpty && getCheckpointInterval > 0
+ getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
}
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
similarity index 57%
rename from jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala
rename to jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
index ddeb48241c86..5ef49431468f 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
@@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
-import org.scalatest.FunSuite
+import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
+import org.scalatest.{FunSuite, Ignore}
import org.apache.hadoop.fs.{FileSystem, Path}
-class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
+class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest {
- private lazy val (model4, model8) = {
- val training = buildDataFrame(Classification.train)
- val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
- "objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism)
- (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
- new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
+ private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
+ Map[String, Any] = {
+ Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
+ "objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
+ "checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
}
- test("test update/load models") {
+ private def createNewModels():
+ (String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
- val manager = new CheckpointManager(sc, tmpPath)
- manager.updateCheckpoint(model4._booster)
+ val (model4, model8) = {
+ val training = buildDataFrame(Classification.train)
+ val paramMap = produceParamMap(tmpPath, 2)
+ (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
+ new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
+ }
+ (tmpPath, model4, model8)
+ }
+
+ test("test update/load models") {
+ val (tmpPath, model4, model8) = createNewModels()
+ val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
+
+ manager.updateCheckpoint(model4._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
- assert(manager.loadCheckpointAsBooster.booster.getVersion == 4)
+ assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
manager.updateCheckpoint(model8._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
- assert(manager.loadCheckpointAsBooster.booster.getVersion == 8)
+ assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
}
test("test cleanUpHigherVersions") {
- val tmpPath = createTmpFolder("test").toAbsolutePath.toString
- val manager = new CheckpointManager(sc, tmpPath)
+ val (tmpPath, model4, model8) = createNewModels()
+
+ val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
- manager.cleanUpHigherVersions(round = 8)
+ manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
- manager.cleanUpHigherVersions(round = 4)
+ manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
}
test("test checkpoint rounds") {
- val tmpPath = createTmpFolder("test").toAbsolutePath.toString
- val manager = new CheckpointManager(sc, tmpPath)
- assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7))
- assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7))
+ import scala.collection.JavaConverters._
+ val (tmpPath, model4, model8) = createNewModels()
+ val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
+ assertResult(Seq(7))(
+ manager.getCheckpointRounds(0, 7).asScala)
+ assertResult(Seq(2, 4, 6, 7))(
+ manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
- assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
+ assertResult(Seq(4, 6, 7))(
+ manager.getCheckpointRounds(2, 7).asScala)
}
@@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val testDM = new DMatrix(Classification.test.iterator)
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
+
+ val paramMap = produceParamMap(tmpPath, 2)
+
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
- val paramMap = Map("eta" -> "1", "max_depth" -> 2,
- "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
- "checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
- skipCleanCheckpointMap
- val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
- def error(model: Booster): Float = eval.eval(
- model.predict(testDM, outPutMargin = true), testDM)
+ val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
+
+ val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
+
+ def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
if (skipCleanCheckpoint) {
// Check only one model is kept after training
@@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
- assert(error(tmpModel) > error(prevModel._booster))
+ assert(error(tmpModel) >= error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
} else {
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
index e387faa31782..22e35e847a0f 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
@@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
" stop the application") {
val spark = ss
import spark.implicits._
- ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
@@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest {
"does not stop application") {
val spark = ss
import spark.implicits._
- ss.sparkContext.setLogLevel("INFO")
// spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
// vector,
val testDF = Seq(
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
index e0450f90dfd8..c802fc0b0448 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
@@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark
import ml.dmlc.xgboost4j.java.XGBoostError
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.apache.spark.ml.param.ParamMap
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala
index 9655104c35d7..e106883ca254 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala
@@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque
import scala.util.Random
-import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker}
+import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker}
import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus
import ml.dmlc.xgboost4j.scala.DMatrix
-import org.apache.spark.{SparkConf, SparkContext}
-import org.scalatest.FunSuite
-
+import org.scalatest.{FunSuite, Ignore}
class RabitRobustnessSuite extends FunSuite with PerTest {
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 99893582906b..a2576d1ba7a9 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -13,6 +13,18 @@
jar
+
+ org.apache.hadoop
+ hadoop-hdfs
+ ${hadoop.version}
+ provided
+
+
+ org.apache.hadoop
+ hadoop-common
+ ${hadoop.version}
+ provided
+
junit
junit
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java
new file mode 100644
index 000000000000..655b99020313
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java
@@ -0,0 +1,117 @@
+package ml.dmlc.xgboost4j.java;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.*;
+import java.util.stream.Collectors;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
+public class ExternalCheckpointManager {
+
+ private Log logger = LogFactory.getLog("ExternalCheckpointManager");
+ private String modelSuffix = ".model";
+ private Path checkpointPath;
+ private FileSystem fs;
+
+ public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
+ if (checkpointPath == null || checkpointPath.isEmpty()) {
+ throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
+ " empty checkpoint path");
+ }
+ this.checkpointPath = new Path(checkpointPath);
+ this.fs = fs;
+ }
+
+ private String getPath(int version) {
+ return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
+ }
+
+ private List getExistingVersions() throws IOException {
+ if (!fs.exists(checkpointPath)) {
+ return new ArrayList<>();
+ } else {
+ return Arrays.stream(fs.listStatus(checkpointPath))
+ .map(path -> path.getPath().getName())
+ .filter(fileName -> fileName.endsWith(modelSuffix))
+ .map(fileName -> Integer.valueOf(
+ fileName.substring(0, fileName.length() - modelSuffix.length())))
+ .collect(Collectors.toList());
+ }
+ }
+
+ public void cleanPath() throws IOException {
+ fs.delete(checkpointPath, true);
+ }
+
+ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
+ List versions = getExistingVersions();
+ if (versions.size() > 0) {
+ int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
+ String checkpointPath = getPath(latestVersion);
+ InputStream in = fs.open(new Path(checkpointPath));
+ logger.info("loaded checkpoint from " + checkpointPath);
+ Booster booster = XGBoost.loadModel(in);
+ booster.setVersion(latestVersion);
+ return booster;
+ } else {
+ return null;
+ }
+ }
+
+ public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
+ List prevModelPaths = getExistingVersions().stream()
+ .map(this::getPath).collect(Collectors.toList());
+ String eventualPath = getPath(boosterToCheckpoint.getVersion());
+ String tempPath = eventualPath + "-" + UUID.randomUUID();
+ try (OutputStream out = fs.create(new Path(tempPath), true)) {
+ boosterToCheckpoint.saveModel(out);
+ fs.rename(new Path(tempPath), new Path(eventualPath));
+ logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
+ prevModelPaths.stream().forEach(path -> {
+ try {
+ fs.delete(new Path(path), true);
+ } catch (IOException e) {
+ logger.error("failed to delete outdated checkpoint at " + path, e);
+ }
+ });
+ }
+ }
+
+ public void cleanUpHigherVersions(int currentRound) throws IOException {
+ getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
+ try {
+ fs.delete(new Path(getPath(v)), true);
+ } catch (IOException e) {
+ logger.error("failed to clean checkpoint from other training instance", e);
+ }
+ });
+ }
+
+ public List getCheckpointRounds(int checkpointInterval, int numOfRounds)
+ throws IOException {
+ if (checkpointInterval > 0) {
+ List prevRounds =
+ getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
+ prevRounds.add(0);
+ int firstCheckpointRound = prevRounds.stream()
+ .max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
+ List arr = new ArrayList<>();
+ for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
+ arr.add(i);
+ }
+ arr.add(numOfRounds);
+ return arr;
+ } else if (checkpointInterval <= 0) {
+ List l = new ArrayList();
+ l.add(numOfRounds);
+ return l;
+ } else {
+ throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
index b6a173bd6bd1..8adf4c0aebd3 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
@@ -15,12 +15,16 @@
*/
package ml.dmlc.xgboost4j.java;
+import java.io.File;
import java.io.IOException;
import java.io.InputStream;
+import java.io.OutputStream;
import java.util.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
/**
* trainer for xgboost
@@ -108,35 +112,34 @@ public static Booster train(
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
}
- /**
- * Train a booster given parameters.
- *
- * @param dtrain Data to be trained.
- * @param params Parameters.
- * @param round Number of boosting iterations.
- * @param watches a group of items to be evaluated during training, this allows user to watch
- * performance on the validation set.
- * @param metrics array containing the evaluation metrics for each matrix in watches for each
- * iteration
- * @param earlyStoppingRounds if non-zero, training would be stopped
- * after a specified number of consecutive
- * goes to the unexpected direction in any evaluation metric.
- * @param obj customized objective
- * @param eval customized evaluation
- * @param booster train from scratch if set to null; train from an existing booster if not null.
- * @return The trained booster.
- */
- public static Booster train(
- DMatrix dtrain,
- Map params,
- int round,
- Map watches,
- float[][] metrics,
- IObjective obj,
- IEvaluation eval,
- int earlyStoppingRounds,
- Booster booster) throws XGBoostError {
+ private static void saveCheckpoint(
+ Booster booster,
+ int iter,
+ Set checkpointIterations,
+ ExternalCheckpointManager ecm) throws XGBoostError {
+ try {
+ if (checkpointIterations.contains(iter)) {
+ ecm.updateCheckpoint(booster);
+ }
+ } catch (Exception e) {
+ logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
+ throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
+ }
+ }
+ public static Booster trainAndSaveCheckpoint(
+ DMatrix dtrain,
+ Map params,
+ int numRounds,
+ Map watches,
+ float[][] metrics,
+ IObjective obj,
+ IEvaluation eval,
+ int earlyStoppingRounds,
+ Booster booster,
+ int checkpointInterval,
+ String checkpointPath,
+ FileSystem fs) throws XGBoostError, IOException {
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
@@ -144,6 +147,11 @@ public static Booster train(
int bestIteration;
List names = new ArrayList();
List mats = new ArrayList();
+ Set checkpointIterations = new HashSet<>();
+ ExternalCheckpointManager ecm = null;
+ if (checkpointPath != null) {
+ ecm = new ExternalCheckpointManager(checkpointPath, fs);
+ }
for (Map.Entry evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey());
@@ -158,7 +166,7 @@ public static Booster train(
bestScore = Float.MAX_VALUE;
}
bestIteration = 0;
- metrics = metrics == null ? new float[evalNames.length][round] : metrics;
+ metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;
//collect all data matrixs
DMatrix[] allMats;
@@ -181,14 +189,19 @@ public static Booster train(
booster.setParams(params);
}
- //begin to train
- for (int iter = booster.getVersion() / 2; iter < round; iter++) {
+ if (ecm != null) {
+ checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
+ }
+
+ // begin to train
+ for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
if (obj != null) {
booster.update(dtrain, obj);
} else {
booster.update(dtrain, iter);
}
+ saveCheckpoint(booster, iter, checkpointIterations, ecm);
booster.saveRabitCheckpoint();
}
@@ -224,7 +237,7 @@ public static Booster train(
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
Rabit.trackerPrint(String.format(
"early stopping after %d rounds away from the best iteration",
- earlyStoppingRounds));
+ earlyStoppingRounds));
break;
}
}
@@ -239,6 +252,44 @@ public static Booster train(
return booster;
}
+ /**
+ * Train a booster given parameters.
+ *
+ * @param dtrain Data to be trained.
+ * @param params Parameters.
+ * @param round Number of boosting iterations.
+ * @param watches a group of items to be evaluated during training, this allows user to watch
+ * performance on the validation set.
+ * @param metrics array containing the evaluation metrics for each matrix in watches for each
+ * iteration
+ * @param earlyStoppingRounds if non-zero, training would be stopped
+ * after a specified number of consecutive
+ * goes to the unexpected direction in any evaluation metric.
+ * @param obj customized objective
+ * @param eval customized evaluation
+ * @param booster train from scratch if set to null; train from an existing booster if not null.
+ * @return The trained booster.
+ */
+ public static Booster train(
+ DMatrix dtrain,
+ Map params,
+ int round,
+ Map watches,
+ float[][] metrics,
+ IObjective obj,
+ IEvaluation eval,
+ int earlyStoppingRounds,
+ Booster booster) throws XGBoostError {
+ try {
+ return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval,
+ earlyStoppingRounds, booster,
+ -1, null, null);
+ } catch (IOException e) {
+ logger.error("training failed in xgboost4j", e);
+ throw new XGBoostError("training failed in xgboost4j ", e);
+ }
+ }
+
private static Integer tryGetIntFromObject(Object o) {
if (o instanceof Integer) {
return (int)o;
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java
index 7da5c176d574..d113f5bfad3e 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java
@@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
public XGBoostError(String message) {
super(message);
}
+
+ public XGBoostError(String message, Throwable cause) {
+ super(message, cause);
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala
new file mode 100644
index 000000000000..240c23871362
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala
@@ -0,0 +1,37 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM}
+import org.apache.hadoop.fs.FileSystem
+
+class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem)
+ extends JavaECM(checkpointPath, fs) {
+
+ def updateCheckpoint(booster: Booster): Unit = {
+ super.updateCheckpoint(booster.booster)
+ }
+
+ def loadCheckpointAsScalaBooster(): Booster = {
+ val loadedBooster = super.loadCheckpointAsBooster()
+ if (loadedBooster == null) {
+ null
+ } else {
+ new Booster(loadedBooster)
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
index 609d7b2cde8c..791f06b716f1 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala
@@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala
import java.io.InputStream
-import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
+import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost}
import scala.collection.JavaConverters._
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
/**
* XGBoost Scala Training function.
*/
object XGBoost {
+ private[scala] def trainAndSaveCheckpoint(
+ dtrain: DMatrix,
+ params: Map[String, Any],
+ numRounds: Int,
+ watches: Map[String, DMatrix] = Map(),
+ metrics: Array[Array[Float]] = null,
+ obj: ObjectiveTrait = null,
+ eval: EvalTrait = null,
+ earlyStoppingRound: Int = 0,
+ prevBooster: Booster,
+ checkpointParams: Option[ExternalCheckpointParams]): Booster = {
+ val jWatches = watches.mapValues(_.jDMatrix).asJava
+ val jBooster = if (prevBooster == null) {
+ null
+ } else {
+ prevBooster.booster
+ }
+ val xgboostInJava = checkpointParams.
+ map(cp => {
+ JXGBoost.trainAndSaveCheckpoint(
+ dtrain.jDMatrix,
+ // we have to filter null value for customized obj and eval
+ params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
+ numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster,
+ cp.checkpointInterval,
+ cp.checkpointPath,
+ new Path(cp.checkpointPath).getFileSystem(new Configuration()))
+ }).
+ getOrElse(
+ JXGBoost.train(
+ dtrain.jDMatrix,
+ // we have to filter null value for customized obj and eval
+ params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
+ numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
+ )
+ if (prevBooster == null) {
+ new Booster(xgboostInJava)
+ } else {
+ // Avoid creating a new SBooster with the same JBooster
+ prevBooster
+ }
+ }
+
/**
* Train a booster given parameters.
*
@@ -55,23 +101,8 @@ object XGBoost {
eval: EvalTrait = null,
earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
- val jWatches = watches.mapValues(_.jDMatrix).asJava
- val jBooster = if (booster == null) {
- null
- } else {
- booster.booster
- }
- val xgboostInJava = JXGBoost.train(
- dtrain.jDMatrix,
- // we have to filter null value for customized obj and eval
- params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
- round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
- if (booster == null) {
- new Booster(xgboostInJava)
- } else {
- // Avoid creating a new SBooster with the same JBooster
- booster
- }
+ trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound,
+ booster, None)
}
/**
@@ -126,3 +157,41 @@ object XGBoost {
new Booster(xgboostInJava)
}
}
+
+private[scala] case class ExternalCheckpointParams(
+ checkpointInterval: Int,
+ checkpointPath: String,
+ skipCleanCheckpoint: Boolean)
+
+private[scala] object ExternalCheckpointParams {
+
+ def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
+ val checkpointPath: String = params.get("checkpoint_path") match {
+ case None | Some(null) | Some("") => null
+ case Some(path: String) => path
+ case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" +
+ s" an instance of String, but current value is ${params("checkpoint_path")}")
+ }
+
+ val checkpointInterval: Int = params.get("checkpoint_interval") match {
+ case None => 0
+ case Some(freq: Int) => freq
+ case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
+ " an instance of Int.")
+ }
+
+ val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
+ case None => false
+ case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
+ case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
+ " an instance of Boolean")
+ }
+ if (checkpointPath == null || checkpointInterval == 0) {
+ None
+ } else {
+ Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
+ }
+ }
+}
+
+