Skip to content

Commit

Permalink
Use escaped characters in all regexp_replace tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jan 19, 2022
1 parent 5d5727b commit cc5d1fa
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.util.{Random, Try}
import ai.rapids.cudf.{ColumnVector, CudfException}
import org.scalatest.FunSuite

import org.apache.spark.sql.rapids.GpuRegExpUtils
import org.apache.spark.sql.types.DataTypes

class RegularExpressionTranspilerSuite extends FunSuite with Arm {
Expand Down Expand Up @@ -418,13 +419,14 @@ class RegularExpressionTranspilerSuite extends FunSuite with Arm {
result
}

private val REPLACE_STRING = "_REPLACE_"
private val REPLACE_STRING = "\\_\\RE\\\\P\\L\\A\\C\\E\\_"

/** cuDF replaceRe helper */
private def gpuReplace(cudfPattern: String, input: Seq[String]): Array[String] = {
val result = new Array[String](input.length)
val replace = GpuRegExpUtils.unescapeReplaceString(REPLACE_STRING)
withResource(ColumnVector.fromStrings(input: _*)) { cv =>
withResource(GpuScalar.from(REPLACE_STRING, DataTypes.StringType)) { replace =>
withResource(GpuScalar.from(replace, DataTypes.StringType)) { replace =>
withResource(cv.replaceRegex(cudfPattern, replace)) { c =>
withResource(c.copyToHost()) { hv =>
result.indices.foreach(i => result(i) = new String(hv.getUTF8(i)))
Expand Down

0 comments on commit cc5d1fa

Please sign in to comment.