Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow *, ?, and {0,...} variants in StringSplit in non-empty match situations #6959

Merged
merged 2 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def test_split_re_negative_limit():
'split(a, "[^o]", -1)',
'split(a, "[o]{1,2}", -1)',
'split(a, "[bf]", -1)',
'split(a, "b[o]+", -1)',
'split(a, "b[o]*", -1)',
'split(a, "b[o]?", -1)',
'split(a, "[o]", -2)'),
conf=_regexp_conf)

Expand All @@ -58,6 +61,9 @@ def test_split_re_zero_limit():
'split(a, "[^o]", 0)',
'split(a, "[o]{1,2}", 0)',
'split(a, "[bf]", 0)',
'split(a, "f[o]+", 0)',
'split(a, "f[o]*", 0)',
'split(a, "f[o]?", 0)',
'split(a, "[o]", 0)'),
conf=_regexp_conf)

Expand All @@ -72,6 +78,9 @@ def test_split_re_one_limit():
'split(a, "[^o]", 1)',
'split(a, "[o]{1,2}", 1)',
'split(a, "[bf]", 1)',
'split(a, "b[o]+", 1)',
'split(a, "b[o]*", 1)',
'split(a, "b[o]?", 1)',
'split(a, "[o]", 1)'),
conf=_regexp_conf)

Expand All @@ -86,6 +95,9 @@ def test_split_re_positive_limit():
'split(a, "[^o]", 55)',
'split(a, "[o]{1,2}", 999)',
'split(a, "[bf]", 2)',
'split(a, "f[o]+", 2)',
'split(a, "f[o]*", 9)',
'split(a, "f[o]?", 5)',
'split(a, "[o]", 5)'),
conf=_regexp_conf)

Expand All @@ -103,6 +115,9 @@ def test_split_re_no_limit():
'split(a, "[o]")',
'split(a, "^(boo|foo):$")',
'split(a, "[bf]$:")',
'split(a, "b[o]+")',
'split(a, "b[o]*")',
'split(a, "b[o]?")',
'split(a, "b^")',
'split(a, "^[o]")'),
conf=_regexp_conf)
Expand Down Expand Up @@ -154,6 +169,19 @@ def test_split_optimized_no_re_combined():
'split(a, "\\\\}Z")'),
conf=_regexp_conf)

# See https://github.com/NVIDIA/spark-rapids/issues/6958 for issue with zero-width match
@allow_non_gpu('ProjectExec', 'StringSplit')
def test_split_unsupported_fallback():
data_gen = mk_str_gen('([bf]o{0,2}:){1,7}') \
.with_special_case('boo:and:foo')
assert_gpu_sql_fallback_collect(
lambda spark : unary_op_df(spark, data_gen),
'StringSplit',
'string_split_table',
'select ' +
'split(a, "o*"),' +
'split(a, "o?") from string_split_table')

def test_split_regexp_disabled_no_fallback():
conf = { 'spark.rapids.sql.regexp.enabled': 'false' }
data_gen = mk_str_gen('([bf]o{0,2}[.?+\\^$|&_]{1,2}){1,7}') \
Expand Down
71 changes: 47 additions & 24 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,8 @@ object RegexFindMode extends RegexMode
object RegexReplaceMode extends RegexMode
object RegexSplitMode extends RegexMode

sealed class RegexRewriteFlags(val emptyRepetition: Boolean)

/**
* Transpile Java/Spark regular expression to a format that cuDF supports, or throw an exception
* if this is not possible.
Expand Down Expand Up @@ -951,14 +953,34 @@ class CudfRegexTranspiler(mode: RegexMode) {
// ignore
}
}

def isEmptyRepetition(regex: RegexAST): Boolean = {
regex match {
case RegexRepetition(_, term) => term match {
case SimpleQuantifier('*') | SimpleQuantifier('?') => true
case QuantifierFixedLength(0) => true
case QuantifierVariableLength(0, _) => true
case _ => false
}
case RegexGroup(_, term) =>
isEmptyRepetition(term)
case RegexSequence(parts) =>
parts.forall(isEmptyRepetition)
// cuDF does not support repetitions adjacent to a choice (eg. "a*|a"), but if
// we did, we would need to add a `case RegexChoice()` here
case _ => false
}
}

checkUnsupported(regex)

rewrite(regex, replacement, previous)
val flags = new RegexRewriteFlags(isEmptyRepetition(regex))

rewrite(regex, replacement, previous, flags)
}

private def rewrite(regex: RegexAST, replacement: Option[RegexReplacement],
previous: Option[RegexAST]): RegexAST = {
previous: Option[RegexAST], flags: RegexRewriteFlags): RegexAST = {
regex match {

case RegexChar(ch) => ch match {
Expand Down Expand Up @@ -1112,7 +1134,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
case Some(RegexEscaped('Z')) =>
RegexEmpty()
case _ =>
rewrite(RegexChar('$'), replacement, previous)
rewrite(RegexChar('$'), replacement, previous, flags)
}
case 's' | 'S' =>
// whitespace characters
Expand Down Expand Up @@ -1184,7 +1206,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
val components: Seq[RegexCharacterClassComponent] = characters
.map {
case r @ RegexChar(ch) if "^$.".contains(ch) => r
case ch => rewrite(ch, replacement, None) match {
case ch => rewrite(ch, replacement, None, flags) match {
case valid: RegexCharacterClassComponent => valid
case _ =>
// this can happen when a character class contains a meta-sequence such as
Expand Down Expand Up @@ -1282,7 +1304,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
case RegexChar(ch) if "\r\u0085\u2028\u2029".contains(ch) =>
r(j) = RegexSequence(
ListBuffer(
rewrite(part, replacement, None),
rewrite(part, replacement, None, flags),
RegexSequence(ListBuffer(
RegexRepetition(lineTerminatorMatcher(Set(ch), true, false),
SimpleQuantifier('?')), RegexChar('$')))))
Expand All @@ -1295,10 +1317,10 @@ class CudfRegexTranspiler(mode: RegexMode) {
throw new RegexUnsupportedException(
"Regex sequences with \\b or \\B not supported around $", part.position)
case _ =>
r.append(rewrite(part, replacement, last))
r.append(rewrite(part, replacement, last, flags))
}
case _ =>
r.append(rewrite(part, replacement, last))
r.append(rewrite(part, replacement, last, flags))
}
r.last match {
case RegexEmpty() =>
Expand All @@ -1309,20 +1331,21 @@ class CudfRegexTranspiler(mode: RegexMode) {
})._1)

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (_, SimpleQuantifier(ch)) if mode == RegexSplitMode && "?*".contains(ch) =>
case (_, SimpleQuantifier(ch)) if mode == RegexSplitMode
&& flags.emptyRepetition && "?*".contains(ch) =>
// example: pattern " ?", input "] b[", replace with "X":
// java: X]XXbX[X
// cuDF: XXXX] b[
// see https://github.com/NVIDIA/spark-rapids/issues/4884
throw new RegexUnsupportedException(
"regexp_split on GPU does not support repetition with ? or * consistently with Spark",
"regexp_split on GPU does not support empty match repetition consistently with Spark",
quantifier.position)

case (_, QuantifierVariableLength(0, _)) if mode == RegexSplitMode =>
case (_, QuantifierVariableLength(0, _)) if mode == RegexSplitMode
&& flags.emptyRepetition =>
// see https://github.com/NVIDIA/spark-rapids/issues/4884
throw new RegexUnsupportedException(
"regexp_split on GPU does not support repetition with {0,} or {0,n} " +
"consistently with Spark",
"regexp_split on GPU does not support empty match repetition consistently with Spark",
quantifier.position)

case (_, QuantifierVariableLength(0, Some(0))) if mode != RegexFindMode =>
Expand All @@ -1344,7 +1367,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
// (\A)+ can be transpiled to (\A) (dropping the repetition)
// we use rewrite(...) here to handle logic regarding modes
// (\A is not supported in RegexSplitMode)
RegexGroup(capture, rewrite(term, replacement, previous))
RegexGroup(capture, rewrite(term, replacement, previous, flags))
// NOTE: (\A)* can be transpiled to (\A)?
// however, (\A)? is not supported in libcudf yet
case _ =>
Expand All @@ -1362,7 +1385,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
// (\A){1,} can be transpiled to (\A) (dropping the repetition)
// we use rewrite(...) here to handle logic regarding modes
// (\A is not supported in RegexSplitMode)
RegexGroup(capture, rewrite(term, replacement, previous))
RegexGroup(capture, rewrite(term, replacement, previous, flags))
// NOTE: (\A)* can be transpiled to (\A)?
// however, (\A)? is not supported in libcudf yet
case _ =>
Expand All @@ -1380,7 +1403,7 @@ class CudfRegexTranspiler(mode: RegexMode) {
// (\A){1,} can be transpiled to (\A) (dropping the repetition)
// we use rewrite(...) here to handle logic regarding modes
// (\A is not supported in RegexSplitMode)
RegexGroup(capture, rewrite(term, replacement, previous))
RegexGroup(capture, rewrite(term, replacement, previous, flags))
// NOTE: (\A)* can be transpiled to (\A)?
// however, (\A)? is not supported in libcudf yet
case _ =>
Expand All @@ -1394,25 +1417,25 @@ class CudfRegexTranspiler(mode: RegexMode) {
throw new RegexUnsupportedException(
s"cuDF does not support repetition of: ${term.toRegexString}", term.position)
}
RegexRepetition(rewrite(base, replacement, None), quantifier)
RegexRepetition(rewrite(base, replacement, None, flags), quantifier)
case (RegexEscaped(ch), SimpleQuantifier('+')) if "AZ".contains(ch) =>
// \A+ can be transpiled to \A (dropping the repetition)
// \Z+ can be transpiled to \Z (dropping the repetition)
// we use rewrite(...) here to handle logic regarding modes
// (\A and \Z are not supported in RegexSplitMode)
rewrite(base, replacement, previous)
rewrite(base, replacement, previous, flags)
// NOTE: \A* can be transpiled to \A?
// however, \A? is not supported in libcudf yet
case (RegexEscaped(ch), QuantifierFixedLength(n)) if n > 0 && "AZ".contains(ch) =>
// \A{2} can be transpiled to \A (dropping the repetition)
// \Z{2} can be transpiled to \Z (dropping the repetition)
rewrite(base, replacement, previous)
rewrite(base, replacement, previous, flags)
case (RegexEscaped(ch), QuantifierVariableLength(n,_)) if n > 0 && "AZ".contains(ch) =>
// \A{1,5} can be transpiled to \A (dropping the repetition)
// \Z{1,} can be transpiled to \Z (dropping the repetition)
rewrite(base, replacement, previous)
rewrite(base, replacement, previous, flags)
case _ if isSupportedRepetitionBase(base) =>
RegexRepetition(rewrite(base, replacement, None), quantifier)
RegexRepetition(rewrite(base, replacement, None, flags), quantifier)
case (RegexRepetition(_, SimpleQuantifier('*')), SimpleQuantifier('+')) =>
throw new RegexUnsupportedException("Possessive quantifier *+ not supported",
quantifier.position)
Expand All @@ -1426,8 +1449,8 @@ class CudfRegexTranspiler(mode: RegexMode) {
}

case RegexChoice(l, r) =>
val ll = rewrite(l, replacement, None)
val rr = rewrite(r, replacement, None)
val ll = rewrite(l, replacement, None, flags)
val rr = rewrite(r, replacement, None, flags)

// cuDF does not support repetition on one side of a choice, such as "a*|a"
if (isRepetition(ll)) {
Expand Down Expand Up @@ -1488,9 +1511,9 @@ class CudfRegexTranspiler(mode: RegexMode) {
case _ =>
}
}
RegexGroup(capture, rewrite(term, replacement, None))
RegexGroup(capture, rewrite(term, replacement, None, flags))
case _ =>
RegexGroup(capture, rewrite(term, replacement, None))
RegexGroup(capture, rewrite(term, replacement, None, flags))
}

case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,27 +478,6 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuMatchesRegexpReplace(patterns, inputs)
}

test("regexp_split - character class repetition - ? and * - fall back to CPU") {
// see https://github.com/NVIDIA/spark-rapids/issues/4884
val patterns = Seq(raw"[1a-zA-Z]?", raw"[1a-zA-Z]*")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexSplitMode,
"regexp_split on GPU does not support repetition with ? or * " +
"consistently with Spark"
)
)
}

test("regexp_split - fall back to CPU for {0,n}, or {0,}") {
// see https://github.com/NVIDIA/spark-rapids/issues/4884
val patterns = Seq("a{0,}", raw"\02{0,}", "a{0,2}", raw"\02{0,10}")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexSplitMode,
"regexp_split on GPU does not support repetition with {0,} or {0,n} " +
"consistently with Spark")
)
}

test("compare CPU and GPU: regexp find fuzz test with limited chars") {
// testing with this limited set of characters finds issues much
// faster than using the full ASCII set
Expand Down Expand Up @@ -668,6 +647,43 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
}
}

test("regexp_split - character class repetition - ? and *") {
val patterns = Set(raw"[a-z][0-9]?", raw"[a-z][0-9]*")
val data = Seq("a", "aa", "a1a1", "a1b2", "a1b")
for (limit <- Seq(Integer.MIN_VALUE, -2, -1)) {
doStringSplitTest(patterns, data, limit)
}
}

test("regexp_split - repetition with {0,n}, or {0,}") {
// see https://github.com/NVIDIA/spark-rapids/issues/6958
val patterns = Set("ba{0,}", raw"a\02{0,}", "ba{0,2}", raw"b\02{0,10}")
val data = Seq("abaa", "baba", "ba\u0002b", "ab\u0002b\u0002a")
for (limit <- Seq(Integer.MIN_VALUE, -2, -1)) {
doStringSplitTest(patterns, data, limit)
}
}

test("regexp_split - character class repetition - ? and * - fall back to CPU") {
// see https://github.com/NVIDIA/spark-rapids/issues/6958
val patterns = Seq(raw"[1a-zA-Z]?", raw"[1a-zA-Z]*")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexSplitMode,
"regexp_split on GPU does not support empty match repetition consistently with Spark"
)
)
}

test("regexp_split - fall back to CPU for {0,n}, or {0,}") {
// see https://github.com/NVIDIA/spark-rapids/issues/6958
val patterns = Seq("a{0,}", raw"\02{0,}", "a{0,2}", raw"\02{0,10}")
patterns.foreach(pattern =>
assertUnsupported(pattern, RegexSplitMode,
"regexp_split on GPU does not support empty match repetition consistently with Spark"
)
)
}

test("string split - limit < 0") {
val patterns = Set("[^A-Z]+", "[0-9]+", ":", "o", "[:o]")
val data = Seq("abc", "123", "1\n2\n3\n", "boo:and:foo")
Expand Down