Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Sep 11, 2024
1 parent 0feb48f commit b8d0deb
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class XGBoostClassifier(override val uid: String,
private var numberClasses = 0

private def validateObjective(dataset: Dataset[_]): Unit = {
// If the objective is set explicitly, it must be in binaryClassificationObjs and
// multiClassificationObjs
// If the objective is set explicitly, it must be in BINARY_CLASSIFICATION_OBJS and
// MULTICLASSIFICATION_OBJS
val obj = if (isSet(objective)) {
val tmpObj = getObjective
val supportedObjs = BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
val predContrib = addToSchema(contribPredictionCol)

var predRaw = false
// For classification case, the tranformed col is probability,
// For classification case, the transformed col is probability,
// while for others, it's the prediction value.
var predTmp = false
this match {
Expand All @@ -503,7 +503,7 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
if (isDefinedNonEmpty(predictionCol)) {
// Let's use transformed col to calculate the prediction
if (!predTmp) {
// Add the transformed col for predition
// Add the transformed col for prediction
schema = schema.add(
StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType)))
predTmp = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ class XGBoostRegressor(override val uid: String,
override protected[spark] def validate(dataset: Dataset[_]): Unit = {
super.validate(dataset)

// If the objective is set explicitly, it must be in binaryClassificationObjs and
// multiClassificationObjs
// If the objective is set explicitly, it must be in REGRESSION_OBJS
if (isSet(objective)) {
val tmpObj = getObjective
require(REGRESSION_OBJS.contains(tmpObj),
Expand Down

0 comments on commit b8d0deb

Please sign in to comment.