From b6346cc28942fed7200d06b931f9f4e7ab294e56 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Sat, 5 Jan 2019 16:14:34 -0800 Subject: [PATCH] fix safe execution --- .../apache/spark/SparkParallelismTracker.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala index 10c6167ae12c..403ea73f330a 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/SparkParallelismTracker.scala @@ -20,7 +20,7 @@ import java.net.URL import org.apache.commons.logging.LogFactory -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorRemoved, SparkListenerTaskEnd} import org.codehaus.jackson.map.ObjectMapper import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext.Implicits.global @@ -98,9 +98,11 @@ class SparkParallelismTracker( */ def execute[T](body: => T): T = { if (timeout <= 0) { + logger.info("starting training without setting timeout for waiting for resources") body } else { try { + logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") waitForCondition(numAliveCores >= requestedCores, timeout) } catch { case _: TimeoutException => @@ -119,9 +121,14 @@ private class ErrorInXGBoostTraining(msg: String) extends ControlThrowable { private[spark] class TaskFailedListener extends SparkListener { override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { taskEnd.reason match { - case reason: TaskFailedReason => - throw new ErrorInXGBoostTraining(s"ExecutorLost during XGBoost Training: " + - s"${reason.toErrorString}") + case taskEnd: SparkListenerTaskEnd => + if (taskEnd.reason.isInstanceOf[TaskFailedReason]) { + throw new ErrorInXGBoostTraining(s"TaskFailed during XGBoost Training: " + + s"${taskEnd.reason}") + } + case executorRemoved: SparkListenerExecutorRemoved => + throw new ErrorInXGBoostTraining(s"Executor lost during XGBoost Training: " + + s"${executorRemoved.reason}") case _ => } }