From e3d61cdc7941b63cc7eb8a7271fb292a6d525938 Mon Sep 17 00:00:00 2001 From: Bobby Date: Wed, 26 Jun 2024 15:23:04 +0800 Subject: [PATCH] update --- .../xgboost4j/scala/example/spark/SparkTraining.scala | 8 ++++---- .../xgboost4j/scala/spark/XGBoostEstimatorSuite.scala | 9 ++++----- .../ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala index 173c8f432982..5be641fef773 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala @@ -80,13 +80,13 @@ private[spark] def run(spark: SparkSession, inputPath: String, "max_depth" -> 2, "objective" -> "multi:softprob", "num_class" -> 3, - "num_round" -> 100, - "num_workers" -> numWorkers, - "device" -> device, - "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) + "device" -> device) val xgbClassifier = new XGBoostClassifier(xgbParam). setFeaturesCol("features"). setLabelCol("classIndex") + .setNumWorkers(numWorkers) + .setNumRound(10) + .setEvalDataset(eval1) val xgbClassificationModel = xgbClassifier.fit(train) xgbClassificationModel.transform(test) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index 484673922d74..c6c8c7f2e73c 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -39,23 +39,22 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu test("RuntimeParameter") { var runtimeParams = new XGBoostClassifier( - Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cpu")) .getRuntimeParameters(true) assert(!runtimeParams.runOnGpu) runtimeParams = new XGBoostClassifier( - Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) runtimeParams = new XGBoostClassifier( - Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cpu", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) runtimeParams = new XGBoostClassifier( - Map("device" -> "cuda", "tree_method" -> "gpu_hist", - "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cuda", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala index fcfa53c07582..3a45cf4448c0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala @@ -108,7 +108,7 @@ class XGBoostSuite extends AnyFunSuite with PerTest { val rdd = df.rdd val runtimeParams = new XGBoostClassifier( - Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1)) + Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1) .getRuntimeParameters(true) assert(runtimeParams.runOnGpu)