From d3bdf99e184e17fa384b4f7492270ad2b38fa541 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 20 May 2024 19:08:34 +0800 Subject: [PATCH] [jvm-packages] refine tracker --- jvm-packages/pom.xml | 14 ++-- .../ml/dmlc/xgboost4j/java/flink/XGBoost.java | 2 +- .../scala/rapids/spark/GpuPreXGBoost.scala | 8 +- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 76 +++++++------------ .../spark/CommunicatorRobustnessSuite.scala | 4 +- jvm-packages/xgboost4j/pom.xml | 6 ++ .../java/ml/dmlc/xgboost4j/java/ITracker.java | 17 +---- .../ml/dmlc/xgboost4j/java/RabitTracker.java | 38 ++++++---- 8 files changed, 72 insertions(+), 93 deletions(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index c5354aad7591..f4a26be1f040 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -35,16 +35,17 @@ 1.8 1.19.0 4.13.2 - 3.4.1 - 3.4.1 + 3.5.1 + 3.5.1 + 2.15.2 2.12.18 2.12 3.4.0 5 OFF OFF - 23.12.1 - 23.12.1 + 24.04.0 + 24.04.0 cuda12 3.2.18 2.12.0 @@ -489,11 +490,6 @@ kryo 5.6.0 - - com.fasterxml.jackson.core - jackson-databind - 2.14.2 - commons-logging commons-logging diff --git a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java index 99608b927489..a660bca8806c 100644 --- a/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java +++ b/jvm-packages/xgboost4j-flink/src/main/java/ml/dmlc/xgboost4j/java/flink/XGBoost.java @@ -176,7 +176,7 @@ public static XGBoostModel train(DataSet> dtrain, new RabitTracker(dtrain.getExecutionEnvironment().getParallelism()); if (tracker.start()) { return dtrain - .mapPartition(new MapFunction(params, numBoostRound, tracker.workerArgs())) + .mapPartition(new MapFunction(params, numBoostRound, tracker.getWorkerArgs())) .reduce((x, y) -> x) .collect() .get(0); diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala index 7e83dc6f17b0..7a562b91d375 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuPreXGBoost.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2022 by Contributors + Copyright (c) 2021-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,13 +23,13 @@ import ml.dmlc.xgboost4j.java.nvidia.spark.GpuColumnBatch import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, QuantileDMatrix} import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon import ml.dmlc.xgboost4j.scala.spark.{PreXGBoost, PreXGBoostProvider, Watches, XGBoost, XGBoostClassificationModel, XGBoostClassifier, XGBoostExecutionParams, XGBoostRegressionModel, XGBoostRegressor} -import org.apache.commons.logging.LogFactory +import org.apache.commons.logging.LogFactory import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.functions.{col, collect_list, struct} import org.apache.spark.sql.types.{ArrayType, FloatType, StructField, StructType} @@ -444,7 +444,7 @@ object GpuPreXGBoost extends PreXGBoostProvider { .groupBy(groupName) .agg(collect_list(struct(schema.fieldNames.map(col): _*)) as "list") - implicit val encoder = RowEncoder(schema) + implicit val encoder = ExpressionEncoder(RowEncoder.encoderFor(schema, false)) // Expand the grouped rows after repartition repartitionInputData(groupedDF, nWorkers).mapPartitions(iter => { new Iterator[Row] { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index e17c68355c5b..10c4b5a72992 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -233,24 +233,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s xgbExecParam.setRawParamMap(overridedParams) xgbExecParam } - - private[spark] def buildRabitParams : Map[String, String] = Map( - "rabit_reduce_ring_mincount" -> - overridedParams.getOrElse("rabit_ring_reduce_threshold", 32 << 10).toString, - "rabit_debug" -> - (overridedParams.getOrElse("verbosity", 0).toString.toInt == 3).toString, - "rabit_timeout" -> - (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0).toString, - "rabit_timeout_sec" -> { - if (overridedParams.getOrElse("rabit_timeout", -1).toString.toInt >= 0) { - overridedParams.get("rabit_timeout").toString - } else { - "1800" - } - }, - "DMLC_WORKER_CONNECT_RETRY" -> - overridedParams.getOrElse("dmlc_worker_connect_retry", 5).toString - ) } /** @@ -475,17 +457,15 @@ object XGBoost extends XGBoostStageLevel { } } - /** visiable for testing */ - private[scala] def getTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { - val tracker: ITracker = new RabitTracker( - nWorkers, trackerConf.hostIp, trackerConf.port, trackerConf.timeout) - tracker - } - - private def startTracker(nWorkers: Int, trackerConf: TrackerConf): ITracker = { - val tracker = getTracker(nWorkers, trackerConf) + // Executes the provided code block inside a tracker and then stops the tracker + private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = { + val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout) require(tracker.start(), "FAULT: Failed to start tracker") - tracker + try { + block(tracker) + } finally { + tracker.stop() + } } /** @@ -501,28 +481,27 @@ object XGBoost extends XGBoostStageLevel { logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}") val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc) - val xgbExecParams = xgbParamsFactory.buildXGBRuntimeParams - val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava + val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams - val prevBooster = xgbExecParams.checkpointParam.map { checkpointParam => + val prevBooster = runtimeParams.checkpointParam.map { checkpointParam => val checkpointManager = new ExternalCheckpointManager( checkpointParam.checkpointPath, FileSystem.get(sc.hadoopConfiguration)) - checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds) + checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds) checkpointManager.loadCheckpointAsScalaBooster() }.orNull // Get the training data RDD and the cachedRDD - val (trainingRDD, optionalCachedRDD) = buildTrainingData(xgbExecParams) + val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams) try { - // Train for every ${savingRound} rounds and save the partially completed booster - val tracker = startTracker(xgbExecParams.numWorkers, xgbExecParams.trackerConf) - val (booster, metrics) = try { - tracker.workerArgs().putAll(xgbRabitParams) - val rabitEnv = tracker.workerArgs + val (booster, metrics) = withTracker( + runtimeParams.numWorkers, + runtimeParams.trackerConf + ) { tracker => + val rabitEnv = tracker.getWorkerArgs() - val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => { + val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter => var optionWatches: Option[() => Watches] = None // take the first Watches to train @@ -530,26 +509,25 @@ object XGBoost extends XGBoostStageLevel { optionWatches = Some(iter.next()) } - optionWatches.map { buildWatches => buildDistributedBooster(buildWatches, - xgbExecParams, rabitEnv, xgbExecParams.obj, xgbExecParams.eval, prevBooster)} - .getOrElse(throw new RuntimeException("No Watches to train")) - - }} + optionWatches.map { buildWatches => + buildDistributedBooster(buildWatches, + runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster) + }.getOrElse(throw new RuntimeException("No Watches to train")) + } - val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, xgbExecParams, + val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams, boostersAndMetrics) // The repartition step is to make training stage as ShuffleMapStage, so that when one // of the training task fails the training stage can retry. ResultStage won't retry when // it fails. val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0) (booster, metrics) - } finally { - tracker.stop() } + // we should delete the checkpoint directory after a successful training - xgbExecParams.checkpointParam.foreach { + runtimeParams.checkpointParam.foreach { cpParam => - if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) { + if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) { val checkpointManager = new ExternalCheckpointManager( cpParam.checkpointPath, FileSystem.get(sc.hadoopConfiguration)) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala index 108053af5d76..d3f3901ad704 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala @@ -45,7 +45,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { val tracker = new RabitTracker(numWorkers) tracker.start() - val trackerEnvs = tracker. workerArgs + val trackerEnvs = tracker.getWorkerArgs val workerCount: Int = numWorkers /* @@ -84,7 +84,7 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest { val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache() val tracker = new RabitTracker(numWorkers) tracker.start() - val trackerEnvs = tracker.workerArgs + val trackerEnvs = tracker.getWorkerArgs val workerCount: Int = numWorkers diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml index 5a83a400c50b..e1e750866c28 100644 --- a/jvm-packages/xgboost4j/pom.xml +++ b/jvm-packages/xgboost4j/pom.xml @@ -53,6 +53,12 @@ ${scalatest.version} provided + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.jackson.version} + provided + diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java index 1bfef677d45c..84e535a269e2 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/ITracker.java @@ -7,7 +7,7 @@ * * - start(timeout): Start the tracker awaiting for worker connections, with a given * timeout value (in seconds). - * - workerArgs(): Return the arguments needed to initialize Rabit clients. + * - getWorkerArgs(): Return the arguments needed to initialize Rabit clients. * - waitFor(timeout): Wait for the task execution by the worker nodes for at most `timeout` * milliseconds. * @@ -21,21 +21,8 @@ * brokers connections between workers. */ public interface ITracker extends Thread.UncaughtExceptionHandler { - enum TrackerStatus { - SUCCESS(0), INTERRUPTED(1), TIMEOUT(2), FAILURE(3); - private int statusCode; - - TrackerStatus(int statusCode) { - this.statusCode = statusCode; - } - - public int getStatusCode() { - return this.statusCode; - } - } - - Map workerArgs() throws XGBoostError; + Map getWorkerArgs() throws XGBoostError; boolean start() throws XGBoostError; diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java index 914a493cc8d1..48b163a7753b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java @@ -1,3 +1,19 @@ +/* + Copyright (c) 2014-2024 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + package ml.dmlc.xgboost4j.java; import java.util.Map; @@ -10,14 +26,12 @@ /** * Java implementation of the Rabit tracker to coordinate distributed workers. - * - * The tracker must be started on driver node before running distributed jobs. */ public class RabitTracker implements ITracker { // Maybe per tracker logger? private static final Log logger = LogFactory.getLog(RabitTracker.class); private long handle = 0; - private Thread tracker_daemon; + private Thread trackerDaemon; public RabitTracker(int numWorkers) throws XGBoostError { this(numWorkers, ""); @@ -44,7 +58,7 @@ public void uncaughtException(Thread t, Throwable e) { } catch (InterruptedException ex) { logger.error(ex); } finally { - this.tracker_daemon.interrupt(); + this.trackerDaemon.interrupt(); } } @@ -52,16 +66,14 @@ public void uncaughtException(Thread t, Throwable e) { * Get environments that can be used to pass to worker. * @return The environment settings. */ - public Map workerArgs() throws XGBoostError { + public Map getWorkerArgs() throws XGBoostError { // fixme: timeout String[] args = new String[1]; XGBoostJNI.checkCall(XGBoostJNI.TrackerWorkerArgs(this.handle, 0, args)); ObjectMapper mapper = new ObjectMapper(); - TypeReference> typeRef = new TypeReference>() { - }; Map config; try { - config = mapper.readValue(args[0], typeRef); + config = mapper.readValue(args[0], new TypeReference>() {}); } catch (JsonProcessingException ex) { throw new XGBoostError("Failed to get worker arguments.", ex); } @@ -74,18 +86,18 @@ public void stop() throws XGBoostError { public boolean start() throws XGBoostError { XGBoostJNI.checkCall(XGBoostJNI.TrackerRun(this.handle)); - this.tracker_daemon = new Thread(() -> { + this.trackerDaemon = new Thread(() -> { try { - XGBoostJNI.checkCall(XGBoostJNI.TrackerWaitFor(this.handle, 0)); + waitFor(0); } catch (XGBoostError ex) { logger.error(ex); return; // exit the thread } }); - this.tracker_daemon.setDaemon(true); - this.tracker_daemon.start(); + this.trackerDaemon.setDaemon(true); + this.trackerDaemon.start(); - return this.tracker_daemon.isAlive(); + return this.trackerDaemon.isAlive(); } public void waitFor(long timeout) throws XGBoostError {