diff --git a/src/common/src/types/chrono_wrapper.rs b/src/common/src/types/chrono_wrapper.rs index 1a3c4ba39755..7ec81c640e6f 100644 --- a/src/common/src/types/chrono_wrapper.rs +++ b/src/common/src/types/chrono_wrapper.rs @@ -61,6 +61,12 @@ macro_rules! impl_chrono_wrapper { Ok($variant_name(s.parse()?)) } } + + impl From<$chrono> for $variant_name { + fn from(data: $chrono) -> Self { + $variant_name(data) + } + } }; } diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 7cf41a2b2c16..e67654712c63 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -634,6 +634,12 @@ impl Zero for Decimal { } } +impl From for Decimal { + fn from(d: RustDecimal) -> Self { + Self::Normalized(d) + } +} + #[cfg(test)] mod tests { diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 088abd7f5eb8..7d08f8bde7bc 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -19,11 +19,12 @@ use std::sync::Arc; use bytes::{Buf, BufMut, Bytes}; use num_traits::Float; use parse_display::{Display, FromStr}; +use postgres_types::FromSql; use risingwave_pb::data::DataType as ProstDataType; use serde::{Deserialize, Serialize}; use crate::array::{ArrayError, ArrayResult, NULL_VAL_FOR_HASH}; -use crate::error::BoxedError; +use crate::error::{BoxedError, ErrorCode}; mod native_type; mod ops; @@ -32,7 +33,7 @@ mod successor; use std::fmt::Debug; use std::io::Cursor; -use std::str::FromStr; +use std::str::{FromStr, Utf8Error}; pub use native_type::*; use risingwave_pb::data::data_type::IntervalType::*; @@ -752,6 +753,174 @@ impl From<&String> for ScalarImpl { } } +impl ScalarImpl { + pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> RwResult { + let res = match data_type { + DataType::Varchar => Self::Utf8( + String::from_sql(&Type::VARCHAR, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Bytea => Self::Bytea( + Vec::::from_sql(&Type::BYTEA, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Boolean => Self::Bool( + bool::from_sql(&Type::BOOL, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + ), + DataType::Int16 => Self::Int16( + i16::from_sql(&Type::INT2, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + ), + DataType::Int32 => Self::Int32( + i32::from_sql(&Type::INT4, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + ), + DataType::Int64 => Self::Int64( + i64::from_sql(&Type::INT8, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + ), + DataType::Float32 => Self::Float32( + f32::from_sql(&Type::FLOAT4, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Float64 => Self::Float64( + f64::from_sql(&Type::FLOAT8, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Decimal => Self::Decimal( + rust_decimal::Decimal::from_sql(&Type::NUMERIC, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Date => Self::NaiveDate( + chrono::NaiveDate::from_sql(&Type::DATE, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Time => Self::NaiveTime( + chrono::NaiveTime::from_sql(&Type::TIME, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Timestamp => Self::NaiveDateTime( + chrono::NaiveDateTime::from_sql(&Type::TIMESTAMP, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .into(), + ), + DataType::Timestamptz => Self::Int64( + chrono::DateTime::::from_sql(&Type::TIMESTAMPTZ, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .timestamp_micros(), + ), + DataType::Interval => Self::Interval( + IntervalUnit::from_sql(&Type::INTERVAL, bytes) + .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + ), + DataType::Jsonb => { + Self::Jsonb(JsonbVal::value_deserialize(bytes).ok_or_else(|| { + ErrorCode::InvalidInputSyntax("Invalid value of Jsonb".to_string()) + })?) + } + DataType::Struct(_) | DataType::List { .. } => { + return Err(ErrorCode::NotSupported( + format!("param type: {}", data_type), + "".to_string(), + ) + .into()) + } + }; + Ok(res) + } + + pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> { + let without_null = if b.last() == Some(&0) { + &b[..b.len() - 1] + } else { + &b[..] + }; + std::str::from_utf8(without_null) + } + + pub fn from_text(bytes: &Bytes, data_type: &DataType) -> RwResult { + let str = Self::cstr_to_str(bytes).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {:?}", bytes)) + })?; + let res = match data_type { + DataType::Varchar => Self::Utf8(str.to_string().into()), + DataType::Boolean => Self::Bool(bool::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Int16 => Self::Int16(i16::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Int32 => Self::Int32(i32::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Int64 => Self::Int64(i64::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Float32 => Self::Float32( + f32::from_str(str) + .map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })? + .into(), + ), + DataType::Float64 => Self::Float64( + f64::from_str(str) + .map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })? + .into(), + ), + DataType::Decimal => Self::Decimal( + rust_decimal::Decimal::from_str(str) + .map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })? + .into(), + ), + DataType::Date => Self::NaiveDate(NaiveDateWrapper::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Time => Self::NaiveTime(NaiveTimeWrapper::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Timestamp => { + Self::NaiveDateTime(NaiveDateTimeWrapper::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?) + } + DataType::Timestamptz => Self::Int64( + chrono::DateTime::::from_str(str) + .map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })? + .timestamp_micros(), + ), + DataType::Interval => Self::Interval(IntervalUnit::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Jsonb => Self::Jsonb(JsonbVal::from_str(str).map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) + })?), + DataType::Bytea | DataType::Struct(_) | DataType::List { .. } => { + return Err(ErrorCode::NotSupported( + format!("param type: {}", data_type), + "".to_string(), + ) + .into()) + } + }; + Ok(res) + } +} + macro_rules! impl_scalar_impl_ref_conversion { ($( { $variant_name:ident, $suffix_name:ident, $scalar:ty, $scalar_ref:ty } ),*) => { impl ScalarImpl {