-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jvm-packages] cancel job instead of killing SparkContext #6019
Changes from 4 commits
2a2b47a
f834537
8852f59
f1df786
2824d66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,18 +19,22 @@ package org.apache.spark | |
import org.apache.commons.logging.LogFactory | ||
import org.apache.spark.scheduler._ | ||
|
||
import scala.collection.mutable.{HashMap, HashSet} | ||
|
||
/** | ||
* A tracker that ensures enough number of executor cores are alive. | ||
* Throws an exception when the number of alive cores is less than nWorkers. | ||
* | ||
* @param sc The SparkContext object | ||
* @param timeout The maximum time to wait for enough number of workers. | ||
* @param numWorkers nWorkers used in an XGBoost Job | ||
* @param killSparkContext kill SparkContext or not when task fails | ||
*/ | ||
class SparkParallelismTracker( | ||
val sc: SparkContext, | ||
timeout: Long, | ||
numWorkers: Int) { | ||
numWorkers: Int, | ||
killSparkContext: Boolean = true) { | ||
|
||
private[this] val requestedCores = numWorkers * sc.conf.getInt("spark.task.cpus", 1) | ||
private[this] val logger = LogFactory.getLog("XGBoostSpark") | ||
|
@@ -58,7 +62,7 @@ class SparkParallelismTracker( | |
} | ||
|
||
private[this] def safeExecute[T](body: => T): T = { | ||
val listener = new TaskFailedListener | ||
val listener = new TaskFailedListener(killSparkContext) | ||
sc.addSparkListener(listener) | ||
try { | ||
body | ||
|
@@ -79,7 +83,7 @@ class SparkParallelismTracker( | |
def execute[T](body: => T): T = { | ||
if (timeout <= 0) { | ||
logger.info("starting training without setting timeout for waiting for resources") | ||
body | ||
safeExecute(body) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we have to change to safeExecute()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The safeExecute wraps TaskFailedListener inside. I don't know why the body was not executed in safeExecute in the previous version, Since it may hang forever if no TaskFailedListener. |
||
} else { | ||
logger.info(s"starting training with timeout set as $timeout ms for waiting for resources") | ||
if (!waitForCondition(numAliveCores >= requestedCores, timeout)) { | ||
|
@@ -90,23 +94,58 @@ class SparkParallelismTracker( | |
} | ||
} | ||
|
||
private[spark] class TaskFailedListener extends SparkListener { | ||
class TaskFailedListener(killSparkContext: Boolean = true) extends SparkListener { | ||
|
||
private[this] val logger = LogFactory.getLog("XGBoostTaskFailedListener") | ||
|
||
// {jobId, [stageId0, stageId1, ...] } | ||
// keep track of the mapping of job id and stage ids | ||
// when a task failed, find the job id and stage Id the task belongs to, finally | ||
// cancel the jobs | ||
private val jobIdToStageIds: HashMap[Int, HashSet[Int]] = HashMap.empty | ||
|
||
override def onJobStart(jobStart: SparkListenerJobStart): Unit = { | ||
if (!killSparkContext) { | ||
jobStart.stageIds.foreach(stageId => { | ||
jobIdToStageIds.getOrElseUpdate(jobStart.jobId, new HashSet[Int]()) += stageId | ||
}) | ||
} | ||
} | ||
|
||
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { | ||
if (!killSparkContext) { | ||
jobIdToStageIds.remove(jobEnd.jobId) | ||
} | ||
} | ||
|
||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { | ||
taskEnd.reason match { | ||
case taskEndReason: TaskFailedReason => | ||
logger.error(s"Training Task Failed during XGBoost Training: " + | ||
s"$taskEndReason, stopping SparkContext") | ||
TaskFailedListener.startedSparkContextKiller() | ||
s"$taskEndReason") | ||
if (killSparkContext) { | ||
logger.error("killing SparkContext") | ||
TaskFailedListener.startedSparkContextKiller() | ||
} else { | ||
val stageId = taskEnd.stageId | ||
// find job ids according to stage id and then cancel the job | ||
jobIdToStageIds.foreach(t => { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very good suggestion, really thx, Done. |
||
val jobId = t._1 | ||
val stageIds = t._2 | ||
|
||
if (stageIds.contains(stageId)) { | ||
logger.error("Cancelling jobId:" + jobId) | ||
jobIdToStageIds.remove(jobId) | ||
SparkContext.getOrCreate().cancelJob(jobId) | ||
} | ||
}) | ||
} | ||
case _ => | ||
} | ||
} | ||
} | ||
|
||
object TaskFailedListener { | ||
|
||
var killerStarted = false | ||
|
||
private def startedSparkContextKiller(): Unit = this.synchronized { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kill_spark_context_on_worker_failure?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, Thx for the naming suggestion