Skip to content

Commit

Permalink
regexp_replace with back-references should fall back to CPU (#4556)
Browse files Browse the repository at this point in the history
* Fall back to CPU if regexp_replace replacement expression contains back-references

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Update compatibility guide

* Add support for escape characters in replacement strings

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Use escaped characters in all regexp_replace tests

* Fix resource leak and add more escaped characters to test
  • Loading branch information
andygrove authored Jan 20, 2022
1 parent 8f9afa6 commit f89b19e
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ Here are some examples of regular expression patterns that are not supported on
- Empty groups: `()`
- Regular expressions containing null characters (unless the pattern is a simple literal string)
- Hex and octal digits
- `regexp_replace` does not support back-references

Work is ongoing to increase the range of regular expressions that can run on the GPU.

Expand Down
25 changes: 25 additions & 0 deletions integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,31 @@ def test_re_replace():
'REGEXP_REPLACE(a, "TEST", NULL)'),
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

@allow_non_gpu('ProjectExec', 'RegExpReplace')
def test_re_replace_backrefs():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "(TEST)", "[$0]")',
'REGEXP_REPLACE(a, "(TEST)", "[$1]")'),
'RegExpReplace',
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

def test_re_replace_backrefs_escaped():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "(TEST)", "[\\\\$0]")',
'REGEXP_REPLACE(a, "(TEST)", "[\\\\$1]")'),
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

def test_re_replace_escaped():
gen = mk_str_gen('.{0,5}TEST[\ud720 A]{0,5}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "[A-Z]+", "\\\\A\\A\\\\t\\\\r\\\\n\\t\\r\\n")'),
conf={'spark.rapids.sql.expression.RegExpReplace': 'true'})

def test_re_replace_null():
gen = mk_str_gen('[\u0000 ]{0,2}TE[\u0000 ]{0,2}ST[\u0000 ]{0,2}')\
.with_special_case("\u0000")\
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexUnsupportedException, TernaryExprMeta}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -48,6 +49,15 @@ class GpuRegExpReplaceMeta(
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}
}

override def convertToGpu(
Expand All @@ -57,8 +67,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(lhs, regexp, rep)
} else {
GpuRegExpReplace(lhs, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends TernaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -48,6 +49,15 @@ class GpuRegExpReplaceMeta(
case _ =>
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}
}

override def convertToGpu(
Expand All @@ -57,8 +67,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(lhs, regexp, rep)
} else {
GpuRegExpReplace(lhs, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -49,6 +50,15 @@ class GpuRegExpReplaceMeta(
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}

GpuOverrides.extractLit(expr.pos).foreach { lit =>
if (lit.value.asInstanceOf[Int] != 1) {
willNotWorkOnGpu("only a search starting position of 1 is supported")
Expand All @@ -68,8 +78,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(subject, regexp, rep)
} else {
GpuRegExpReplace(subject, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2
import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexUnsupportedException}

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuStringReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpUtils, GpuStringReplace}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -30,6 +30,7 @@ class GpuRegExpReplaceMeta(
extends QuaternaryExprMeta[RegExpReplace](expr, conf, parent, rule) {

private var pattern: Option[String] = None
private var replacement: Option[String] = None

override def tagExprForGpu(): Unit = {
expr.regexp match {
Expand All @@ -49,6 +50,15 @@ class GpuRegExpReplaceMeta(
willNotWorkOnGpu(s"only non-null literal strings are supported on GPU")
}

expr.rep match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
if (GpuRegExpUtils.containsBackrefs(s.toString)) {
willNotWorkOnGpu("regexp_replace with back-references is not supported")
}
replacement = Some(GpuRegExpUtils.unescapeReplaceString(s.toString))
case _ =>
}

GpuOverrides.extractLit(expr.pos).foreach { lit =>
if (lit.value.asInstanceOf[Int] != 1) {
willNotWorkOnGpu("only a search starting position of 1 is supported")
Expand All @@ -68,8 +78,12 @@ class GpuRegExpReplaceMeta(
if (GpuOverrides.isSupportedStringReplacePattern(expr.regexp)) {
GpuStringReplace(subject, regexp, rep)
} else {
GpuRegExpReplace(subject, regexp, rep, pattern.getOrElse(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern")))
(pattern, replacement) match {
case (Some(cudfPattern), Some(cudfReplacement)) =>
GpuRegExpReplace(lhs, regexp, rep, cudfPattern, cudfReplacement)
case _ =>
throw new IllegalStateException("Expression has not been tagged correctly")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -744,6 +744,48 @@ case class GpuLike(left: Expression, right: Expression, escapeChar: Char)
}
}

object GpuRegExpUtils {

/**
* Determine if a string contains back-references such as `$1` but ignoring
* if preceded by escape character.
*/
def containsBackrefs(s: String): Boolean = {
var i = 0
while (i < s.length) {
if (s.charAt(i) == '\\') {
i += 2
} else {
if (s.charAt(i) == '$' && i+1 < s.length) {
if (s.charAt(i+1).isDigit) {
return true
}
}
i += 1
}
}
false
}

/**
* We need to remove escape characters in the regexp_replace
* replacement string before passing to cuDF.
*/
def unescapeReplaceString(s: String): String = {
val b = new StringBuilder
var i = 0
while (i < s.length) {
if (s.charAt(i) == '\\' && i+1 < s.length) {
i += 1
}
b.append(s.charAt(i))
i += 1
}
b.toString
}

}

class GpuRLikeMeta(
expr: RLike,
conf: RapidsConf,
Expand Down Expand Up @@ -854,7 +896,8 @@ case class GpuRegExpReplace(
srcExpr: Expression,
searchExpr: Expression,
replaceExpr: Expression,
cudfRegexPattern: String)
cudfRegexPattern: String,
cudfReplacementString: String)
extends GpuRegExpTernaryBase with ImplicitCastInputTypes {

override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
Expand All @@ -863,15 +906,18 @@ case class GpuRegExpReplace(
override def second: Expression = searchExpr
override def third: Expression = replaceExpr

def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String) = {
this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern)
def this(srcExpr: Expression, searchExpr: Expression, cudfRegexPattern: String,
cudfReplacementString: String) = {
this(srcExpr, searchExpr, GpuLiteral("", StringType), cudfRegexPattern, cudfReplacementString)
}

override def doColumnar(
strExpr: GpuColumnVector,
searchExpr: GpuScalar,
replaceExpr: GpuScalar): ColumnVector = {
strExpr.getBase.replaceRegex(cudfRegexPattern, replaceExpr.getBase)
withResource(Scalar.fromString(cudfReplacementString)) { rep =>
strExpr.getBase.replaceRegex(cudfRegexPattern, rep)
}
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.util.{Random, Try}
import ai.rapids.cudf.{ColumnVector, CudfException}
import org.scalatest.FunSuite

import org.apache.spark.sql.rapids.GpuRegExpUtils
import org.apache.spark.sql.types.DataTypes

class RegularExpressionTranspilerSuite extends FunSuite with Arm {
Expand Down Expand Up @@ -418,13 +419,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
result
}

private val REPLACE_STRING = "_REPLACE_"
private val REPLACE_STRING = "\\_\\RE\\\\P\\L\\A\\C\\E\\_"

/** cuDF replaceRe helper */
private def gpuReplace(cudfPattern: String, input: Seq[String]): Array[String] = {
val result = new Array[String](input.length)
val replace = GpuRegExpUtils.unescapeReplaceString(REPLACE_STRING)
withResource(ColumnVector.fromStrings(input: _*)) { cv =>
withResource(GpuScalar.from(REPLACE_STRING, DataTypes.StringType)) { replace =>
withResource(GpuScalar.from(replace, DataTypes.StringType)) { replace =>
withResource(cv.replaceRegex(cudfPattern, replace)) { c =>
withResource(c.copyToHost()) { hv =>
result.indices.foreach(i => result(i) = new String(hv.getUTF8(i)))
Expand Down

0 comments on commit f89b19e

Please sign in to comment.