Skip to content

Commit

Permalink
Integrate GenerateDot and add integration test
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove committed May 25, 2021
1 parent 5dee53e commit cd675b5
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
*/
package com.nvidia.spark.rapids.tool.profiling

import java.io.File
import java.util.concurrent.TimeUnit

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo


/**
* CollectInformation mainly print information based on this event log:
* Such as executors, parameters, etc.
Expand Down Expand Up @@ -72,4 +78,39 @@ class CollectInformation(apps: ArrayBuffer[ApplicationInfo]) {
messageHeader = messageHeader)
}
}

def generateDot(): Unit = {
for (app <- apps) {
val requiredDataFrames = Seq("sqlMetricsDF", "driverAccumDF",
"taskStageAccumDF", "taskStageAccumDF")
.map(name => s"${name}_${app.index}")
if (requiredDataFrames.forall(app.allDataFrames.contains)) {
val accums = app.runQuery(app.generateSQLAccums)
val start = System.nanoTime()
val accumSummary = accums
.select(col("sqlId"), col("accumulatorId"), col("max_value"))
.collect()
val map = new mutable.HashMap[Long, ArrayBuffer[(Long,Long)]]()
for (row <- accumSummary) {
val list = map.getOrElseUpdate(row.getLong(0), new ArrayBuffer[(Long, Long)]())
list += row.getLong(1) -> row.getLong(2)
}
val outDir = new File(app.args.outputDirectory())
for ((sqlID, planInfo) <- app.sqlPlan) {
val fileDir = new File(outDir, s"${app.appId}-query-$sqlID")
fileDir.mkdirs()
val metrics = map.getOrElse(sqlID, Seq.empty).toMap
GenerateDot.generateDotGraph(
QueryPlanWithMetrics(planInfo, metrics), None, fileDir, sqlID + ".dot")
}
val duration = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)
fileWriter.write(s"Generated DOT graphs for app ${app.appId} " +
s"to ${outDir.getAbsolutePath} in $duration second(s)\n")
} else {
val missingDataFrames = requiredDataFrames.filterNot(app.allDataFrames.contains)
fileWriter.write(s"Could not generate DOT graph for app ${app.appId} " +
s"because of missing data frames: ${missingDataFrames.mkString(", ")}\n")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,9 @@ For usage see below:
val numOutputRows: ScallopOption[Int] =
opt[Int](required = false,
descr = "Number of output rows for each Application. Default is 1000")
val generateDot: ScallopOption[Boolean] =
opt[Boolean](required = false,
descr = "Generate query visualizations in DOT format. Default is false")

verify()
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,41 @@
package com.nvidia.spark.rapids.tool.profiling

import java.io.FileWriter

import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.rapids.tool.profiling._

/**
* A profiling tool to parse Spark Event Log
* This is the Main function.
*/
object ProfileMain extends Logging {
/**
* Entry point from spark-submit running this as the driver.
*/
def main(args: Array[String]) {
val sparkSession = ProfileUtils.createSparkSession
val exitCode = mainInternal(sparkSession, new ProfileArgs(args))
if (exitCode != 0) {
System.exit(exitCode)
}
}

/**
* Entry point for tests
*/
def mainInternal(sparkSession: SparkSession, appArgs: ProfileArgs): Int = {

// This tool's output log file name
val logFileName = "rapids_4_spark_tools_output.log"

// Parsing args
val appArgs = new ProfileArgs(args)
val eventlogPaths = appArgs.eventlog()
val outputDirectory = appArgs.outputDirectory().stripSuffix("/")

// Create the FileWriter and sparkSession used for ALL Applications.
val fileWriter = new FileWriter(s"$outputDirectory/$logFileName")
val sparkSession = ProfileUtils.createSparkSession
logInfo(s"Output directory: $outputDirectory")

// Convert the input path string to Path(s)
Expand All @@ -68,9 +77,9 @@ object ProfileMain extends Logging {
//Exit if there are no applications to process.
if (apps.isEmpty) {
logInfo("No application to process. Exiting")
System.exit(0)
return 0
}
processApps(apps)
processApps(apps, generateDot = false)
// Show the application Id <-> appIndex mapping.
for (app <- apps) {
logApplicationInfo(app)
Expand All @@ -84,7 +93,7 @@ object ProfileMain extends Logging {
val app = new ApplicationInfo(appArgs, sparkSession, fileWriter, path, index)
apps += app
logApplicationInfo(app)
processApps(apps)
processApps(apps, appArgs.generateDot())
app.dropAllTempViews()
index += 1
}
Expand All @@ -100,7 +109,7 @@ object ProfileMain extends Logging {
* evaluated at once and the output is one row per application. Else each eventlog is parsed one
* at a time.
*/
def processApps(apps: ArrayBuffer[ApplicationInfo]): Unit = {
def processApps(apps: ArrayBuffer[ApplicationInfo], generateDot: Boolean): Unit = {
if (appArgs.compare()) { // Compare Applications
logInfo(s"### A. Compare Information Collected ###")
val compare = new CompareApplications(apps)
Expand All @@ -113,6 +122,9 @@ object ProfileMain extends Logging {
collect.printAppInfo()
collect.printExecutorInfo()
collect.printRapidsProperties()
if (generateDot) {
collect.generateDot()
}
}

logInfo(s"### B. Analysis ###")
Expand All @@ -133,5 +145,7 @@ object ProfileMain extends Logging {
logInfo(s"============== ${app.appId} (index=${app.index}) ==============")
logInfo("========================================================================")
}

0
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.nvidia.spark.rapids.tool.profiling

import java.io.{File, FilenameFilter}

import com.google.common.io.Files
import org.scalatest.FunSuite
import scala.io.Source

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

class GenerateDotSuite extends FunSuite with Logging {

test("Generate DOT") {
val eventLogDir = Files.createTempDir()
eventLogDir.deleteOnExit()

val spark = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.config("spark.eventLog.enabled", "true")
.config("spark.eventLog.dir", eventLogDir.getAbsolutePath)
.getOrCreate()

// generate some events
import spark.implicits._
val t1 = Seq((1, 2), (3, 4)).toDF("a", "b")
t1.createOrReplaceTempView("t1")
val df = spark.sql("SELECT a, MAX(b) FROM t1 GROUP BY a ORDER BY a")
df.collect()

// close the event log
spark.close()

val files = eventLogDir.listFiles(new FilenameFilter {
override def accept(file: File, s: String): Boolean = !s.startsWith(".")
})
assert(files.length === 1)

// create new session for tool to use
val spark2 = SparkSession
.builder()
.master("local[*]")
.appName("Rapids Spark Profiling Tool Unit Tests")
.getOrCreate()

val dotFileDir = Files.createTempDir()
dotFileDir.deleteOnExit()

val appArgs = new ProfileArgs(Array(
"--output-directory",
dotFileDir.getAbsolutePath,
"--generate-dot",
files.head.getAbsolutePath
))
ProfileMain.mainInternal(spark2, appArgs)

// assert that a file was generated
val dotDirs = listFilesEnding(dotFileDir, "-1")
assert(dotDirs.length === 1)
val dotFiles = listFilesEnding(dotDirs.head, ".dot")
assert(dotFiles.length === 1)

// assert that the generated file looks something like what we expect
val source = Source.fromFile(dotFiles.head)
try {
val lines = source.getLines().toArray
assert(lines.head === "digraph G {")
assert(lines.last === "}")
assert(lines.count(_.contains("HashAggregate")) === 2)
} finally {
source.close()
}
}

private def listFilesEnding(dir: File, pattern: String): Array[File] = {
dir.listFiles(new FilenameFilter {
override def accept(file: File, s: String): Boolean = s.endsWith(pattern)
})
}

}

0 comments on commit cd675b5

Please sign in to comment.