Skip to content

Commit

Permalink
Support missing value and add tests. (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Jun 19, 2024
1 parent 9470497 commit a563bb0
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters.iterableAsScalaIterableConverter
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
import org.apache.spark.ml.xgboost.SparkUtils
Expand Down Expand Up @@ -170,7 +170,7 @@ private[spark] abstract class XGBoostEstimator[
* @param dataset
* @return
*/
private def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {
private[spark] def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {

// Columns to be selected for XGBoost training
val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
Expand Down Expand Up @@ -198,34 +198,66 @@ private[spark] abstract class XGBoostEstimator[
(input, columnIndices)
}

private def toXGBLabeledPoint(dataset: Dataset[_],
columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = {
dataset.rdd.map {
case row: Row =>
val label = row.getFloat(columnIndexes.labelId)
val features = row.getAs[Vector](columnIndexes.featureId.get)
val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
val group = columnIndexes.groupId.map(row.getFloat).getOrElse(-1.0f)

// TODO support sparse vector.
// TODO support array
val values = features.toArray.map(_.toFloat)
XGBLabeledPoint(label, values.length, null, values, weight, group.toInt, baseMargin)
/** visible for testing */
private[spark] def toXGBLabeledPoint(dataset: Dataset[_],
columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = {
val missing = getMissing
dataset.toDF().rdd.mapPartitions { input: Iterator[Row] =>

def isMissing(values: Array[Double]): Boolean = {
if (missing.isNaN) {
values.exists(_.toFloat.isNaN)
} else {
values.exists(_.toFloat == missing)
}
}

new Iterator[XGBLabeledPoint] {
private var tmp: Option[XGBLabeledPoint] = None

override def hasNext: Boolean = {
if (tmp.isDefined) {
return true
}
while (input.hasNext) {
val row = input.next()
val features = row.getAs[Vector](columnIndexes.featureId.get)
if (!isMissing(features.toArray)) {
val label = row.getFloat(columnIndexes.labelId)
val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
val group = columnIndexes.groupId.map(row.getFloat).getOrElse(-1.0f)
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
tmp = Some(XGBLabeledPoint(label, size, indices, values, weight,
group.toInt, baseMargin))
return true
}
}
false
}

override def next(): XGBLabeledPoint = {
val xgbLabeledPoint = tmp.get
tmp = None
xgbLabeledPoint
}
}
}
}

/**
* Convert the dataframe to RDD
* Convert the dataframe to RDD, visible to testing
*
* @param dataset
* @param columnsOrder the order of columns including weight/group/base margin ...
* @return RDD
*/
def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = {
private[spark] def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = {
val trainRDD = toXGBLabeledPoint(dataset, columnIndices)

val x = getEvalDataset()
getEvalDataset().map { eval =>
val (evalDf, _) = preprocess(eval)
val evalRDD = toXGBLabeledPoint(evalDf, columnIndices)
Expand Down Expand Up @@ -310,10 +342,9 @@ private[spark] abstract class XGBoostEstimator[

val taskCpus = dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", 1)
if (isDefined(nthread)) {
if (getNthread > taskCpus) {
logger.warn("nthread must be smaller or equal to spark.task.cpus.")
setNthread(taskCpus)
}
require(getNthread <= taskCpus,
s"the nthread configuration ($getNthread) must be no larger than " +
s"spark.task.cpus ($taskCpus)")
} else {
setNthread(taskCpus)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe

def setInferBatchSize(value: Int): T = set(inferBatchSize, value).asInstanceOf[T]

def setMissing(value: Float): T = set(missing, value).asInstanceOf[T]

def setRabitTrackerTimeout(value: Int): T = set(rabitTrackerTimeout, value).asInstanceOf[T]

def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark

import org.scalatest.funsuite.AnyFunSuite

import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}

class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {

test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
same thread pool spawned by the local mode of Spark. As these tests simulate worker crashes
by throwing exceptions, the crashed worker thread never calls Rabit.shutdown, and therefore
corrupts the internal state of the native Rabit C++ code. Calling Rabit.init() in subsequent
tests on a reentrant thread will crash the entire Spark application, an undesired side-effect
that should be avoided.
*/
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()

val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.getWorkerArgs

val workerCount: Int = numWorkers
/*
Simulate worker crash events by creating dummy Rabit workers, and throw exceptions in the
last created worker. A cascading event chain will be triggered once the RuntimeException is
thrown: the thread running the dummy spark job (sparkThread) catches the exception and
delegates it to the UnCaughtExceptionHandler, which is the Rabit tracker itself.
To prevent unit tests from crashing, deterministic delays were introduced to make sure that
the exception is thrown at last, ideally after all worker connections have been established.
*/
val dummyTasks = rdd.mapPartitions { iter =>
Communicator.init(trackerEnvs)
val index = iter.next()
Thread.sleep(100 + index * 10)
if (index == workerCount) {
// kill the worker by throwing an exception
throw new RuntimeException("Worker exception.")
}
Communicator.shutdown()
Iterator(index)
}.cache()

val sparkThread = new Thread() {
override def run(): Unit = {
// forces a Spark job.
dummyTasks.foreachPartition(() => _)
}
}

sparkThread.setUncaughtExceptionHandler(tracker)
sparkThread.start()
}

test("Communicator allreduce works.") {
val rdd = sc.parallelize(1 to numWorkers, numWorkers).cache()
val tracker = new RabitTracker(numWorkers)
tracker.start()
val trackerEnvs = tracker.getWorkerArgs

val workerCount: Int = numWorkers

rdd.mapPartitions { iter =>
val index = iter.next()
Communicator.init(trackerEnvs)
val a = Array(1.0f, 2.0f, 3.0f)
System.out.println(a.mkString(", "))
val b = Communicator.allReduce(a, Communicator.OpType.SUM)
for (i <- 0 to 2) {
assert(a(i) * workerCount == b(i))
}
val c = Communicator.allReduce(a, Communicator.OpType.MIN);
for (i <- 0 to 2) {
assert(a(i) == c(i))
}
Communicator.shutdown()
Iterator(index)
}.collect()
}

test("should allow the dataframe containing communicator calls to be partially evaluated for" +
" multiple times (ISSUE-4406)") {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
val trainingDF = smallBinaryClassificationVector
val model = new XGBoostClassifier(paramMap)
.setNumWorkers(numWorkers)
.setNumRound(10)
.fit(trainingDF)
val prediction = model.transform(trainingDF)
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
// threads
prediction.show()
// a full evaluation here will re-run init and shutdown all rabit proxy
// expecting no error
prediction.collect()
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -20,14 +20,15 @@ import java.io.{File, FileInputStream}

import org.apache.commons.io.IOUtils
import org.apache.spark.SparkContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

trait PerTest extends BeforeAndAfterEach {
self: AnyFunSuite =>

protected val numWorkers: Int = 1
protected val numWorkers: Int = 4

@transient private var currentSession: SparkSession = _

Expand Down Expand Up @@ -84,4 +85,33 @@ trait PerTest extends BeforeAndAfterEach {
r.close()
}
}

def smallBinaryClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
(0.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
(1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
(0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
(1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)),
))).toDF("label", "margin", "weight", "features")

def smallMultiClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
(2.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
(1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
(0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
(2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)),
))).toDF("label", "margin", "weight", "features")


def smallGroupVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
(1.0, 0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
(0.0, 1, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
(2.0, 1, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
(1.0, 0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
(0.0, 2, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
(2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7)),
))).toDF("label", "group", "margin", "weight", "features")

}
Loading

0 comments on commit a563bb0

Please sign in to comment.