Skip to content

Commit

Permalink
[jvm-packages] delete all constraints from spark layer about obj and …
Browse files Browse the repository at this point in the history
…eval metrics and handle error in jvm layer (#4560)

* temp

* prediction part

* remove supported*

* add for test

* fix param name

* add rabit

* update rabit

* return value of rabit init

* eliminate compilation warnings

* update rabit

* shutdown

* update rabit again

* check sparkcontext shutdown

* fix logic

* sleep

* fix tests

* test with relaxed threshold

* create new thread each time

* stop for job quitting

* udpate rabit

* update rabit

* update rabit

* update git modules
  • Loading branch information
CodingCat authored Jun 27, 2019
1 parent dd01f7c commit abffbe0
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext, TaskFailedListener}
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -153,9 +153,11 @@ object XGBoost extends Serializable {
}
val taskId = TaskContext.getPartitionId().toString
rabitEnv.put("DMLC_TASK_ID", taskId)
Rabit.init(rabitEnv)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")

try {
Rabit.init(rabitEnv)

val numEarlyStoppingRounds = params.get("num_early_stopping_rounds")
.map(_.toString.toInt).getOrElse(0)
val overridedParams = if (numEarlyStoppingRounds > 0 &&
Expand All @@ -176,6 +178,10 @@ object XGBoost extends Serializable {
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
} catch {
case xgbException: XGBoostError =>
logger.error(s"XGBooster worker $taskId has failed due to ", xgbException)
throw xgbException
} finally {
Rabit.shutdown()
watches.delete()
Expand Down Expand Up @@ -467,6 +473,12 @@ object XGBoost extends Serializable {
tracker.stop()
}
}.last
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
logger.error("the job was aborted due to ", t)
trainingData.sparkContext.stop()
throw t
} finally {
uncacheTrainingData(params.getOrElse("cacheTrainingSet", false).asInstanceOf[Boolean],
transformedTrainingData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ class XGBoostClassificationModel private[ml](

private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ class XGBoostRegressionModel private[ml] (

private val batchIterImpl = rowIterator.grouped($(inferBatchSize)).flatMap { batchRow =>
if (batchCnt == 0) {
val rabitEnv = Array("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString).toMap
val rabitEnv = Array(
"DMLC_TASK_ID" -> TaskContext.getPartitionId().toString,
"DMLC_WORKER_STOP_PROCESS_ON_ERROR" -> "false").toMap
Rabit.init(rabitEnv.asJava)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ private[spark] trait LearningTaskParams extends Params {
* count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
* default: reg:squarederror
*/
final val objective = new Param[String](this, "objective", "objective function used for " +
s"training, options: {${LearningTaskParams.supportedObjective.mkString(",")}",
(value: String) => LearningTaskParams.supportedObjective.contains(value))
final val objective = new Param[String](this, "objective",
"objective function used for training")

final def getObjective: String = $(objective)

Expand Down Expand Up @@ -62,9 +61,7 @@ private[spark] trait LearningTaskParams extends Params {
*/
final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
"validation data, a default metric will be assigned according to objective " +
"(rmse for regression, and error for classification, mean average precision for ranking), " +
s"options: {${LearningTaskParams.supportedEvalMetrics.mkString(",")}}",
(value: String) => LearningTaskParams.supportedEvalMetrics.contains(value))
"(rmse for regression, and error for classification, mean average precision for ranking)")

final def getEvalMetric: String = $(evalMetric)

Expand Down Expand Up @@ -106,16 +103,11 @@ private[spark] trait LearningTaskParams extends Params {
}

private[spark] object LearningTaskParams {
val supportedObjective = HashSet("reg:linear", "reg:squarederror", "reg:logistic",
"reg:squaredlogerror", "binary:logistic", "binary:logitraw", "count:poisson", "multi:softmax",
"multi:softprob", "rank:pairwise", "rank:ndcg", "rank:map", "reg:gamma", "reg:tweedie")

val supportedObjectiveType = HashSet("regression", "classification")

val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map")

val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "logloss", "error", "merror",
"mlogloss", "gamma-deviance")

val supportedEvalMetrics = evalMetricsToMaximize union evalMetricsToMinimize
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark

import java.net.URL
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.commons.logging.LogFactory

Expand Down Expand Up @@ -123,18 +124,30 @@ private[spark] class TaskFailedListener extends SparkListener {
case taskEndReason: TaskFailedReason =>
logger.error(s"Training Task Failed during XGBoost Training: " +
s"$taskEndReason, stopping SparkContext")
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread
val sparkContextKiller = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
SparkContext.getOrCreate().stop()
}
TaskFailedListener.startedSparkContextKiller()
case _ =>
}
}
}

object TaskFailedListener {

var killerStarted = false

private def startedSparkContextKiller(): Unit = this.synchronized {
if (!killerStarted) {
// Spark does not allow ListenerThread to shutdown SparkContext so that we have to do it
// in a separate thread
val sparkContextKiller = new Thread() {
override def run(): Unit = {
LiveListenerBus.withinListenerThread.withValue(false) {
SparkContext.getOrCreate().stop()
}
}
sparkContextKiller.setDaemon(true)
sparkContextKiller.start()
case _ =>
}
sparkContextKiller.setDaemon(true)
sparkContextKiller.start()
killerStarted = true
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.java.XGBoostError
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.ml.param.ParamMap

class ParameterSuite extends FunSuite with PerTest with BeforeAndAfterAll {

test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
assert(xgb.getObjectiveType === "classification")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}

private def waitForSparkContextShutdown(): Unit = {
var totalWaitedTime = 0L
while (!ss.sparkContext.isStopped && totalWaitedTime <= 120000) {
Thread.sleep(10000)
totalWaitedTime += 10000
}
assert(ss.sparkContext.isStopped === true)
}

test("fail training elegantly with unsupported objective function") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers)
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
try {
val model = xgb.fit(trainingDF)
} catch {
case e: Throwable => // swallow anything
} finally {
waitForSparkContextShutdown()
}
}

test("fail training elegantly with unsupported eval metrics") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
"num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
try {
val model = xgb.fit(trainingDF)
} catch {
case e: Throwable => // swallow anything
} finally {
waitForSparkContextShutdown()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File

import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.spark.{SparkConf, SparkContext}

import org.apache.spark.{SparkConf, SparkContext, TaskFailedListener}
import org.apache.spark.sql._
import org.scalatest.{BeforeAndAfterEach, FunSuite}

import scala.util.Random

trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
Expand Down Expand Up @@ -50,6 +50,7 @@ trait PerTest extends BeforeAndAfterEach { self: FunSuite =>
cleanExternalCache(currentSession.sparkContext.appName)
currentSession = null
}
TaskFailedListener.killerStarted = false
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.FunSuite


class RabitSuite extends FunSuite with PerTest {
class RabitRobustnessSuite extends FunSuite with PerTest {

test("training with Scala-implemented Rabit tracker") {
val eval = new EvalError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,6 @@ class XGBoostClassifierSuite extends FunSuite with PerTest {
assert(model.summary.validationObjectiveHistory.isEmpty)
}

test("XGBoost and Spark parameters synchronize correctly") {
val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
"objective_type" -> "classification")
// from xgboost params to spark params
val xgb = new XGBoostClassifier(xgbParamMap)
assert(xgb.getEta === 1.0)
assert(xgb.getObjective === "binary:logistic")
assert(xgb.getObjectiveType === "classification")
// from spark to xgboost params
val xgbCopy = xgb.copy(ParamMap.empty)
assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
}

test("multi class classification") {
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
Expand Down
14 changes: 10 additions & 4 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,8 +818,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
argv.push_back(&args[i][0]);
}

RabitInit(args.size(), dmlc::BeginPtr(argv));
return 0;
if (RabitInit(args.size(), dmlc::BeginPtr(argv))) {
return 0;
} else {
return 1;
}
}

/*
Expand All @@ -829,8 +832,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitInit
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
RabitFinalize();
return 0;
if (RabitFinalize()) {
return 0;
} else {
return 1;
}
}

/*
Expand Down

0 comments on commit abffbe0

Please sign in to comment.