diff --git a/src/batch/src/executor/top_n.rs b/src/batch/src/executor/top_n.rs index 4c19c4e563cf..4eca84596c9a 100644 --- a/src/batch/src/executor/top_n.rs +++ b/src/batch/src/executor/top_n.rs @@ -25,7 +25,7 @@ use risingwave_common::estimate_size::EstimateSize; use risingwave_common::memory::MemoryContext; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; -use risingwave_common::util::memcmp_encoding::encode_chunk; +use risingwave_common::util::memcmp_encoding::{encode_chunk, MemcmpEncoded}; use risingwave_common::util::sort_util::ColumnOrder; use risingwave_pb::batch_plan::plan_node::NodeBody; @@ -200,7 +200,7 @@ impl TopNHeap { #[derive(Clone, EstimateSize)] pub struct HeapElem { - encoded_row: Vec, + encoded_row: MemcmpEncoded, row: OwnedRow, } @@ -225,7 +225,7 @@ impl Ord for HeapElem { } impl HeapElem { - pub fn new(encoded_row: Vec, row: impl Row) -> Self { + pub fn new(encoded_row: MemcmpEncoded, row: impl Row) -> Self { Self { encoded_row, row: row.into_owned_row(), diff --git a/src/common/benches/bench_encoding.rs b/src/common/benches/bench_encoding.rs index 994b86dabd25..ffa4005cb8c3 100644 --- a/src/common/benches/bench_encoding.rs +++ b/src/common/benches/bench_encoding.rs @@ -19,6 +19,7 @@ use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{ DataType, Date, Datum, Interval, ScalarImpl, StructType, Time, Timestamp, }; +use risingwave_common::util::memcmp_encoding::MemcmpEncoded; use risingwave_common::util::sort_util::OrderType; use risingwave_common::util::{memcmp_encoding, value_encoding}; @@ -42,7 +43,7 @@ impl Case { } } -fn key_serialization(datum: &Datum) -> Vec { +fn key_serialization(datum: &Datum) -> MemcmpEncoded { let result = memcmp_encoding::encode_value( datum.as_ref().map(ScalarImpl::as_scalar_ref_impl), OrderType::default(), diff --git a/src/common/src/util/memcmp_encoding.rs b/src/common/src/util/memcmp_encoding.rs index fce92676f5b7..08e2f715ca32 100644 --- a/src/common/src/util/memcmp_encoding.rs +++ b/src/common/src/util/memcmp_encoding.rs @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use bytes::{Buf, BufMut}; use itertools::Itertools; use serde::{Deserialize, Serialize}; use super::iter_util::{ZipEqDebug, ZipEqFast}; use crate::array::{ArrayImpl, DataChunk}; +use crate::estimate_size::EstimateSize; use crate::row::{OwnedRow, Row}; use crate::types::{ DataType, Date, Datum, Int256, ScalarImpl, Serial, Time, Timestamp, ToDatumRef, F32, F64, @@ -180,12 +183,83 @@ fn calculate_encoded_size_inner( Ok(deserializer.position() - base_position) } -pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> memcomparable::Result> { +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, EstimateSize)] +pub struct MemcmpEncoded(Box<[u8]>); + +impl MemcmpEncoded { + pub fn as_inner(&self) -> &[u8] { + &self.0 + } + + pub fn into_inner(self) -> Box<[u8]> { + self.0 + } +} + +impl AsRef<[u8]> for MemcmpEncoded { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl Deref for MemcmpEncoded { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl IntoIterator for MemcmpEncoded { + type IntoIter = std::vec::IntoIter; + type Item = u8; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_vec().into_iter() + } +} + +impl FromIterator for MemcmpEncoded { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl From> for MemcmpEncoded { + fn from(v: Vec) -> Self { + Self(v.into_boxed_slice()) + } +} + +impl From> for MemcmpEncoded { + fn from(v: Box<[u8]>) -> Self { + Self(v) + } +} + +impl From for Vec { + fn from(v: MemcmpEncoded) -> Self { + v.0.into() + } +} + +impl From for Box<[u8]> { + fn from(v: MemcmpEncoded) -> Self { + v.0 + } +} + +/// Encode a datum into memcomparable format. +pub fn encode_value( + value: impl ToDatumRef, + order: OrderType, +) -> memcomparable::Result { let mut serializer = memcomparable::Serializer::new(vec![]); serialize_datum(value, order, &mut serializer)?; - Ok(serializer.into_inner()) + Ok(serializer.into_inner().into()) } +/// Decode a datum from memcomparable format. pub fn decode_value( ty: &DataType, encoded_value: &[u8], @@ -195,7 +269,11 @@ pub fn decode_value( deserialize_datum(ty, order, &mut deserializer) } -pub fn encode_array(array: &ArrayImpl, order: OrderType) -> memcomparable::Result>> { +/// Encode an array into memcomparable format. +pub fn encode_array( + array: &ArrayImpl, + order: OrderType, +) -> memcomparable::Result> { let mut data = Vec::with_capacity(array.len()); for datum in array.iter() { data.push(encode_value(datum, order)?); @@ -203,13 +281,11 @@ pub fn encode_array(array: &ArrayImpl, order: OrderType) -> memcomparable::Resul Ok(data) } -/// This function is used to accelerate the comparison of tuples. It takes datachunk and -/// user-defined order as input, yield encoded binary string with order preserved for each tuple in -/// the datachunk. +/// Encode a chunk into memcomparable format. pub fn encode_chunk( chunk: &DataChunk, column_orders: &[ColumnOrder], -) -> memcomparable::Result>> { +) -> memcomparable::Result> { let encoded_columns: Vec<_> = column_orders .iter() .map(|o| encode_array(chunk.column_at(o.column_index), o.order_type)) @@ -222,18 +298,22 @@ pub fn encode_chunk( } } - Ok(encoded_chunk) + Ok(encoded_chunk.into_iter().map(Into::into).collect()) } /// Encode a row into memcomparable format. -pub fn encode_row(row: impl Row, order_types: &[OrderType]) -> memcomparable::Result> { +pub fn encode_row( + row: impl Row, + order_types: &[OrderType], +) -> memcomparable::Result { let mut serializer = memcomparable::Serializer::new(vec![]); row.iter() .zip_eq_debug(order_types) .try_for_each(|(datum, order)| serialize_datum(datum, *order, &mut serializer))?; - Ok(serializer.into_inner()) + Ok(serializer.into_inner().into()) } +/// Decode a row from memcomparable format. pub fn decode_row( encoded_row: &[u8], data_types: &[DataType], @@ -259,11 +339,12 @@ mod tests { use crate::array::{DataChunk, ListValue, StructValue}; use crate::row::{OwnedRow, RowExt}; use crate::types::{DataType, FloatExt, ScalarImpl, F32}; + use crate::util::iter_util::ZipEqFast; use crate::util::sort_util::{ColumnOrder, OrderType}; #[test] fn test_memcomparable() { - fn encode_num(num: Option, order_type: OrderType) -> Vec { + fn encode_num(num: Option, order_type: OrderType) -> MemcmpEncoded { encode_value(num.map(ScalarImpl::from), order_type).unwrap() } @@ -465,11 +546,11 @@ mod tests { use num_traits::*; use rand::seq::SliceRandom; - fn serialize(f: F32) -> Vec { + fn serialize(f: F32) -> MemcmpEncoded { encode_value(&Some(ScalarImpl::from(f)), OrderType::default()).unwrap() } - fn deserialize(data: Vec) -> F32 { + fn deserialize(data: MemcmpEncoded) -> F32 { decode_value(&DataType::Float32, &data, OrderType::default()) .unwrap() .unwrap() @@ -539,7 +620,7 @@ mod tests { let concated_encoded_row1 = encoded_v10 .into_iter() .chain(encoded_v11.into_iter()) - .collect_vec(); + .collect(); assert_eq!(encoded_row1, concated_encoded_row1); let encoded_row2 = encode_row(row2.project(&order_col_indices), &order_types).unwrap(); diff --git a/src/stream/src/executor/aggregation/agg_state_cache.rs b/src/stream/src/executor/aggregation/agg_state_cache.rs index 22d120268f46..0fa83143ab67 100644 --- a/src/stream/src/executor/aggregation/agg_state_cache.rs +++ b/src/stream/src/executor/aggregation/agg_state_cache.rs @@ -17,6 +17,7 @@ use risingwave_common::array::{ArrayImpl, Op}; use risingwave_common::buffer::Bitmap; use risingwave_common::estimate_size::EstimateSize; use risingwave_common::types::{Datum, DatumRef}; +use risingwave_common::util::memcmp_encoding::MemcmpEncoded; use risingwave_common::util::row_serde::OrderedRowSerde; use smallvec::SmallVec; @@ -24,7 +25,7 @@ use super::minput_agg_impl::MInputAggregator; use crate::common::cache::{StateCache, StateCacheFiller}; /// Cache key type. -type CacheKey = Vec; +type CacheKey = MemcmpEncoded; // TODO(yuchao): May extract common logic here to `struct [Data/Stream]ChunkRef` if there's other // usage in the future. https://github.com/risingwavelabs/risingwave/pull/5908#discussion_r1002896176 @@ -76,7 +77,7 @@ impl<'a> Iterator for StateCacheInputBatch<'a> { .map(|col_idx| self.columns[*col_idx].value_at(self.idx)), &mut key, ); - key + key.into() }; let value = self .arg_col_indices diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index fbb3bbd26393..723f6eafb614 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -204,7 +204,7 @@ impl MaterializedInputState { .project(&self.state_table_order_col_indices), &mut cache_key, ); - cache_key + cache_key.into() }; let cache_value = self .state_table_arg_col_indices diff --git a/src/stream/src/executor/over_window/eowc.rs b/src/stream/src/executor/over_window/eowc.rs index 317f76e3cd3e..87cf3cbbde99 100644 --- a/src/stream/src/executor/over_window/eowc.rs +++ b/src/stream/src/executor/over_window/eowc.rs @@ -24,7 +24,7 @@ use risingwave_common::estimate_size::EstimateSize; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::types::{DataType, ToDatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast}; -use risingwave_common::util::memcmp_encoding; +use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded}; use risingwave_common::util::sort_util::OrderType; use risingwave_common::{must_match, row}; use risingwave_expr::function::window::WindowFuncCall; @@ -32,7 +32,6 @@ use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; use super::state::{create_window_state, EstimatedVecDeque, WindowState}; -use super::MemcmpEncoded; use crate::cache::{new_unbounded, ManagedLruCache}; use crate::common::table::state_table::StateTable; use crate::executor::over_window::state::{StateEvictHint, StateKey}; @@ -241,8 +240,7 @@ impl EowcOverWindowExecutor { let encoded_pk = memcmp_encoding::encode_row( (&row).project(&this.input_pk_indices), &vec![OrderType::ascending(); this.input_pk_indices.len()], - )? - .into_boxed_slice(); + )?; let key = StateKey { order_key: order_key.into(), encoded_pk, @@ -292,8 +290,7 @@ impl EowcOverWindowExecutor { let encoded_partition_key = memcmp_encoding::encode_row( &partition_key, &vec![OrderType::ascending(); this.partition_key_indices.len()], - )? - .into_boxed_slice(); + )?; // Get the partition. Self::ensure_key_in_cache( @@ -316,8 +313,7 @@ impl EowcOverWindowExecutor { let encoded_pk = memcmp_encoding::encode_row( input_row.project(&this.input_pk_indices), &vec![OrderType::ascending(); this.input_pk_indices.len()], - )? - .into_boxed_slice(); + )?; let key = StateKey { order_key: order_key.into(), encoded_pk, diff --git a/src/stream/src/executor/over_window/mod.rs b/src/stream/src/executor/over_window/mod.rs index fc461bb70c20..d44415a84bf5 100644 --- a/src/stream/src/executor/over_window/mod.rs +++ b/src/stream/src/executor/over_window/mod.rs @@ -16,5 +16,3 @@ mod eowc; mod state; pub use eowc::{EowcOverWindowExecutor, EowcOverWindowExecutorArgs}; - -type MemcmpEncoded = Box<[u8]>; diff --git a/src/stream/src/executor/over_window/state/mod.rs b/src/stream/src/executor/over_window/state/mod.rs index 8bbc411f3ce4..aa2e7435fe8f 100644 --- a/src/stream/src/executor/over_window/state/mod.rs +++ b/src/stream/src/executor/over_window/state/mod.rs @@ -17,10 +17,10 @@ use std::collections::{BTreeSet, VecDeque}; use educe::Educe; use risingwave_common::estimate_size::{EstimateSize, KvSize}; use risingwave_common::types::{Datum, DefaultOrdered, ScalarImpl}; +use risingwave_common::util::memcmp_encoding::MemcmpEncoded; use risingwave_expr::function::window::{WindowFuncCall, WindowFuncKind}; use smallvec::SmallVec; -use super::MemcmpEncoded; use crate::executor::{StreamExecutorError, StreamExecutorResult}; mod buffer; diff --git a/src/stream/src/executor/sort_buffer.rs b/src/stream/src/executor/sort_buffer.rs index 17f2c9bf4a87..769d8f5c8978 100644 --- a/src/stream/src/executor/sort_buffer.rs +++ b/src/stream/src/executor/sort_buffer.rs @@ -27,6 +27,7 @@ use risingwave_common::row::{self, OwnedRow, Row, RowExt}; use risingwave_common::types::{ DefaultOrd, DefaultOrdered, ScalarImpl, ScalarRefImpl, ToOwnedDatum, }; +use risingwave_common::util::memcmp_encoding::MemcmpEncoded; use risingwave_storage::row_serde::row_serde_util::deserialize_pk_with_vnode; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -35,9 +36,6 @@ use super::{StreamExecutorError, StreamExecutorResult}; use crate::common::cache::{OrderedStateCache, StateCache, StateCacheFiller}; use crate::common::table::state_table::StateTable; -// TODO(rc): This should be a struct in `memcmp_encoding` module. See #8606. -type MemcmpEncoded = Box<[u8]>; - type CacheKey = ( DefaultOrdered, // sort (watermark) column value MemcmpEncoded, // memcmp-encoded pk @@ -56,7 +54,7 @@ fn row_to_cache_key( buffer_table .pk_serde() .serialize((&row).project(buffer_table.pk_indices()), &mut pk); - (timestamp_val.into(), pk.into_boxed_slice()) + (timestamp_val.into(), pk.into()) } /// [`SortBuffer`] is a common component that consume an unordered stream and produce an ordered