Skip to content

Commit

Permalink
Fix: Internal error in regexp_replace() for some StringView input (#1…
Browse files Browse the repository at this point in the history
…2203)

* Fix: Internal error in regexp_replace() for some StringView input

* fix regex bench

* fmt

* fix bench regx

* clippy

* fmt

* adds tests for flags + includes type signature for utf8view with flag

* fix: adding collect for string view type
  • Loading branch information
devanbenz authored Sep 12, 2024
1 parent 533fbc4 commit 04e8e53
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 94 deletions.
14 changes: 7 additions & 7 deletions datafusion/functions/benches/regx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
extern crate criterion;

use arrow::array::builder::StringBuilder;
use arrow::array::{ArrayRef, StringArray};
use arrow::array::{ArrayRef, AsArray, StringArray};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_functions::regex::regexplike::regexp_like;
use datafusion_functions::regex::regexpmatch::regexp_match;
Expand Down Expand Up @@ -122,12 +122,12 @@ fn criterion_benchmark(c: &mut Criterion) {

b.iter(|| {
black_box(
regexp_replace::<i32>(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&replacement),
Arc::clone(&flags),
])
regexp_replace::<i32, _, _>(
data.as_string::<i32>(),
regex.as_string::<i32>(),
replacement.as_string::<i32>(),
Some(&flags),
)
.expect("regexp_replace should work on valid values"),
)
})
Expand Down
258 changes: 171 additions & 87 deletions datafusion/functions/src/regex/regexpreplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
// under the License.

//! Regx expressions
use arrow::array::new_null_array;
use arrow::array::ArrayAccessor;
use arrow::array::ArrayDataBuilder;
use arrow::array::BufferBuilder;
use arrow::array::GenericStringArray;
use arrow::array::StringViewBuilder;
use arrow::array::{new_null_array, ArrayIter, AsArray};
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
use arrow::array::{ArrayAccessor, StringViewArray};
use arrow::datatypes::DataType;
use datafusion_common::cast::as_string_view_array;
use datafusion_common::exec_err;
Expand Down Expand Up @@ -59,6 +59,7 @@ impl RegexpReplaceFunc {
Exact(vec![Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8]),
Exact(vec![Utf8, Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -187,104 +188,147 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
/// # Ok(())
/// # }
/// ```
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>(
string_array: V,
pattern_array: B,
replacement_array: B,
flags: Option<&ArrayRef>,
) -> Result<ArrayRef>
where
V: ArrayAccessor<Item = &'a str>,
B: ArrayAccessor<Item = &'a str>,
{
// Default implementation for regexp_replace, assumes all args are arrays
// and args is a sequence of 3 or 4 elements.

// creating Regex is expensive so create hashmap for memoization
let mut patterns: HashMap<String, Regex> = HashMap::new();

match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern_array = as_generic_string_array::<T>(&args[1])?;
let replacement_array = as_generic_string_array::<T>(&args[2])?;

let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.map(|((string, pattern), replacement)| match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);

// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(pattern) {
Some(re) => Ok(re),
None => {
match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re);
Ok(patterns.get(pattern).unwrap())
let datatype = string_array.data_type().to_owned();

let string_array_iter = ArrayIter::new(string_array);
let pattern_array_iter = ArrayIter::new(pattern_array);
let replacement_array_iter = ArrayIter::new(replacement_array);

match flags {
None => {
let result_iter = string_array_iter
.zip(pattern_array_iter)
.zip(replacement_array_iter)
.map(|((string, pattern), replacement)| {
match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);
// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(pattern) {
Some(re) => Ok(re),
None => match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re);
Ok(patterns.get(pattern).unwrap())
}
Err(err) => {
Err(DataFusionError::External(Box::new(err)))
}
},
Err(err) => Err(DataFusionError::External(Box::new(err))),
}
}
};
};

Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose()
Some(re.map(|re| re.replace(string, replacement.as_str())))
.transpose()
}
_ => Ok(None),
}
});

match datatype {
DataType::Utf8 | DataType::LargeUtf8 => {
let result =
result_iter.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;

Ok(Arc::new(result) as ArrayRef)
DataType::Utf8View => {
let result = result_iter.collect::<Result<StringViewArray>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!(
"Unsupported data type {other:?} for function regex_replace"
)
}
}
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern_array = as_generic_string_array::<T>(&args[1])?;
let replacement_array = as_generic_string_array::<T>(&args[2])?;
let flags_array = as_generic_string_array::<T>(&args[3])?;

let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);

// format flags into rust pattern
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true)
} else {
(format!("(?{flags}){pattern}"), false)
};

// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(&pattern) {
Some(re) => Ok(re),
None => {
match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern.clone(), re);
Ok(patterns.get(&pattern).unwrap())
Some(flags) => {
let flags_array = as_generic_string_array::<T>(flags)?;

let result_iter = string_array_iter
.zip(pattern_array_iter)
.zip(replacement_array_iter)
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| {
match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);

// format flags into rust pattern
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(
format!(
"(?{}){}",
flags.to_string().replace('g', ""),
pattern
),
true,
)
} else {
(format!("(?{flags}){pattern}"), false)
};

// if patterns hashmap already has regexp then use else create and return
let re = match patterns.get(&pattern) {
Some(re) => Ok(re),
None => match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern.clone(), re);
Ok(patterns.get(&pattern).unwrap())
}
Err(err) => {
Err(DataFusionError::External(Box::new(err)))
}
},
Err(err) => Err(DataFusionError::External(Box::new(err))),
}
};

Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
}))
.transpose()
}
};

Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
})).transpose()
_ => Ok(None),
}
});

match datatype {
DataType::Utf8 | DataType::LargeUtf8 => {
let result =
result_iter.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;

Ok(Arc::new(result) as ArrayRef)
DataType::Utf8View => {
let result = result_iter.collect::<Result<StringViewArray>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!(
"Unsupported data type {other:?} for function regex_replace"
)
}
}
}
other => exec_err!(
"regexp_replace was called with {other} arguments. It requires at least 3 and at most 4."
),
}
}

Expand Down Expand Up @@ -495,7 +539,47 @@ pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
.iter()
.map(|arg| arg.clone().into_array(inferred_length))
.collect::<Result<Vec<_>>>()?;
regexp_replace::<T>(&args)

match args[0].data_type() {
DataType::Utf8View => {
let string_array = args[0].as_string_view();
let pattern_array = args[1].as_string::<i32>();
let replacement_array = args[2].as_string::<i32>();
regexp_replace::<i32, _, _>(
string_array,
pattern_array,
replacement_array,
args.get(3),
)
}
DataType::Utf8 => {
let string_array = args[0].as_string::<i32>();
let pattern_array = args[1].as_string::<i32>();
let replacement_array = args[2].as_string::<i32>();
regexp_replace::<i32, _, _>(
string_array,
pattern_array,
replacement_array,
args.get(3),
)
}
DataType::LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let pattern_array = args[1].as_string::<i64>();
let replacement_array = args[2].as_string::<i64>();
regexp_replace::<i64, _, _>(
string_array,
pattern_array,
replacement_array,
args.get(3),
)
}
other => {
exec_err!(
"Unsupported data type {other:?} for function regex_replace"
)
}
}
}
}
}
Expand Down
Loading

0 comments on commit 04e8e53

Please sign in to comment.