Skip to content

Commit

Permalink
Remove includeCodeGen option and address other PR feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove committed May 20, 2021
1 parent e666aa7 commit 5fe4ea2
Showing 1 changed file with 55 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ object GenerateDot {
plan: QueryPlanWithMetrics,
comparisonPlan: Option[QueryPlanWithMetrics],
dir: File,
filename: String,
includeCodegen: Boolean): Unit = {
filename: String): Unit = {

var nextId = 1

Expand All @@ -68,19 +67,6 @@ object GenerateDot {
}
}

/**
* Optionally remove code-gen nodes.
*/
def normalize(plan: SparkPlanInfo): SparkPlanInfo = {
if (!includeCodegen
&& (plan.nodeName.startsWith("WholeStageCodegen")
|| plan.nodeName.startsWith("InputAdapter"))) {
plan.children.head
} else {
plan
}
}

def formatMetric(m: SQLMetricInfo, value: Long): String = {
val formatter = java.text.NumberFormat.getIntegerInstance
m.metricType match {
Expand All @@ -102,56 +88,47 @@ object GenerateDot {
comparisonNode: QueryPlanWithMetrics,
id: Int = 0): Unit = {

val nodeNormalized = normalize(node.plan)
val comparisonNodeNormalized = normalize(comparisonNode.plan)
if (nodeNormalized.nodeName == comparisonNodeNormalized.nodeName &&
nodeNormalized.children.length == comparisonNodeNormalized.children.length) {
val nodePlan = node.plan
val comparisonPlan = comparisonNode.plan
if (nodePlan.nodeName == comparisonPlan.nodeName &&
nodePlan.children.length == comparisonPlan.children.length) {

val metricNames = (nodeNormalized.metrics.map(_.name) ++
comparisonNodeNormalized.metrics.map(_.name)).distinct.sorted
val metricNames = (nodePlan.metrics.map(_.name) ++
comparisonPlan.metrics.map(_.name)).distinct.sorted

val metrics = metricNames.flatMap(name => {
val l = nodeNormalized.metrics.find(_.name == name)
val r = comparisonNodeNormalized.metrics.find(_.name == name)
if (l.isDefined && r.isDefined) {
val metric1 = l.get
val metric2 = r.get
(node.metrics.get(metric1.accumulatorId),
comparisonNode.metrics.get(metric1.accumulatorId)) match {
case (Some(value1), Some(value2)) =>
if (value1 == value2) {
Some(s"$name: ${formatMetric(metric1, value1)}")
} else {
metric1.metricType match {
case "nsTiming" | "timing" =>
val n1 = value1
val n2 = value2
val pct = (n2 - n1) * 100.0 / n1
val pctStr = if (pct < 0) {
f"$pct%.1f"
} else {
f"+$pct%.1f"
}
Some(s"$name: ${formatMetric(metric1, value1)} / " +
s"${formatMetric(metric2, value2)} ($pctStr %)")
case _ =>
Some(s"$name: ${formatMetric(metric1, value1)} / " +
s"${formatMetric(metric2, value2)}")
val l = nodePlan.metrics.find(_.name == name)
val r = comparisonPlan.metrics.find(_.name == name)
(l, r) match {
case (Some(metric1), Some(metric2)) =>
(node.metrics.get(metric1.accumulatorId),
comparisonNode.metrics.get(metric1.accumulatorId)) match {
case (Some(value1), Some(value2)) =>
if (value1 == value2) {
Some(s"$name: ${formatMetric(metric1, value1)}")
} else {
metric1.metricType match {
case "nsTiming" | "timing" =>
val pctStr = createPercentDiffString(value1, value2)
Some(s"$name: ${formatMetric(metric1, value1)} / " +
s"${formatMetric(metric2, value2)} ($pctStr %)")
case _ =>
Some(s"$name: ${formatMetric(metric1, value1)} / " +
s"${formatMetric(metric2, value2)}")
}
}
}
case _ => None
}
} else {
None
case _ => None
}
case _ => None
}
}).mkString("\n")

val color = if (isGpuPlan(nodeNormalized)) { GPU_COLOR } else { CPU_COLOR }
val color = if (isGpuPlan(nodePlan)) { GPU_COLOR } else { CPU_COLOR }

val label = if (nodeNormalized.nodeName.contains("QueryStage")) {
nodeNormalized.simpleString
val label = if (nodePlan.nodeName.contains("QueryStage")) {
nodePlan.simpleString
} else {
nodeNormalized.nodeName
nodePlan.nodeName
}

val nodeText =
Expand All @@ -161,16 +138,16 @@ object GenerateDot {
|""".stripMargin

w.write(nodeText)
nodeNormalized.children.indices.foreach(i => {
nodePlan.children.indices.foreach(i => {
val childId = nextId
nextId += 1
writeGraph(
w,
QueryPlanWithMetrics(nodeNormalized.children(i), node.metrics),
QueryPlanWithMetrics(comparisonNodeNormalized.children(i), comparisonNode.metrics),
QueryPlanWithMetrics(nodePlan.children(i), node.metrics),
QueryPlanWithMetrics(comparisonPlan.children(i), comparisonNode.metrics),
childId);

val style = (isGpuPlan(nodeNormalized), isGpuPlan(nodeNormalized.children(i))) match {
val style = (isGpuPlan(nodePlan), isGpuPlan(nodePlan.children(i))) match {
case (true, true) => s"""color="$GPU_COLOR""""
case (false, false) => s"""color="$CPU_COLOR""""
case _ =>
Expand All @@ -184,18 +161,31 @@ object GenerateDot {
w.write(
s"""node$id [shape=box, color=red,
|label = "plans diverge here:
|${nodeNormalized.nodeName} vs ${comparisonNodeNormalized.nodeName}"];\n""".stripMargin)
|${nodePlan.nodeName} vs ${comparisonPlan.nodeName}"];\n""".stripMargin)
}
}

// write the dot graph to a file
val file = new File(dir, filename)
println(s"Writing ${file.getAbsolutePath}")
val w = new FileWriter(file)
w.write("digraph G {\n")
writeGraph(w, plan, comparisonPlan.getOrElse(plan), 0)
w.write("}\n")
w.close()
try {
w.write("digraph G {\n")
writeGraph(w, plan, comparisonPlan.getOrElse(plan), 0)
w.write("}\n")
} finally {
w.close()
}
}

private def createPercentDiffString(n1: Long, n2: Long) = {
val pct = (n2 - n1) * 100.0 / n1
val pctStr = if (pct < 0) {
f"$pct%.1f"
} else {
f"+$pct%.1f"
}
pctStr
}
}

Expand Down

0 comments on commit 5fe4ea2

Please sign in to comment.