Skip to content

Commit

Permalink
Fix bugs in CAST string to integer (#2919)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jul 16, 2021
1 parent 15fe52f commit 65eef79
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 66 deletions.
109 changes: 73 additions & 36 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,69 @@ object GpuCast extends Arm {
cv.stringReplaceWithBackrefs(rule.search, rule.replace)
})
}

def sanitizeStringToIntegralType(input: ColumnVector, ansiEnabled: Boolean): ColumnVector = {
// Convert any strings containing whitespace to null values. The input is assumed to already
// have been stripped of leading and trailing whitespace
val sanitized = withResource(input.containsRe("\\s")) { hasWhitespace =>
withResource(hasWhitespace.any()) { any =>
if (any.getBoolean) {
if (ansiEnabled) {
throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
} else {
withResource(GpuScalar.from(null, DataTypes.StringType)) { nullVal =>
hasWhitespace.ifElse(nullVal, input)
}
}
} else {
input.incRefCount()
}
}
}

if (ansiEnabled) {
// ansi mode only supports simple integers, so no exponents or decimal places
val regex = "^[+\\-]?[0-9]+$"
withResource(sanitized.matchesRe(regex)) { isInt =>
withResource(isInt.all()) { allInts =>
// Check that all non-null values are valid integers. Note that allInts will be false
// if all rows are null so we need to check for that condition.
if (!allInts.getBoolean && sanitized.getNullCount != sanitized.getRowCount) {
throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
sanitized
}
} else {
// truncate strings that represent decimals, so that we just look at the string before the dot
withResource(sanitized) { _ =>
withResource(Scalar.fromString(".")) { dot =>
withResource(sanitized.stringContains(dot)) { hasDot =>
// only do the decimal sanitization if any strings do contain dot
withResource(hasDot.any(DType.BOOL8)) { anyDot =>
if (anyDot.getBoolean) {
// Special handling for strings that have no numeric value before the dot, such
// as "." and ".1" because extractsRe returns null for the capture group
// for these values and it also returns null for invalid inputs so we need this
// explicit check
withResource(sanitized.matchesRe("^[+\\-]?\\.[0-9]*$")) { startsWithDot =>
withResource(sanitized.extractRe("^([+\\-]?[0-9]*)\\.[0-9]*$")) { table =>
withResource(Scalar.fromString("0")) { zero =>
withResource(startsWithDot.ifElse(zero, table.getColumn(0))) {
decimal => hasDot.ifElse(decimal, sanitized)
}
}
}
}
} else {
sanitized.incRefCount()
}
}
}
}
}
}
}
}

/**
Expand Down Expand Up @@ -685,47 +748,21 @@ case class GpuCast(
input: ColumnVector,
ansiEnabled: Boolean,
dType: DType): ColumnVector = {
val cleaned = if (!ansiEnabled) {
// TODO would be great to get rid of this regex, but the overflow checks don't work
// on the more lenient pattern.
// To avoid doing the expensive regex all the time, we will first check to see if we need
// to do it. The only time we do need to do it is when we have a '.' in any of the strings.
val data = input.getData
val hasDot = if (data != null) {
withResource(
ColumnView.fromDeviceBuffer(data, 0, DType.INT8, data.getLength.toInt)) { childData =>
withResource(GpuScalar.from('.'.toByte, ByteType)) { dot =>
childData.contains(dot)
}
}
} else {
false
}

if (hasDot) {
withResource(input.extractRe("^([+\\-]?[0-9]+)(?:\\.[0-9]*)?$")) { table =>
table.getColumn(0).incRefCount()
}
} else {
input.incRefCount()
}
} else {
input.incRefCount()
}
withResource(cleaned) { cleaned =>
withResource(cleaned.isInteger(dType)) { isInt =>
withResource(GpuCast.sanitizeStringToIntegralType(input, ansiEnabled)) { sanitized =>
withResource(sanitized.isInteger(dType)) { isInt =>
if (ansiEnabled) {
withResource(isInt.all()) { allInts =>
if (!allInts.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_INPUT_MESSAGE)
// Check that all non-null values are valid integers. Note that allInts will be false
// if all rows are null so we need to check for that condition.
if (!allInts.getBoolean && sanitized.getNullCount != sanitized.getRowCount) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
cleaned.castTo(dType)
} else {
withResource(cleaned.castTo(dType)) { parsedInt =>
withResource(GpuScalar.from(null, dataType)) { nullVal =>
isInt.ifElse(parsedInt, nullVal)
}
}
withResource(sanitized.castTo(dType)) { parsedInt =>
withResource(GpuScalar.from(null, dataType)) { nullVal =>
isInt.ifElse(parsedInt, nullVal)
}
}
}
Expand Down
134 changes: 104 additions & 30 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ import java.util.TimeZone

import ai.rapids.cudf.ColumnVector
import scala.collection.JavaConverters._
import scala.util.Random
import scala.util.{Failure, Random, Success, Try}

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class CastOpSuite extends GpuExpressionTestSuite {
Expand Down Expand Up @@ -74,27 +75,50 @@ class CastOpSuite extends GpuExpressionTestSuite {
"f", "F", "True", "TRUE", "true", "tRuE", "t", "T", "Y", "y", "10", "01", "0", "1"))
}

ignore("Cast from string to byte using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
test("Cast from string to byte using random inputs") {
testCastStringTo(DataTypes.ByteType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

ignore("Cast from string to short using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
test("Cast from string to short using random inputs") {
testCastStringTo(DataTypes.ShortType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

ignore("Cast from string to int using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
test("Cast from string to int using random inputs") {
testCastStringTo(DataTypes.IntegerType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

ignore("Cast from string to long using random inputs") {
// Test ignored due to known issues
// https://github.com/NVIDIA/spark-rapids/issues/2899
test("Cast from string to int using hand-picked values") {
testCastStringTo(DataTypes.IntegerType, Seq(".--e-37602.n", "\r\r\t\n11.12380", "-.2", ".3",
".", "+1.2", "\n123\n456\n", "1e+4"))
}

test("Cast from string to int ANSI mode with mix of valid and invalid values") {
testCastStringTo(DataTypes.IntegerType, Seq(".--e-37602.n", "\r\r\t\n11.12380", "-.2", ".3",
".", "+1.2", "\n123\n456\n", "1 2", null, "123"), ansiMode = AnsiExpectFailure)
}

test("Cast from string to int ANSI mode with valid values") {
testCastStringTo(DataTypes.IntegerType, Seq("1", "-1"),
ansiMode = AnsiExpectSuccess)
}

test("Cast from string to int ANSI mode with invalid values") {
val values = Seq("1e4", "Inf", "1.2")
// test the values individually
for (value <- values ) {
testCastStringTo(DataTypes.IntegerType, Seq(value), ansiMode = AnsiExpectFailure)
}
}

test("Cast from string to int ANSI mode with nulls") {
testCastStringTo(DataTypes.IntegerType, Seq(null, null, null), ansiMode = AnsiExpectSuccess)
}

test("Cast from string to int ANSI mode with newline in string") {
testCastStringTo(DataTypes.IntegerType, Seq("1\n2"), ansiMode = AnsiExpectFailure)
}

test("Cast from string to long using random inputs") {
testCastStringTo(DataTypes.LongType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

Expand Down Expand Up @@ -147,7 +171,10 @@ class CastOpSuite extends GpuExpressionTestSuite {
.map(_ => prefix.getOrElse("") + r.nextString())
}

private def testCastStringTo(toType: DataType, strings: Seq[String]) {
private def testCastStringTo(
toType: DataType,
strings: Seq[String],
ansiMode: AnsiTestMode = AnsiDisabled) {

def castDf(spark: SparkSession): Seq[Row] = {
import spark.implicits._
Expand All @@ -160,30 +187,48 @@ class CastOpSuite extends GpuExpressionTestSuite {
val INDEX_C0 = 0
val INDEX_C1 = 2

val cpu = withCpuSparkSession(castDf)
.sortBy(_.getInt(INDEX_ID))
val ansiModeBoolString = (ansiMode != AnsiDisabled).toString

val cpuConf = new SparkConf()
.set(SQLConf.ANSI_ENABLED.key, ansiModeBoolString)

val tryCpu = Try(withCpuSparkSession(castDf, cpuConf)
.sortBy(_.getInt(INDEX_ID)))

val conf = new SparkConf()
val gpuConf = new SparkConf()
.set(SQLConf.ANSI_ENABLED.key, ansiModeBoolString)
.set(RapidsConf.EXPLAIN.key, "ALL")
.set(RapidsConf.INCOMPATIBLE_DATE_FORMATS.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_FLOAT.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_DECIMAL.key, "true")
.set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")

val gpu = withGpuSparkSession(castDf, conf)
.sortBy(_.getInt(INDEX_ID))

for ((cpuRow, gpuRow) <- cpu.zip(gpu)) {
assert(cpuRow.getString(INDEX_C0) === gpuRow.getString(INDEX_C0))
assert(cpuRow.getInt(INDEX_ID) === gpuRow.getInt(INDEX_ID))
val cpuValue = cpuRow.get(INDEX_C1)
val gpuValue = gpuRow.get(INDEX_C1)
if (!compare(cpuValue, gpuValue)) {
val inputValue = cpuRow.getString(INDEX_C0)
fail(s"Mismatch casting string [$inputValue] " +
s"to $toType. CPU: $cpuValue; GPU: $gpuValue")
}
val tryGpu = Try(withGpuSparkSession(castDf, gpuConf)
.sortBy(_.getInt(INDEX_ID)))

(tryCpu, tryGpu) match {
case (Success(cpu), Success(gpu)) if ansiMode != AnsiExpectFailure =>
for ((cpuRow, gpuRow) <- cpu.zip(gpu)) {
assert(cpuRow.getString(INDEX_C0) === gpuRow.getString(INDEX_C0))
assert(cpuRow.getInt(INDEX_ID) === gpuRow.getInt(INDEX_ID))
val cpuValue = cpuRow.get(INDEX_C1)
val gpuValue = gpuRow.get(INDEX_C1)
if (!compare(cpuValue, gpuValue)) {
val inputValue = cpuRow.getString(INDEX_C0)
fail(s"Mismatch casting string [$inputValue] " +
s"to $toType. CPU: $cpuValue; GPU: $gpuValue")
}
}

case (Failure(_), Failure(_)) if ansiMode == AnsiExpectFailure =>
// this is fine

case (Success(_), Failure(gpu)) =>
fail(s"Query succeeded on CPU but failed on GPU: $gpu")

case (Failure(cpu), Success(_)) =>
fail(s"Query succeeded on GPU but failed on CPU: $cpu")
}
}

Expand Down Expand Up @@ -827,6 +872,30 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

test("CAST string to integer - sanitize step") {
val testPairs: Seq[(String, String)] = Seq(
(null, null),
("1e4", "1e4"),
("123", "123"),
(".", "0"),
(".2", "0"),
("-.2", "0"),
("0.123", "0"),
("321.123", "321"),
("0.123\r123", null),
(".\r123", null)
)
val inputs = testPairs.map(_._1)
val expected = testPairs.map(_._2)
withResource(ColumnVector.fromStrings(inputs: _*)) { v =>
withResource(ColumnVector.fromStrings(expected: _*)) { expected =>
withResource(GpuCast.sanitizeStringToIntegralType(v, ansiEnabled = false)) { actual =>
CudfTestHelper.assertColumnsAreEqual(expected, actual)
}
}
}
}

test("CAST string to date - sanitize step") {
val testPairs = Seq(
("2001-1", "2001-01"),
Expand Down Expand Up @@ -1290,3 +1359,8 @@ object CastOpSuite {
private val timestampValues: Seq[Long] = Seq(6321706291000L)

}

sealed trait AnsiTestMode;
case object AnsiDisabled extends AnsiTestMode
case object AnsiExpectSuccess extends AnsiTestMode
case object AnsiExpectFailure extends AnsiTestMode

0 comments on commit 65eef79

Please sign in to comment.