Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): support JSONB in UDF #9103

Merged
merged 12 commits into from
Apr 12, 2023
19 changes: 19 additions & 0 deletions e2e_test/udf/python.slt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ create function hex_to_dec(varchar) returns decimal language python as hex_to_de
statement ok
create function array_access(varchar[], int) returns varchar language python as array_access using link 'http://localhost:8815';

statement ok
create function jsonb_access(jsonb, int) returns jsonb language python as jsonb_access using link 'http://localhost:8815';

statement ok
create function jsonb_concat(jsonb[]) returns jsonb language python as jsonb_concat using link 'http://localhost:8815';

query I
select int_42();
----
Expand All @@ -68,6 +74,19 @@ select array_access(ARRAY['a', 'b', 'c'], 2);
----
b

query T
select jsonb_access(a::jsonb, 1) from
(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a);
----
"b"
NULL
false

query T
select jsonb_concat(ARRAY['null'::jsonb, '1'::jsonb, '"str"'::jsonb, '{}'::jsonb]);
----
[null, 1, "str", {}]

query I
select series(5);
----
Expand Down
22 changes: 19 additions & 3 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import socket
import struct
import sys
from typing import Iterator, List, Optional, Tuple
from typing import Iterator, List, Optional, Tuple, Any
from decimal import Decimal
sys.path.append('src/udf/python') # noqa

Expand All @@ -20,7 +20,7 @@ def gcd(x: int, y: int) -> int:
return x


@udf(input_types=['INT', 'INT', 'INT'], result_type='INT')
@udf(name='gcd3', input_types=['INT', 'INT', 'INT'], result_type='INT')
def gcd3(x: int, y: int, z: int) -> int:
return gcd(gcd(x, y), z)

Expand Down Expand Up @@ -69,8 +69,22 @@ def array_access(list: List[str], idx: int) -> Optional[str]:
return list[idx - 1]


@udf(input_types=['JSONB', 'INT'], result_type='JSONB')
def jsonb_access(json: Any, i: int) -> Any:
if not json:
return None
return json[i]


@udf(input_types=['JSONB[]'], result_type='JSONB')
def jsonb_concat(list: List[Any]) -> Any:
if not list:
return None
return list


if __name__ == '__main__':
server = UdfServer()
server = UdfServer(location="0.0.0.0:8815")
server.add_function(int_42)
server.add_function(gcd)
server.add_function(gcd3)
Expand All @@ -79,4 +93,6 @@ def array_access(list: List[str], idx: int) -> Optional[str]:
server.add_function(extract_tcp_info)
server.add_function(hex_to_dec)
server.add_function(array_access)
server.add_function(jsonb_access)
server.add_function(jsonb_concat)
server.serve()
131 changes: 101 additions & 30 deletions src/common/src/array/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

//! Converts between arrays and Apache Arrow arrays.

use std::fmt::Write;

use arrow_schema::{Field, DECIMAL256_MAX_PRECISION};
use chrono::{NaiveDateTime, NaiveTime};
use itertools::Itertools;
Expand All @@ -38,16 +40,16 @@ impl From<&DataChunk> for arrow_array::RecordBatch {
}
}

impl From<&arrow_array::RecordBatch> for DataChunk {
fn from(batch: &arrow_array::RecordBatch) -> Self {
DataChunk::new(
batch
.columns()
.iter()
.map(|array| Column::new(Arc::new(array.into())))
.collect(),
batch.num_rows(),
)
impl TryFrom<&arrow_array::RecordBatch> for DataChunk {
type Error = ArrayError;

fn try_from(batch: &arrow_array::RecordBatch) -> Result<Self, Self::Error> {
let mut columns = Vec::with_capacity(batch.num_columns());
for array in batch.columns() {
let column = Column::new(Arc::new(array.try_into()?));
columns.push(column);
}
Ok(DataChunk::new(columns, batch.num_rows()))
}
}

Expand All @@ -64,20 +66,21 @@ macro_rules! converts_generic {
}
}
// Arrow array -> RisingWave array
impl From<&arrow_array::ArrayRef> for ArrayImpl {
fn from(array: &arrow_array::ArrayRef) -> Self {
impl TryFrom<&arrow_array::ArrayRef> for ArrayImpl {
type Error = ArrayError;
fn try_from(array: &arrow_array::ArrayRef) -> Result<Self, Self::Error> {
use arrow_schema::DataType::*;
use arrow_schema::IntervalUnit::*;
use arrow_schema::TimeUnit::*;
match array.data_type() {
$($ArrowPattern => $ArrayImplPattern(
$($ArrowPattern => Ok($ArrayImplPattern(
array
.as_any()
.downcast_ref::<$ArrowType>()
.unwrap()
.into(),
),)*
t => todo!("Unsupported arrow data type: {t:?}"),
.try_into()?,
)),)*
t => Err(ArrayError::FromArrow(format!("unsupported data type: {t:?}"))),
}
}
}
Expand All @@ -99,7 +102,8 @@ converts_generic! {
{ arrow_array::Time64NanosecondArray, Time64(Nanosecond), ArrayImpl::Time },
{ arrow_array::StructArray, Struct(_), ArrayImpl::Struct },
{ arrow_array::ListArray, List(_), ArrayImpl::List },
{ arrow_array::BinaryArray, Binary, ArrayImpl::Bytea }
{ arrow_array::BinaryArray, Binary, ArrayImpl::Bytea },
{ arrow_array::LargeStringArray, LargeUtf8, ArrayImpl::Jsonb } // we use LargeUtf8 to represent Jsonb in arrow
}

// Arrow Datatype -> Risingwave Datatype
Expand All @@ -119,6 +123,7 @@ impl From<&arrow_schema::DataType> for DataType {
Interval(_) => Self::Interval, // TODO: check time unit
Binary => Self::Bytea,
Utf8 => Self::Varchar,
LargeUtf8 => Self::Jsonb,
Struct(field) => Self::Struct(Arc::new(struct_type::StructType {
fields: field.iter().map(|f| f.data_type().into()).collect(),
field_names: field.iter().map(|f| f.name().clone()).collect(),
Expand Down Expand Up @@ -154,6 +159,7 @@ impl From<&DataType> for arrow_schema::DataType {
DataType::Time => Self::Time64(arrow_schema::TimeUnit::Millisecond),
DataType::Interval => Self::Interval(arrow_schema::IntervalUnit::DayTime),
DataType::Varchar => Self::Utf8,
DataType::Jsonb => Self::LargeUtf8,
DataType::Bytea => Self::Binary,
DataType::Decimal => Self::Decimal128(28, 0), // arrow precision can not be 0
DataType::Struct(struct_type) => {
Expand Down Expand Up @@ -411,6 +417,40 @@ impl From<&arrow_array::Decimal128Array> for DecimalArray {
}
}

impl From<&JsonbArray> for arrow_array::LargeStringArray {
fn from(array: &JsonbArray) -> Self {
let mut builder =
arrow_array::builder::LargeStringBuilder::with_capacity(array.len(), array.len() * 16);
for value in array.iter() {
match value {
Some(jsonb) => {
write!(&mut builder, "{}", jsonb).unwrap();
builder.append_value("");
}
None => builder.append_null(),
}
}
builder.finish()
}
}

impl TryFrom<&arrow_array::LargeStringArray> for JsonbArray {
type Error = ArrayError;

fn try_from(array: &arrow_array::LargeStringArray) -> Result<Self, Self::Error> {
array
.iter()
.map(|o| {
o.map(|s| {
s.parse()
.map_err(|_| ArrayError::FromArrow(format!("invalid json: {s}")))
})
.transpose()
})
.try_collect()
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl From<&Int256Array> for arrow_array::Decimal256Array {
fn from(array: &Int256Array) -> Self {
array
Expand Down Expand Up @@ -529,7 +569,12 @@ impl From<&ListArray> for arrow_array::ListArray {
Time64NanosecondBuilder::with_capacity(a.len()),
|b, v| b.append_option(v.map(|d| d.into_arrow())),
),
ArrayImpl::Jsonb(_) => todo!("list of jsonb"),
ArrayImpl::Jsonb(a) => build(
array,
a,
LargeStringBuilder::with_capacity(a.len(), a.len() * 16),
|b, v| b.append_option(v.map(|j| j.to_string())),
),
ArrayImpl::Serial(_) => todo!("list of serial"),
ArrayImpl::Struct(_) => todo!("list of struct"),
ArrayImpl::List(_) => todo!("list of list"),
Expand All @@ -543,10 +588,15 @@ impl From<&ListArray> for arrow_array::ListArray {
}
}

impl From<&arrow_array::ListArray> for ListArray {
fn from(array: &arrow_array::ListArray) -> Self {
let iter = array.iter().map(|o| o.map(|a| ArrayImpl::from(&a)));
ListArray::from_iter(iter, (&array.value_type()).into())
impl TryFrom<&arrow_array::ListArray> for ListArray {
type Error = ArrayError;

fn try_from(array: &arrow_array::ListArray) -> Result<Self, Self::Error> {
let iter: Vec<_> = array
.iter()
.map(|o| o.map(|a| ArrayImpl::try_from(&a)).transpose())
.try_collect()?;
Ok(ListArray::from_iter(iter, (&array.value_type()).into()))
}
}

Expand All @@ -573,24 +623,30 @@ impl From<&StructArray> for arrow_array::StructArray {
}
}

impl From<&arrow_array::StructArray> for StructArray {
fn from(array: &arrow_array::StructArray) -> Self {
impl TryFrom<&arrow_array::StructArray> for StructArray {
type Error = ArrayError;

fn try_from(array: &arrow_array::StructArray) -> Result<Self, Self::Error> {
let mut null_bitmap = Vec::new();
for i in 0..arrow_array::Array::len(&array) {
null_bitmap.push(!arrow_array::Array::is_null(&array, i))
}
match arrow_array::Array::data_type(&array) {
Ok(match arrow_array::Array::data_type(&array) {
arrow_schema::DataType::Struct(fields) => StructArray::from_slices_with_field_names(
&(null_bitmap),
array.columns().iter().map(ArrayImpl::from).collect(),
array
.columns()
.iter()
.map(ArrayImpl::try_from)
.try_collect()?,
fields
.iter()
.map(|f| DataType::from(f.data_type()))
.collect(),
array.column_names().into_iter().map(String::from).collect(),
),
_ => panic!("nested field types cannot be determined."),
}
})
}
}

Expand Down Expand Up @@ -686,6 +742,20 @@ mod tests {
assert_eq!(DecimalArray::from(&arrow), array);
}

#[test]
fn jsonb() {
let array = JsonbArray::from_iter([
None,
Some("null".parse().unwrap()),
Some("false".parse().unwrap()),
Some("1".parse().unwrap()),
Some("[1, 2, 3]".parse().unwrap()),
Some(r#"{ "a": 1, "b": null }"#.parse().unwrap()),
]);
let arrow = arrow_array::LargeStringArray::from(&array);
assert_eq!(JsonbArray::try_from(&arrow).unwrap(), array);
}

#[test]
fn int256() {
let values = vec![
Expand Down Expand Up @@ -720,7 +790,7 @@ mod tests {

// Empty array - arrow to risingwave conversion.
let test_arr_2 = arrow_array::StructArray::from(vec![]);
assert_eq!(StructArray::from(&test_arr_2).len(), 0);
assert_eq!(StructArray::try_from(&test_arr_2).unwrap().len(), 0);

// Struct array with primitive types. arrow to risingwave conversion.
let test_arrow_struct_array = arrow_array::StructArray::try_from(vec![
Expand All @@ -744,7 +814,8 @@ mod tests {
),
])
.unwrap();
let actual_risingwave_struct_array = StructArray::from(&test_arrow_struct_array);
let actual_risingwave_struct_array =
StructArray::try_from(&test_arrow_struct_array).unwrap();
let expected_risingwave_struct_array = StructArray::from_slices_with_field_names(
&[true, true, true, false],
vec![
Expand Down Expand Up @@ -772,6 +843,6 @@ mod tests {
DataType::Int32,
);
let arrow = arrow_array::ListArray::from(&array);
assert_eq!(ListArray::from(&arrow), array);
assert_eq!(ListArray::try_from(&arrow).unwrap(), array);
}
}
11 changes: 11 additions & 0 deletions src/common/src/array/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::convert::Infallible;

pub use anyhow::anyhow;
use risingwave_pb::PbFieldNotFound;
use thiserror::Error;
Expand All @@ -28,6 +30,9 @@ pub enum ArrayError {

#[error(transparent)]
Internal(#[from] anyhow::Error),

#[error("Arrow error: {0}")]
FromArrow(String),
}

impl From<ArrayError> for RwError {
Expand All @@ -42,6 +47,12 @@ impl From<PbFieldNotFound> for ArrayError {
}
}

impl From<Infallible> for ArrayError {
fn from(err: Infallible) -> Self {
unreachable!("Infallible error: {:?}", err)
}
}

impl ArrayError {
pub fn internal(msg: impl ToString) -> Self {
ArrayError::Internal(anyhow!(msg.to_string()))
Expand Down
Loading