Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
Signed-off-by: Zach Puller <zpuller@nvidia.com>
  • Loading branch information
zpuller committed Sep 24, 2024
1 parent b6a9fd7 commit ebb0803
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ private final class GpuSemaphore() extends Logging {
def completeTask(context: TaskContext): Unit = {
val taskAttemptId = context.taskAttemptId()
GpuTaskMetrics.get.updateRetry(taskAttemptId)
GpuTaskMetrics.get.updateMaxMemory(taskAttemptId)
GpuTaskMetrics.get.updateMaxGpuMemory(taskAttemptId)
val refs = tasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ class NanoSecondAccumulator extends AccumulatorV2[jl.Long, NanoTime] {
override def value: NanoTime = NanoTime(_sum)
}

class WatermarkAccumulator extends AccumulatorV2[jl.Long, Long] {
class HighWatermarkAccumulator extends AccumulatorV2[jl.Long, Long] {
private var _value = 0L
override def isZero: Boolean = _value == 0

override def copy(): WatermarkAccumulator = {
val newAcc = new WatermarkAccumulator
override def copy(): HighWatermarkAccumulator = {
val newAcc = new HighWatermarkAccumulator
newAcc._value = this._value
newAcc
}
Expand All @@ -97,7 +97,7 @@ class WatermarkAccumulator extends AccumulatorV2[jl.Long, Long] {
}

override def merge(other: AccumulatorV2[jl.Long, Long]): Unit = other match {
case wa: WatermarkAccumulator =>
case wa: HighWatermarkAccumulator =>
_value = _value.max(wa._value)
case _ =>
throw new UnsupportedOperationException(
Expand All @@ -120,7 +120,7 @@ class GpuTaskMetrics extends Serializable {
private val readSpillFromHostTimeNs = new NanoSecondAccumulator
private val readSpillFromDiskTimeNs = new NanoSecondAccumulator

private val maxDeviceMemoryBytes = new WatermarkAccumulator
private val maxDeviceMemoryBytes = new HighWatermarkAccumulator

private val metrics = Map[String, AccumulatorV2[_, _]](
"gpuSemaphoreWait" -> semWaitTimeNs,
Expand Down Expand Up @@ -211,7 +211,7 @@ class GpuTaskMetrics extends Serializable {
}
}

def updateMaxMemory(taskAttemptId: Long): Unit = {
def updateMaxGpuMemory(taskAttemptId: Long): Unit = {
val maxMem = RmmSpark.getAndResetGpuMaxMemoryAllocated(taskAttemptId)
if (maxMem > 0) {
maxDeviceMemoryBytes.add(maxMem)
Expand Down

0 comments on commit ebb0803

Please sign in to comment.