Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark Dataframe count pushdown #29

Merged
merged 5 commits into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ object Dependencies {
private val TypesafeLoggingVersion = "3.7.2"

private val ScalaTestVersion = "3.0.5"
private val MockitoVersion = "2.22.0"
private val ContainersJdbcVersion = "1.8.3"
private val MockitoVersion = "2.23.0"
private val ContainersJdbcVersion = "1.10.1"
private val ContainersScalaVersion = "0.19.0"

private val sparkCurrentVersion =
Expand Down
4 changes: 2 additions & 2 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.3.7")

// Adds Contrib Warts
// http://github.com/wartremover/wartremover-contrib/
addSbtPlugin("org.wartremover" % "sbt-wartremover-contrib" % "1.2.3")
addSbtPlugin("org.wartremover" % "sbt-wartremover-contrib" % "1.2.4")

// Adds Extra Warts
// http://github.com/danielnixon/extrawarts
addSbtPlugin("org.danielnixon" % "sbt-extrawarts" % "1.0.3")

// Adds a `assembly` task to create a fat JAR with all of its dependencies
// https://github.com/sbt/sbt-assembly
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.7")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9")

// Adds a `BuildInfo` tasks
// https://github.com/sbt/sbt-buildinfo
Expand Down
4 changes: 2 additions & 2 deletions sbtx
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ set -o pipefail
declare -r sbt_release_version="0.13.17"
declare -r sbt_unreleased_version="0.13.17"

declare -r latest_213="2.13.0-M4"
declare -r latest_212="2.12.6"
declare -r latest_213="2.13.0-M5"
declare -r latest_212="2.12.7"
declare -r latest_211="2.11.12"
declare -r latest_210="2.10.7"
declare -r latest_29="2.9.3"
Expand Down
52 changes: 44 additions & 8 deletions src/main/scala/com/exasol/spark/ExasolRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ class ExasolRelation(
val stmt = conn.createStatement()
val resultSet = stmt.executeQuery(queryStringLimit)
val metadata = resultSet.getMetaData
Types.createSparkStructType(metadata)
val sparkStruct = Types.createSparkStructType(metadata)
resultSet.close()
stmt.close()
sparkStruct
}
}

Expand All @@ -52,15 +55,48 @@ class ExasolRelation(
buildScan(requiredColumns, Array.empty)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] =
new ExasolRDD(
sqlContext.sparkContext,
enrichQuery(requiredColumns, filters),
Types.selectColumns(requiredColumns, schema),
manager
)
if (requiredColumns.isEmpty) {
makeEmptyRDD(filters)
} else {
new ExasolRDD(
sqlContext.sparkContext,
enrichQuery(requiredColumns, filters),
Types.selectColumns(requiredColumns, schema),
manager
)
}

/**
* When a count action is run from Spark dataframe we do not have to read the actual data and
* perform all serializations through the network. Instead we can create a RDD with empty Row-s
* with expected number of rows from actual query.
*
* This also called count pushdown.
*
* @param filters A list of [[org.apache.spark.sql.sources.Filter]]-s that can be pushed as
* where clause
* @return An RDD of empty Row-s which has as many elements as count(*) from enriched query
*/
private[this] def makeEmptyRDD(filters: Array[Filter]): RDD[Row] = {
val cntQuery = enrichQuery(Array.empty[String], filters)
val cnt = manager.withCountQuery(cntQuery)
sqlContext.sparkContext.parallelize(1L to cnt, 4).map(r => Row.empty)
}

/**
* Improves the original query with column pushdown and predicate pushdown.
*
* It will use provided column names to create a sub select query and similarly add where clause
* if filters are provided.
*
* Additionally, if no column names are provided it creates a 'COUNT(*)' query.
*
* @param columns A list of column names
* @param filters A list of Spark [[org.apache.spark.sql.sources.Filter]]-s
* @return An enriched query with column selection and where clauses
*/
private[this] def enrichQuery(columns: Array[String], filters: Array[Filter]): String = {
val columnStr = if (columns.isEmpty) "*" else columns.map(c => s"A.$c").mkString(", ")
val columnStr = if (columns.isEmpty) "COUNT(*)" else columns.map(c => s"A.$c").mkString(", ")
val filterStr = Filters.createWhereClause(schema, filters)
val whereClause = if (filterStr.trim.isEmpty) "" else s"WHERE $filterStr"
val enrichedQuery = s"SELECT $columnStr FROM ($queryString) A $whereClause"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ final case class ExasolConnectionManager(config: ExasolConfiguration) {
ExasolConnectionManager
.withConnection(mainConnectionUrl, config.username, config.password)(handle)

def withCountQuery(query: String): Long = withConnection[Long] { conn =>
val stmt = conn.createStatement()
val resultSet = stmt.executeQuery(query)
val cnt = if (resultSet.next()) {
resultSet.getLong(1)
} else {
throw new IllegalStateException("Could not query the count!")
}
resultSet.close()
stmt.close()
cnt
}

}

object ExasolConnectionManager extends LazyLogging {
Expand Down Expand Up @@ -70,7 +83,9 @@ object ExasolConnectionManager extends LazyLogging {
def makeConnection(url: String, username: String, password: String): EXAConnection = {
logger.debug(s"Making a connection using url = $url")
removeIfClosed(url)
val _ = connections.putIfAbsent(url, createConnection(url, username, password))
if (!connections.containsKey(url)) {
val _ = connections.put(url, createConnection(url, username, password))
}
connections.get(url)
}

Expand Down
39 changes: 39 additions & 0 deletions src/test/scala/com/exasol/spark/ExasolRelationSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.exasol.spark

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType

import com.exasol.spark.util.ExasolConnectionManager

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.mockito.Mockito._
import org.scalatest.FunSuite
import org.scalatest.Matchers
import org.scalatest.mockito.MockitoSugar

class ExasolRelationSuite
extends FunSuite
with Matchers
with MockitoSugar
with DataFrameSuiteBase {

test("buildScan returns RDD of empty Row-s when requiredColumns is empty (count pushdown)") {
val query = "SELECT 1"
val cntQuery = "SELECT COUNT(*) FROM (SELECT 1) A "
val cnt = 5L

val manager = mock[ExasolConnectionManager]
when(manager.withCountQuery(cntQuery)).thenReturn(cnt)

val relation = new ExasolRelation(spark.sqlContext, query, Option(new StructType), manager)
val rdd = relation.buildScan()

assert(rdd.isInstanceOf[RDD[Row]])
assert(rdd.partitions.size === 4)
assert(rdd.count === cnt)
verify(manager, times(1)).withCountQuery(cntQuery)
}

}