Skip to content

Commit

Permalink
Regex: transpile . to [^\r\n] in cuDF (#4064)
Browse files Browse the repository at this point in the history
* rewrite dot to cuDF regex that will not match CR

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* remove unused import
  • Loading branch information
andygrove authored Nov 10, 2021
1 parent bd961dd commit dfc3963
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 62 deletions.
1 change: 0 additions & 1 deletion docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ is provided.
The GPU implementation of `RLike` has the following known issues where behavior is not consistent with Apache Spark and
this expression is disabled by default. It can be enabled setting `spark.rapids.sql.expression.RLike=true`.

- `.` matches `\r` on the GPU but not on the CPU ([cuDF issue #9619](https://github.com/rapidsai/cudf/issues/9619))
- `$` does not match the end of string if the string ends with a line-terminator
([cuDF issue #9620](https://github.com/rapidsai/cudf/issues/9620))

Expand Down
137 changes: 87 additions & 50 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RegexParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -398,36 +398,95 @@ class CudfRegexTranspiler {
// parse the source regular expression
val regex = new RegexParser(pattern).parse()
// validate that the regex is supported by cuDF
validate(regex)
val cudfRegex = rewrite(regex)
// write out to regex string, performing minor transformations
// such as adding additional escaping
regex.toRegexString
cudfRegex.toRegexString
}

private def validate(regex: RegexAST): Unit = {
private def rewrite(regex: RegexAST): RegexAST = {
regex match {

case RegexChar(ch) => ch match {
case '.' =>
// workaround for https://github.com/rapidsai/cudf/issues/9619
RegexCharacterClass(negated = true, ListBuffer(RegexChar('\r'), RegexChar('\n')))
case _ =>
regex
}

case RegexOctalChar(_) =>
// cuDF produced different results compared to Spark in some cases
// example: "a\141|.$"
throw new RegexUnsupportedException(
s"cuDF does not support octal digits consistently with Spark")
case RegexEscaped(ch) if ch == 'b' || ch == 'B' =>
// example: "a\Bb"
// this needs further analysis to determine why words boundaries behave
// differently between Java and cuDF
throw new RegexUnsupportedException("word boundaries are not supported")
case RegexSequence(parts) if parts.isEmpty =>
// examples: "", "()", "a|", "|b"
throw new RegexUnsupportedException("empty sequence not supported")
case RegexRepetition(RegexEscaped(_), _) =>
// example: "\B?"
throw new RegexUnsupportedException(nothingToRepeat)
case RegexRepetition(RegexChar(a), _) if "$^".contains(a) =>
// example: "$*"
throw new RegexUnsupportedException(nothingToRepeat)
case RegexRepetition(RegexRepetition(_, _), _) =>
// example: "a*+"
throw new RegexUnsupportedException(nothingToRepeat)

case RegexHexDigit(_) =>
regex

case RegexEscaped(ch) => ch match {
case 'b' | 'B' =>
// example: "a\Bb"
// this needs further analysis to determine why words boundaries behave
// differently between Java and cuDF
throw new RegexUnsupportedException("word boundaries are not supported")
case _ =>
regex
}

case RegexCharacterRange(_, _) =>
regex

case RegexCharacterClass(negated, characters) =>
characters.foreach {
case RegexChar(ch) if ch == '[' || ch == ']' =>
// examples:
// - "[a[]" should match the literal characters "a" and "["
// - "[a-b[c-d]]" is supported by Java but not cuDF
throw new RegexUnsupportedException("nested character classes are not supported")
case _ =>

}
val components: Seq[RegexCharacterClassComponent] = characters
.map(ch => rewrite(ch).asInstanceOf[RegexCharacterClassComponent])
RegexCharacterClass(negated, ListBuffer(components: _*))

case RegexSequence(parts) =>
if (parts.isEmpty) {
// examples: "", "()", "a|", "|b"
throw new RegexUnsupportedException("empty sequence not supported")
}
if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) {
// examples: "a|", "|b"
throw new RegexUnsupportedException(nothingToRepeat)
}
if (isRegexChar(parts.head, '{')) {
// example: "{"
// cuDF would treat this as a quantifier even though in this
// context (being at the start of a sequence) it is not quantifying anything
// note that we could choose to escape this in the transpiler rather than
// falling back to CPU
throw new RegexUnsupportedException(nothingToRepeat)
}
RegexSequence(parts.map(rewrite))

case RegexRepetition(base, quantifier) => (base, quantifier) match {
case (RegexEscaped(_), _) =>
// example: "\B?"
throw new RegexUnsupportedException(nothingToRepeat)

case (RegexChar(a), _) if "$^".contains(a) =>
// example: "$*"
throw new RegexUnsupportedException(nothingToRepeat)

case (RegexRepetition(_, _), _) =>
// example: "a*+"
throw new RegexUnsupportedException(nothingToRepeat)

case _ =>
RegexRepetition(rewrite(base), quantifier)
}

case RegexChoice(l, r) =>
(l, r) match {
// check for empty left-hand side caused by ^ or $ or a repetition
Expand All @@ -443,6 +502,7 @@ class CudfRegexTranspiler {
// example: "a*|a"
throw new RegexUnsupportedException(nothingToRepeat)
case _ =>
RegexChoice(rewrite(l), rewrite(r))
}
// check for empty right-hand side caused by ^ or $
case (_, RegexSequence(b)) =>
Expand All @@ -454,41 +514,23 @@ class CudfRegexTranspiler {
// example: "a|$"
throw new RegexUnsupportedException(nothingToRepeat)
case _ =>
RegexChoice(rewrite(l), rewrite(r))
}
case (RegexRepetition(_, _), _) =>
// example: "a*|a"
throw new RegexUnsupportedException(nothingToRepeat)
case _ =>
RegexChoice(rewrite(l), rewrite(r))
}

case RegexSequence(parts) =>
if (isRegexChar(parts.head, '|') || isRegexChar(parts.last, '|')) {
// examples: "a|", "|b"
case RegexGroup(term) => term match {
case RegexSequence(ListBuffer(RegexChar(ch))) if "?*+".contains(ch) =>
throw new RegexUnsupportedException(nothingToRepeat)
}
if (isRegexChar(parts.head, '{')) {
// example: "{"
// cuDF would treat this as a quantifier even though in this
// context (being at the start of a sequence) it is not quantifying anything
// note that we could choose to escape this in the transpiler rather than
// falling back to CPU
throw new RegexUnsupportedException(nothingToRepeat)
}
case RegexCharacterClass(_, characters) =>
characters.foreach {
case RegexChar(ch) if ch == '[' || ch == ']' =>
// examples:
// - "[a[]" should match the literal characters "a" and "["
// - "[a-b[c-d]]" is supported by Java but not cuDF
throw new RegexUnsupportedException("nested character classes are not supported")
case _ =>
}
case _ =>
RegexGroup(rewrite(term))
}

case _ =>
}

// walk down the tree and validate children
regex.children().foreach(validate)
}

private def isRegexChar(expr: RegexAST, value: Char): Boolean = expr match {
Expand Down Expand Up @@ -567,11 +609,6 @@ sealed case class RegexChar(a: Char) extends RegexCharacterClassComponent {
override def toRegexString: String = s"$a"
}

sealed case class RegexUnicodeChar(a: String) extends RegexCharacterClassComponent {
override def children(): Seq[RegexAST] = Seq.empty
override def toRegexString: String = s"\\u$a"
}

sealed case class RegexEscaped(a: Char) extends RegexCharacterClassComponent{
override def children(): Seq[RegexAST] = Seq.empty
override def toRegexString: String = s"\\$a"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuContainsMatches(Seq(pattern), Seq("2", "2\n", "2\r", "\2\r\n"))
}

ignore("known issue - dot matches CR on GPU but not on CPU") {
test("dot matches CR on GPU but not on CPU") {
// see https://github.com/rapidsai/cudf/issues/9619
val pattern = "1."
// '.' matches '\r' on GPU but not on CPU
assertCpuGpuContainsMatches(Seq(pattern), Seq("1\r2", "1\n2", "1\r\n2"))
}

Expand All @@ -127,6 +126,10 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assert(transpiled === expected)
}

test("transpile dot") {
assert(new CudfRegexTranspiler().transpile(".+") === "[^\r\n]+")
}

test("transpile complex regex 1") {
val VALID_FLOAT_REGEX =
"^" + // start of line
Expand Down Expand Up @@ -155,8 +158,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
"[0-9]{2}:[0-9]{2}:[0-9]{2})" +
"(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$"

// input and output should be identical
doTranspileTest(TIMESTAMP_TRUNCATE_REGEX, TIMESTAMP_TRUNCATE_REGEX)
// input and output should be identical except for `.` being replaced with `[^\r\n]`
doTranspileTest(TIMESTAMP_TRUNCATE_REGEX,
TIMESTAMP_TRUNCATE_REGEX.replaceAll("\\.", "[^\r\n]"))

}

Expand Down Expand Up @@ -187,21 +191,20 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
test("compare CPU and GPU: fuzz test with limited chars") {
// testing with this limited set of characters finds issues much
// faster than using the full ASCII set
// CR and LF has been excluded due to known issues
doFuzzTest(Some("|()[]{},.^$*+?abc123x\\ \tB"))
// LF has been excluded due to known issues
doFuzzTest(Some("|()[]{},.^$*+?abc123x\\ \r\tB"))
}

test("compare CPU and GPU: fuzz test printable ASCII chars plus TAB") {
test("compare CPU and GPU: fuzz test printable ASCII chars plus CR and TAB") {
// CR and LF has been excluded due to known issues
doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\t"))
doFuzzTest(Some((0x20 to 0x7F).map(_.toChar) + "\r\t"))
}

test("compare CPU and GPU: fuzz test ASCII chars") {
// CR and LF has been excluded due to known issues
// LF has been excluded due to known issues
val chars = (0x00 to 0x7F)
.map(_.toChar)
.filterNot(_ == '\n')
.filterNot(_ == '\r')
doFuzzTest(Some(chars.mkString))
}

Expand All @@ -215,7 +218,9 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
val r = new EnhancedRandom(new Random(seed = 0L),
options = FuzzerOptions(validChars, maxStringLen = 12))

val data = Range(0, 1000).map(_ => r.nextString())
val data = Range(0, 1000)
// remove trailing newlines as workaround for https://github.com/rapidsai/cudf/issues/9620
.map(_ => removeTrailingNewlines(r.nextString()))

// generate patterns that are valid on both CPU and GPU
val patterns = ListBuffer[String]()
Expand All @@ -229,6 +234,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
assertCpuGpuContainsMatches(patterns, data)
}

private def removeTrailingNewlines(input: String): String = {
var s = input
while (s.endsWith("\r") || s.endsWith("\n")) {
s = s.substring(0, s.length - 1)
}
s
}

private def assertCpuGpuContainsMatches(javaPatterns: Seq[String], input: Seq[String]) = {
for (javaPattern <- javaPatterns) {
val cpu = cpuContains(javaPattern, input)
Expand Down

0 comments on commit dfc3963

Please sign in to comment.