Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix node constructor for DB platforms #761

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,31 @@ package org.apache.spark.sql.rapids.tool.util
import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable
import scala.reflect.runtime.universe._

import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphEdge, SparkPlanGraphNode, SQLPlanMetric}


// Container class to hold snapshot of the reflection fields instead of recalculating them every
// time we call the constructor
case class DBReflectionContainer() {
val mirror = runtimeMirror(getClass.getClassLoader)
// Get the class symbol
val classSymbol = mirror.staticClass("org.apache.spark.sql.execution.ui.SparkPlanGraphNode")
// Get the constructor method symbol
val constructor = classSymbol.primaryConstructor.asMethod

def constructNode(id: Long, name: String, desc: String,
metrics: collection.Seq[SQLPlanMetric]): SparkPlanGraphNode = {
// Define argument values
val argValues = List(id, name, desc, metrics, "", false, None, None)
mirror.reflectClass(classSymbol)
.reflectConstructor(constructor)(argValues: _*)
.asInstanceOf[org.apache.spark.sql.execution.ui.SparkPlanGraphNode]
}
}

/**
* This code is mostly copied from org.apache.spark.sql.execution.ui.SparkPlanGraph
* with changes to handle GPU nodes. Without this special handle, the default SparkPlanGraph
Expand All @@ -31,6 +52,28 @@ import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster,
* Build a SparkPlanGraph from the root of a SparkPlan tree.
*/
object ToolsPlanGraph {
// TODO: We should have a util to detect if the runtime is Databricks.
// This can be achieved by checking for spark properties
// spark.databricks.clusterUsageTags.clusterAllTags
private lazy val dbRuntimeReflection = DBReflectionContainer()
// By default call the Spark constructor. If this fails, we fall back to the DB constructor
def constructGraphNode(id: Long, name: String, desc: String,
metrics: collection.Seq[SQLPlanMetric]): SparkPlanGraphNode = {
try {
new SparkPlanGraphNode(id, name, desc, metrics)
} catch {
case _: java.lang.NoSuchMethodError =>
// DataBricks has different constructor of the sparkPlanGraphNode
// [(long,java.lang.String,java.lang.String,scala.collection.Seq,java.lang.String,
// boolean,scala.Option,scala.Option)] and
// [final long id, final java.lang.String name, final java.lang.String desc,
// final scala.collection.Seq<org.apache.spark.sql.execution.ui.SQLPlanMetric> metrics,
// final java.lang.String rddScopeId, final boolean started,
// final scala.Option<scala.math.BigInt> estRowCount)
dbRuntimeReflection.constructNode(id, name, desc, metrics)
}
}

/**
* Build a SparkPlanGraph from the root of a SparkPlan tree.
*/
Expand Down Expand Up @@ -110,9 +153,8 @@ object ToolsPlanGraph {
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
planInfo.simpleString, metrics)
val node = constructGraphNode(nodeIdGenerator.getAndIncrement(),
planInfo.nodeName, planInfo.simpleString, metrics)
if (subgraph == null) {
nodes += node
} else {
Expand Down
Loading