diff --git a/Cargo.lock b/Cargo.lock index f740010071dc..d6fca8241d77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3949,6 +3949,7 @@ version = "0.8.2" dependencies = [ "api", "arrow-schema", + "async-recursion", "async-trait", "bytes", "catalog", diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index 285f8dbeec41..fcf33e45fe44 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] api.workspace = true arrow-schema.workspace = true +async-recursion = "1.0" async-trait.workspace = true bytes.workspace = true catalog.workspace = true diff --git a/src/flow/src/adapter/worker.rs b/src/flow/src/adapter/worker.rs index 4d9ad2f52447..f69a396cda27 100644 --- a/src/flow/src/adapter/worker.rs +++ b/src/flow/src/adapter/worker.rs @@ -15,14 +15,14 @@ //! For single-thread flow worker use std::collections::{BTreeMap, VecDeque}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use common_telemetry::info; use enum_as_inner::EnumAsInner; use hydroflow::scheduled::graph::Hydroflow; -use snafu::{ensure, OptionExt}; -use tokio::sync::{broadcast, mpsc, Mutex}; +use snafu::ensure; +use tokio::sync::{broadcast, mpsc, oneshot, Mutex}; use crate::adapter::error::{Error, FlowAlreadyExistSnafu, InternalSnafu, UnexpectedSnafu}; use crate::adapter::FlowId; @@ -39,7 +39,7 @@ type ReqId = usize; pub fn create_worker<'a>() -> (WorkerHandle, Worker<'a>) { let (itc_client, itc_server) = create_inter_thread_call(); let worker_handle = WorkerHandle { - itc_client: Mutex::new(itc_client), + itc_client, shutdown: AtomicBool::new(false), }; let worker = Worker { @@ -106,7 +106,7 @@ impl<'subgraph> ActiveDataflowState<'subgraph> { #[derive(Debug)] pub struct WorkerHandle { - itc_client: Mutex, + itc_client: InterThreadCallClient, shutdown: AtomicBool, } @@ -122,12 +122,7 @@ impl WorkerHandle { } ); - let ret = self - .itc_client - .lock() - .await - .call_with_resp(create_reqs) - .await?; + let ret = self.itc_client.call_with_resp(create_reqs).await?; ret.into_create().map_err(|ret| { InternalSnafu { reason: format!( @@ -141,7 +136,8 @@ impl WorkerHandle { /// remove task, return task id pub async fn remove_flow(&self, flow_id: FlowId) -> Result { let req = Request::Remove { flow_id }; - let ret = self.itc_client.lock().await.call_with_resp(req).await?; + + let ret = self.itc_client.call_with_resp(req).await?; ret.into_remove().map_err(|ret| { InternalSnafu { @@ -157,15 +153,12 @@ impl WorkerHandle { /// /// the returned error is unrecoverable, and the worker should be shutdown/rebooted pub async fn run_available(&self, now: repr::Timestamp) -> Result<(), Error> { - self.itc_client - .lock() - .await - .call_no_resp(Request::RunAvail { now }) + self.itc_client.call_no_resp(Request::RunAvail { now }) } pub async fn contains_flow(&self, flow_id: FlowId) -> Result { let req = Request::ContainTask { flow_id }; - let ret = self.itc_client.lock().await.call_with_resp(req).await?; + let ret = self.itc_client.call_with_resp(req).await?; ret.into_contain_task().map_err(|ret| { InternalSnafu { @@ -178,23 +171,9 @@ impl WorkerHandle { } /// shutdown the worker - pub async fn shutdown(&self) -> Result<(), Error> { + pub fn shutdown(&self) -> Result<(), Error> { if !self.shutdown.fetch_or(true, Ordering::SeqCst) { - self.itc_client.lock().await.call_no_resp(Request::Shutdown) - } else { - UnexpectedSnafu { - reason: "Worker already shutdown", - } - .fail() - } - } - - /// shutdown the worker - pub fn shutdown_blocking(&self) -> Result<(), Error> { - if !self.shutdown.fetch_or(true, Ordering::SeqCst) { - self.itc_client - .blocking_lock() - .call_no_resp(Request::Shutdown) + self.itc_client.call_no_resp(Request::Shutdown) } else { UnexpectedSnafu { reason: "Worker already shutdown", @@ -206,8 +185,7 @@ impl WorkerHandle { impl Drop for WorkerHandle { fn drop(&mut self) { - let ret = futures::executor::block_on(async { self.shutdown().await }); - if let Err(ret) = ret { + if let Err(ret) = self.shutdown() { common_telemetry::error!( ret; "While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state." @@ -276,7 +254,7 @@ impl<'s> Worker<'s> { /// Run the worker, blocking, until shutdown signal is received pub fn run(&mut self) { loop { - let (req_id, req) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() { + let (req, ret_tx) = if let Some(ret) = self.itc_server.blocking_lock().blocking_recv() { ret } else { common_telemetry::error!( @@ -285,19 +263,26 @@ impl<'s> Worker<'s> { break; }; - let ret = self.handle_req(req_id, req); - match ret { - Ok(Some((id, resp))) => { - if let Err(err) = self.itc_server.blocking_lock().resp(id, resp) { + let ret = self.handle_req(req); + match (ret, ret_tx) { + (Ok(Some(resp)), Some(ret_tx)) => { + if let Err(err) = ret_tx.send(resp) { common_telemetry::error!( err; - "Worker's itc server has been closed unexpectedly, shutting down worker" + "Result receiver is dropped, can't send result" ); - break; }; } - Ok(None) => continue, - Err(()) => { + (Ok(None), None) => continue, + (Ok(Some(resp)), None) => { + common_telemetry::error!( + "Expect no result for current request, but found {resp:?}" + ) + } + (Ok(None), Some(_)) => { + common_telemetry::error!("Expect result for current request, but found nothing") + } + (Err(()), _) => { break; } } @@ -315,7 +300,7 @@ impl<'s> Worker<'s> { /// handle request, return response if any, Err if receive shutdown signal /// /// return `Err(())` if receive shutdown request - fn handle_req(&mut self, req_id: ReqId, req: Request) -> Result, ()> { + fn handle_req(&mut self, req: Request) -> Result, ()> { let ret = match req { Request::Create { flow_id, @@ -339,16 +324,13 @@ impl<'s> Worker<'s> { create_if_not_exists, err_collector, ); - Some(( - req_id, - Response::Create { - result: task_create_result, - }, - )) + Some(Response::Create { + result: task_create_result, + }) } Request::Remove { flow_id } => { let ret = self.remove_flow(flow_id); - Some((req_id, Response::Remove { result: ret })) + Some(Response::Remove { result: ret }) } Request::RunAvail { now } => { self.run_tick(now); @@ -356,7 +338,7 @@ impl<'s> Worker<'s> { } Request::ContainTask { flow_id } => { let ret = self.task_states.contains_key(&flow_id); - Some((req_id, Response::ContainTask { result: ret })) + Some(Response::ContainTask { result: ret }) } Request::Shutdown => return Err(()), }; @@ -406,83 +388,50 @@ enum Response { fn create_inter_thread_call() -> (InterThreadCallClient, InterThreadCallServer) { let (arg_send, arg_recv) = mpsc::unbounded_channel(); - let (ret_send, ret_recv) = mpsc::unbounded_channel(); let client = InterThreadCallClient { - call_id: AtomicUsize::new(0), arg_sender: arg_send, - ret_recv, - }; - let server = InterThreadCallServer { - arg_recv, - ret_sender: ret_send, }; + let server = InterThreadCallServer { arg_recv }; (client, server) } #[derive(Debug)] struct InterThreadCallClient { - call_id: AtomicUsize, - arg_sender: mpsc::UnboundedSender<(ReqId, Request)>, - ret_recv: mpsc::UnboundedReceiver<(ReqId, Response)>, + arg_sender: mpsc::UnboundedSender<(Request, Option>)>, } impl InterThreadCallClient { - /// call without expecting responses or blocking fn call_no_resp(&self, req: Request) -> Result<(), Error> { - // TODO(discord9): relax memory order later - let call_id = self.call_id.fetch_add(1, Ordering::SeqCst); - self.arg_sender - .send((call_id, req)) - .map_err(from_send_error) + self.arg_sender.send((req, None)).map_err(from_send_error) } - /// call blocking, and return the result - async fn call_with_resp(&mut self, req: Request) -> Result { - // TODO(discord9): relax memory order later - let call_id = self.call_id.fetch_add(1, Ordering::SeqCst); + async fn call_with_resp(&self, req: Request) -> Result { + let (tx, rx) = oneshot::channel(); self.arg_sender - .send((call_id, req)) + .send((req, Some(tx))) .map_err(from_send_error)?; - - // TODO(discord9): better inter thread call impl, i.e. support multiple client(also consider if it's necessary) - // since one node manger might manage multiple worker, but one worker should only belong to one node manager - let (ret_call_id, ret) = self - .ret_recv - .recv() - .await - .context(InternalSnafu { reason: "InterThreadCallClient call_blocking failed, ret_recv has been closed and there are no remaining messages in the channel's buffer" })?; - - ensure!( - ret_call_id == call_id, + rx.await.map_err(|_| { InternalSnafu { - reason: "call id mismatch, worker/worker handler should be in sync", + reason: "Sender is dropped", } - ); - Ok(ret) + .build() + }) } } #[derive(Debug)] struct InterThreadCallServer { - pub arg_recv: mpsc::UnboundedReceiver<(ReqId, Request)>, - pub ret_sender: mpsc::UnboundedSender<(ReqId, Response)>, + pub arg_recv: mpsc::UnboundedReceiver<(Request, Option>)>, } impl InterThreadCallServer { - pub async fn recv(&mut self) -> Option<(usize, Request)> { + pub async fn recv(&mut self) -> Option<(Request, Option>)> { self.arg_recv.recv().await } - pub fn blocking_recv(&mut self) -> Option<(usize, Request)> { + pub fn blocking_recv(&mut self) -> Option<(Request, Option>)> { self.arg_recv.blocking_recv() } - - /// Send response back to the client - pub fn resp(&self, call_id: ReqId, resp: Response) -> Result<(), Error> { - self.ret_sender - .send((call_id, resp)) - .map_err(from_send_error) - } } fn from_send_error(err: mpsc::error::SendError) -> Error { @@ -546,7 +495,10 @@ mod test { create_if_not_exists: true, err_collector: ErrCollector::default(), }; - handle.create_flow(create_reqs).await.unwrap(); + assert_eq!( + handle.create_flow(create_reqs).await.unwrap(), + Some(flow_id) + ); tx.send((Row::empty(), 0, 0)).unwrap(); handle.run_available(0).await.unwrap(); assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty()); diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index 2109356ad621..c30b67dbffa4 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -43,7 +43,7 @@ use crate::repr::{self, value_to_internal_ts, Row}; /// UnmaterializableFunc is a function that can't be eval independently, /// and require special handling -#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] +#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)] pub enum UnmaterializableFunc { Now, CurrentSchema, diff --git a/src/flow/src/expr/linear.rs b/src/flow/src/expr/linear.rs index dcfed4eb0d28..b0e32c94d87b 100644 --- a/src/flow/src/expr/linear.rs +++ b/src/flow/src/expr/linear.rs @@ -49,7 +49,7 @@ use crate::repr::{self, value_to_internal_ts, Diff, Row}; /// expressions in `self.expressions`, even though this is not something /// we can directly evaluate. The plan creation methods will defensively /// ensure that the right thing happens. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct MapFilterProject { /// A sequence of expressions that should be appended to the row. /// @@ -462,7 +462,7 @@ impl MapFilterProject { } /// A wrapper type which indicates it is safe to simply evaluate all expressions. -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct SafeMfpPlan { /// the inner `MapFilterProject` that is safe to evaluate. pub(crate) mfp: MapFilterProject, diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index a873c267b1a5..661f716dcd29 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -23,7 +23,7 @@ mod accum; mod func; /// Describes an aggregation expression. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct AggregateExpr { /// Names the aggregation function. pub func: AggregateFunc, @@ -32,6 +32,5 @@ pub struct AggregateExpr { /// so it only used in generate KeyValPlan from AggregateExpr pub expr: ScalarExpr, /// Should the aggregation be applied only to distinct results in each group. - #[serde(default)] pub distinct: bool, } diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 7335511be0f0..591d2c246fc1 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -43,7 +43,7 @@ use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFun use crate::repr::{ColumnType, RelationDesc, RelationType}; use crate::transform::{from_scalar_fn_to_df_fn_impl, FunctionExtensions}; /// A scalar expression with a known type. -#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)] +#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)] pub struct TypedExpr { /// The expression. pub expr: ScalarExpr, @@ -129,7 +129,7 @@ impl TypedExpr { } /// A scalar expression, which can be evaluated to a value. -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ScalarExpr { /// A column of the input row Column(usize), @@ -191,9 +191,9 @@ impl DfScalarFunction { }) } - pub fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result { + pub async fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result { Ok(Self { - fn_impl: raw_fn.get_fn_impl()?, + fn_impl: raw_fn.get_fn_impl().await?, df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?), raw_fn, }) @@ -264,27 +264,7 @@ impl DfScalarFunction { } } -// simply serialize the raw_fn instead of derive to avoid complex deserialize of struct -impl Serialize for DfScalarFunction { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.raw_fn.serialize(serializer) - } -} - -impl<'de> serde::de::Deserialize<'de> for DfScalarFunction { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - let raw_fn = RawDfScalarFn::deserialize(deserializer)?; - DfScalarFunction::try_from_raw_fn(raw_fn).map_err(serde::de::Error::custom) - } -} - -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct RawDfScalarFn { /// The raw bytes encoded datafusion scalar function pub(crate) f: bytes::BytesMut, @@ -311,7 +291,7 @@ impl RawDfScalarFn { extensions, }) } - fn get_fn_impl(&self) -> Result, Error> { + async fn get_fn_impl(&self) -> Result, Error> { let f = ScalarFunction::decode(&mut self.f.as_ref()) .context(DecodeRelSnafu) .map_err(BoxedError::new) @@ -320,7 +300,7 @@ impl RawDfScalarFn { let input_schema = &self.input_schema; let extensions = &self.extensions; - from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions) + from_scalar_fn_to_df_fn_impl(&f, input_schema, extensions).await } } @@ -894,10 +874,7 @@ mod test { .unwrap(); let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]); let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap(); - let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).unwrap(); - let as_str = serde_json::to_string(&df_func).unwrap(); - let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap(); - assert_eq!(df_func, from_str); + let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await.unwrap(); assert_eq!( df_func .eval(&[Value::Null], &[ScalarExpr::Column(0)]) diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index 6e4b13673302..95816b17cb03 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -33,7 +33,7 @@ pub(crate) use crate::plan::reduce::{AccumulablePlan, AggrWithIndex, KeyValPlan, use crate::repr::{ColumnType, DiffRow, RelationDesc, RelationType}; /// A plan for a dataflow component. But with type to indicate the output type of the relation. -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct TypedPlan { /// output type of the relation pub schema: RelationDesc, @@ -121,7 +121,7 @@ impl TypedPlan { /// TODO(discord9): support `TableFunc`(by define FlatMap that map 1 to n) /// Plan describe how to transform data in dataflow -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum Plan { /// A constant collection of rows. Constant { rows: Vec }, diff --git a/src/flow/src/plan/join.rs b/src/flow/src/plan/join.rs index 13bb95f51159..4acf0db2342e 100644 --- a/src/flow/src/plan/join.rs +++ b/src/flow/src/plan/join.rs @@ -18,13 +18,13 @@ use crate::expr::ScalarExpr; use crate::plan::SafeMfpPlan; /// TODO(discord9): consider impl more join strategies -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub enum JoinPlan { Linear(LinearJoinPlan), } /// Determine if a given row should stay in the output. And apply a map filter project before output the row -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct JoinFilter { /// each element in the outer vector will check if each expr in itself can be eval to same value /// if not, the row will be filtered out. Useful for equi-join(join based on equality of some columns) @@ -37,7 +37,7 @@ pub struct JoinFilter { /// /// A linear join is a sequence of stages, each of which introduces /// a new collection. Each stage is represented by a [LinearStagePlan]. -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct LinearJoinPlan { /// The source relation from which we start the join. pub source_relation: usize, @@ -60,7 +60,7 @@ pub struct LinearJoinPlan { /// Each stage is a binary join between the current accumulated /// join results, and a new collection. The former is referred to /// as the "stream" and the latter the "lookup". -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq, Ord, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct LinearStagePlan { /// The index of the relation into which we will look up. pub lookup_relation: usize, diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index 85c84a42f342..3d0d8b356a37 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use crate::expr::{AggregateExpr, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr}; /// Describe how to extract key-value pair from a `Row` -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct KeyValPlan { /// Extract key from row pub key_plan: SafeMfpPlan, @@ -27,7 +27,7 @@ pub struct KeyValPlan { /// TODO(discord9): def&impl of Hierarchical aggregates(for min/max with support to deletion) and /// basic aggregates(for other aggregate functions) and mixed aggregate -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub enum ReducePlan { /// Plan for not computing any aggregations, just determining the set of /// distinct keys. @@ -38,7 +38,7 @@ pub enum ReducePlan { } /// Accumulable plan for the execution of a reduction. -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct AccumulablePlan { /// All of the aggregations we were asked to compute, stored /// in order. @@ -57,7 +57,7 @@ pub struct AccumulablePlan { /// Invariant: the output index is the index of the aggregation in `full_aggrs` /// which means output index is always smaller than the length of `full_aggrs` -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct AggrWithIndex { /// aggregation expression pub expr: AggregateExpr, diff --git a/src/flow/src/transform.rs b/src/flow/src/transform.rs index 35d811a03732..e86dac85fb7f 100644 --- a/src/flow/src/transform.rs +++ b/src/flow/src/transform.rs @@ -140,7 +140,7 @@ pub async fn sql_to_flow_plan( .map_err(BoxedError::new) .context(ExternalSnafu)?; - let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan)?; + let flow_plan = TypedPlan::from_substrait_plan(ctx, &sub_plan).await?; Ok(flow_plan) } diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 3cc1512692d5..6456f00a5c75 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -58,7 +58,7 @@ use crate::repr::{self, ColumnType, RelationDesc, RelationType}; use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; impl TypedExpr { - fn from_substrait_agg_grouping( + async fn from_substrait_agg_grouping( ctx: &mut FlownodeContext, groupings: &[Grouping], typ: &RelationDesc, @@ -69,7 +69,7 @@ impl TypedExpr { match groupings.len() { 1 => { for e in &groupings[0].grouping_expressions { - let x = TypedExpr::from_substrait_rex(e, typ, extensions)?; + let x = TypedExpr::from_substrait_rex(e, typ, extensions).await?; group_expr.push(x); } } @@ -87,7 +87,7 @@ impl AggregateExpr { /// Convert list of `Measure` into Flow's AggregateExpr /// /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function - fn from_substrait_agg_measures( + async fn from_substrait_agg_measures( ctx: &mut FlownodeContext, measures: &[Measure], typ: &RelationDesc, @@ -98,11 +98,15 @@ impl AggregateExpr { let mut post_maps = vec![]; for m in measures { - let filter = &m + let filter = match m .filter .as_ref() .map(|fil| TypedExpr::from_substrait_rex(fil, typ, extensions)) - .transpose()?; + { + Some(fut) => Some(fut.await), + None => None, + } + .transpose()?; let (aggr_expr, post_mfp) = match &m.measure { Some(f) => { @@ -112,9 +116,10 @@ impl AggregateExpr { _ => false, }; AggregateExpr::from_substrait_agg_func( - f, typ, extensions, filter, // TODO(discord9): impl order_by + f, typ, extensions, &filter, // TODO(discord9): impl order_by &None, distinct, ) + .await } None => not_impl_err!("Aggregate without aggregate function is not supported"), }?; @@ -142,7 +147,7 @@ impl AggregateExpr { /// /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)` - pub fn from_substrait_agg_func( + pub async fn from_substrait_agg_func( f: &proto::AggregateFunction, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -157,7 +162,7 @@ impl AggregateExpr { for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - TypedExpr::from_substrait_rex(e, input_schema, extensions) + TypedExpr::from_substrait_rex(e, input_schema, extensions).await } _ => not_impl_err!("Aggregated function argument non-Value type not supported"), }?; @@ -306,13 +311,14 @@ impl TypedPlan { /// The output of aggr plan is: /// /// .. - pub fn from_substrait_agg_rel( + #[async_recursion::async_recursion] + pub async fn from_substrait_agg_rel( ctx: &mut FlownodeContext, agg: &proto::AggregateRel, extensions: &FunctionExtensions, ) -> Result { let input = if let Some(input) = agg.input.as_ref() { - TypedPlan::from_substrait_rel(ctx, input, extensions)? + TypedPlan::from_substrait_rel(ctx, input, extensions).await? } else { return not_impl_err!("Aggregate without an input is not supported"); }; @@ -323,7 +329,8 @@ impl TypedPlan { &agg.groupings, &input.schema, extensions, - )?; + ) + .await?; TypedExpr::expand_multi_value(&input.schema.typ, &group_exprs)? }; @@ -335,7 +342,8 @@ impl TypedPlan { &agg.measures, &input.schema, extensions, - )?; + ) + .await?; let key_val_plan = KeyValPlan::from_substrait_gen_key_val_plan( &mut aggr_exprs, @@ -479,7 +487,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan).is_err()); + assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .is_err()); } #[tokio::test] @@ -489,7 +499,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -578,6 +590,7 @@ mod test { }, }, ) + .await .unwrap(), exprs: vec![ScalarExpr::Column(0)], }]) @@ -630,7 +643,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -743,6 +758,7 @@ mod test { ]), }, }) + .await .unwrap(), exprs: vec![ScalarExpr::Column(3)], }, @@ -766,7 +782,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_exprs = vec![ AggregateExpr { @@ -913,7 +931,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -1029,7 +1049,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -1145,7 +1167,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_exprs = vec![ AggregateExpr { @@ -1250,7 +1272,9 @@ mod test { let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_exprs = vec![ AggregateExpr { @@ -1341,7 +1365,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let typ = RelationType::new(vec![ColumnType::new( ConcreteDataType::uint64_datatype(), true, @@ -1404,7 +1428,9 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap(); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan) + .await + .unwrap(); let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, @@ -1482,7 +1508,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_expr = AggregateExpr { func: AggregateFunc::SumUInt32, diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index a6b312504d28..a10e9b121f8c 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use datafusion_physical_expr::PhysicalExpr; use datatypes::data_type::ConcreteDataType as CDT; -use itertools::Itertools; use snafu::{OptionExt, ResultExt}; use substrait_proto::proto::expression::field_reference::ReferenceType::DirectReference; use substrait_proto::proto::expression::reference_segment::ReferenceType::StructField; @@ -60,7 +59,7 @@ fn typename_to_cdt(name: &str) -> CDT { } /// Convert [`ScalarFunction`] to corresponding Datafusion's [`PhysicalExpr`] -pub(crate) fn from_scalar_fn_to_df_fn_impl( +pub(crate) async fn from_scalar_fn_to_df_fn_impl( f: &ScalarFunction, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -70,7 +69,7 @@ pub(crate) fn from_scalar_fn_to_df_fn_impl( }; let schema = input_schema.to_df_schema()?; - let df_expr = futures::executor::block_on(async { + let df_expr = // TODO(discord9): consider coloring everything async.... substrait::df_logical_plan::consumer::from_substrait_rex( &datafusion::prelude::SessionContext::new(), @@ -79,7 +78,7 @@ pub(crate) fn from_scalar_fn_to_df_fn_impl( &extensions.inner_ref(), ) .await - }); + ; let expr = df_expr.map_err(|err| { DatafusionSnafu { raw: err, @@ -138,7 +137,7 @@ fn rewrite_scalar_function(f: &ScalarFunction) -> ScalarFunction { } impl TypedExpr { - pub fn from_substrait_to_datafusion_scalar_func( + pub async fn from_substrait_to_datafusion_scalar_func( f: &ScalarFunction, arg_exprs_typed: Vec, extensions: &FunctionExtensions, @@ -152,7 +151,7 @@ impl TypedExpr { let raw_fn = RawDfScalarFn::from_proto(&f_rewrite, input_schema.clone(), extensions.clone())?; - let df_func = DfScalarFunction::try_from_raw_fn(raw_fn)?; + let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).await?; let expr = ScalarExpr::CallDf { df_scalar_fn: df_func, exprs: arg_exprs, @@ -163,7 +162,7 @@ impl TypedExpr { } /// Convert ScalarFunction into Flow's ScalarExpr - pub fn from_substrait_scalar_func( + pub async fn from_substrait_scalar_func( f: &ScalarFunction, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -178,16 +177,19 @@ impl TypedExpr { ), })?; let arg_len = f.arguments.len(); - let arg_typed_exprs: Vec = f - .arguments - .iter() - .map(|arg| match &arg.arg_type { - Some(ArgType::Value(e)) => { - TypedExpr::from_substrait_rex(e, input_schema, extensions) - } - _ => not_impl_err!("Aggregated function argument non-Value type not supported"), - }) - .try_collect()?; + let arg_typed_exprs: Vec = { + let mut rets = Vec::new(); + for arg in f.arguments.iter() { + let ret = match &arg.arg_type { + Some(ArgType::Value(e)) => { + TypedExpr::from_substrait_rex(e, input_schema, extensions).await + } + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }?; + rets.push(ret); + } + rets + }; // literal's type is determined by the function and type of other args let (arg_exprs, arg_types): (Vec<_>, Vec<_>) = arg_typed_exprs @@ -293,7 +295,8 @@ impl TypedExpr { f, arg_typed_exprs, extensions, - )?; + ) + .await?; Ok(try_as_df) } } @@ -301,38 +304,44 @@ impl TypedExpr { } /// Convert IfThen into Flow's ScalarExpr - pub fn from_substrait_ifthen_rex( + pub async fn from_substrait_ifthen_rex( if_then: &IfThen, input_schema: &RelationDesc, extensions: &FunctionExtensions, ) -> Result { - let ifs: Vec<_> = if_then - .ifs - .iter() - .map(|if_clause| { + let ifs: Vec<_> = { + let mut ifs = Vec::new(); + for if_clause in if_then.ifs.iter() { let proto_if = if_clause.r#if.as_ref().with_context(|| InvalidQuerySnafu { reason: "IfThen clause without if", })?; let proto_then = if_clause.then.as_ref().with_context(|| InvalidQuerySnafu { reason: "IfThen clause without then", })?; - let cond = TypedExpr::from_substrait_rex(proto_if, input_schema, extensions)?; - let then = TypedExpr::from_substrait_rex(proto_then, input_schema, extensions)?; - Ok((cond, then)) - }) - .try_collect()?; + let cond = + TypedExpr::from_substrait_rex(proto_if, input_schema, extensions).await?; + let then = + TypedExpr::from_substrait_rex(proto_then, input_schema, extensions).await?; + ifs.push((cond, then)); + } + ifs + }; // if no else is presented - let els = if_then + let els = match if_then .r#else .as_ref() .map(|e| TypedExpr::from_substrait_rex(e, input_schema, extensions)) - .transpose()? - .unwrap_or_else(|| { - TypedExpr::new( - ScalarExpr::literal_null(), - ColumnType::new_nullable(CDT::null_datatype()), - ) - }); + { + Some(fut) => Some(fut.await), + None => None, + } + .transpose()? + .unwrap_or_else(|| { + TypedExpr::new( + ScalarExpr::literal_null(), + ColumnType::new_nullable(CDT::null_datatype()), + ) + }); fn build_if_then_recur( mut next_if_then: impl Iterator, @@ -356,7 +365,8 @@ impl TypedExpr { Ok(expr_if) } /// Convert Substrait Rex into Flow's ScalarExpr - pub fn from_substrait_rex( + #[async_recursion::async_recursion] + pub async fn from_substrait_rex( e: &Expression, input_schema: &RelationDesc, extensions: &FunctionExtensions, @@ -377,7 +387,7 @@ impl TypedExpr { if !s.options.is_empty() { return not_impl_err!("In list expression is not supported"); } - TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions) + TypedExpr::from_substrait_rex(substrait_expr, input_schema, extensions).await } Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { @@ -400,16 +410,16 @@ impl TypedExpr { _ => not_impl_err!("unsupported field ref type"), }, Some(RexType::ScalarFunction(f)) => { - TypedExpr::from_substrait_scalar_func(f, input_schema, extensions) + TypedExpr::from_substrait_scalar_func(f, input_schema, extensions).await } Some(RexType::IfThen(if_then)) => { - TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions) + TypedExpr::from_substrait_ifthen_rex(if_then, input_schema, extensions).await } Some(RexType::Cast(cast)) => { let input = cast.input.as_ref().with_context(|| InvalidQuerySnafu { reason: "Cast expression without input", })?; - let input = TypedExpr::from_substrait_rex(input, input_schema, extensions)?; + let input = TypedExpr::from_substrait_rex(input, input_schema, extensions).await?; let cast_type = from_substrait_type(cast.r#type.as_ref().with_context(|| { InvalidQuerySnafu { reason: "Cast expression without type", @@ -453,7 +463,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; // optimize binary and to variadic and let filter = ScalarExpr::CallVariadic { @@ -509,7 +519,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::boolean_datatype(), true)]) @@ -534,7 +544,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) @@ -572,7 +582,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::int16_datatype(), true)]) @@ -611,7 +621,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), true)]) @@ -641,8 +651,8 @@ mod test { assert_eq!(flow_plan.unwrap(), expected); } - #[test] - fn test_func_sig() { + #[tokio::test] + async fn test_func_sig() { fn lit(v: impl ToString) -> substrait_proto::proto::FunctionArgument { use substrait_proto::proto::expression; let expr = Expression { @@ -669,7 +679,9 @@ mod test { let input_schema = RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]).into_unnamed(); let extensions = FunctionExtensions::from_iter([(0, "is_null".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, @@ -695,7 +707,9 @@ mod test { ]) .into_unnamed(); let extensions = FunctionExtensions::from_iter([(0, "add".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, @@ -722,7 +736,9 @@ mod test { ]) .into_unnamed(); let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, @@ -750,7 +766,9 @@ mod test { ]) .into_unnamed(); let extensions = FunctionExtensions::from_iter(vec![(0, "tumble".to_string())]); - let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions).unwrap(); + let res = TypedExpr::from_substrait_scalar_func(&f, &input_schema, &extensions) + .await + .unwrap(); assert_eq!( res, diff --git a/src/flow/src/transform/literal.rs b/src/flow/src/transform/literal.rs index 9dc93d17c549..1fa5bc86a81c 100644 --- a/src/flow/src/transform/literal.rs +++ b/src/flow/src/transform/literal.rs @@ -172,7 +172,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)]) diff --git a/src/flow/src/transform/plan.rs b/src/flow/src/transform/plan.rs index a9d9e29310e9..f1f6ba53dd35 100644 --- a/src/flow/src/transform/plan.rs +++ b/src/flow/src/transform/plan.rs @@ -32,7 +32,7 @@ use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; impl TypedPlan { /// Convert Substrait Plan into Flow's TypedPlan - pub fn from_substrait_plan( + pub async fn from_substrait_plan( ctx: &mut FlownodeContext, plan: &SubPlan, ) -> Result { @@ -45,13 +45,13 @@ impl TypedPlan { match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension)?) + Ok(TypedPlan::from_substrait_rel(ctx, rel, &function_extension).await?) }, plan_rel::RelType::Root(root) => { let input = root.input.as_ref().with_context(|| InvalidQuerySnafu { reason: "Root relation without input", })?; - Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension)?) + Ok(TypedPlan::from_substrait_rel(ctx, input, &function_extension).await?) } }, None => plan_err!("Cannot parse plan relation: None") @@ -64,13 +64,14 @@ impl TypedPlan { } } - pub fn from_substrait_project( + #[async_recursion::async_recursion] + pub async fn from_substrait_project( ctx: &mut FlownodeContext, p: &ProjectRel, extensions: &FunctionExtensions, ) -> Result { let input = if let Some(input) = p.input.as_ref() { - TypedPlan::from_substrait_rel(ctx, input, extensions)? + TypedPlan::from_substrait_rel(ctx, input, extensions).await? } else { return not_impl_err!("Projection without an input is not supported"); }; @@ -93,7 +94,7 @@ impl TypedPlan { let mut exprs: Vec = Vec::with_capacity(p.expressions.len()); for e in &p.expressions { - let expr = TypedExpr::from_substrait_rex(e, &schema_before_expand, extensions)?; + let expr = TypedExpr::from_substrait_rex(e, &schema_before_expand, extensions).await?; exprs.push(expr); } let is_literal = exprs.iter().all(|expr| expr.expr.is_literal()); @@ -131,26 +132,27 @@ impl TypedPlan { } } - pub fn from_substrait_filter( + #[async_recursion::async_recursion] + pub async fn from_substrait_filter( ctx: &mut FlownodeContext, filter: &FilterRel, extensions: &FunctionExtensions, ) -> Result { let input = if let Some(input) = filter.input.as_ref() { - TypedPlan::from_substrait_rel(ctx, input, extensions)? + TypedPlan::from_substrait_rel(ctx, input, extensions).await? } else { return not_impl_err!("Filter without an input is not supported"); }; let expr = if let Some(condition) = filter.condition.as_ref() { - TypedExpr::from_substrait_rex(condition, &input.schema, extensions)? + TypedExpr::from_substrait_rex(condition, &input.schema, extensions).await? } else { return not_impl_err!("Filter without an condition is not valid"); }; input.filter(expr) } - pub fn from_substrait_read( + pub async fn from_substrait_read( ctx: &mut FlownodeContext, read: &ReadRel, _extensions: &FunctionExtensions, @@ -212,16 +214,22 @@ impl TypedPlan { /// Convert Substrait Rel into Flow's TypedPlan /// TODO(discord9): SELECT DISTINCT(does it get compile with something else?) - pub fn from_substrait_rel( + pub async fn from_substrait_rel( ctx: &mut FlownodeContext, rel: &Rel, extensions: &FunctionExtensions, ) -> Result { match &rel.rel_type { - Some(RelType::Project(p)) => Self::from_substrait_project(ctx, p.as_ref(), extensions), - Some(RelType::Filter(filter)) => Self::from_substrait_filter(ctx, filter, extensions), - Some(RelType::Read(read)) => Self::from_substrait_read(ctx, read, extensions), - Some(RelType::Aggregate(agg)) => Self::from_substrait_agg_rel(ctx, agg, extensions), + Some(RelType::Project(p)) => { + Self::from_substrait_project(ctx, p.as_ref(), extensions).await + } + Some(RelType::Filter(filter)) => { + Self::from_substrait_filter(ctx, filter, extensions).await + } + Some(RelType::Read(read)) => Self::from_substrait_read(ctx, read, extensions).await, + Some(RelType::Aggregate(agg)) => { + Self::from_substrait_agg_rel(ctx, agg, extensions).await + } _ => not_impl_err!("Unsupported relation type: {:?}", rel.rel_type), } } @@ -353,7 +361,7 @@ mod test { let plan = sql_to_substrait(engine.clone(), sql).await; let mut ctx = create_test_ctx(); - let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan); + let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint32_datatype(), false)]) diff --git a/src/flow/src/utils.rs b/src/flow/src/utils.rs index 69c300ab8f5c..69ff8fa2d248 100644 --- a/src/flow/src/utils.rs +++ b/src/flow/src/utils.rs @@ -40,7 +40,7 @@ pub type Spine = BTreeMap; /// If a key is expired, any future updates to it should be ignored. /// /// Note that key is expired by it's event timestamp (contained in the key), not by the time it's inserted (system timestamp). -#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct KeyExpiryManager { /// A map from event timestamp to key, used for expire keys. event_ts_to_key: BTreeMap>, @@ -157,7 +157,7 @@ impl KeyExpiryManager { /// /// Note the two way arrow between reduce operator and arrange, it's because reduce operator need to query existing state /// and also need to update existing state. -#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] +#[derive(Debug, Clone, Default, Eq, PartialEq, Ord, PartialOrd)] pub struct Arrangement { /// A name or identifier for the arrangement which can be used for debugging or logging purposes. /// This field is not critical to the functionality but aids in monitoring and management of arrangements. diff --git a/src/mito2/benches/memtable_bench.rs b/src/mito2/benches/memtable_bench.rs index f50d3e32dd5e..4309520cdd0c 100644 --- a/src/mito2/benches/memtable_bench.rs +++ b/src/mito2/benches/memtable_bench.rs @@ -24,6 +24,7 @@ use datatypes::schema::ColumnSchema; use mito2::memtable::partition_tree::{PartitionTreeConfig, PartitionTreeMemtable}; use mito2::memtable::time_series::TimeSeriesMemtable; use mito2::memtable::{KeyValues, Memtable}; +use mito2::region::options::MergeMode; use mito2::test_util::memtable_util::{self, region_metadata_to_row_schema}; use rand::rngs::ThreadRng; use rand::seq::SliceRandom; @@ -51,7 +52,7 @@ fn write_rows(c: &mut Criterion) { }); }); group.bench_function("time_series", |b| { - let memtable = TimeSeriesMemtable::new(metadata.clone(), 1, None, true); + let memtable = TimeSeriesMemtable::new(metadata.clone(), 1, None, true, MergeMode::LastRow); let kvs = memtable_util::build_key_values(&metadata, "hello".to_string(), 42, ×tamps, 1); b.iter(|| { @@ -83,7 +84,7 @@ fn full_scan(c: &mut Criterion) { }); }); group.bench_function("time_series", |b| { - let memtable = TimeSeriesMemtable::new(metadata.clone(), 1, None, true); + let memtable = TimeSeriesMemtable::new(metadata.clone(), 1, None, true, MergeMode::LastRow); for kvs in generator.iter() { memtable.write(&kvs).unwrap(); } @@ -121,7 +122,7 @@ fn filter_1_host(c: &mut Criterion) { }); }); group.bench_function("time_series", |b| { - let memtable = TimeSeriesMemtable::new(metadata.clone(), 1, None, true); + let memtable = TimeSeriesMemtable::new(metadata.clone(), 1, None, true, MergeMode::LastRow); for kvs in generator.iter() { memtable.write(&kvs).unwrap(); } diff --git a/src/mito2/src/compaction.rs b/src/mito2/src/compaction.rs index 5204a2b32bd6..d9593bfe5892 100644 --- a/src/mito2/src/compaction.rs +++ b/src/mito2/src/compaction.rs @@ -55,6 +55,7 @@ use crate::read::projection::ProjectionMapper; use crate::read::scan_region::ScanInput; use crate::read::seq_scan::SeqScan; use crate::read::BoxedBatchReader; +use crate::region::options::MergeMode; use crate::region::version::{VersionControlRef, VersionRef}; use crate::region::ManifestContextRef; use crate::request::{OptionOutputTx, OutputTx, WorkerRequest}; @@ -454,31 +455,39 @@ pub struct SerializedCompactionOutput { output_time_range: Option, } -/// Builds [BoxedBatchReader] that reads all SST files and yields batches in primary key order. -async fn build_sst_reader( +/// Builders to create [BoxedBatchReader] for compaction. +struct CompactionSstReaderBuilder<'a> { metadata: RegionMetadataRef, sst_layer: AccessLayerRef, cache: Option, - inputs: &[FileHandle], + inputs: &'a [FileHandle], append_mode: bool, filter_deleted: bool, time_range: Option, -) -> Result { - let mut scan_input = ScanInput::new(sst_layer, ProjectionMapper::all(&metadata)?) - .with_files(inputs.to_vec()) - .with_append_mode(append_mode) - .with_cache(cache) - .with_filter_deleted(filter_deleted) - // We ignore file not found error during compaction. - .with_ignore_file_not_found(true); - - // This serves as a workaround of https://github.com/GreptimeTeam/greptimedb/issues/3944 - // by converting time ranges into predicate. - if let Some(time_range) = time_range { - scan_input = scan_input.with_predicate(time_range_to_predicate(time_range, &metadata)?); - } + merge_mode: MergeMode, +} - SeqScan::new(scan_input).build_reader().await +impl<'a> CompactionSstReaderBuilder<'a> { + /// Builds [BoxedBatchReader] that reads all SST files and yields batches in primary key order. + async fn build_sst_reader(self) -> Result { + let mut scan_input = ScanInput::new(self.sst_layer, ProjectionMapper::all(&self.metadata)?) + .with_files(self.inputs.to_vec()) + .with_append_mode(self.append_mode) + .with_cache(self.cache) + .with_filter_deleted(self.filter_deleted) + // We ignore file not found error during compaction. + .with_ignore_file_not_found(true) + .with_merge_mode(self.merge_mode); + + // This serves as a workaround of https://github.com/GreptimeTeam/greptimedb/issues/3944 + // by converting time ranges into predicate. + if let Some(time_range) = self.time_range { + scan_input = + scan_input.with_predicate(time_range_to_predicate(time_range, &self.metadata)?); + } + + SeqScan::new(scan_input).build_reader().await + } } /// Converts time range to predicates so that rows outside the range will be filtered. diff --git a/src/mito2/src/compaction/compactor.rs b/src/mito2/src/compaction/compactor.rs index 827aba80cca0..216dff88e792 100644 --- a/src/mito2/src/compaction/compactor.rs +++ b/src/mito2/src/compaction/compactor.rs @@ -26,8 +26,8 @@ use store_api::storage::RegionId; use crate::access_layer::{AccessLayer, AccessLayerRef, SstWriteRequest}; use crate::cache::{CacheManager, CacheManagerRef}; -use crate::compaction::build_sst_reader; use crate::compaction::picker::{new_picker, PickerOutput}; +use crate::compaction::CompactionSstReaderBuilder; use crate::config::MitoConfig; use crate::error::{EmptyRegionDirSnafu, JoinSnafu, ObjectStoreNotFoundSnafu, Result}; use crate::manifest::action::{RegionEdit, RegionMetaAction, RegionMetaActionList}; @@ -137,7 +137,8 @@ pub async fn open_compaction_region( let memtable_builder = MemtableBuilderProvider::new(None, Arc::new(mito_config.clone())) .builder_for_options( req.region_options.memtable.as_ref(), - !req.region_options.append_mode, + req.region_options.need_dedup(), + req.region_options.merge_mode(), ); // Initial memtable id is 0. @@ -282,16 +283,19 @@ impl Compactor for DefaultCompactor { .index_options .clone(); let append_mode = compaction_region.current_version.options.append_mode; + let merge_mode = compaction_region.current_version.options.merge_mode(); futs.push(async move { - let reader = build_sst_reader( - region_metadata.clone(), - sst_layer.clone(), - Some(cache_manager.clone()), - &output.inputs, + let reader = CompactionSstReaderBuilder { + metadata: region_metadata.clone(), + sst_layer: sst_layer.clone(), + cache: Some(cache_manager.clone()), + inputs: &output.inputs, append_mode, - output.filter_deleted, - output.output_time_range, - ) + filter_deleted: output.filter_deleted, + time_range: output.output_time_range, + merge_mode, + } + .build_sst_reader() .await?; let file_meta_opt = sst_layer .write_sst( diff --git a/src/mito2/src/compaction/window.rs b/src/mito2/src/compaction/window.rs index 1683d28f9a9c..cf5f1721555a 100644 --- a/src/mito2/src/compaction/window.rs +++ b/src/mito2/src/compaction/window.rs @@ -260,6 +260,7 @@ mod tests { wal_options: Default::default(), index_options: Default::default(), memtable: None, + merge_mode: None, }, }) } diff --git a/src/mito2/src/engine.rs b/src/mito2/src/engine.rs index 86324ec9f097..7af21da298be 100644 --- a/src/mito2/src/engine.rs +++ b/src/mito2/src/engine.rs @@ -39,6 +39,8 @@ mod flush_test; #[cfg(any(test, feature = "test"))] pub mod listener; #[cfg(test)] +mod merge_mode_test; +#[cfg(test)] mod open_test; #[cfg(test)] mod parallel_test; diff --git a/src/mito2/src/engine/append_mode_test.rs b/src/mito2/src/engine/append_mode_test.rs index 6e4d1e62734c..0fb148be44b0 100644 --- a/src/mito2/src/engine/append_mode_test.rs +++ b/src/mito2/src/engine/append_mode_test.rs @@ -113,7 +113,7 @@ async fn test_append_mode_compaction() { .await .unwrap(); - // Flush 2 SSTs for compaction. + // Flush 3 SSTs for compaction. // a, field 1, 2 let rows = Rows { schema: column_schemas.clone(), diff --git a/src/mito2/src/engine/merge_mode_test.rs b/src/mito2/src/engine/merge_mode_test.rs new file mode 100644 index 000000000000..1adf51d12f41 --- /dev/null +++ b/src/mito2/src/engine/merge_mode_test.rs @@ -0,0 +1,208 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Tests for append mode. + +use api::v1::Rows; +use common_recordbatch::RecordBatches; +use store_api::region_engine::RegionEngine; +use store_api::region_request::{RegionCompactRequest, RegionRequest}; +use store_api::storage::{RegionId, ScanRequest}; + +use crate::config::MitoConfig; +use crate::test_util::batch_util::sort_batches_and_print; +use crate::test_util::{ + build_delete_rows_for_key, build_rows_with_fields, delete_rows, delete_rows_schema, + flush_region, put_rows, reopen_region, rows_schema, CreateRequestBuilder, TestEnv, +}; + +#[tokio::test] +async fn test_merge_mode_write_query() { + common_telemetry::init_default_ut_logging(); + + let mut env = TestEnv::new(); + let engine = env.create_engine(MitoConfig::default()).await; + + let region_id = RegionId::new(1, 1); + let request = CreateRequestBuilder::new() + .field_num(2) + .insert_option("merge_mode", "last_non_null") + .build(); + + let column_schemas = rows_schema(&request); + engine + .handle_request(region_id, RegionRequest::Create(request)) + .await + .unwrap(); + + let rows = build_rows_with_fields( + "a", + &[1, 2, 3], + &[(Some(1), None), (None, None), (None, Some(3))], + ); + let rows = Rows { + schema: column_schemas.clone(), + rows, + }; + put_rows(&engine, region_id, rows).await; + + let rows = build_rows_with_fields("a", &[2, 3], &[(Some(12), None), (Some(13), None)]); + let rows = Rows { + schema: column_schemas.clone(), + rows, + }; + put_rows(&engine, region_id, rows).await; + + let rows = build_rows_with_fields("a", &[1, 2], &[(Some(11), None), (Some(22), Some(222))]); + let rows = Rows { + schema: column_schemas, + rows, + }; + put_rows(&engine, region_id, rows).await; + + let request = ScanRequest::default(); + let stream = engine.scan_to_stream(region_id, request).await.unwrap(); + let batches = RecordBatches::try_collect(stream).await.unwrap(); + let expected = "\ ++-------+---------+---------+---------------------+ +| tag_0 | field_0 | field_1 | ts | ++-------+---------+---------+---------------------+ +| a | 11.0 | | 1970-01-01T00:00:01 | +| a | 22.0 | 222.0 | 1970-01-01T00:00:02 | +| a | 13.0 | 3.0 | 1970-01-01T00:00:03 | ++-------+---------+---------+---------------------+"; + assert_eq!(expected, batches.pretty_print().unwrap()); +} + +#[tokio::test] +async fn test_merge_mode_compaction() { + common_telemetry::init_default_ut_logging(); + + let mut env = TestEnv::new(); + let engine = env + .create_engine(MitoConfig { + scan_parallelism: 2, + ..Default::default() + }) + .await; + let region_id = RegionId::new(1, 1); + + let request = CreateRequestBuilder::new() + .field_num(2) + .insert_option("compaction.type", "twcs") + .insert_option("compaction.twcs.max_active_window_files", "2") + .insert_option("compaction.twcs.max_inactive_window_files", "2") + .insert_option("merge_mode", "last_non_null") + .build(); + let region_dir = request.region_dir.clone(); + let region_opts = request.options.clone(); + let delete_schema = delete_rows_schema(&request); + let column_schemas = rows_schema(&request); + engine + .handle_request(region_id, RegionRequest::Create(request)) + .await + .unwrap(); + + // Flush 3 SSTs for compaction. + // a, 1 => (1, null), 2 => (null, null), 3 => (null, 3), 4 => (4, 4) + let rows = build_rows_with_fields( + "a", + &[1, 2, 3, 4], + &[ + (Some(1), None), + (None, None), + (None, Some(3)), + (Some(4), Some(4)), + ], + ); + let rows = Rows { + schema: column_schemas.clone(), + rows, + }; + put_rows(&engine, region_id, rows).await; + flush_region(&engine, region_id, None).await; + + // a, 1 => (null, 11), 2 => (2, null), 3 => (null, 13) + let rows = build_rows_with_fields( + "a", + &[1, 2, 3], + &[(None, Some(11)), (Some(2), None), (None, Some(13))], + ); + let rows = Rows { + schema: column_schemas.clone(), + rows, + }; + put_rows(&engine, region_id, rows).await; + flush_region(&engine, region_id, None).await; + + // Delete a, 4 + let rows = Rows { + schema: delete_schema.clone(), + rows: build_delete_rows_for_key("a", 4, 5), + }; + delete_rows(&engine, region_id, rows).await; + flush_region(&engine, region_id, None).await; + + let output = engine + .handle_request( + region_id, + RegionRequest::Compact(RegionCompactRequest::default()), + ) + .await + .unwrap(); + assert_eq!(output.affected_rows, 0); + + // a, 1 => (21, null), 2 => (22, null) + let rows = build_rows_with_fields("a", &[1, 2], &[(Some(21), None), (Some(22), None)]); + let rows = Rows { + schema: column_schemas.clone(), + rows, + }; + put_rows(&engine, region_id, rows).await; + + let expected = "\ ++-------+---------+---------+---------------------+ +| tag_0 | field_0 | field_1 | ts | ++-------+---------+---------+---------------------+ +| a | 21.0 | 11.0 | 1970-01-01T00:00:01 | +| a | 22.0 | | 1970-01-01T00:00:02 | +| a | | 13.0 | 1970-01-01T00:00:03 | ++-------+---------+---------+---------------------+"; + // Scans in parallel. + let scanner = engine.scanner(region_id, ScanRequest::default()).unwrap(); + assert_eq!(1, scanner.num_files()); + assert_eq!(1, scanner.num_memtables()); + let stream = scanner.scan().await.unwrap(); + let batches = RecordBatches::try_collect(stream).await.unwrap(); + assert_eq!(expected, sort_batches_and_print(&batches, &["tag_0", "ts"])); + + // Reopens engine with parallelism 1. + let engine = env + .reopen_engine( + engine, + MitoConfig { + scan_parallelism: 1, + ..Default::default() + }, + ) + .await; + // Reopens the region. + reopen_region(&engine, region_id, region_dir, false, region_opts).await; + let stream = engine + .scan_to_stream(region_id, ScanRequest::default()) + .await + .unwrap(); + let batches = RecordBatches::try_collect(stream).await.unwrap(); + assert_eq!(expected, sort_batches_and_print(&batches, &["tag_0", "ts"])); +} diff --git a/src/mito2/src/memtable.rs b/src/mito2/src/memtable.rs index b807197f099d..3cc497b25405 100644 --- a/src/mito2/src/memtable.rs +++ b/src/mito2/src/memtable.rs @@ -34,7 +34,7 @@ use crate::memtable::partition_tree::{PartitionTreeConfig, PartitionTreeMemtable use crate::memtable::time_series::TimeSeriesMemtableBuilder; use crate::metrics::WRITE_BUFFER_BYTES; use crate::read::Batch; -use crate::region::options::MemtableOptions; +use crate::region::options::{MemtableOptions, MergeMode}; pub mod bulk; pub mod key_values; @@ -251,11 +251,13 @@ impl MemtableBuilderProvider { &self, options: Option<&MemtableOptions>, dedup: bool, + merge_mode: MergeMode, ) -> MemtableBuilderRef { match options { Some(MemtableOptions::TimeSeries) => Arc::new(TimeSeriesMemtableBuilder::new( self.write_buffer_manager.clone(), dedup, + merge_mode, )), Some(MemtableOptions::PartitionTree(opts)) => { Arc::new(PartitionTreeMemtableBuilder::new( @@ -264,15 +266,16 @@ impl MemtableBuilderProvider { data_freeze_threshold: opts.data_freeze_threshold, fork_dictionary_bytes: opts.fork_dictionary_bytes, dedup, + merge_mode, }, self.write_buffer_manager.clone(), )) } - None => self.default_memtable_builder(dedup), + None => self.default_memtable_builder(dedup, merge_mode), } } - fn default_memtable_builder(&self, dedup: bool) -> MemtableBuilderRef { + fn default_memtable_builder(&self, dedup: bool, merge_mode: MergeMode) -> MemtableBuilderRef { match &self.config.memtable { MemtableConfig::PartitionTree(config) => { let mut config = config.clone(); @@ -285,6 +288,7 @@ impl MemtableBuilderProvider { MemtableConfig::TimeSeries => Arc::new(TimeSeriesMemtableBuilder::new( self.write_buffer_manager.clone(), dedup, + merge_mode, )), } } diff --git a/src/mito2/src/memtable/partition_tree.rs b/src/mito2/src/memtable/partition_tree.rs index af3b1e343751..0d902aaa8537 100644 --- a/src/mito2/src/memtable/partition_tree.rs +++ b/src/mito2/src/memtable/partition_tree.rs @@ -43,6 +43,7 @@ use crate::memtable::{ AllocTracker, BoxedBatchIterator, BulkPart, IterBuilder, KeyValues, Memtable, MemtableBuilder, MemtableId, MemtableRange, MemtableRangeContext, MemtableRef, MemtableStats, }; +use crate::region::options::MergeMode; /// Use `1/DICTIONARY_SIZE_FACTOR` of OS memory as dictionary size. pub(crate) const DICTIONARY_SIZE_FACTOR: u64 = 8; @@ -80,6 +81,9 @@ pub struct PartitionTreeConfig { pub dedup: bool, /// Total bytes of dictionary to keep in fork. pub fork_dictionary_bytes: ReadableSize, + /// Merge mode of the tree. + #[serde(skip_deserializing)] + pub merge_mode: MergeMode, } impl Default for PartitionTreeConfig { @@ -98,6 +102,7 @@ impl Default for PartitionTreeConfig { data_freeze_threshold: 131072, dedup: true, fork_dictionary_bytes, + merge_mode: MergeMode::LastRow, } } } diff --git a/src/mito2/src/memtable/partition_tree/tree.rs b/src/mito2/src/memtable/partition_tree/tree.rs index 110032a68ed5..3df106f7da4a 100644 --- a/src/mito2/src/memtable/partition_tree/tree.rs +++ b/src/mito2/src/memtable/partition_tree/tree.rs @@ -40,7 +40,9 @@ use crate::memtable::partition_tree::partition::{ use crate::memtable::partition_tree::PartitionTreeConfig; use crate::memtable::{BoxedBatchIterator, KeyValues}; use crate::metrics::{PARTITION_TREE_READ_STAGE_ELAPSED, READ_ROWS_TOTAL, READ_STAGE_ELAPSED}; +use crate::read::dedup::LastNonNullIter; use crate::read::Batch; +use crate::region::options::MergeMode; use crate::row_converter::{McmpRowCodec, RowCodec, SortField}; /// The partition tree. @@ -83,9 +85,13 @@ impl PartitionTree { .collect(), }; let is_partitioned = Partition::has_multi_partitions(&metadata); + let mut config = config.clone(); + if config.merge_mode == MergeMode::LastNonNull { + config.dedup = false; + } PartitionTree { - config: config.clone(), + config, metadata, row_codec: Arc::new(row_codec), partitions: Default::default(), @@ -237,7 +243,13 @@ impl PartitionTree { iter.fetch_next_partition(context)?; iter.metrics.iter_elapsed += start.elapsed(); - Ok(Box::new(iter)) + + if self.config.merge_mode == MergeMode::LastNonNull { + let iter = LastNonNullIter::new(iter); + Ok(Box::new(iter)) + } else { + Ok(Box::new(iter)) + } } /// Returns true if the tree is empty. diff --git a/src/mito2/src/memtable/time_series.rs b/src/mito2/src/memtable/time_series.rs index ac3965f6f1c3..4c7a11456620 100644 --- a/src/mito2/src/memtable/time_series.rs +++ b/src/mito2/src/memtable/time_series.rs @@ -47,7 +47,9 @@ use crate::memtable::{ MemtableId, MemtableRange, MemtableRangeContext, MemtableRef, MemtableStats, }; use crate::metrics::{READ_ROWS_TOTAL, READ_STAGE_ELAPSED}; +use crate::read::dedup::LastNonNullIter; use crate::read::{Batch, BatchBuilder, BatchColumn}; +use crate::region::options::MergeMode; use crate::row_converter::{McmpRowCodec, RowCodec, SortField}; /// Initial vector builder capacity. @@ -58,14 +60,20 @@ const INITIAL_BUILDER_CAPACITY: usize = 0; pub struct TimeSeriesMemtableBuilder { write_buffer_manager: Option, dedup: bool, + merge_mode: MergeMode, } impl TimeSeriesMemtableBuilder { /// Creates a new builder with specific `write_buffer_manager`. - pub fn new(write_buffer_manager: Option, dedup: bool) -> Self { + pub fn new( + write_buffer_manager: Option, + dedup: bool, + merge_mode: MergeMode, + ) -> Self { Self { write_buffer_manager, dedup, + merge_mode, } } } @@ -77,6 +85,7 @@ impl MemtableBuilder for TimeSeriesMemtableBuilder { id, self.write_buffer_manager.clone(), self.dedup, + self.merge_mode, )) } } @@ -91,6 +100,7 @@ pub struct TimeSeriesMemtable { max_timestamp: AtomicI64, min_timestamp: AtomicI64, dedup: bool, + merge_mode: MergeMode, } impl TimeSeriesMemtable { @@ -99,6 +109,7 @@ impl TimeSeriesMemtable { id: MemtableId, write_buffer_manager: Option, dedup: bool, + merge_mode: MergeMode, ) -> Self { let row_codec = Arc::new(McmpRowCodec::new( region_metadata @@ -107,6 +118,11 @@ impl TimeSeriesMemtable { .collect(), )); let series_set = SeriesSet::new(region_metadata.clone(), row_codec.clone()); + let dedup = if merge_mode == MergeMode::LastNonNull { + false + } else { + dedup + }; Self { id, region_metadata, @@ -116,6 +132,7 @@ impl TimeSeriesMemtable { max_timestamp: AtomicI64::new(i64::MIN), min_timestamp: AtomicI64::new(i64::MAX), dedup, + merge_mode, } } @@ -251,7 +268,13 @@ impl Memtable for TimeSeriesMemtable { let iter = self .series_set .iter_series(projection, filters, self.dedup)?; - Ok(Box::new(iter)) + + if self.merge_mode == MergeMode::LastNonNull { + let iter = LastNonNullIter::new(iter); + Ok(Box::new(iter)) + } else { + Ok(Box::new(iter)) + } } fn ranges( @@ -272,6 +295,7 @@ impl Memtable for TimeSeriesMemtable { projection, predicate, dedup: self.dedup, + merge_mode: self.merge_mode, }); let context = Arc::new(MemtableRangeContext::new(self.id, builder)); @@ -320,6 +344,7 @@ impl Memtable for TimeSeriesMemtable { id, self.alloc_tracker.write_buffer_manager(), self.dedup, + self.merge_mode, )) } } @@ -856,6 +881,7 @@ struct TimeSeriesIterBuilder { projection: HashSet, predicate: Option, dedup: bool, + merge_mode: MergeMode, } impl IterBuilder for TimeSeriesIterBuilder { @@ -865,7 +891,13 @@ impl IterBuilder for TimeSeriesIterBuilder { self.predicate.clone(), self.dedup, )?; - Ok(Box::new(iter)) + + if self.merge_mode == MergeMode::LastNonNull { + let iter = LastNonNullIter::new(iter); + Ok(Box::new(iter)) + } else { + Ok(Box::new(iter)) + } } } @@ -1234,7 +1266,7 @@ mod tests { fn check_memtable_dedup(dedup: bool) { let schema = schema_for_test(); let kvs = build_key_values(&schema, "hello".to_string(), 42, 100); - let memtable = TimeSeriesMemtable::new(schema, 42, None, dedup); + let memtable = TimeSeriesMemtable::new(schema, 42, None, dedup, MergeMode::LastRow); memtable.write(&kvs).unwrap(); memtable.write(&kvs).unwrap(); @@ -1283,7 +1315,7 @@ mod tests { common_telemetry::init_default_ut_logging(); let schema = schema_for_test(); let kvs = build_key_values(&schema, "hello".to_string(), 42, 100); - let memtable = TimeSeriesMemtable::new(schema, 42, None, true); + let memtable = TimeSeriesMemtable::new(schema, 42, None, true, MergeMode::LastRow); memtable.write(&kvs).unwrap(); let iter = memtable.iter(Some(&[3]), None).unwrap(); diff --git a/src/mito2/src/read/dedup.rs b/src/mito2/src/read/dedup.rs index c8709edcd965..52ff05fd1231 100644 --- a/src/mito2/src/read/dedup.rs +++ b/src/mito2/src/read/dedup.rs @@ -317,9 +317,9 @@ impl LastFieldsBuilder { self.contains_null = self.last_fields.iter().any(Value::is_null); } - /// Merges last not null fields, builds a new batch and resets the builder. + /// Merges last non-null fields, builds a new batch and resets the builder. /// It may overwrites the last row of the `buffer`. - fn merge_last_not_null( + fn merge_last_non_null( &mut self, buffer: Batch, metrics: &mut DedupMetrics, @@ -380,20 +380,20 @@ impl LastFieldsBuilder { } } -/// Dedup strategy that keeps the last not null field for the same key. +/// Dedup strategy that keeps the last non-null field for the same key. /// /// It assumes that batches from files and memtables don't contain duplicate rows /// and the merge reader never concatenates batches from different source. /// /// We might implement a new strategy if we need to process files with duplicate rows. -pub(crate) struct LastNotNull { +pub(crate) struct LastNonNull { /// Buffered batch that fields in the last row may be updated. buffer: Option, /// Fields that overlaps with the last row of the `buffer`. last_fields: LastFieldsBuilder, } -impl LastNotNull { +impl LastNonNull { /// Creates a new strategy with the given `filter_deleted` flag. #[allow(dead_code)] pub(crate) fn new(filter_deleted: bool) -> Self { @@ -404,7 +404,7 @@ impl LastNotNull { } } -impl DedupStrategy for LastNotNull { +impl DedupStrategy for LastNonNull { fn push_batch(&mut self, batch: Batch, metrics: &mut DedupMetrics) -> Result> { if batch.is_empty() { return Ok(None); @@ -422,14 +422,14 @@ impl DedupStrategy for LastNotNull { if buffer.primary_key() != batch.primary_key() { // Next key is different. let buffer = std::mem::replace(buffer, batch); - let merged = self.last_fields.merge_last_not_null(buffer, metrics)?; + let merged = self.last_fields.merge_last_non_null(buffer, metrics)?; return Ok(merged); } if buffer.last_timestamp() != batch.first_timestamp() { // The next batch has a different timestamp. let buffer = std::mem::replace(buffer, batch); - let merged = self.last_fields.merge_last_not_null(buffer, metrics)?; + let merged = self.last_fields.merge_last_non_null(buffer, metrics)?; return Ok(merged); } @@ -449,7 +449,7 @@ impl DedupStrategy for LastNotNull { // Moves the remaining rows to the buffer. let batch = batch.slice(1, batch.num_rows() - 1); let buffer = std::mem::replace(buffer, batch); - let merged = self.last_fields.merge_last_not_null(buffer, metrics)?; + let merged = self.last_fields.merge_last_non_null(buffer, metrics)?; Ok(merged) } @@ -462,12 +462,107 @@ impl DedupStrategy for LastNotNull { // Initializes last fields with the first buffer. self.last_fields.maybe_init(&buffer); - let merged = self.last_fields.merge_last_not_null(buffer, metrics)?; + let merged = self.last_fields.merge_last_non_null(buffer, metrics)?; Ok(merged) } } +/// An iterator that dedup rows by [LastNonNull] strategy. +/// The input iterator must returns sorted batches. +pub(crate) struct LastNonNullIter { + /// Inner iterator that returns sorted batches. + iter: Option, + /// Dedup strategy. + strategy: LastNonNull, + /// Dedup metrics. + metrics: DedupMetrics, + /// The current batch returned by the iterator. If it is None, we need to + /// fetch a new batch. + /// The batch is always not empty. + current_batch: Option, +} + +impl LastNonNullIter { + /// Creates a new iterator with the given inner iterator. + pub(crate) fn new(iter: I) -> Self { + Self { + iter: Some(iter), + // We only use the iter in memtables. Memtables never filter deleted. + strategy: LastNonNull::new(false), + metrics: DedupMetrics::default(), + current_batch: None, + } + } + + /// Finds the index of the first row that has the same timestamp with the next row. + /// If no duplicate rows, returns None. + fn find_split_index(batch: &Batch) -> Option { + if batch.num_rows() < 2 { + return None; + } + + // Safety: The batch is not empty. + let timestamps = batch.timestamps_native().unwrap(); + timestamps.windows(2).position(|t| t[0] == t[1]) + } +} + +impl>> LastNonNullIter { + /// Fetches the next batch from the inner iterator. It will slice the batch if it + /// contains duplicate rows. + fn next_batch_for_merge(&mut self) -> Result> { + if self.current_batch.is_none() { + // No current batch. Fetches a new batch from the inner iterator. + let Some(iter) = self.iter.as_mut() else { + // The iterator is exhausted. + return Ok(None); + }; + + self.current_batch = iter.next().transpose()?; + if self.current_batch.is_none() { + // The iterator is exhausted. + self.iter = None; + return Ok(None); + } + } + + if let Some(batch) = &self.current_batch { + let Some(index) = Self::find_split_index(batch) else { + // No duplicate rows in the current batch. + return Ok(self.current_batch.take()); + }; + + let first = batch.slice(0, index + 1); + let batch = batch.slice(index + 1, batch.num_rows() - index - 1); + // `index` is Some indicates that the batch has at least one row remaining. + debug_assert!(!batch.is_empty()); + self.current_batch = Some(batch); + return Ok(Some(first)); + } + + Ok(None) + } + + fn next_batch(&mut self) -> Result> { + while let Some(batch) = self.next_batch_for_merge()? { + if let Some(batch) = self.strategy.push_batch(batch, &mut self.metrics)? { + return Ok(Some(batch)); + } + } + + self.strategy.finish(&mut self.metrics) + } +} + +impl>> Iterator for LastNonNullIter { + type Item = Result; + + fn next(&mut self) -> Option { + self.next_batch().transpose() + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -506,9 +601,9 @@ mod tests { assert_eq!(0, reader.metrics().num_unselected_rows); assert_eq!(0, reader.metrics().num_deleted_rows); - // Test last not null. + // Test last non-null. let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(true)); + let mut reader = DedupReader::new(reader, LastNonNull::new(true)); check_reader_result(&mut reader, &input).await; assert_eq!(0, reader.metrics().num_unselected_rows); assert_eq!(0, reader.metrics().num_deleted_rows); @@ -640,7 +735,7 @@ mod tests { } #[tokio::test] - async fn test_last_not_null_merge() { + async fn test_last_non_null_merge() { let input = [ new_batch_multi_fields( b"k1", @@ -688,7 +783,7 @@ mod tests { // Filter deleted. let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(true)); + let mut reader = DedupReader::new(reader, LastNonNull::new(true)); check_reader_result( &mut reader, &[ @@ -722,7 +817,7 @@ mod tests { // Does not filter deleted. let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(false)); + let mut reader = DedupReader::new(reader, LastNonNull::new(false)); check_reader_result( &mut reader, &[ @@ -762,7 +857,7 @@ mod tests { } #[tokio::test] - async fn test_last_not_null_skip_merge_single() { + async fn test_last_non_null_skip_merge_single() { let input = [new_batch_multi_fields( b"k1", &[1, 2, 3], @@ -772,7 +867,7 @@ mod tests { )]; let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(true)); + let mut reader = DedupReader::new(reader, LastNonNull::new(true)); check_reader_result( &mut reader, &[new_batch_multi_fields( @@ -788,14 +883,14 @@ mod tests { assert_eq!(1, reader.metrics().num_deleted_rows); let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(false)); + let mut reader = DedupReader::new(reader, LastNonNull::new(false)); check_reader_result(&mut reader, &input).await; assert_eq!(0, reader.metrics().num_unselected_rows); assert_eq!(0, reader.metrics().num_deleted_rows); } #[tokio::test] - async fn test_last_not_null_skip_merge_no_null() { + async fn test_last_non_null_skip_merge_no_null() { let input = [ new_batch_multi_fields( b"k1", @@ -815,7 +910,7 @@ mod tests { ]; let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(true)); + let mut reader = DedupReader::new(reader, LastNonNull::new(true)); check_reader_result( &mut reader, &[ @@ -835,7 +930,7 @@ mod tests { } #[tokio::test] - async fn test_last_not_null_merge_null() { + async fn test_last_non_null_merge_null() { let input = [ new_batch_multi_fields( b"k1", @@ -849,7 +944,7 @@ mod tests { ]; let reader = VecBatchReader::new(&input); - let mut reader = DedupReader::new(reader, LastNotNull::new(true)); + let mut reader = DedupReader::new(reader, LastNonNull::new(true)); check_reader_result( &mut reader, &[ @@ -884,7 +979,7 @@ mod tests { } #[test] - fn test_last_not_null_strategy_delete_last() { + fn test_last_non_null_strategy_delete_last() { let input = [ new_batch_multi_fields(b"k1", &[1], &[6], &[OpType::Put], &[(Some(11), None)]), new_batch_multi_fields( @@ -905,7 +1000,7 @@ mod tests { new_batch_multi_fields(b"k2", &[3], &[3], &[OpType::Put], &[(None, Some(3))]), ]; - let mut strategy = LastNotNull::new(true); + let mut strategy = LastNonNull::new(true); check_dedup_strategy( &input, &mut strategy, @@ -918,13 +1013,13 @@ mod tests { } #[test] - fn test_last_not_null_strategy_delete_one() { + fn test_last_non_null_strategy_delete_one() { let input = [ new_batch_multi_fields(b"k1", &[1], &[1], &[OpType::Delete], &[(None, None)]), new_batch_multi_fields(b"k2", &[1], &[6], &[OpType::Put], &[(Some(11), None)]), ]; - let mut strategy = LastNotNull::new(true); + let mut strategy = LastNonNull::new(true); check_dedup_strategy( &input, &mut strategy, @@ -939,18 +1034,18 @@ mod tests { } #[test] - fn test_last_not_null_strategy_delete_all() { + fn test_last_non_null_strategy_delete_all() { let input = [ new_batch_multi_fields(b"k1", &[1], &[1], &[OpType::Delete], &[(None, None)]), new_batch_multi_fields(b"k2", &[1], &[6], &[OpType::Delete], &[(Some(11), None)]), ]; - let mut strategy = LastNotNull::new(true); + let mut strategy = LastNonNull::new(true); check_dedup_strategy(&input, &mut strategy, &[]); } #[test] - fn test_last_not_null_strategy_same_batch() { + fn test_last_non_null_strategy_same_batch() { let input = [ new_batch_multi_fields(b"k1", &[1], &[6], &[OpType::Put], &[(Some(11), None)]), new_batch_multi_fields( @@ -971,7 +1066,7 @@ mod tests { new_batch_multi_fields(b"k1", &[3], &[3], &[OpType::Put], &[(None, Some(3))]), ]; - let mut strategy = LastNotNull::new(true); + let mut strategy = LastNonNull::new(true); check_dedup_strategy( &input, &mut strategy, @@ -982,4 +1077,92 @@ mod tests { ], ); } + + #[test] + fn test_last_non_null_iter_on_batch() { + let input = [new_batch_multi_fields( + b"k1", + &[1, 1, 2], + &[13, 12, 13], + &[OpType::Put, OpType::Put, OpType::Put], + &[(None, None), (Some(1), None), (Some(2), Some(22))], + )]; + let iter = input.into_iter().map(Ok); + let iter = LastNonNullIter::new(iter); + let actual: Vec<_> = iter.map(|batch| batch.unwrap()).collect(); + let expect = [ + new_batch_multi_fields(b"k1", &[1], &[13], &[OpType::Put], &[(Some(1), None)]), + new_batch_multi_fields(b"k1", &[2], &[13], &[OpType::Put], &[(Some(2), Some(22))]), + ]; + assert_eq!(&expect, &actual[..]); + } + + #[test] + fn test_last_non_null_iter_same_row() { + let input = [ + new_batch_multi_fields( + b"k1", + &[1, 1, 1], + &[13, 12, 11], + &[OpType::Put, OpType::Put, OpType::Put], + &[(None, None), (Some(1), None), (Some(11), None)], + ), + new_batch_multi_fields( + b"k1", + &[1, 1], + &[10, 9], + &[OpType::Put, OpType::Put], + &[(None, Some(11)), (Some(21), Some(31))], + ), + ]; + let iter = input.into_iter().map(Ok); + let iter = LastNonNullIter::new(iter); + let actual: Vec<_> = iter.map(|batch| batch.unwrap()).collect(); + let expect = [new_batch_multi_fields( + b"k1", + &[1], + &[13], + &[OpType::Put], + &[(Some(1), Some(11))], + )]; + assert_eq!(&expect, &actual[..]); + } + + #[test] + fn test_last_non_null_iter_multi_batch() { + let input = [ + new_batch_multi_fields( + b"k1", + &[1, 1, 2], + &[13, 12, 13], + &[OpType::Put, OpType::Put, OpType::Put], + &[(None, None), (Some(1), None), (Some(2), Some(22))], + ), + new_batch_multi_fields( + b"k1", + &[2, 3], + &[12, 13], + &[OpType::Put, OpType::Delete], + &[(None, Some(12)), (None, None)], + ), + new_batch_multi_fields( + b"k2", + &[1, 1, 2], + &[13, 12, 13], + &[OpType::Put, OpType::Put, OpType::Put], + &[(None, None), (Some(1), None), (Some(2), Some(22))], + ), + ]; + let iter = input.into_iter().map(Ok); + let iter = LastNonNullIter::new(iter); + let actual: Vec<_> = iter.map(|batch| batch.unwrap()).collect(); + let expect = [ + new_batch_multi_fields(b"k1", &[1], &[13], &[OpType::Put], &[(Some(1), None)]), + new_batch_multi_fields(b"k1", &[2], &[13], &[OpType::Put], &[(Some(2), Some(22))]), + new_batch_multi_fields(b"k1", &[3], &[13], &[OpType::Delete], &[(None, None)]), + new_batch_multi_fields(b"k2", &[1], &[13], &[OpType::Put], &[(Some(1), None)]), + new_batch_multi_fields(b"k2", &[2], &[13], &[OpType::Put], &[(Some(2), Some(22))]), + ]; + assert_eq!(&expect, &actual[..]); + } } diff --git a/src/mito2/src/read/scan_region.rs b/src/mito2/src/read/scan_region.rs index 4fe783ce9e35..e29b1611a2f7 100644 --- a/src/mito2/src/read/scan_region.rs +++ b/src/mito2/src/read/scan_region.rs @@ -42,6 +42,7 @@ use crate::read::projection::ProjectionMapper; use crate::read::seq_scan::SeqScan; use crate::read::unordered_scan::UnorderedScan; use crate::read::{Batch, Source}; +use crate::region::options::MergeMode; use crate::region::version::VersionRef; use crate::sst::file::{overlaps, FileHandle, FileMeta}; use crate::sst::index::applier::builder::SstIndexApplierBuilder; @@ -295,7 +296,8 @@ impl ScanRegion { .with_parallelism(self.parallelism) .with_start_time(self.start_time) .with_append_mode(self.version.options.append_mode) - .with_filter_deleted(filter_deleted); + .with_filter_deleted(filter_deleted) + .with_merge_mode(self.version.options.merge_mode()); Ok(input) } @@ -398,6 +400,8 @@ pub(crate) struct ScanInput { pub(crate) append_mode: bool, /// Whether to remove deletion markers. pub(crate) filter_deleted: bool, + /// Mode to merge duplicate rows. + pub(crate) merge_mode: MergeMode, } impl ScanInput { @@ -418,6 +422,7 @@ impl ScanInput { query_start: None, append_mode: false, filter_deleted: true, + merge_mode: MergeMode::default(), } } @@ -497,6 +502,13 @@ impl ScanInput { self } + /// Sets the merge mode. + #[must_use] + pub(crate) fn with_merge_mode(mut self, merge_mode: MergeMode) -> Self { + self.merge_mode = merge_mode; + self + } + /// Scans sources in parallel. /// /// # Panics if the input doesn't allow parallel scan. diff --git a/src/mito2/src/read/seq_scan.rs b/src/mito2/src/read/seq_scan.rs index 17151b624d49..9a0038135f4a 100644 --- a/src/mito2/src/read/seq_scan.rs +++ b/src/mito2/src/read/seq_scan.rs @@ -35,12 +35,13 @@ use tokio::sync::Semaphore; use crate::error::{PartitionOutOfRangeSnafu, Result}; use crate::memtable::MemtableRef; -use crate::read::dedup::{DedupReader, LastRow}; +use crate::read::dedup::{DedupReader, LastNonNull, LastRow}; use crate::read::merge::MergeReaderBuilder; use crate::read::scan_region::{ FileRangeCollector, ScanInput, ScanPart, ScanPartList, StreamContext, }; use crate::read::{BatchReader, BoxedBatchReader, ScannerMetrics, Source}; +use crate::region::options::MergeMode; use crate::sst::file::FileMeta; use crate::sst::parquet::file_range::FileRange; use crate::sst::parquet::reader::ReaderMetrics; @@ -210,10 +211,16 @@ impl SeqScan { let dedup = !stream_ctx.input.append_mode; if dedup { - let reader = Box::new(DedupReader::new( - reader, - LastRow::new(stream_ctx.input.filter_deleted), - )); + let reader = match stream_ctx.input.merge_mode { + MergeMode::LastRow => Box::new(DedupReader::new( + reader, + LastRow::new(stream_ctx.input.filter_deleted), + )) as _, + MergeMode::LastNonNull => Box::new(DedupReader::new( + reader, + LastNonNull::new(stream_ctx.input.filter_deleted), + )) as _, + }; Ok(Some(reader)) } else { let reader = Box::new(reader); diff --git a/src/mito2/src/region/opener.rs b/src/mito2/src/region/opener.rs index e20a00d35ae1..50aa7c68cd37 100644 --- a/src/mito2/src/region/opener.rs +++ b/src/mito2/src/region/opener.rs @@ -119,9 +119,10 @@ impl RegionOpener { } /// Sets options for the region. - pub(crate) fn options(mut self, options: RegionOptions) -> Self { + pub(crate) fn options(mut self, options: RegionOptions) -> Result { + options.validate()?; self.options = Some(options); - self + Ok(self) } /// Sets the cache manager for the region. @@ -192,9 +193,11 @@ impl RegionOpener { ) .await?; - let memtable_builder = self - .memtable_builder_provider - .builder_for_options(options.memtable.as_ref(), !options.append_mode); + let memtable_builder = self.memtable_builder_provider.builder_for_options( + options.memtable.as_ref(), + options.need_dedup(), + options.merge_mode(), + ); // Initial memtable id is 0. let part_duration = options.compaction.time_window(); let mutable = Arc::new(TimePartitions::new( @@ -323,7 +326,8 @@ impl RegionOpener { )); let memtable_builder = self.memtable_builder_provider.builder_for_options( region_options.memtable.as_ref(), - !region_options.append_mode, + region_options.need_dedup(), + region_options.merge_mode(), ); // Initial memtable id is 0. let part_duration = region_options.compaction.time_window(); diff --git a/src/mito2/src/region/options.rs b/src/mito2/src/region/options.rs index 970c70c74421..93fa84dbadf3 100644 --- a/src/mito2/src/region/options.rs +++ b/src/mito2/src/region/options.rs @@ -24,15 +24,28 @@ use common_wal::options::{WalOptions, WAL_OPTIONS_KEY}; use serde::de::Error as _; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; -use serde_with::{serde_as, with_prefix, DisplayFromStr}; +use serde_with::{serde_as, with_prefix, DisplayFromStr, NoneAsEmptyString}; use snafu::{ensure, ResultExt}; use store_api::storage::ColumnId; +use strum::EnumString; use crate::error::{Error, InvalidRegionOptionsSnafu, JsonOptionsSnafu, Result}; use crate::memtable::partition_tree::{DEFAULT_FREEZE_THRESHOLD, DEFAULT_MAX_KEYS_PER_SHARD}; const DEFAULT_INDEX_SEGMENT_ROW_COUNT: usize = 1024; +/// Mode to handle duplicate rows while merging. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, EnumString)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum MergeMode { + /// Keeps the last row. + #[default] + LastRow, + /// Keeps the last non-null field for each row. + LastNonNull, +} + /// Options that affect the entire region. /// /// Users need to specify the options while creating/opening a region. @@ -54,6 +67,34 @@ pub struct RegionOptions { pub index_options: IndexOptions, /// Memtable options. pub memtable: Option, + /// The mode to merge duplicate rows. + /// Only takes effect when `append_mode` is `false`. + pub merge_mode: Option, +} + +impl RegionOptions { + /// Validates options. + pub fn validate(&self) -> Result<()> { + if self.append_mode { + ensure!( + self.merge_mode.is_none(), + InvalidRegionOptionsSnafu { + reason: "merge_mode is not allowed when append_mode is enabled", + } + ); + } + Ok(()) + } + + /// Returns `true` if deduplication is needed. + pub fn need_dedup(&self) -> bool { + !self.append_mode + } + + /// Returns the `merge_mode` if it is set, otherwise returns the default `MergeMode`. + pub fn merge_mode(&self) -> MergeMode { + self.merge_mode.unwrap_or_default() + } } impl TryFrom<&HashMap> for RegionOptions { @@ -89,7 +130,7 @@ impl TryFrom<&HashMap> for RegionOptions { None }; - Ok(RegionOptions { + let opts = RegionOptions { ttl: options.ttl, compaction, storage: options.storage, @@ -97,7 +138,11 @@ impl TryFrom<&HashMap> for RegionOptions { wal_options, index_options, memtable, - }) + merge_mode: options.merge_mode, + }; + opts.validate()?; + + Ok(opts) } } @@ -179,6 +224,8 @@ struct RegionOptionsWithoutEnum { storage: Option, #[serde_as(as = "DisplayFromStr")] append_mode: bool, + #[serde_as(as = "NoneAsEmptyString")] + merge_mode: Option, } impl Default for RegionOptionsWithoutEnum { @@ -188,6 +235,7 @@ impl Default for RegionOptionsWithoutEnum { ttl: options.ttl, storage: options.storage, append_mode: options.append_mode, + merge_mode: options.merge_mode, } } } @@ -477,6 +525,21 @@ mod tests { assert_eq!(StatusCode::InvalidArguments, err.status_code()); } + #[test] + fn test_with_merge_mode() { + let map = make_map(&[("merge_mode", "last_row")]); + let options = RegionOptions::try_from(&map).unwrap(); + assert_eq!(MergeMode::LastRow, options.merge_mode()); + + let map = make_map(&[("merge_mode", "last_non_null")]); + let options = RegionOptions::try_from(&map).unwrap(); + assert_eq!(MergeMode::LastNonNull, options.merge_mode()); + + let map = make_map(&[("merge_mode", "unknown")]); + let err = RegionOptions::try_from(&map).unwrap_err(); + assert_eq!(StatusCode::InvalidArguments, err.status_code()); + } + #[test] fn test_with_all() { let wal_options = WalOptions::Kafka(KafkaWalOptions { @@ -489,7 +552,7 @@ mod tests { ("compaction.twcs.time_window", "2h"), ("compaction.type", "twcs"), ("storage", "S3"), - ("append_mode", "true"), + ("append_mode", "false"), ("index.inverted_index.ignore_column_ids", "1,2,3"), ("index.inverted_index.segment_row_count", "512"), ( @@ -500,6 +563,7 @@ mod tests { ("memtable.partition_tree.index_max_keys_per_shard", "2048"), ("memtable.partition_tree.data_freeze_threshold", "2048"), ("memtable.partition_tree.fork_dictionary_bytes", "128M"), + ("merge_mode", "last_non_null"), ]); let options = RegionOptions::try_from(&map).unwrap(); let expect = RegionOptions { @@ -510,7 +574,7 @@ mod tests { time_window: Some(Duration::from_secs(3600 * 2)), }), storage: Some("S3".to_string()), - append_mode: true, + append_mode: false, wal_options, index_options: IndexOptions { inverted_index: InvertedIndexOptions { @@ -523,6 +587,7 @@ mod tests { data_freeze_threshold: 2048, fork_dictionary_bytes: ReadableSize::mb(128), })), + merge_mode: Some(MergeMode::LastNonNull), }; assert_eq!(expect, options); } diff --git a/src/mito2/src/test_util.rs b/src/mito2/src/test_util.rs index ac5f34f7895d..374e7548b05e 100644 --- a/src/mito2/src/test_util.rs +++ b/src/mito2/src/test_util.rs @@ -920,6 +920,38 @@ pub fn build_rows(start: usize, end: usize) -> Vec { .collect() } +/// Build rows with schema (string, f64, f64, ts_millis). +/// - `key`: A string key that is common across all rows. +/// - `timestamps`: Array of timestamp values. +/// - `fields`: Array of tuples where each tuple contains two optional i64 values, representing two optional float fields. +/// Returns a vector of `Row` each containing the key, two optional float fields, and a timestamp. +pub fn build_rows_with_fields( + key: &str, + timestamps: &[i64], + fields: &[(Option, Option)], +) -> Vec { + timestamps + .iter() + .zip(fields.iter()) + .map(|(ts, (field1, field2))| api::v1::Row { + values: vec![ + api::v1::Value { + value_data: Some(ValueData::StringValue(key.to_string())), + }, + api::v1::Value { + value_data: field1.map(|v| ValueData::F64Value(v as f64)), + }, + api::v1::Value { + value_data: field2.map(|v| ValueData::F64Value(v as f64)), + }, + api::v1::Value { + value_data: Some(ValueData::TimestampMillisecondValue(*ts * 1000)), + }, + ], + }) + .collect() +} + /// Get column schemas for rows. pub fn rows_schema(request: &RegionCreateRequest) -> Vec { request diff --git a/src/mito2/src/worker/handle_catchup.rs b/src/mito2/src/worker/handle_catchup.rs index 93a84b92e272..a4353fe52952 100644 --- a/src/mito2/src/worker/handle_catchup.rs +++ b/src/mito2/src/worker/handle_catchup.rs @@ -57,7 +57,7 @@ impl RegionWorkerLoop { self.intermediate_manager.clone(), ) .cache(Some(self.cache_manager.clone())) - .options(region.version().options.clone()) + .options(region.version().options.clone())? .skip_wal_replay(true) .open(&self.config, &self.wal) .await?,