Skip to content

Commit

Permalink
[SPARK-50372][CONNECT][SQL] Make all DF execution path collect observ…
Browse files Browse the repository at this point in the history
…ed metrics

### What changes were proposed in this pull request?

This PR fixes an issue that some of DataFrame execution paths would not process `ObservedMetrics`. The fix is done by injecting a lazy processing logic into the result iterator.

The following private execution APIs are affected by this issue:

- `SparkSession.execute(proto.Relation.Builder)`
- `SparkSession.execute(proto.Command)`
- `SparkSession.execute(proto.Plan)`

The following user-facing API is affected by this issue:
- `DataFrame.write.format("...").mode("...").save()`

This PR also fixes an issue in which on the Server side, two observed metrics can be assigned to the same Plan ID when they are in the same plan (e.g., one observation is used as the input of another). The fix is to traverse the plan and find all observations with correct IDs.

Another bug is discovered as a byproduct of introducing a new test case. Copying the PR comment here from SparkConnectPlanner.scala:

> This fixes a bug where the input of a `CollectMetrics` can be processed two times, once in Line 1190 and once here/below.
>
> When the `input` contains another `CollectMetrics`, transforming it twice will cause two `Observation` objects (in the input) to be initialised and registered two times to the system. Since only one of them will be fulfilled when the query finishes, the one we'll be looking at may not have any data.
>
> This issue is highlighted in the test case `Observation.get is blocked until the query is finished ...`, where we specifically execute `observedObservedDf`, which is a `CollectMetrics` that has another `CollectMetrics` as its input.

### Why are the changes needed?

To fix a bug.

### Does this PR introduce _any_ user-facing change?

Yes, this bug is user-facing.

### How was this patch tested?

New tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48920 from xupefei/observation-notify-fix.

Authored-by: Paddy Xu <xupaddy@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
xupefei authored and HyukjinKwon committed Dec 2, 2024
1 parent 6c84f15 commit d0e2c06
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalog.Catalog
Expand Down Expand Up @@ -371,13 +372,8 @@ class SparkSession private[sql] (
private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)

private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
new SparkResult(
value,
allocator,
encoder,
timeZoneId,
Some(setMetricsAndUnregisterObservation))
val value = executeInternal(plan)
new SparkResult(value, allocator, encoder, timeZoneId)
}

private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
Expand All @@ -386,7 +382,7 @@ class SparkSession private[sql] (
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
// .foreach forces that the iterator is consumed and closed
client.execute(plan).foreach(_ => ())
executeInternal(plan).foreach(_ => ())
}

@Since("4.0.0")
Expand All @@ -395,11 +391,26 @@ class SparkSession private[sql] (
val plan = proto.Plan.newBuilder().setCommand(command).build()
// .toSeq forces that the iterator is consumed and closed. On top, ignore all
// progress messages.
client.execute(plan).filter(!_.hasExecutionProgress).toSeq
executeInternal(plan).filter(!_.hasExecutionProgress).toSeq
}

private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] =
client.execute(plan)
/**
* The real `execute` method that calls into `SparkConnectClient`.
*
* Here we inject a lazy map to process registered observed metrics, so consumers of the
* returned iterator does not need to worry about it.
*
* Please make sure all `execute` methods call this method.
*/
private[sql] def executeInternal(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] = {
client
.execute(plan)
.map { response =>
// Note, this map() is lazy.
processRegisteredObservedMetrics(response.getObservedMetricsList)
response
}
}

private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
Expand Down Expand Up @@ -541,10 +552,14 @@ class SparkSession private[sql] (
observationRegistry.putIfAbsent(planId, observation)
}

private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: Row): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(metrics)
private def processRegisteredObservedMetrics(metrics: java.util.List[ObservedMetrics]): Unit = {
metrics.asScala.map { metric =>
// Here we only process metrics that belong to a registered Observation object.
// All metrics, whether registered or not, will be collected by `SparkResult`.
val observationOrNull = observationRegistry.remove(metric.getPlanId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(SparkResult.transformObservedMetrics(metric))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1536,28 +1536,49 @@ class ClientE2ETestSuite
val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema))
val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema))

val obMetrics = observedDf.collectResult().getObservedMetrics
assert(df.collectResult().getObservedMetrics === Map.empty)
assert(observedDf.collectResult().getObservedMetrics === ob1Metrics)
assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics)
}

test("Observation.get is blocked until the query is finished") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val observation = new Observation("ob1")
val observedDf = df.observe(observation, min("id"), avg("id"), max("id"))

// Start a new thread to get the observation
val future = Future(observation.get)(ExecutionContext.global)
// make sure the thread is blocked right now
val e = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future, 2.seconds)
assert(obMetrics.map(_._2.schema) === Seq(ob1Schema))

val obObMetrics = observedObservedDf.collectResult().getObservedMetrics
assert(obObMetrics === ob1Metrics ++ ob2Metrics)
assert(obObMetrics.map(_._2.schema).exists(_.equals(ob1Schema)))
assert(obObMetrics.map(_._2.schema).exists(_.equals(ob2Schema)))
}

for (collectFunc <- Seq(
("collect", (df: DataFrame) => df.collect()),
("collectAsList", (df: DataFrame) => df.collectAsList()),
("collectResult", (df: DataFrame) => df.collectResult().length),
("write", (df: DataFrame) => df.write.format("noop").mode("append").save())))
test(
"Observation.get is blocked until the query is finished, " +
s"collect using method ${collectFunc._1}") {
val df = spark.range(99).withColumn("extra", col("id") - 1)
val ob1 = new Observation("ob1")
val ob2 = new Observation("ob2")
val observedDf = df.observe(ob1, min("id"), avg("id"), max("id"))
val observedObservedDf = observedDf.observe(ob2, min("extra"), avg("extra"), max("extra"))
// Start new threads to get observations
val future1 = Future(ob1.get)(ExecutionContext.global)
val future2 = Future(ob2.get)(ExecutionContext.global)
// make sure the threads are blocked right now
val e1 = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future1, 2.seconds)
}
assert(e1.getMessage.contains("timed out after"))
val e2 = intercept[java.util.concurrent.TimeoutException] {
SparkThreadUtils.awaitResult(future2, 2.seconds)
}
assert(e2.getMessage.contains("timed out after"))
collectFunc._2(observedObservedDf)
// make sure the threads are unblocked after the query is finished
val metrics1 = SparkThreadUtils.awaitResult(future1, 5.seconds)
assert(metrics1 === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
val metrics2 = SparkThreadUtils.awaitResult(future2, 5.seconds)
assert(metrics2 === Map("min(extra)" -> -1, "avg(extra)" -> 48, "max(extra)" -> 97))
}
assert(e.getMessage.contains("Future timed out"))
observedDf.collect()
// make sure the thread is unblocked after the query is finished
val metrics = SparkThreadUtils.awaitResult(future, 2.seconds)
assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98))
}

test("SPARK-48852: trim function on a string column returns correct results") {
val session: SparkSession = spark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable {

override def close() = self.close()
}

override def map[B](f: E => B): CloseableIterator[B] = {
new CloseableIterator[B] {
override def next(): B = f(self.next())

override def hasNext: Boolean = self.hasNext

override def close(): Unit = self.close()
}
}
}

private[sql] abstract class WrappedCloseableIterator[E] extends CloseableIterator[E] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String,
setObservationMetricsOpt: Option[(Long, Row) => Unit] = None)
timeZoneId: String)
extends AutoCloseable { self =>

case class StageInfo(
Expand Down Expand Up @@ -122,7 +121,8 @@ private[sql] class SparkResult[T](
while (!stop && responses.hasNext) {
val response = responses.next()

// Collect metrics for this response
// Collect **all** metrics for this response, whether or not registered to an Observation
// object.
observedMetrics ++= processObservedMetrics(response.getObservedMetricsList)

// Save and validate operationId
Expand Down Expand Up @@ -209,23 +209,7 @@ private[sql] class SparkResult[T](
private def processObservedMetrics(
metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
metrics.asScala.map { metric =>
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val values = mutable.ArrayBuilder.make[Any]
values.sizeHint(metric.getKeysCount)
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
values += value
}
val row = new GenericRowWithSchema(values.result(), schema)
// If the metrics is registered by an Observation object, attach them and unblock any
// blocked thread.
setObservationMetricsOpt.foreach { setObservationMetrics =>
setObservationMetrics(metric.getPlanId, row)
}
metric.getName -> row
metric.getName -> SparkResult.transformObservedMetrics(metric)
}
}

Expand Down Expand Up @@ -387,8 +371,23 @@ private[sql] class SparkResult[T](
}
}

private object SparkResult {
private[sql] object SparkResult {
private val cleaner: Cleaner = Cleaner.create()

/** Return value is a Seq of pairs, to preserve the order of values. */
private[sql] def transformObservedMetrics(metric: ObservedMetrics): Row = {
assert(metric.getKeysCount == metric.getValuesCount)
var schema = new StructType()
val values = mutable.ArrayBuilder.make[Any]
values.sizeHint(metric.getKeysCount)
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i))
schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
values += value
}
new GenericRowWithSchema(values.result(), schema)
}
}

private[client] class SparkResultCloseable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
.createObservedMetricsResponse(
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
executeHolder.request.getPlan.getRoot.getCommon.getPlanId,
executeHolder.allObservationAndPlanIds,
observedMetrics ++ accumulatedInPython))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
createObservedMetricsResponse(request.getSessionId, dataframe).foreach(
responseObserver.onNext)
createObservedMetricsResponse(
request.getSessionId,
executeHolder.allObservationAndPlanIds,
dataframe).foreach(responseObserver.onNext)
}

type Batch = (Array[Byte], Long)
Expand Down Expand Up @@ -255,6 +257,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)

private def createObservedMetricsResponse(
sessionId: String,
observationAndPlanIds: Map[String, Long],
dataframe: DataFrame): Option[ExecutePlanResponse] = {
val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
case (name, row) if !executeHolder.observations.contains(name) =>
Expand All @@ -264,13 +267,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
name -> values
}
if (observedMetrics.nonEmpty) {
val planId = executeHolder.request.getPlan.getRoot.getCommon.getPlanId
Some(
SparkConnectPlanExecution
.createObservedMetricsResponse(
sessionId,
sessionHolder.serverSessionId,
planId,
observationAndPlanIds,
observedMetrics))
} else None
}
Expand All @@ -280,17 +282,17 @@ object SparkConnectPlanExecution {
def createObservedMetricsResponse(
sessionId: String,
serverSessionId: String,
planId: Long,
observationAndPlanIds: Map[String, Long],
metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = {
val observedMetrics = metrics.map { case (name, values) =>
val metrics = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
.setPlanId(planId)
values.foreach { case (key, value) =>
metrics.addValues(toLiteralProto(value))
key.foreach(metrics.addKeys)
}
observationAndPlanIds.get(name).foreach(metrics.setPlanId)
metrics.build()
}
// Prepare a response with the observed metrics.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1190,14 +1190,14 @@ class SparkConnectPlanner(
val input = transformRelation(rel.getInput)

if (input.isStreaming || executeHolderOpt.isEmpty) {
CollectMetrics(name, metrics.map(_.named), transformRelation(rel.getInput), planId)
CollectMetrics(name, metrics.map(_.named), input, planId)
} else {
// TODO this might be too complex for no good reason. It might
// be easier to inspect the plan after it completes.
val observation = Observation(name)
session.observationManager.register(observation, planId)
executeHolderOpt.get.addObservation(name, observation)
CollectMetrics(name, metrics.map(_.named), transformRelation(rel.getInput), planId)
CollectMetrics(name, metrics.map(_.named), input, planId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.jdk.CollectionConverters._

import com.google.protobuf.GeneratedMessage

import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -81,6 +83,10 @@ private[connect] class ExecuteHolder(

val observations: mutable.Map[String, Observation] = mutable.Map.empty

lazy val allObservationAndPlanIds: Map[String, Long] = {
ExecuteHolder.collectAllObservationAndPlanIds(request.getPlan).toMap
}

private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)

/** System.currentTimeMillis when this ExecuteHolder was created. */
Expand Down Expand Up @@ -289,6 +295,26 @@ private[connect] class ExecuteHolder(
def operationId: String = key.operationId
}

private object ExecuteHolder {
private def collectAllObservationAndPlanIds(
planOrMessage: GeneratedMessage,
collected: mutable.Map[String, Long] = mutable.Map.empty): mutable.Map[String, Long] = {
planOrMessage match {
case relation: proto.Relation if relation.hasCollectMetrics =>
collected += relation.getCollectMetrics.getName -> relation.getCommon.getPlanId
collectAllObservationAndPlanIds(relation.getCollectMetrics.getInput, collected)
case _ =>
planOrMessage.getAllFields.values().asScala.foreach {
case message: GeneratedMessage =>
collectAllObservationAndPlanIds(message, collected)
case _ =>
// not a message (probably a primitive type), do nothing
}
}
collected
}
}

/** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */
object ExecuteJobTag {
private val prefix = "SparkConnect_OperationTag"
Expand Down

0 comments on commit d0e2c06

Please sign in to comment.