Skip to content

Commit

Permalink
[jvm-packages] remove default parameters (#7938)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored May 28, 2022
1 parent 47224dd commit fbc3d86
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
def buildXGBRuntimeParams: XGBoostExecutionParams = {
val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
val round = overridedParams("num_round").asInstanceOf[Int]
val useExternalMemory = overridedParams("use_external_memory").asInstanceOf[Boolean]
val useExternalMemory = overridedParams
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014,2021 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -261,6 +261,7 @@ private[spark] trait BoosterParams extends Params {

final val treeLimit = new IntParam(this, name = "treeLimit",
doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
setDefault(treeLimit, 0)

final def getTreeLimit: Int = $(treeLimit)

Expand All @@ -280,13 +281,6 @@ private[spark] trait BoosterParams extends Params {

final def getInteractionConstraints: String = $(interactionConstraints)

setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0,
growPolicy -> "depthwise", maxBins -> 256,
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
}

private[scala] object BoosterParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ private[spark] trait GeneralParams extends Params {
*/
final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
ParamValidators.gtEq(1))
setDefault(numRound, 1)

final def getNumRound: Int = $(numRound)

Expand All @@ -37,6 +38,7 @@ private[spark] trait GeneralParams extends Params {
*/
final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
ParamValidators.gtEq(1))
setDefault(numWorkers, 1)

final def getNumWorkers: Int = $(numWorkers)

Expand All @@ -45,6 +47,7 @@ private[spark] trait GeneralParams extends Params {
*/
final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
ParamValidators.gtEq(1))
setDefault(nthread, 1)

final def getNthread: Int = $(nthread)

Expand All @@ -53,6 +56,7 @@ private[spark] trait GeneralParams extends Params {
*/
final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
"whether to use external memory as cache")
setDefault(useExternalMemory, false)

final def getUseExternalMemory: Boolean = $(useExternalMemory)

Expand Down Expand Up @@ -94,6 +98,7 @@ private[spark] trait GeneralParams extends Params {
* the value treated as missing. default: Float.NaN
*/
final val missing = new FloatParam(this, "missing", "the value treated as missing")
setDefault(missing, Float.NaN)

final def getMissing: Float = $(missing)

Expand All @@ -109,6 +114,7 @@ private[spark] trait GeneralParams extends Params {
"not use Spark's VectorAssembler class to construct the feature vector " +
"but instead used a method that preserves zeros in your vector."
)
setDefault(allowNonZeroForMissing, false)

final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)

Expand Down Expand Up @@ -163,18 +169,14 @@ private[spark] trait GeneralParams extends Params {
* Ignored if the tracker implementation is "python".
*/
final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
setDefault(trackerConf, TrackerConf())

/** Random seed for the C++ part of XGBoost and train/test splitting. */
final val seed = new LongParam(this, "seed", "random seed")
setDefault(seed, 0L)

final def getSeed: Long = $(seed)

setDefault(numRound -> 1, numWorkers -> 1, nthread -> 1,
useExternalMemory -> false, silent -> 0, verbosity -> 1,
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf(), seed -> 0,
checkpointPath -> "", checkpointInterval -> -1,
allowNonZeroForMissing -> false)
}

trait HasLeafPredictionCol extends Params {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -26,7 +26,7 @@ private[spark] trait InferenceParams extends Params {
final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")

/** @group getParam */
final def getInferBatchSize: Int = ${inferBatchSize}
final def getInferBatchSize: Int = $(inferBatchSize)

setDefault(inferBatchSize, 32 << 10)
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ private[spark] trait LearningTaskParams extends Params {
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
"fraction of training points to use for testing",
ParamValidators.inRange(0, 1))
setDefault(trainTestRatio, 1.0)

final def getTrainTestRatio: Double = $(trainTestRatio)

Expand Down Expand Up @@ -105,8 +106,6 @@ private[spark] trait LearningTaskParams extends Params {

final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)

setDefault(baseScore -> 0.5, trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0, cacheTrainingSet -> false)
}

private[spark] object LearningTaskParams {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014 by Contributors
Copyright (c) 2014-2022 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,12 +29,14 @@ private[spark] trait RabitParams extends Params {
final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
"threshold count to enable allreduce/broadcast with ring based topology",
ParamValidators.gtEq(1))
setDefault(rabitRingReduceThreshold, (32 << 10))

final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
"timeout threshold after rabit observed failures")
setDefault(rabitTimeout, -1)

final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
"number of retry worker do before fail", ParamValidators.gtEq(1))
setDefault(rabitConnectRetry, 5)

setDefault(rabitRingReduceThreshold -> (32 << 10), rabitConnectRetry -> 5, rabitTimeout -> -1)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {

def needDeterministicRepartitioning: Boolean = {
getCheckpointPath != null && getCheckpointPath.nonEmpty && getCheckpointInterval > 0
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
isDefined(checkpointInterval) && getCheckpointInterval > 0
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,11 @@ class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {

new XGBoostClassifier(paramMap).fit(trainingDF)
}

test("Default parameters") {
val classifier = new XGBoostClassifier()
intercept[NoSuchElementException] {
classifier.getBaseScore
}
}
}

0 comments on commit fbc3d86

Please sign in to comment.