Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for encrypted ORC write [databricks] #5764

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@
</exclusions>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-321cdh-test-src</id>
<goals><goal>add-test-source</goal></goals>
<configuration>
<sources>
<source>${project.basedir}/src/test/320+/scala</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile>
<!--
Expand Down Expand Up @@ -273,6 +292,62 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>release321</id>
<activation>
<property>
<name>buildver</name>
<value>321</value>
</property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-321-test-src</id>
<goals><goal>add-test-source</goal></goals>
<configuration>
<sources>
<source>${project.basedir}/src/test/320+/scala</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>release322</id>
<activation>
<property>
<name>buildver</name>
<value>322</value>
</property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-322-test-src</id>
<goals><goal>add-test-source</goal></goals>
<configuration>
<sources>
<source>${project.basedir}/src/test/320+/scala</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>release330</id>
<activation>
Expand All @@ -293,6 +368,7 @@
<configuration>
<sources>
<source>${project.basedir}/src/test/330/scala</source>
<source>${project.basedir}/src/test/320+/scala</source>
</sources>
</configuration>
</execution>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.io.File
import java.security.SecureRandom

import com.nvidia.spark.rapids.shims.SparkShimImpl
import org.apache.hadoop.conf.Configuration
import org.apache.orc.{EncryptionAlgorithm, InMemoryKeystore}
import org.apache.orc.impl.CryptoUtils

class OrcEncryptionSuite extends SparkQueryCompareTestSuite {

// Create an InMemoryKeystore provider and addKey `pii` to it.
// CryptoUtils caches it so it can be used later by the test
val hadoopConf = new Configuration()
hadoopConf.set("orc.key.provider", "memory")
val random = new SecureRandom()
val keystore: InMemoryKeystore =
CryptoUtils.getKeyProvider(hadoopConf, random).asInstanceOf[InMemoryKeystore]
val algorithm: EncryptionAlgorithm = EncryptionAlgorithm.AES_CTR_128
val piiKey = new Array[Byte](algorithm.keyLength)
val topSecretKey = new Array[Byte](algorithm.keyLength)
random.nextBytes(piiKey)
random.nextBytes(topSecretKey)
keystore.addKey("pii", algorithm, piiKey).addKey("top_secret", algorithm, topSecretKey)

testGpuWriteFallback(
"Write encrypted ORC fallback",
"DataWritingCommandExec",
intsDf,
execsAllowedNonGpu = Seq("ShuffleExchangeExec", "DataWritingCommandExec")) {
frame =>
// ORC encryption is only allowed in 3.2+
val isValidTestForSparkVersion = SparkShimImpl.getSparkShimVersion match {
case SparkShimVersion(major, minor, _) => major == 3 && minor != 1
case DatabricksShimVersion(major, minor, _, _) => major == 3 && minor != 1
case ClouderaShimVersion(major, minor, _, _) => major == 3 && minor != 1
case _ => true
}
assume(isValidTestForSparkVersion)

val tempFile = File.createTempFile("orc-encryption-test", "")
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
frame.write.options(Map("orc.key.provider" -> "memory",
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
"orc.encrypt" -> "pii:ints,more_ints",
"orc.mask" -> "sha256:ints,more_ints")).mode("overwrite").orc(tempFile.getAbsolutePath)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,57 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
}
}

def writeOnCpuAndGpuWithCapture(df: SparkSession => DataFrame,
tgravescs marked this conversation as resolved.
Show resolved Hide resolved
fun: DataFrame => Unit,
conf: SparkConf = new SparkConf(),
repart: Integer = 1)
: (SparkPlan, SparkPlan) = {
conf.setIfMissing("spark.sql.shuffle.partitions", "2")

// force a new session to avoid accidentally capturing a late callback from a previous query
TrampolineUtil.cleanupAnyExistingSession()
ExecutionPlanCaptureCallback.startCapture()
var cpuPlan: Option[SparkPlan] = null
try {
withCpuSparkSession(session => {
var data = df(session)
if (repart > 0) {
// repartition the data so it is turned into a projection,
// not folded into the table scan exec
data = data.repartition(repart)
}
fun(data)
}, conf)
} finally {
cpuPlan = ExecutionPlanCaptureCallback.getResultWithTimeout()
}
if (cpuPlan.isEmpty) {
throw new RuntimeException("Did not capture CPU plan")
}

ExecutionPlanCaptureCallback.startCapture()
var gpuPlan: Option[SparkPlan] = null
try {
withGpuSparkSession(session => {
var data = df(session)
if (repart > 0) {
// repartition the data so it is turned into a projection,
// not folded into the table scan exec
data = data.repartition(repart)
}
fun(data)
}, conf)
} finally {
gpuPlan = ExecutionPlanCaptureCallback.getResultWithTimeout()
}

if (gpuPlan.isEmpty) {
throw new RuntimeException("Did not capture GPU plan")
}

(cpuPlan.get, gpuPlan.get)
}

def runOnCpuAndGpuWithCapture(df: SparkSession => DataFrame,
fun: DataFrame => DataFrame,
conf: SparkConf = new SparkConf(),
Expand Down Expand Up @@ -302,6 +353,29 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm {
(fromCpu, cpuPlan.get, fromGpu, gpuPlan.get)
}

def testGpuWriteFallback(testName: String,
fallbackCpuClass: String,
df: SparkSession => DataFrame,
conf: SparkConf = new SparkConf(),
repart: Integer = 1,
sort: Boolean = false,
maxFloatDiff: Double = 0.0,
incompat: Boolean = false,
execsAllowedNonGpu: Seq[String] = Seq.empty,
sortBeforeRepart: Boolean = false)
(fun: DataFrame => Unit): Unit = {
val (testConf, qualifiedTestName) =
setupTestConfAndQualifierName(testName, incompat, sort, conf, execsAllowedNonGpu,
maxFloatDiff, sortBeforeRepart)
test(qualifiedTestName) {
val (_, gpuPlan) = writeOnCpuAndGpuWithCapture(df, fun,
conf = testConf,
repart = repart)
// Now check the GPU Conditions
ExecutionPlanCaptureCallback.assertDidFallBack(gpuPlan, fallbackCpuClass)
}
}

def testGpuFallback(testName: String,
fallbackCpuClass: String,
df: SparkSession => DataFrame,
Expand Down