diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 461c07153edd..47a1b611a199 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -37,6 +37,7 @@ 2.4.3 2.12.8 2.12 + 2.7.3 diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala deleted file mode 100644 index e4705f181b8a..000000000000 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManager.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package ml.dmlc.xgboost4j.scala.spark - -import ml.dmlc.xgboost4j.scala.Booster -import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost} -import org.apache.commons.logging.LogFactory -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext - -/** - * A class which allows user to save checkpoints every a few rounds. If a previous job fails, - * the job can restart training from a saved checkpoints instead of from scratch. This class - * provides interface and helper methods for the checkpoint functionality. - * - * NOTE: This checkpoint is different from Rabit checkpoint. Rabit checkpoint is a native-level - * checkpoint stored in executor memory. This is a checkpoint which Spark driver store on HDFS - * for every a few iterations. - * - * @param sc the sparkContext object - * @param checkpointPath the hdfs path to store checkpoints - */ -private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String) { - private val logger = LogFactory.getLog("XGBoostSpark") - private val modelSuffix = ".model" - - private def getPath(version: Int) = { - s"$checkpointPath/$version$modelSuffix" - } - - private def getExistingVersions: Seq[Int] = { - val fs = FileSystem.get(sc.hadoopConfiguration) - if (checkpointPath.isEmpty || !fs.exists(new Path(checkpointPath))) { - Seq() - } else { - fs.listStatus(new Path(checkpointPath)).map(_.getPath.getName).collect { - case fileName if fileName.endsWith(modelSuffix) => fileName.stripSuffix(modelSuffix).toInt - } - } - } - - def cleanPath(): Unit = { - if (checkpointPath != "") { - FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true) - } - } - - /** - * Load existing checkpoint with the highest version as a Booster object - * - * @return the booster with the highest version, null if no checkpoints available. - */ - private[spark] def loadCheckpointAsBooster: Booster = { - val versions = getExistingVersions - if (versions.nonEmpty) { - val version = versions.max - val fullPath = getPath(version) - val inputStream = FileSystem.get(sc.hadoopConfiguration).open(new Path(fullPath)) - logger.info(s"Start training from previous booster at $fullPath") - val booster = SXGBoost.loadModel(inputStream) - booster.booster.setVersion(version) - booster - } else { - null - } - } - - /** - * Clean up all previous checkpoints and save a new checkpoint - * - * @param checkpoint the checkpoint to save as an XGBoostModel - */ - private[spark] def updateCheckpoint(checkpoint: Booster): Unit = { - val fs = FileSystem.get(sc.hadoopConfiguration) - val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version))) - val fullPath = getPath(checkpoint.getVersion) - val outputStream = fs.create(new Path(fullPath), true) - logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath") - checkpoint.saveModel(outputStream) - prevModelPaths.foreach(path => fs.delete(path, true)) - } - - /** - * Clean up checkpoint boosters with version higher than or equal to the round. - * - * @param round the number of rounds in the current training job - */ - private[spark] def cleanUpHigherVersions(round: Int): Unit = { - val higherVersions = getExistingVersions.filter(_ / 2 >= round) - higherVersions.foreach { version => - val fs = FileSystem.get(sc.hadoopConfiguration) - fs.delete(new Path(getPath(version)), true) - } - } - - /** - * Calculate a list of checkpoint rounds to save checkpoints based on the checkpointInterval - * and total number of rounds for the training. Concretely, the checkpoint rounds start with - * prevRounds + checkpointInterval, and increase by checkpointInterval in each step until it - * reaches total number of rounds. If checkpointInterval is 0, the checkpoint will be disabled - * and the method returns Seq(round) - * - * @param checkpointInterval Period (in iterations) between checkpoints. - * @param round the total number of rounds for the training - * @return a seq of integers, each represent the index of round to save the checkpoints - */ - private[spark] def getCheckpointRounds(checkpointInterval: Int, round: Int): Seq[Int] = { - if (checkpointPath.nonEmpty && checkpointInterval > 0) { - val prevRounds = getExistingVersions.map(_ / 2) - val firstCheckpointRound = (0 +: prevRounds).max + checkpointInterval - (firstCheckpointRound until round by checkpointInterval) :+ round - } else if (checkpointInterval <= 0) { - Seq(round) - } else { - throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.") - } - } -} - -object CheckpointManager { - - case class CheckpointParam( - checkpointPath: String, - checkpointInterval: Int, - skipCleanCheckpoint: Boolean) - - private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = { - val checkpointPath: String = params.get("checkpoint_path") match { - case None => "" - case Some(path: String) => path - case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" + - " an instance of String.") - } - - val checkpointInterval: Int = params.get("checkpoint_interval") match { - case None => 0 - case Some(freq: Int) => freq - case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" + - " an instance of Int.") - } - - val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match { - case None => false - case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint - case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" + - " an instance of Boolean") - } - CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile) - } -} diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index a63a3df03e76..c0354866e24b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -25,12 +25,13 @@ import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker -import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams +import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import org.apache.commons.io.FileUtils import org.apache.commons.logging.LogFactory +import org.apache.hadoop.fs.FileSystem import org.apache.spark.rdd.RDD import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener} @@ -64,7 +65,7 @@ private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, see private[this] case class XGBoostExecutionParams( numWorkers: Int, - round: Int, + numRounds: Int, useExternalMemory: Boolean, obj: ObjectiveTrait, eval: EvalTrait, @@ -72,7 +73,7 @@ private[this] case class XGBoostExecutionParams( allowNonZeroForMissing: Boolean, trackerConf: TrackerConf, timeoutRequestWorkers: Long, - checkpointParam: CheckpointParam, + checkpointParam: Option[ExternalCheckpointParams], xgbInputParams: XGBoostExecutionInputParams, earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, cacheTrainingSet: Boolean) { @@ -167,7 +168,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s .getOrElse("allow_non_zero_for_missing", false) .asInstanceOf[Boolean] validateSparkSslConf - if (overridedParams.contains("tree_method")) { require(overridedParams("tree_method") == "hist" || overridedParams("tree_method") == "approx" || @@ -198,7 +198,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s " an instance of Long.") } val checkpointParam = - CheckpointManager.extractParams(overridedParams) + ExternalCheckpointParams.extractParams(overridedParams) val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) .asInstanceOf[Double] @@ -339,11 +339,9 @@ object XGBoost extends Serializable { watches: Watches, xgbExecutionParam: XGBoostExecutionParams, rabitEnv: java.util.Map[String, String], - round: Int, obj: ObjectiveTrait, eval: EvalTrait, prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = { - // to workaround the empty partitions in training dataset, // this might not be the best efficient implementation, see // (https://github.com/dmlc/xgboost/issues/1277) @@ -357,14 +355,23 @@ object XGBoost extends Serializable { rabitEnv.put("DMLC_TASK_ID", taskId) rabitEnv.put("DMLC_NUM_ATTEMPT", attempt) rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false") - + val numRounds = xgbExecutionParam.numRounds + val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0 try { Rabit.init(rabitEnv) val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds - val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) - val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round, - watches.toMap, metrics, obj, eval, - earlyStoppingRound = numEarlyStoppingRounds, prevBooster) + val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds)) + val externalCheckpointParams = xgbExecutionParam.checkpointParam + val booster = if (makeCheckpoint) { + SXGBoost.trainAndSaveCheckpoint( + watches.toMap("train"), xgbExecutionParam.toMap, numRounds, + watches.toMap, metrics, obj, eval, + earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams) + } else { + SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds, + watches.toMap, metrics, obj, eval, + earlyStoppingRound = numEarlyStoppingRounds, prevBooster) + } Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) } catch { case xgbException: XGBoostError => @@ -437,7 +444,6 @@ object XGBoost extends Serializable { trainingData: RDD[XGBLabeledPoint], xgbExecutionParams: XGBoostExecutionParams, rabitEnv: java.util.Map[String, String], - checkpointRound: Int, prevBooster: Booster, evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { if (evalSetsMap.isEmpty) { @@ -446,8 +452,8 @@ object XGBoost extends Serializable { processMissingValues(labeledPoints, xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing), getCacheDirName(xgbExecutionParams.useExternalMemory)) - buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, - xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) + buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj, + xgbExecutionParams.eval, prevBooster) }).cache() } else { coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers). @@ -459,8 +465,8 @@ object XGBoost extends Serializable { xgbExecutionParams.missing, xgbExecutionParams.allowNonZeroForMissing)) }, getCacheDirName(xgbExecutionParams.useExternalMemory)) - buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, - xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) + buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, xgbExecutionParams.obj, + xgbExecutionParams.eval, prevBooster) }.cache() } } @@ -469,7 +475,6 @@ object XGBoost extends Serializable { trainingData: RDD[Array[XGBLabeledPoint]], xgbExecutionParam: XGBoostExecutionParams, rabitEnv: java.util.Map[String, String], - checkpointRound: Int, prevBooster: Booster, evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { if (evalSetsMap.isEmpty) { @@ -478,7 +483,7 @@ object XGBoost extends Serializable { processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing), getCacheDirName(xgbExecutionParam.useExternalMemory)) - buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, + buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster) }).cache() } else { @@ -490,7 +495,7 @@ object XGBoost extends Serializable { xgbExecutionParam.missing, xgbExecutionParam.allowNonZeroForMissing)) }, getCacheDirName(xgbExecutionParam.useExternalMemory)) - buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, + buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster) @@ -529,60 +534,58 @@ object XGBoost extends Serializable { logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext) val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams - val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava val sc = trainingData.sparkContext - val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam. - checkpointPath) - checkpointManager.cleanUpHigherVersions(xgbExecParams.round) val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet, hasGroup, xgbExecParams.numWorkers) - var prevBooster = checkpointManager.loadCheckpointAsBooster + val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam => + val checkpointManager = new ExternalCheckpointManager( + checkpointParam.checkpointPath, + FileSystem.get(sc.hadoopConfiguration)) + checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds) + checkpointManager.loadCheckpointAsScalaBooster() + }.orNull try { // Train for every ${savingRound} rounds and save the partially completed booster - val producedBooster = checkpointManager.getCheckpointRounds( - xgbExecParams.checkpointParam.checkpointInterval, - xgbExecParams.round).map { - checkpointRound: Int => - val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) - try { - val parallelismTracker = new SparkParallelismTracker(sc, - xgbExecParams.timeoutRequestWorkers, - xgbExecParams.numWorkers) - - tracker.getWorkerEnvs().putAll(xgbRabitParams) - val boostersAndMetrics = if (hasGroup) { - trainForRanking(transformedTrainingData.left.get, xgbExecParams, - tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) - } else { - trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, - tracker.getWorkerEnvs(), checkpointRound, prevBooster, evalSetsMap) - } - val sparkJobThread = new Thread() { - override def run() { - // force the job - boostersAndMetrics.foreachPartition(() => _) - } - } - sparkJobThread.setUncaughtExceptionHandler(tracker) - sparkJobThread.start() - val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) - logger.info(s"Rabit returns with exit code $trackerReturnVal") - val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, - boostersAndMetrics, sparkJobThread) - if (checkpointRound < xgbExecParams.round) { - prevBooster = booster - checkpointManager.updateCheckpoint(prevBooster) - } - (booster, metrics) - } finally { - tracker.stop() + val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) + val (booster, metrics) = try { + val parallelismTracker = new SparkParallelismTracker(sc, + xgbExecParams.timeoutRequestWorkers, + xgbExecParams.numWorkers) + val rabitEnv = tracker.getWorkerEnvs + val boostersAndMetrics = if (hasGroup) { + trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, prevBooster, + evalSetsMap) + } else { + trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv, + prevBooster, evalSetsMap) + } + val sparkJobThread = new Thread() { + override def run() { + // force the job + boostersAndMetrics.foreachPartition(() => _) } - }.last + } + sparkJobThread.setUncaughtExceptionHandler(tracker) + sparkJobThread.start() + val trackerReturnVal = parallelismTracker.execute(tracker.waitFor(0L)) + logger.info(s"Rabit returns with exit code $trackerReturnVal") + val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, + boostersAndMetrics, sparkJobThread) + (booster, metrics) + } finally { + tracker.stop() + } // we should delete the checkpoint directory after a successful training - if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) { - checkpointManager.cleanPath() + xgbExecParams.checkpointParam.foreach { + cpParam => + if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) { + val checkpointManager = new ExternalCheckpointManager( + cpParam.checkpointPath, + FileSystem.get(sc.hadoopConfiguration)) + checkpointManager.cleanPath() + } } - producedBooster + (booster, metrics) } catch { case t: Throwable => // if the job was aborted due to an exception diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala index a69097c805b4..1c8ddfb08967 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorCommon.scala @@ -24,7 +24,7 @@ private[spark] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables { def needDeterministicRepartitioning: Boolean = { - getCheckpointPath.nonEmpty && getCheckpointInterval > 0 + getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0 } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala similarity index 57% rename from jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala rename to jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala index ddeb48241c86..5ef49431468f 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CheckpointManagerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala @@ -18,54 +18,71 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File -import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost} -import org.scalatest.FunSuite +import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost} +import org.scalatest.{FunSuite, Ignore} import org.apache.hadoop.fs.{FileSystem, Path} -class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest { +class ExternalCheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTest { - private lazy val (model4, model8) = { - val training = buildDataFrame(Classification.train) - val paramMap = Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", - "objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism) - (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), - new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) + private def produceParamMap(checkpointPath: String, checkpointInterval: Int): + Map[String, Any] = { + Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism, + "checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval) } - test("test update/load models") { + private def createNewModels(): + (String, XGBoostClassificationModel, XGBoostClassificationModel) = { val tmpPath = createTmpFolder("test").toAbsolutePath.toString - val manager = new CheckpointManager(sc, tmpPath) - manager.updateCheckpoint(model4._booster) + val (model4, model8) = { + val training = buildDataFrame(Classification.train) + val paramMap = produceParamMap(tmpPath, 2) + (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training), + new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training)) + } + (tmpPath, model4, model8) + } + + test("test update/load models") { + val (tmpPath, model4, model8) = createNewModels() + val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) + + manager.updateCheckpoint(model4._booster.booster) var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "4.model") - assert(manager.loadCheckpointAsBooster.booster.getVersion == 4) + assert(manager.loadCheckpointAsScalaBooster().getVersion == 4) manager.updateCheckpoint(model8._booster) files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath)) assert(files.length == 1) assert(files.head.getPath.getName == "8.model") - assert(manager.loadCheckpointAsBooster.booster.getVersion == 8) + assert(manager.loadCheckpointAsScalaBooster().getVersion == 8) } test("test cleanUpHigherVersions") { - val tmpPath = createTmpFolder("test").toAbsolutePath.toString - val manager = new CheckpointManager(sc, tmpPath) + val (tmpPath, model4, model8) = createNewModels() + + val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) manager.updateCheckpoint(model8._booster) - manager.cleanUpHigherVersions(round = 8) + manager.cleanUpHigherVersions(8) assert(new File(s"$tmpPath/8.model").exists()) - manager.cleanUpHigherVersions(round = 4) + manager.cleanUpHigherVersions(4) assert(!new File(s"$tmpPath/8.model").exists()) } test("test checkpoint rounds") { - val tmpPath = createTmpFolder("test").toAbsolutePath.toString - val manager = new CheckpointManager(sc, tmpPath) - assertResult(Seq(7))(manager.getCheckpointRounds(checkpointInterval = 0, round = 7)) - assertResult(Seq(2, 4, 6, 7))(manager.getCheckpointRounds(checkpointInterval = 2, round = 7)) + import scala.collection.JavaConverters._ + val (tmpPath, model4, model8) = createNewModels() + val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration)) + assertResult(Seq(7))( + manager.getCheckpointRounds(0, 7).asScala) + assertResult(Seq(2, 4, 6, 7))( + manager.getCheckpointRounds(2, 7).asScala) manager.updateCheckpoint(model4._booster) - assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7)) + assertResult(Seq(4, 6, 7))( + manager.getCheckpointRounds(2, 7).asScala) } @@ -75,17 +92,18 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes val testDM = new DMatrix(Classification.test.iterator) val tmpPath = createTmpFolder("model1").toAbsolutePath.toString + + val paramMap = produceParamMap(tmpPath, 2) + val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map() val skipCleanCheckpointMap = if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map() - val paramMap = Map("eta" -> "1", "max_depth" -> 2, - "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath, - "checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++ - skipCleanCheckpointMap - val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training) - def error(model: Booster): Float = eval.eval( - model.predict(testDM, outPutMargin = true), testDM) + val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap + + val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training) + + def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM) if (skipCleanCheckpoint) { // Check only one model is kept after training @@ -95,7 +113,7 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model") // Train next model based on prev model val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training) - assert(error(tmpModel) > error(prevModel._booster)) + assert(error(tmpModel) >= error(prevModel._booster)) assert(error(prevModel._booster) > error(nextModel._booster)) assert(error(nextModel._booster) < 0.1) } else { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala index e387faa31782..22e35e847a0f 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala @@ -127,7 +127,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest { " stop the application") { val spark = ss import spark.implicits._ - ss.sparkContext.setLogLevel("INFO") // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense // vector, val testDF = Seq( @@ -155,7 +154,6 @@ class MissingValueHandlingSuite extends FunSuite with PerTest { "does not stop application") { val spark = ss import spark.implicits._ - ss.sparkContext.setLogLevel("INFO") // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense // vector, val testDF = Seq( diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index e0450f90dfd8..c802fc0b0448 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -17,7 +17,7 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.java.XGBoostError -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} import org.apache.spark.ml.param.ParamMap diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala index 9655104c35d7..e106883ca254 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/RabitRobustnessSuite.scala @@ -20,14 +20,12 @@ import java.util.concurrent.LinkedBlockingDeque import scala.util.Random -import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, RabitTracker => PyRabitTracker} +import ml.dmlc.xgboost4j.java.{Rabit, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.{RabitTracker => ScalaRabitTracker} import ml.dmlc.xgboost4j.java.IRabitTracker.TrackerStatus import ml.dmlc.xgboost4j.scala.DMatrix -import org.apache.spark.{SparkConf, SparkContext} -import org.scalatest.FunSuite - +import org.scalatest.{FunSuite, Ignore} class RabitRobustnessSuite extends FunSuite with PerTest { diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 99893582906b..a2576d1ba7a9 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -13,6 +13,18 @@ jar + + org.apache.hadoop + hadoop-hdfs + ${hadoop.version} + provided + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + provided + junit junit diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java new file mode 100644 index 000000000000..655b99020313 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ExternalCheckpointManager.java @@ -0,0 +1,117 @@ +package ml.dmlc.xgboost4j.java; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.*; +import java.util.stream.Collectors; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; + +public class ExternalCheckpointManager { + + private Log logger = LogFactory.getLog("ExternalCheckpointManager"); + private String modelSuffix = ".model"; + private Path checkpointPath; + private FileSystem fs; + + public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError { + if (checkpointPath == null || checkpointPath.isEmpty()) { + throw new XGBoostError("cannot create ExternalCheckpointManager with null or" + + " empty checkpoint path"); + } + this.checkpointPath = new Path(checkpointPath); + this.fs = fs; + } + + private String getPath(int version) { + return checkpointPath.toUri().getPath() + "/" + version + modelSuffix; + } + + private List getExistingVersions() throws IOException { + if (!fs.exists(checkpointPath)) { + return new ArrayList<>(); + } else { + return Arrays.stream(fs.listStatus(checkpointPath)) + .map(path -> path.getPath().getName()) + .filter(fileName -> fileName.endsWith(modelSuffix)) + .map(fileName -> Integer.valueOf( + fileName.substring(0, fileName.length() - modelSuffix.length()))) + .collect(Collectors.toList()); + } + } + + public void cleanPath() throws IOException { + fs.delete(checkpointPath, true); + } + + public Booster loadCheckpointAsBooster() throws IOException, XGBoostError { + List versions = getExistingVersions(); + if (versions.size() > 0) { + int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get(); + String checkpointPath = getPath(latestVersion); + InputStream in = fs.open(new Path(checkpointPath)); + logger.info("loaded checkpoint from " + checkpointPath); + Booster booster = XGBoost.loadModel(in); + booster.setVersion(latestVersion); + return booster; + } else { + return null; + } + } + + public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError { + List prevModelPaths = getExistingVersions().stream() + .map(this::getPath).collect(Collectors.toList()); + String eventualPath = getPath(boosterToCheckpoint.getVersion()); + String tempPath = eventualPath + "-" + UUID.randomUUID(); + try (OutputStream out = fs.create(new Path(tempPath), true)) { + boosterToCheckpoint.saveModel(out); + fs.rename(new Path(tempPath), new Path(eventualPath)); + logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion()); + prevModelPaths.stream().forEach(path -> { + try { + fs.delete(new Path(path), true); + } catch (IOException e) { + logger.error("failed to delete outdated checkpoint at " + path, e); + } + }); + } + } + + public void cleanUpHigherVersions(int currentRound) throws IOException { + getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> { + try { + fs.delete(new Path(getPath(v)), true); + } catch (IOException e) { + logger.error("failed to clean checkpoint from other training instance", e); + } + }); + } + + public List getCheckpointRounds(int checkpointInterval, int numOfRounds) + throws IOException { + if (checkpointInterval > 0) { + List prevRounds = + getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList()); + prevRounds.add(0); + int firstCheckpointRound = prevRounds.stream() + .max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval; + List arr = new ArrayList<>(); + for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) { + arr.add(i); + } + arr.add(numOfRounds); + return arr; + } else if (checkpointInterval <= 0) { + List l = new ArrayList(); + l.add(numOfRounds); + return l; + } else { + throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set."); + } + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index b6a173bd6bd1..8adf4c0aebd3 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -15,12 +15,16 @@ */ package ml.dmlc.xgboost4j.java; +import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.util.*; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; /** * trainer for xgboost @@ -108,35 +112,34 @@ public static Booster train( return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null); } - /** - * Train a booster given parameters. - * - * @param dtrain Data to be trained. - * @param params Parameters. - * @param round Number of boosting iterations. - * @param watches a group of items to be evaluated during training, this allows user to watch - * performance on the validation set. - * @param metrics array containing the evaluation metrics for each matrix in watches for each - * iteration - * @param earlyStoppingRounds if non-zero, training would be stopped - * after a specified number of consecutive - * goes to the unexpected direction in any evaluation metric. - * @param obj customized objective - * @param eval customized evaluation - * @param booster train from scratch if set to null; train from an existing booster if not null. - * @return The trained booster. - */ - public static Booster train( - DMatrix dtrain, - Map params, - int round, - Map watches, - float[][] metrics, - IObjective obj, - IEvaluation eval, - int earlyStoppingRounds, - Booster booster) throws XGBoostError { + private static void saveCheckpoint( + Booster booster, + int iter, + Set checkpointIterations, + ExternalCheckpointManager ecm) throws XGBoostError { + try { + if (checkpointIterations.contains(iter)) { + ecm.updateCheckpoint(booster); + } + } catch (Exception e) { + logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e); + throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e); + } + } + public static Booster trainAndSaveCheckpoint( + DMatrix dtrain, + Map params, + int numRounds, + Map watches, + float[][] metrics, + IObjective obj, + IEvaluation eval, + int earlyStoppingRounds, + Booster booster, + int checkpointInterval, + String checkpointPath, + FileSystem fs) throws XGBoostError, IOException { //collect eval matrixs String[] evalNames; DMatrix[] evalMats; @@ -144,6 +147,11 @@ public static Booster train( int bestIteration; List names = new ArrayList(); List mats = new ArrayList(); + Set checkpointIterations = new HashSet<>(); + ExternalCheckpointManager ecm = null; + if (checkpointPath != null) { + ecm = new ExternalCheckpointManager(checkpointPath, fs); + } for (Map.Entry evalEntry : watches.entrySet()) { names.add(evalEntry.getKey()); @@ -158,7 +166,7 @@ public static Booster train( bestScore = Float.MAX_VALUE; } bestIteration = 0; - metrics = metrics == null ? new float[evalNames.length][round] : metrics; + metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics; //collect all data matrixs DMatrix[] allMats; @@ -181,14 +189,19 @@ public static Booster train( booster.setParams(params); } - //begin to train - for (int iter = booster.getVersion() / 2; iter < round; iter++) { + if (ecm != null) { + checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds)); + } + + // begin to train + for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { if (booster.getVersion() % 2 == 0) { if (obj != null) { booster.update(dtrain, obj); } else { booster.update(dtrain, iter); } + saveCheckpoint(booster, iter, checkpointIterations, ecm); booster.saveRabitCheckpoint(); } @@ -224,7 +237,7 @@ public static Booster train( if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) { Rabit.trackerPrint(String.format( "early stopping after %d rounds away from the best iteration", - earlyStoppingRounds)); + earlyStoppingRounds)); break; } } @@ -239,6 +252,44 @@ public static Booster train( return booster; } + /** + * Train a booster given parameters. + * + * @param dtrain Data to be trained. + * @param params Parameters. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param metrics array containing the evaluation metrics for each matrix in watches for each + * iteration + * @param earlyStoppingRounds if non-zero, training would be stopped + * after a specified number of consecutive + * goes to the unexpected direction in any evaluation metric. + * @param obj customized objective + * @param eval customized evaluation + * @param booster train from scratch if set to null; train from an existing booster if not null. + * @return The trained booster. + */ + public static Booster train( + DMatrix dtrain, + Map params, + int round, + Map watches, + float[][] metrics, + IObjective obj, + IEvaluation eval, + int earlyStoppingRounds, + Booster booster) throws XGBoostError { + try { + return trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, + earlyStoppingRounds, booster, + -1, null, null); + } catch (IOException e) { + logger.error("training failed in xgboost4j", e); + throw new XGBoostError("training failed in xgboost4j ", e); + } + } + private static Integer tryGetIntFromObject(Object o) { if (o instanceof Integer) { return (int)o; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java index 7da5c176d574..d113f5bfad3e 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostError.java @@ -24,4 +24,8 @@ public class XGBoostError extends Exception { public XGBoostError(String message) { super(message); } + + public XGBoostError(String message, Throwable cause) { + super(message, cause); + } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala new file mode 100644 index 000000000000..240c23871362 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ExternalCheckpointManager.scala @@ -0,0 +1,37 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import ml.dmlc.xgboost4j.java.{ExternalCheckpointManager => JavaECM} +import org.apache.hadoop.fs.FileSystem + +class ExternalCheckpointManager(checkpointPath: String, fs: FileSystem) + extends JavaECM(checkpointPath, fs) { + + def updateCheckpoint(booster: Booster): Unit = { + super.updateCheckpoint(booster.booster) + } + + def loadCheckpointAsScalaBooster(): Booster = { + val loadedBooster = super.loadCheckpointAsBooster() + if (loadedBooster == null) { + null + } else { + new Booster(loadedBooster) + } + } +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala index 609d7b2cde8c..791f06b716f1 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -18,14 +18,60 @@ package ml.dmlc.xgboost4j.scala import java.io.InputStream -import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError} +import ml.dmlc.xgboost4j.java.{XGBoostError, Booster => JBooster, XGBoost => JXGBoost} import scala.collection.JavaConverters._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + /** * XGBoost Scala Training function. */ object XGBoost { + private[scala] def trainAndSaveCheckpoint( + dtrain: DMatrix, + params: Map[String, Any], + numRounds: Int, + watches: Map[String, DMatrix] = Map(), + metrics: Array[Array[Float]] = null, + obj: ObjectiveTrait = null, + eval: EvalTrait = null, + earlyStoppingRound: Int = 0, + prevBooster: Booster, + checkpointParams: Option[ExternalCheckpointParams]): Booster = { + val jWatches = watches.mapValues(_.jDMatrix).asJava + val jBooster = if (prevBooster == null) { + null + } else { + prevBooster.booster + } + val xgboostInJava = checkpointParams. + map(cp => { + JXGBoost.trainAndSaveCheckpoint( + dtrain.jDMatrix, + // we have to filter null value for customized obj and eval + params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, + numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster, + cp.checkpointInterval, + cp.checkpointPath, + new Path(cp.checkpointPath).getFileSystem(new Configuration())) + }). + getOrElse( + JXGBoost.train( + dtrain.jDMatrix, + // we have to filter null value for customized obj and eval + params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, + numRounds, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) + ) + if (prevBooster == null) { + new Booster(xgboostInJava) + } else { + // Avoid creating a new SBooster with the same JBooster + prevBooster + } + } + /** * Train a booster given parameters. * @@ -55,23 +101,8 @@ object XGBoost { eval: EvalTrait = null, earlyStoppingRound: Int = 0, booster: Booster = null): Booster = { - val jWatches = watches.mapValues(_.jDMatrix).asJava - val jBooster = if (booster == null) { - null - } else { - booster.booster - } - val xgboostInJava = JXGBoost.train( - dtrain.jDMatrix, - // we have to filter null value for customized obj and eval - params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, - round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) - if (booster == null) { - new Booster(xgboostInJava) - } else { - // Avoid creating a new SBooster with the same JBooster - booster - } + trainAndSaveCheckpoint(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, + booster, None) } /** @@ -126,3 +157,41 @@ object XGBoost { new Booster(xgboostInJava) } } + +private[scala] case class ExternalCheckpointParams( + checkpointInterval: Int, + checkpointPath: String, + skipCleanCheckpoint: Boolean) + +private[scala] object ExternalCheckpointParams { + + def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = { + val checkpointPath: String = params.get("checkpoint_path") match { + case None | Some(null) | Some("") => null + case Some(path: String) => path + case _ => throw new IllegalArgumentException("parameter \"checkpoint_path\" must be" + + s" an instance of String, but current value is ${params("checkpoint_path")}") + } + + val checkpointInterval: Int = params.get("checkpoint_interval") match { + case None => 0 + case Some(freq: Int) => freq + case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" + + " an instance of Int.") + } + + val skipCleanCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match { + case None => false + case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint + case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" + + " an instance of Boolean") + } + if (checkpointPath == null || checkpointInterval == 0) { + None + } else { + Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile)) + } + } +} + +