Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix regression in cast from string to float that caused signed NaN to be considered valid #4259

Merged
merged 1 commit into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,26 +173,30 @@ object GpuCast extends Arm {
input: ColumnVector,
ansiEnabled: Boolean): ColumnVector = {

// This regex gets applied after the transformation to normalize use of Inf and is
// just strict enough to filter out known edge cases that would result in incorrect
// values. We further filter out invalid values using the cuDF isFloat method.
// This regex is just strict enough to filter out known edge cases that would result
// in incorrect values. We further filter out invalid values using the cuDF isFloat method.
val VALID_FLOAT_REGEX =
"^" + // start of line
"[+\\-]?" + // optional + or - at start of string
"(" +
"^" + // start of line
"[Nn][Aa][Nn]" + // NaN
"|" +
"(" +
"[+\\-]?" + // optional sign preceding Inf or numeric
"(" +
"([0-9]+)|" + // digits, OR
"([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR
"([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing
"([Ii][Nn][Ff]" + // Inf, Infinity
"([Ii][Nn][Ii][Tt][Yy])?)" +
"|" +
"(" +
"(" +
"([0-9]+)|" + // digits, OR
"([0-9]*\\.[0-9]+)|" + // decimal with optional leading and mandatory trailing, OR
"([0-9]+\\.[0-9]*)" + // decimal with mandatory leading and optional trailing
")" +
"([eE][+\\-]?[0-9]+)?" + // exponent
"[fFdD]?" + // floating-point designator
")" +
")" +
"([eE][+\\-]?[0-9]+)?" + // exponent
"[fFdD]?" + // floating-point designator
")" +
"|Inf" + // Infinity
"|[nN][aA][nN]" + // NaN
")" +
"$" // end of line
"$" // end of line

withResource(input.lstrip()) { stripped =>
withResource(GpuScalar.from(null, DataTypes.StringType)) { nullString =>
Expand All @@ -203,39 +207,26 @@ object GpuCast extends Arm {
_.ifElse(nullString, stripped)
}
}
// replace all possible versions of "Inf" and "Infinity" with "Inf"
val inf = withResource(withoutWhitespace) { _ =>
withoutWhitespace.stringReplaceWithBackrefs(
"(?:[iI][nN][fF])" + "(?:[iI][nN][iI][tT][yY])?", "Inf")
}
// replace "+Inf" with "Inf" because cuDF only supports "Inf" and "-Inf"
val infWithoutPlus = withResource(inf) { _ =>
withResource(GpuScalar.from("+Inf", DataTypes.StringType)) { search =>
withResource(GpuScalar.from("Inf", DataTypes.StringType)) { replace =>
inf.stringReplace(search, replace)
}
}
}
// filter out any strings that are not valid floating point numbers according
// to the regex pattern
val floatOrNull = withResource(infWithoutPlus) { _ =>
withResource(infWithoutPlus.matchesRe(VALID_FLOAT_REGEX)) { isFloat =>
val floatOrNull = withResource(withoutWhitespace) { _ =>
withResource(withoutWhitespace.matchesRe(VALID_FLOAT_REGEX)) { isFloat =>
if (ansiEnabled) {
withResource(isFloat.all()) { allMatch =>
// Check that all non-null values are valid floats.
if (allMatch.isValid && !allMatch.getBoolean) {
throw new NumberFormatException(GpuCast.INVALID_NUMBER_MSG)
}
infWithoutPlus.incRefCount()
withoutWhitespace.incRefCount()
}
} else {
isFloat.ifElse(infWithoutPlus, nullString)
isFloat.ifElse(withoutWhitespace, nullString)
}
}
}
// strip floating-point designator 'f' or 'd' but don't strip the 'f' from 'Inf'
withResource(floatOrNull) {
_.stringReplaceWithBackrefs("([^n])[fFdD]$", "\\1")
_.stringReplaceWithBackrefs("([^nN])[fFdD]$", "\\1")
}
}
}
Expand Down Expand Up @@ -831,13 +822,11 @@ object GpuCast extends Arm {
input: ColumnVector,
ansiEnabled: Boolean,
dType: DType): ColumnVector = {
// 1. convert the different infinities to "Inf"/"-Inf" which is the only variation cudf
// understands
// 2. identify the nans
// 3. identify the floats. "nan", "null" and letters are not considered floats
// 4. if ansi is enabled we want to throw an exception if the string is neither float nor nan
// 5. convert everything that's not floats to null
// 6. set the indices where we originally had nans to Float.NaN
// 1. identify the nans
// 2. identify the floats. "null" and letters are not considered floats
// 3. if ansi is enabled we want to throw an exception if the string is neither float nor nan
// 4. convert everything that's not floats to null
// 5. set the indices where we originally had nans to Float.NaN
//
// NOTE Limitation: "1.7976931348623159E308" and "-1.7976931348623159E308" are not considered
// Inf even though Spark does
Expand Down
12 changes: 6 additions & 6 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
testCastStringTo(DataTypes.FloatType, generateRandomStrings(Some(NUMERIC_CHARS)))
}

// https://github.com/NVIDIA/spark-rapids/issues/4246
ignore("Cast from string to float using hand-picked values") {
test("Cast from string to float using hand-picked values") {
testCastStringTo(DataTypes.FloatType, Seq(".", "e", "Infinity", "+Infinity", "-Infinity",
"+nAn", "-naN", "Nan", "5f", "1.2f", "\riNf", null))
}
Expand Down Expand Up @@ -933,10 +932,11 @@ class CastOpSuite extends GpuExpressionTestSuite {

test("CAST string to float - sanitize step") {
val testPairs = Seq(
("\tinf", "Inf"),
("\t+InFinITy", "Inf"),
("\tInFinITy", "Inf"),
("\t-InFinITy", "-Inf"),
("\tinf", "inf"),
("\riNf", "iNf"),
("\t+InFinITy", "+InFinITy"),
("\tInFinITy", "InFinITy"),
("\t-InFinITy", "-InFinITy"),
("\t61f", "61"),
(".8E4f", ".8E4")
)
Expand Down