Skip to content

Commit

Permalink
Support regular expressions containing \s and \S (#5089)
Browse files Browse the repository at this point in the history
Closes #4528

Signed-off-by: sperlingxx lovedreamf@gmail.com

Support regular expressions which contain \s and \S through transpiling `\s` to `[ \f\t\r\n\\x0b]`.
  • Loading branch information
sperlingxx authored Apr 2, 2022
1 parent 00d4340 commit 17984c0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 7 deletions.
17 changes: 17 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,23 @@ def test_regexp_extract_idx_0():
'regexp_extract(a, "^([a-d]*)[0-9]*([a-d]*)\\z", 0)'),
conf=_regexp_conf)

def test_regexp_whitespace():
gen = mk_str_gen('\u001e[abcd]\t\n{1,3} [0-9]\n {1,3}\x0b\t[abcd]\r\f[0-9]{0,10}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'rlike(a, "\\s{2}")',
'rlike(a, "\\s{3}")',
'rlike(a, "[abcd]+\\s+[0-9]+")',
'rlike(a, "\\S{3}")',
'rlike(a, "[abcd]+\\s+\\S{2,3}")',
'regexp_extract(a, "([a-d]+)([0-9\\s]+)([a-d]+)", 2)',
'regexp_extract(a, "([a-d]+)(\\S+)([0-9]+)", 2)',
'regexp_extract(a, "([a-d]+)(\\S+)([0-9]+)", 3)',
'regexp_replace(a, "(\\s+)", "@")',
'regexp_replace(a, "(\\S+)", "#")',
),
conf=_regexp_conf)

def test_rlike():
gen = mk_str_gen('[abcd]{1,3}')
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,6 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'b' | 'B' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4517
throw new RegexUnsupportedException("word boundaries are not supported")
case 's' | 'S' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4528
throw new RegexUnsupportedException("whitespace classes are not supported")
case 'A' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\A is not supported in split mode")
case 'Z' if mode == RegexSplitMode =>
Expand All @@ -604,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'Z' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4532
throw new RegexUnsupportedException("string anchor \\Z is not supported")
case 's' | 'S' =>
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u000b'))
chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case _ =>
regex
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,6 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'b' | 'B' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4517
throw new RegexUnsupportedException("word boundaries are not supported")
case 's' | 'S' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4528
throw new RegexUnsupportedException("whitespace classes are not supported")
case 'A' if mode == RegexSplitMode =>
throw new RegexUnsupportedException("string anchor \\A is not supported in split mode")
case 'Z' if mode == RegexSplitMode =>
Expand All @@ -604,6 +601,11 @@ class CudfRegexTranspiler(mode: RegexMode) {
case 'Z' =>
// see https://github.com/NVIDIA/spark-rapids/issues/4532
throw new RegexUnsupportedException("string anchor \\Z is not supported")
case 's' | 'S' =>
val chars: ListBuffer[RegexCharacterClassComponent] = ListBuffer(
RegexChar(' '), RegexChar('\u000b'))
chars ++= Seq('n', 't', 'r', 'f').map(RegexEscaped)
RegexCharacterClass(negated = ch.isUpper, characters = chars)
case _ =>
regex
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
}

test("whitespace boundaries - replace") {
assertCpuGpuMatchesRegexpReplace(
Seq("\\s", "\\S"),
Seq("\u001eTEST"))
}

test("match literal $ - find") {
assertCpuGpuMatchesRegexpFind(
Seq("\\$", "\\$[0-9]"),
Expand Down Expand Up @@ -292,7 +298,7 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpFind(patterns, inputs)
}

private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\nBsdwSDWzZ"
private val REGEXP_LIMITED_CHARS_COMMON = "|()[]{},.^$*+?abc123x\\ \t\r\n\f\u000bBsdwSDWzZ"

private val REGEXP_LIMITED_CHARS_FIND = REGEXP_LIMITED_CHARS_COMMON

Expand Down

0 comments on commit 17984c0

Please sign in to comment.