diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index d14b84ddc79a..b96a446d13d8 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -24,6 +24,8 @@ import scala.util.Random import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker} import ml.dmlc.xgboost4j.scala.rabit.RabitTracker +import ml.dmlc.xgboost4j.scala.spark.CheckpointManager.CheckpointParam +import ml.dmlc.xgboost4j.scala.spark.XGBoost.logger import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} @@ -55,6 +57,172 @@ object TrackerConf { def apply(): TrackerConf = TrackerConf(0L, "python") } +private[this] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int, + maximizeEvalMetrics: Boolean) + +private[this] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long) + +private[this] case class XGBoostExecutionParams( + numWorkers: Int, + round: Int, + useExternalMemory: Boolean, + obj: ObjectiveTrait, + eval: EvalTrait, + missing: Float, + trackerConf: TrackerConf, + timeoutRequestWorkers: Long, + checkpointParam: CheckpointParam, + xgbInputParams: XGBoostExecutionInputParams, + earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, + cacheTrainingSet: Boolean) { + + private var rawParamMap: Map[String, Any] = _ + + def setRawParamMap(inputMap: Map[String, Any]): Unit = { + rawParamMap = inputMap + } + + def toMap: Map[String, Any] = { + rawParamMap + } +} + +private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){ + + private val logger = LogFactory.getLog("XGBoostSpark") + + private val overridedParams = overrideParams(rawParams, sc) + + /** + * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). + * If so, throw an exception unless this safety measure has been explicitly overridden + * via conf `xgboost.spark.ignoreSsl`. + */ + private def validateSparkSslConf: Unit = { + val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) = + SparkSession.getActiveSession match { + case Some(ss) => + (ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean, + ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean) + case None => + (sc.getConf.getBoolean("spark.ssl.enabled", false), + sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false)) + } + if (sparkSslEnabled) { + if (xgboostSparkIgnoreSsl) { + logger.warn(s"spark-xgboost is being run without encrypting data in transit! " + + s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.") + } else { + throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " + + "in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " + + "To override this protection and still use xgboost-spark at your own risk, " + + "you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.") + } + } + } + + /** + * we should not include any nested structure in the output of this function as the map is + * eventually to be feed to xgboost4j layer + */ + private def overrideParams( + params: Map[String, Any], + sc: SparkContext): Map[String, Any] = { + val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1) + var overridedParams = params + if (overridedParams.contains("nthread")) { + val nThread = overridedParams("nthread").toString.toInt + require(nThread <= coresPerTask, + s"the nthread configuration ($nThread) must be no larger than " + + s"spark.task.cpus ($coresPerTask)") + } else { + overridedParams = overridedParams + ("nthread" -> coresPerTask) + } + + val numEarlyStoppingRounds = overridedParams.getOrElse( + "num_early_stopping_rounds", 0).asInstanceOf[Int] + overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds + if (numEarlyStoppingRounds > 0 && + !overridedParams.contains("maximize_evaluation_metrics")) { + if (overridedParams.contains("custom_eval")) { + throw new IllegalArgumentException("custom_eval does not support early stopping") + } + val eval_metric = overridedParams("eval_metric").toString + val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric + logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize) + overridedParams += ("maximize_evaluation_metrics" -> maximize) + } + overridedParams + } + + def buildXGBRuntimeParams: XGBoostExecutionParams = { + val nWorkers = overridedParams("num_workers").asInstanceOf[Int] + val round = overridedParams("num_round").asInstanceOf[Int] + val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean] + val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] + val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] + val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float] + validateSparkSslConf + + if (overridedParams.contains("tree_method")) { + require(overridedParams("tree_method") == "hist" || + overridedParams("tree_method") == "approx" || + overridedParams("tree_method") == "auto", "xgboost4j-spark only supports tree_method as" + + " 'hist', 'approx' and 'auto'") + } + if (overridedParams.contains("train_test_ratio")) { + logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + + "'eval_set_names'") + } + require(nWorkers > 0, "you must specify more than 0 workers") + if (obj != null) { + require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " + + "is not defined, you have to specify the objective type as classification or regression" + + " with a customized objective function") + } + val trackerConf = overridedParams.get("tracker_conf") match { + case None => TrackerConf() + case Some(conf: TrackerConf) => conf + case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + + "instance of TrackerConf.") + } + val timeoutRequestWorkers: Long = overridedParams.get("timeout_request_workers") match { + case None => 0L + case Some(interval: Long) => interval + case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + + " an instance of Long.") + } + val checkpointParam = + CheckpointManager.extractParams(overridedParams) + + val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) + .asInstanceOf[Double] + val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long] + val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed) + + val earlyStoppingRounds = overridedParams.getOrElse( + "num_early_stopping_rounds", 0).asInstanceOf[Int] + val maximizeEvalMetrics = overridedParams.getOrElse( + "maximize_evaluation_metrics", true).asInstanceOf[Boolean] + val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds, + maximizeEvalMetrics) + + val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) + .asInstanceOf[Boolean] + + val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval, + missing, trackerConf, + timeoutRequestWorkers, + checkpointParam, + inputParams, + xgbExecEarlyStoppingParams, + cacheTrainingSet) + xgbExecParam.setRawParamMap(overridedParams) + xgbExecParam + } +} + /** * Traing data group in a RDD partition. * @param groupId The group id @@ -136,7 +304,7 @@ object XGBoost extends Serializable { private def buildDistributedBooster( watches: Watches, - params: Map[String, Any], + xgbExecutionParam: XGBoostExecutionParams, rabitEnv: java.util.Map[String, String], round: Int, obj: ObjectiveTrait, @@ -157,24 +325,9 @@ object XGBoost extends Serializable { try { Rabit.init(rabitEnv) - - val numEarlyStoppingRounds = params.get("num_early_stopping_rounds") - .map(_.toString.toInt).getOrElse(0) - val overridedParams = if (numEarlyStoppingRounds > 0 && - !params.contains("maximize_evaluation_metrics")) { - if (params.contains("custom_eval")) { - throw new IllegalArgumentException("maximize_evaluation_metrics has to be " - + "specified when custom_eval is set") - } - val eval_metric = params("eval_metric").toString - val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric - logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize) - params + ("maximize_evaluation_metrics" -> maximize) - } else { - params - } + val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](round)) - val booster = SXGBoost.train(watches.toMap("train"), overridedParams, round, + val booster = SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, round, watches.toMap, metrics, obj, eval, earlyStoppingRound = numEarlyStoppingRounds, prevBooster) Iterator(booster -> watches.toMap.keys.zip(metrics).toMap) @@ -188,22 +341,6 @@ object XGBoost extends Serializable { } } - private def overrideParamsAccordingToTaskCPUs( - params: Map[String, Any], - sc: SparkContext): Map[String, Any] = { - val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1) - var overridedParams = params - if (overridedParams.contains("nthread")) { - val nThread = overridedParams("nthread").toString.toInt - require(nThread <= coresPerTask, - s"the nthread configuration ($nThread) must be no larger than " + - s"spark.task.cpus ($coresPerTask)") - } else { - overridedParams = params + ("nthread" -> coresPerTask) - } - overridedParams - } - private def startTracker(nWorkers: Int, trackerConf: TrackerConf): IRabitTracker = { val tracker: IRabitTracker = trackerConf.trackerImpl match { case "scala" => new RabitTracker(nWorkers) @@ -261,138 +398,64 @@ object XGBoost extends Serializable { } } - /** - * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). - * If so, throw an exception unless this safety measure has been explicitly overridden - * via conf `xgboost.spark.ignoreSsl`. - * - * @param sc SparkContext for the training dataset. When looking for the confs, this method - * first checks for an active SparkSession. If one is not available, it falls back - * to this SparkContext. - */ - private def validateSparkSslConf(sc: SparkContext): Unit = { - val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) = - SparkSession.getActiveSession match { - case Some(ss) => - (ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean, - ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean) - case None => - (sc.getConf.getBoolean("spark.ssl.enabled", false), - sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false)) - } - if (sparkSslEnabled) { - if (xgboostSparkIgnoreSsl) { - logger.warn(s"spark-xgboost is being run without encrypting data in transit! " + - s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.") - } else { - throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " + - "in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " + - "To override this protection and still use xgboost-spark at your own risk, " + - "you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.") - } - } - } - - private def parameterFetchAndValidation(params: Map[String, Any], sparkContext: SparkContext) = { - val nWorkers = params("num_workers").asInstanceOf[Int] - val round = params("num_round").asInstanceOf[Int] - val useExternalMemory = params("use_external_memory").asInstanceOf[Boolean] - val obj = params.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] - val eval = params.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] - val missing = params.getOrElse("missing", Float.NaN).asInstanceOf[Float] - validateSparkSslConf(sparkContext) - - if (params.contains("tree_method")) { - require(params("tree_method") == "hist" || - params("tree_method") == "approx" || - params("tree_method") == "auto", "xgboost4j-spark only supports tree_method as 'hist'," + - " 'approx' and 'auto'") - } - if (params.contains("train_test_ratio")) { - logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + - " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + - "'eval_set_names'") - } - require(nWorkers > 0, "you must specify more than 0 workers") - if (obj != null) { - require(params.get("objective_type").isDefined, "parameter \"objective_type\" is not" + - " defined, you have to specify the objective type as classification or regression" + - " with a customized objective function") - } - val trackerConf = params.get("tracker_conf") match { - case None => TrackerConf() - case Some(conf: TrackerConf) => conf - case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + - "instance of TrackerConf.") - } - val timeoutRequestWorkers: Long = params.get("timeout_request_workers") match { - case None => 0L - case Some(interval: Long) => interval - case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" + - " an instance of Long.") - } - val checkpointParam = - CheckpointManager.extractParams(params) - (nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers, - checkpointParam.checkpointPath, checkpointParam.checkpointInterval, - checkpointParam.skipCleanCheckpoint) - } - private def trainForNonRanking( trainingData: RDD[XGBLabeledPoint], - params: Map[String, Any], + xgbExecutionParams: XGBoostExecutionParams, rabitEnv: java.util.Map[String, String], checkpointRound: Int, prevBooster: Booster, evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { - val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) = - parameterFetchAndValidation(params, trainingData.sparkContext) if (evalSetsMap.isEmpty) { trainingData.mapPartitions(labeledPoints => { - val watches = Watches.buildWatches(params, - processMissingValues(labeledPoints, missing), - getCacheDirName(useExternalMemory)) - buildDistributedBooster(watches, params, rabitEnv, checkpointRound, - obj, eval, prevBooster) + val watches = Watches.buildWatches(xgbExecutionParams, + processMissingValues(labeledPoints, xgbExecutionParams.missing), + getCacheDirName(xgbExecutionParams.useExternalMemory)) + buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, + xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) }).cache() } else { - coPartitionNoGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions { - nameAndLabeledPointSets => - val watches = Watches.buildWatches( - nameAndLabeledPointSets.map { - case (name, iter) => (name, processMissingValues(iter, missing))}, - getCacheDirName(useExternalMemory)) - buildDistributedBooster(watches, params, rabitEnv, checkpointRound, - obj, eval, prevBooster) - }.cache() + coPartitionNoGroupSets(trainingData, evalSetsMap, xgbExecutionParams.numWorkers). + mapPartitions { + nameAndLabeledPointSets => + val watches = Watches.buildWatches( + nameAndLabeledPointSets.map { + case (name, iter) => (name, processMissingValues(iter, + xgbExecutionParams.missing)) + }, + getCacheDirName(xgbExecutionParams.useExternalMemory)) + buildDistributedBooster(watches, xgbExecutionParams, rabitEnv, checkpointRound, + xgbExecutionParams.obj, xgbExecutionParams.eval, prevBooster) + }.cache() } } private def trainForRanking( trainingData: RDD[Array[XGBLabeledPoint]], - params: Map[String, Any], + xgbExecutionParam: XGBoostExecutionParams, rabitEnv: java.util.Map[String, String], checkpointRound: Int, prevBooster: Booster, evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = { - val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) = - parameterFetchAndValidation(params, trainingData.sparkContext) if (evalSetsMap.isEmpty) { trainingData.mapPartitions(labeledPointGroups => { - val watches = Watches.buildWatchesWithGroup(params, - processMissingValuesWithGroup(labeledPointGroups, missing), - getCacheDirName(useExternalMemory)) - buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, prevBooster) + val watches = Watches.buildWatchesWithGroup(xgbExecutionParam, + processMissingValuesWithGroup(labeledPointGroups, xgbExecutionParam.missing), + getCacheDirName(xgbExecutionParam.useExternalMemory)) + buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, + xgbExecutionParam.obj, xgbExecutionParam.eval, prevBooster) }).cache() } else { - coPartitionGroupSets(trainingData, evalSetsMap, nWorkers).mapPartitions( + coPartitionGroupSets(trainingData, evalSetsMap, xgbExecutionParam.numWorkers).mapPartitions( labeledPointGroupSets => { val watches = Watches.buildWatchesWithGroup( labeledPointGroupSets.map { - case (name, iter) => (name, processMissingValuesWithGroup(iter, missing)) + case (name, iter) => (name, processMissingValuesWithGroup(iter, + xgbExecutionParam.missing)) }, - getCacheDirName(useExternalMemory)) - buildDistributedBooster(watches, params, rabitEnv, checkpointRound, obj, eval, + getCacheDirName(xgbExecutionParam.useExternalMemory)) + buildDistributedBooster(watches, xgbExecutionParam, rabitEnv, checkpointRound, + xgbExecutionParam.obj, + xgbExecutionParam.eval, prevBooster) }).cache() } @@ -428,31 +491,32 @@ object XGBoost extends Serializable { evalSetsMap: Map[String, RDD[XGBLabeledPoint]] = Map()): (Booster, Map[String, Array[Float]]) = { logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") - val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers, - checkpointPath, checkpointInterval, skipCleanCheckpoint) = - parameterFetchAndValidation(params, - trainingData.sparkContext) + val xgbExecParams = new XGBoostExecutionParamsFactory(params, trainingData.sparkContext). + buildXGBRuntimeParams val sc = trainingData.sparkContext - val checkpointManager = new CheckpointManager(sc, checkpointPath) - checkpointManager.cleanUpHigherVersions(round.asInstanceOf[Int]) - val transformedTrainingData = composeInputData(trainingData, - params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], hasGroup, nWorkers) + val checkpointManager = new CheckpointManager(sc, xgbExecParams.checkpointParam. + checkpointPath) + checkpointManager.cleanUpHigherVersions(xgbExecParams.round) + val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet, + hasGroup, xgbExecParams.numWorkers) var prevBooster = checkpointManager.loadCheckpointAsBooster try { // Train for every ${savingRound} rounds and save the partially completed booster - val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map { + val producedBooster = checkpointManager.getCheckpointRounds( + xgbExecParams.checkpointParam.checkpointInterval, + xgbExecParams.round).map { checkpointRound: Int => - val tracker = startTracker(nWorkers, trackerConf) + val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) try { - val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc) - val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, - nWorkers) + val parallelismTracker = new SparkParallelismTracker(sc, + xgbExecParams.timeoutRequestWorkers, + xgbExecParams.numWorkers) val rabitEnv = tracker.getWorkerEnvs val boostersAndMetrics = if (hasGroup) { - trainForRanking(transformedTrainingData.left.get, overriddenParams, rabitEnv, + trainForRanking(transformedTrainingData.left.get, xgbExecParams, rabitEnv, checkpointRound, prevBooster, evalSetsMap) } else { - trainForNonRanking(transformedTrainingData.right.get, overriddenParams, rabitEnv, + trainForNonRanking(transformedTrainingData.right.get, xgbExecParams, rabitEnv, checkpointRound, prevBooster, evalSetsMap) } val sparkJobThread = new Thread() { @@ -467,7 +531,7 @@ object XGBoost extends Serializable { logger.info(s"Rabit returns with exit code $trackerReturnVal") val (booster, metrics) = postTrackerReturnProcessing(trackerReturnVal, boostersAndMetrics, sparkJobThread) - if (checkpointRound < round) { + if (checkpointRound < xgbExecParams.round) { prevBooster = booster checkpointManager.updateCheckpoint(prevBooster) } @@ -477,7 +541,7 @@ object XGBoost extends Serializable { } }.last // we should delete the checkpoint directory after a successful training - if (!skipCleanCheckpoint) { + if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) { checkpointManager.cleanPath() } producedBooster @@ -488,8 +552,7 @@ object XGBoost extends Serializable { trainingData.sparkContext.stop() throw t } finally { - uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean], - transformedTrainingData) + uncacheTrainingData(xgbExecParams.cacheTrainingSet, transformedTrainingData) } } @@ -673,11 +736,11 @@ private object Watches { } def buildWatches( - params: Map[String, Any], + xgbExecutionParams: XGBoostExecutionParams, labeledPoints: Iterator[XGBLabeledPoint], cacheDirName: Option[String]): Watches = { - val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0) - val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) + val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio + val seed = xgbExecutionParams.xgbInputParams.seed val r = new Random(seed) val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint] val trainBaseMargins = new mutable.ArrayBuilder.ofFloat @@ -743,11 +806,11 @@ private object Watches { } def buildWatchesWithGroup( - params: Map[String, Any], + xgbExecutionParams: XGBoostExecutionParams, labeledPointGroups: Iterator[Array[XGBLabeledPoint]], cacheDirName: Option[String]): Watches = { - val trainTestRatio = params.get("train_test_ratio").map(_.toString.toDouble).getOrElse(1.0) - val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime()) + val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio + val seed = xgbExecutionParams.xgbInputParams.seed val r = new Random(seed) val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint] val trainBaseMargins = new mutable.ArrayBuilder.ofFloat