diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 9d122f6101a74..713b371e3fda5 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -18,10 +18,10 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; @@ -45,7 +45,14 @@ impl RepeatFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + vec![ + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. + // If that fails, it proceeds to `(Utf8, Int64)`. + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + ], Volatility::Immutable, ), } @@ -71,9 +78,10 @@ impl ScalarUDFImpl for RepeatFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { + DataType::Utf8View => make_scalar_function(repeat_utf8view, vec![])(args), DataType::Utf8 => make_scalar_function(repeat::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(repeat::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function repeat"), + other => exec_err!("Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8"), } } } @@ -87,18 +95,35 @@ fn repeat(args: &[ArrayRef]) -> Result { let result = string_array .iter() .zip(number_array.iter()) - .map(|(string, number)| match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - Some(string.repeat(number as usize)) - } - (Some(_), Some(_)) => Some("".to_string()), - _ => None, - }) + .map(|(string, number)| repeat_common(string, number)) .collect::>(); Ok(Arc::new(result) as ArrayRef) } +fn repeat_utf8view(args: &[ArrayRef]) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + let number_array = as_int64_array(&args[1])?; + + let result = string_view_array + .iter() + .zip(number_array.iter()) + .map(|(string, number)| repeat_common(string, number)) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +fn repeat_common(string: Option<&str>, number: Option) -> Option { + match (string, number) { + (Some(string), Some(number)) if number >= 0 => { + Some(string.repeat(number as usize)) + } + (Some(_), Some(_)) => Some("".to_string()), + _ => None, + } +} + #[cfg(test)] mod tests { use arrow::array::{Array, StringArray}; @@ -124,7 +149,6 @@ mod tests { Utf8, StringArray ); - test_function!( RepeatFunc::new(), &[ @@ -148,6 +172,40 @@ mod tests { StringArray ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + Ok(()) } }