Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SQLPLanMetric constructor for DB Platforms #763

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,34 @@ import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster,
// Container class to hold snapshot of the reflection fields instead of recalculating them every
// time we call the constructor
case class DBReflectionContainer() {
val mirror = runtimeMirror(getClass.getClassLoader)
// Get the class symbol
val classSymbol = mirror.staticClass("org.apache.spark.sql.execution.ui.SparkPlanGraphNode")
// Get the constructor method symbol
val constructor = classSymbol.primaryConstructor.asMethod
private val mirror = runtimeMirror(getClass.getClassLoader)
// Get the node class symbol
private val nodeClassSymbol = mirror.staticClass("org.apache.spark.sql.execution.ui.SparkPlanGraphNode")
// Get the node constructor method symbol
private val nodeConstr = nodeClassSymbol.primaryConstructor.asMethod
// Get the SQL class symbol
private val metricClassSymbol = mirror.staticClass("org.apache.spark.sql.execution.ui.SQLPlanMetric")
// Get the metric constructor method symbol
private val metricConstr = metricClassSymbol.primaryConstructor.asMethod

def constructNode(id: Long, name: String, desc: String,
metrics: collection.Seq[SQLPlanMetric]): SparkPlanGraphNode = {
// Define argument values
val argValues = List(id, name, desc, metrics, "", false, None, None)
mirror.reflectClass(classSymbol)
.reflectConstructor(constructor)(argValues: _*)
mirror.reflectClass(nodeClassSymbol)
.reflectConstructor(nodeConstr)(argValues: _*)
.asInstanceOf[org.apache.spark.sql.execution.ui.SparkPlanGraphNode]
}

def constructSQLPlanMetric(name: String,
accumulatorId: Long,
metricType: String): SQLPlanMetric = {
// Define argument values
val argValues = List(name, accumulatorId, metricType, false)
mirror.reflectClass(metricClassSymbol)
.reflectConstructor(metricConstr)(argValues: _*)
.asInstanceOf[org.apache.spark.sql.execution.ui.SQLPlanMetric]
}
}

/**
Expand All @@ -57,7 +71,7 @@ object ToolsPlanGraph {
// spark.databricks.clusterUsageTags.clusterAllTags
private lazy val dbRuntimeReflection = DBReflectionContainer()
// By default call the Spark constructor. If this fails, we fall back to the DB constructor
def constructGraphNode(id: Long, name: String, desc: String,
private def constructGraphNode(id: Long, name: String, desc: String,
metrics: collection.Seq[SQLPlanMetric]): SparkPlanGraphNode = {
try {
new SparkPlanGraphNode(id, name, desc, metrics)
Expand All @@ -72,6 +86,21 @@ object ToolsPlanGraph {
// final scala.Option<scala.math.BigInt> estRowCount)
dbRuntimeReflection.constructNode(id, name, desc, metrics)
}

}

private def constructSQLPlanMetric(name: String,
accumulatorId: Long,
metricType: String): SQLPlanMetric = {
try {
SQLPlanMetric(name, accumulatorId, metricType)
} catch {
case _: java.lang.NoSuchMethodError =>
// DataBricks has different constructor of the sparkPlanGraphNode
//Array(final java.lang.String name, final long accumulatorId,
// final java.lang.String metricType, final boolean experimental)
dbRuntimeReflection.constructSQLPlanMetric(name, accumulatorId, metricType)
}
}

/**
Expand Down Expand Up @@ -151,7 +180,7 @@ object ToolsPlanGraph {
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
constructSQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = constructGraphNode(nodeIdGenerator.getAndIncrement(),
planInfo.nodeName, planInfo.simpleString, metrics)
Expand Down
Loading