Skip to content

Commit

Permalink
[jvm-packages]support multiple validation datasets in Spark (#3910)
Browse files Browse the repository at this point in the history
* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* add back train method but mark as deprecated

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* fix scalastyle error

* wrap iterators

* enable copartition training and validationset

* add parameters

* converge code path and have init unit test

* enable multi evals for ranking

* unit test and doc

* update example

* fix early stopping

* address the offline comments

* udpate doc

* test eval metrics

* fix compilation issue

* fix example
  • Loading branch information
CodingCat authored Dec 18, 2018
1 parent c8c7b96 commit c055a32
Show file tree
Hide file tree
Showing 14 changed files with 477 additions and 136 deletions.
5 changes: 5 additions & 0 deletions doc/jvm/xgboost4j_spark_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ In additional to ``num_early_stopping_rounds``, you also need to define ``maximi

After specifying these two parameters, the training would stop when the metrics goes to the other direction against the one specified by ``maximize_evaluation_metrics`` for ``num_early_stopping_rounds`` iterations.

Training with Evaluation Sets
----------------

You can also monitor the performance of the model during training with multiple evaluation datasets. By specifying ``eval_sets`` or call ``setEvalSets`` over a XGBoostClassifier or XGBoostRegressor, you can pass in multiple evaluation datasets typed as a Map from String to DataFrame.

Prediction
==========

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object SparkTraining {
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv(args(0))
val rawInput = spark.read.schema(schema).csv(inputPath)

// transform class to index to make xgboost happy
val stringIndexer = new StringIndexer()
Expand All @@ -55,6 +55,8 @@ object SparkTraining {
val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
"classIndex")

val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))

/**
* setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
* to get 2 workers within 60000 ms
Expand All @@ -67,12 +69,13 @@ object SparkTraining {
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2)
"num_workers" -> 2,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
val xgbClassificationModel = xgbClassifier.fit(xgbInput)
val results = xgbClassificationModel.transform(xgbInput)
val xgbClassificationModel = xgbClassifier.fit(train)
val results = xgbClassificationModel.transform(test)
results.show()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.Param
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{FloatType, IntegerType}

object DataUtils extends Serializable {
private[spark] implicit class XGBLabeledPointFeatures(
Expand Down Expand Up @@ -67,4 +72,38 @@ object DataUtils extends Serializable {
XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat))
}
}

private[spark] def convertDataFrameToXGBLabeledPointRDDs(
labelCol: Column,
featuresCol: Column,
weight: Column,
baseMargin: Column,
group: Option[Column],
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
groupCol.cast(IntegerType),
baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
baseMargin.cast(FloatType)))
dataFrames.toArray.map {
df => df.select(selectedColumns: _*).rdd.map {
case Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, group, baseMargin)
case Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin)
}
}
}

}
Loading

0 comments on commit c055a32

Please sign in to comment.