From 397515ee817dca27d376e469996119f49b1bde06 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Dec 2021 09:16:29 -0700 Subject: [PATCH 1/4] Make regular expression pattern `[^a]` consistent with Spark with resepect to newline characters Signed-off-by: Andy Grove --- .../src/main/python/string_test.py | 15 +++++++ .../com/nvidia/spark/rapids/RegexParser.scala | 39 ++++++++++++++++++- .../rapids/RegularExpressionParserSuite.scala | 13 +++++++ .../RegularExpressionTranspilerSuite.scala | 8 +++- 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index f89d0a7f96d..dfde5d30467 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -494,6 +494,21 @@ def test_regexp_replace_null_pattern_fallback(): 'RegExpReplace', conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) +def test_regexp_replace_character_set_negated(): + gen = mk_str_gen('[abcd]{0,3}[\r\n]{0,2}[abcd]{0,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'regexp_replace(a, "([^a])|([^b])", "1")', + 'regexp_replace(a, "[^a]", "1")', + 'regexp_replace(a, "([^a]|[\r\n])", "1")', + 'regexp_replace(a, "[^a\r\n]", "1")', + 'regexp_replace(a, "[^a\r]", "1")', + 'regexp_replace(a, "[^a\n]", "1")', + 'regexp_replace(a, "[^\r\n]", "1")', + 'regexp_replace(a, "[^\r]", "1")', + 'regexp_replace(a, "[^\n]", "1")'), + conf={'spark.rapids.sql.expression.RegExpReplace': 'true'}) + def test_rlike(): gen = mk_str_gen('[abcd]{1,3}') assert_gpu_and_cpu_are_equal_collect( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 86ba30d9c83..77a035c37b5 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -150,9 +150,11 @@ class RegexParser(pattern: String) { case '[' => // treat as a literal character and add to the character class characterClass.append(ch) - case ']' if pos > start + 1 => + case ']' if (!characterClass.negated && pos > start + 1) || + (characterClass.negated && pos > start + 2) => // "[]" is not a valid character class // "[]a]" is a valid character class containing the characters "]" and "a" + // "[^]a]" is a valid negated character class containing the characters "]" and "a" characterClassComplete = true case '^' if pos == start + 1 => // Negates the character class, causing it to match a single character not listed in @@ -492,7 +494,40 @@ class CudfRegexTranspiler(replace: Boolean) { } val components: Seq[RegexCharacterClassComponent] = characters .map(x => rewrite(x).asInstanceOf[RegexCharacterClassComponent]) - RegexCharacterClass(negated, ListBuffer(components: _*)) + + if (negated) { + // There are differences between cuDF and Java handling of newlines + // for negative character matches. The expression `[^a]` will match + // `\r` and `\n` in Java but not in cuDF, so we replace `[^a]` with + // `(?:[\r\n]|[^a])`. We also have to take into account whether any + // newline characters are included in the character range. + // + // Examples: + // + // `[^a]` => `(?:[\r\n]|[^a])` + // `[^a\r]` => `(?:[\n]|[^a])` + // `[^a\n]` => `(?:[\r]|[^a])` + // `[^a\r\n]` => `[^a]` + + val newlineCharsInClass = characters.flatMap { + case RegexChar(ch) if ch == '\n' || ch == '\r' => + Seq(ch) + case _ => + Seq.empty + } + val negatedNewlines = Seq('\r', '\n').diff(newlineCharsInClass) + if (negatedNewlines.isEmpty) { + RegexCharacterClass(negated, ListBuffer(components: _*)) + } else { + RegexGroup(capture = false, + RegexChoice( + RegexCharacterClass(negated = false, + characters = ListBuffer(negatedNewlines.map(RegexChar): _*)), + RegexCharacterClass(negated, ListBuffer(components: _*)))) + } + } else { + RegexCharacterClass(negated, ListBuffer(components: _*)) + } case RegexSequence(parts) => if (parts.isEmpty) { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index aac297eea0a..f205a1d5658 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -68,12 +68,25 @@ class RegularExpressionParserSuite extends FunSuite { RegexCharacterRange('A', 'Z')))))) } + test("foo") { + assert(parse("[^]+d]+") === RegexSequence(ListBuffer( + RegexRepetition( + RegexCharacterClass(negated = true, + ListBuffer(RegexChar(']'), RegexChar('+'), RegexChar('d'))), + SimpleQuantifier('+'))))) + } + test("character classes containing ']'") { // "[]a]" is a valid character class containing ']' and 'a' assert(parse("[]a]") === RegexSequence(ListBuffer( RegexCharacterClass(negated = false, ListBuffer(RegexChar(']'), RegexChar('a')))))) + // "[^]a]" is a valid negated character class containing ']' and 'a' + assert(parse("[^]a]") === + RegexSequence(ListBuffer( + RegexCharacterClass(negated = true, + ListBuffer(RegexChar(']'), RegexChar('a')))))) // "[a]]" is a valid character class "[a]" followed by character ']' assert(parse("[a]]") === RegexSequence(ListBuffer( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 240ba7d6eab..777ff1cff25 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -165,7 +165,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { test("transpile character class unescaped range symbol") { val patterns = Seq("a[-b]", "a[+-]", "a[-+]", "a[-]", "a[^-]") - val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", raw"a[^\-]") + val expected = Seq(raw"a[\-b]", raw"a[+\-]", raw"a[\-+]", raw"a[\-]", "a(?:[\r\n]|[^\\-])") val transpiler = new CudfRegexTranspiler(replace=false) val transpiled = patterns.map(transpiler.transpile) assert(transpiled === expected) @@ -248,6 +248,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { assertCpuGpuMatchesRegexpReplace(patterns, inputs) } + test("compare CPU and GPU: regexp replace negated character class") { + val inputs = Seq("a", "b", "a\nb") + val patterns = Seq("[^z]") + assertCpuGpuMatchesRegexpReplace(patterns, inputs) + } + test("compare CPU and GPU: regexp replace fuzz test with limited chars") { // testing with this limited set of characters finds issues much // faster than using the full ASCII set From a0160544cab30f3e096e4a6db29388e3b7358dd7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Dec 2021 14:00:19 -0700 Subject: [PATCH 2/4] update test name --- .../com/nvidia/spark/rapids/RegularExpressionParserSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala index f205a1d5658..4836c8f15cb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionParserSuite.scala @@ -68,7 +68,7 @@ class RegularExpressionParserSuite extends FunSuite { RegexCharacterRange('A', 'Z')))))) } - test("foo") { + test("character class complex example") { assert(parse("[^]+d]+") === RegexSequence(ListBuffer( RegexRepetition( RegexCharacterClass(negated = true, From b1a3ef37b38982fb98d5269c3c067df0a06ab9e4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 3 Dec 2021 17:59:50 -0700 Subject: [PATCH 3/4] Handle edge cases --- .../com/nvidia/spark/rapids/RegexParser.scala | 37 ++++++++++++------- .../RegularExpressionTranspilerSuite.scala | 5 ++- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index 77a035c37b5..b8c50f5c35c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -169,11 +169,9 @@ class RegexParser(pattern: String) { throw new RegexUnsupportedException( s"unexpected EOF while parsing escaped character", Some(pos)) case Some(ch) => - ch match { - case '\\' | '^' | '-' | ']' | '+' => - // escaped metacharacter within character class - characterClass.appendEscaped(consumeExpected(ch)) - } + // typically an escaped metacharacter ('\\' | '^' | '-' | ']' | '+') + // within the character class, but could be any escaped character + characterClass.appendEscaped(consumeExpected(ch)) } case '\u0000' => throw new RegexUnsupportedException( @@ -508,23 +506,36 @@ class CudfRegexTranspiler(replace: Boolean) { // `[^a\r]` => `(?:[\n]|[^a])` // `[^a\n]` => `(?:[\r]|[^a])` // `[^a\r\n]` => `[^a]` + // `[^\r\n]` => `[^\r\n]` - val newlineCharsInClass = characters.flatMap { - case RegexChar(ch) if ch == '\n' || ch == '\r' => - Seq(ch) - case _ => - Seq.empty + val allLinefeed = components.forall { + case RegexChar(ch) => ch == '\n' || ch == '\r' + case RegexEscaped(ch) => ch == 'n' || ch == 'r' + case _ => false } + + val newlineCharsInClass = components.flatMap { + case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch) + case RegexEscaped(ch) if ch == 'n' => Seq('\n') + case RegexEscaped(ch) if ch == 'r' => Seq('\r') + case _ => Seq.empty + }.distinct + val negatedNewlines = Seq('\r', '\n').diff(newlineCharsInClass) - if (negatedNewlines.isEmpty) { - RegexCharacterClass(negated, ListBuffer(components: _*)) + + if (allLinefeed && newlineCharsInClass.length == 2) { + // special case for `[^\r\n]` + RegexCharacterClass(negated = true, ListBuffer(components: _*)) + } else if (negatedNewlines.isEmpty) { + RegexCharacterClass(negated = true, ListBuffer(components: _*)) } else { RegexGroup(capture = false, RegexChoice( RegexCharacterClass(negated = false, characters = ListBuffer(negatedNewlines.map(RegexChar): _*)), - RegexCharacterClass(negated, ListBuffer(components: _*)))) + RegexCharacterClass(negated = true, ListBuffer(components: _*)))) } + } else { RegexCharacterClass(negated, ListBuffer(components: _*)) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala index 777ff1cff25..73295d02ecb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionTranspilerSuite.scala @@ -249,8 +249,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm { } test("compare CPU and GPU: regexp replace negated character class") { - val inputs = Seq("a", "b", "a\nb") - val patterns = Seq("[^z]") + val inputs = Seq("a", "b", "a\nb", "a\r\nb\n\rc\rd") + val patterns = Seq("[^z]", "[^\r]", "[^\n]", "[^\r]", "[^\r\n]", + "[^a\n]", "[^b\r]", "[^bc\r\n]", "[^\\r\\n]") assertCpuGpuMatchesRegexpReplace(patterns, inputs) } From d550d3e2cee48fbae983ef1ddf3c4100ee86819b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 4 Dec 2021 08:41:19 -0700 Subject: [PATCH 4/4] code cleanup --- .../com/nvidia/spark/rapids/RegexParser.scala | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala index b8c50f5c35c..43ac51cb56d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala @@ -169,7 +169,7 @@ class RegexParser(pattern: String) { throw new RegexUnsupportedException( s"unexpected EOF while parsing escaped character", Some(pos)) case Some(ch) => - // typically an escaped metacharacter ('\\' | '^' | '-' | ']' | '+') + // typically an escaped metacharacter ('\\', '^', '-', ']', '+') // within the character class, but could be any escaped character characterClass.appendEscaped(consumeExpected(ch)) } @@ -508,23 +508,19 @@ class CudfRegexTranspiler(replace: Boolean) { // `[^a\r\n]` => `[^a]` // `[^\r\n]` => `[^\r\n]` - val allLinefeed = components.forall { - case RegexChar(ch) => ch == '\n' || ch == '\r' - case RegexEscaped(ch) => ch == 'n' || ch == 'r' - case _ => false - } - - val newlineCharsInClass = components.flatMap { + val linefeedCharsInPattern = components.flatMap { case RegexChar(ch) if ch == '\n' || ch == '\r' => Seq(ch) case RegexEscaped(ch) if ch == 'n' => Seq('\n') case RegexEscaped(ch) if ch == 'r' => Seq('\r') case _ => Seq.empty - }.distinct + } - val negatedNewlines = Seq('\r', '\n').diff(newlineCharsInClass) + val onlyLinefeedChars = components.length == linefeedCharsInPattern.length - if (allLinefeed && newlineCharsInClass.length == 2) { - // special case for `[^\r\n]` + val negatedNewlines = Seq('\r', '\n').diff(linefeedCharsInPattern.distinct) + + if (onlyLinefeedChars && linefeedCharsInPattern.length == 2) { + // special case for `[^\r\n]` and `[^\\r\\n]` RegexCharacterClass(negated = true, ListBuffer(components: _*)) } else if (negatedNewlines.isEmpty) { RegexCharacterClass(negated = true, ListBuffer(components: _*)) @@ -535,7 +531,6 @@ class CudfRegexTranspiler(replace: Boolean) { characters = ListBuffer(negatedNewlines.map(RegexChar): _*)), RegexCharacterClass(negated = true, ListBuffer(components: _*)))) } - } else { RegexCharacterClass(negated, ListBuffer(components: _*)) }