Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Jun 26, 2024
1 parent 468b812 commit e3d61cd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ private[spark] def run(spark: SparkSession, inputPath: String,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> numWorkers,
"device" -> device,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
"device" -> device)
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
.setNumWorkers(numWorkers)
.setNumRound(10)
.setEvalDataset(eval1)
val xgbClassificationModel = xgbClassifier.fit(train)
xgbClassificationModel.transform(test)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,22 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu

test("RuntimeParameter") {
var runtimeParams = new XGBoostClassifier(
Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1))
Map("device" -> "cpu"))
.getRuntimeParameters(true)
assert(!runtimeParams.runOnGpu)

runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1))
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)

runtimeParams = new XGBoostClassifier(
Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1))
Map("device" -> "cpu", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)

runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda", "tree_method" -> "gpu_hist",
"num_workers" -> 1, "num_round" -> 1))
Map("device" -> "cuda", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
val rdd = df.rdd

val runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1))
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)

Expand Down

0 comments on commit e3d61cd

Please sign in to comment.