diff --git a/rust/frontend/src/optimizer/mod.rs b/rust/frontend/src/optimizer/mod.rs index c1dd698edca3..ac995cbcd7c6 100644 --- a/rust/frontend/src/optimizer/mod.rs +++ b/rust/frontend/src/optimizer/mod.rs @@ -174,10 +174,13 @@ impl PlanRoot { let stream_plan = match self.plan.convention() { Convention::Logical => { let plan = self.gen_optimized_logical_plan(); - let (plan, mut out_col_change) = plan.logical_rewrite_for_stream(); - self.required_dist = - out_col_change.rewrite_distribution(self.required_dist.clone()); - self.required_order = out_col_change.rewrite_order(self.required_order.clone()); + let (plan, out_col_change) = plan.logical_rewrite_for_stream(); + self.required_dist = out_col_change + .rewrite_required_distribution(&self.required_dist) + .unwrap(); + self.required_order = out_col_change + .rewrite_required_order(&self.required_order) + .unwrap(); self.out_fields = out_col_change.rewrite_bitset(&self.out_fields); self.schema = plan.schema().clone(); plan.to_stream_with_dist_required(&self.required_dist) diff --git a/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs b/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs index a9083745437a..9d8c736825d9 100644 --- a/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs +++ b/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs @@ -33,12 +33,18 @@ pub struct BatchHashAgg { impl BatchHashAgg { pub fn new(logical: LogicalAgg) -> Self { let ctx = logical.base.ctx.clone(); - let base = PlanBase::new_batch( - ctx, - logical.schema().clone(), - Distribution::any().clone(), - Order::any().clone(), - ); + let input = logical.input(); + let input_dist = input.distribution(); + let dist = match input_dist { + Distribution::Any => Distribution::Any, + Distribution::Single => Distribution::Single, + Distribution::Broadcast => panic!(), + Distribution::AnyShard => Distribution::AnyShard, + Distribution::HashShard(_) => logical + .i2o_col_mapping() + .rewrite_provided_distribution(input_dist), + }; + let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone()); BatchHashAgg { base, logical } } pub fn agg_calls(&self) -> &[PlanAggCall] { diff --git a/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs b/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs index 384698d0fc50..b12c1bb7de13 100644 --- a/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs +++ b/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs @@ -23,6 +23,7 @@ use super::{ ToDistributedBatch, }; use crate::optimizer::property::{Distribution, Order, WithSchema}; +use crate::utils::ColIndexMapping; /// `BatchHashJoin` implements [`super::LogicalJoin`] with hash table. It builds a hash table /// from inner (right-side) relation and then probes with data from outer (left-side) relation to @@ -40,13 +41,13 @@ pub struct BatchHashJoin { impl BatchHashJoin { pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self { let ctx = logical.base.ctx.clone(); - // TODO: derive from input - let base = PlanBase::new_batch( - ctx, - logical.schema().clone(), - Distribution::any().clone(), - Order::any().clone(), + let dist = Self::derive_dist( + logical.left().distribution(), + logical.right().distribution(), + &eq_join_predicate, + &logical.l2o_col_mapping(), ); + let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone()); Self { base, @@ -55,6 +56,24 @@ impl BatchHashJoin { } } + fn derive_dist( + left: &Distribution, + right: &Distribution, + predicate: &EqJoinPredicate, + l2o_mapping: &ColIndexMapping, + ) -> Distribution { + match (left, right) { + (Distribution::Any, Distribution::Any) => Distribution::Any, + (Distribution::Single, Distribution::Single) => Distribution::Single, + (Distribution::HashShard(_), Distribution::HashShard(_)) => { + assert!(left.satisfies(&Distribution::HashShard(predicate.left_eq_indexes()))); + assert!(right.satisfies(&Distribution::HashShard(predicate.right_eq_indexes()))); + l2o_mapping.rewrite_provided_distribution(left) + } + (_, _) => panic!(), + } + } + /// Get a reference to the batch hash join's eq join predicate. pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate diff --git a/rust/frontend/src/optimizer/plan_node/batch_project.rs b/rust/frontend/src/optimizer/plan_node/batch_project.rs index df6e4ba0e8ef..7dd4fcc34647 100644 --- a/rust/frontend/src/optimizer/plan_node/batch_project.rs +++ b/rust/frontend/src/optimizer/plan_node/batch_project.rs @@ -36,20 +36,9 @@ pub struct BatchProject { impl BatchProject { pub fn new(logical: LogicalProject) -> Self { let ctx = logical.base.ctx.clone(); - let i2o = LogicalProject::i2o_col_mapping(logical.input().schema().len(), logical.exprs()); - let distribution = match logical.input().distribution() { - Distribution::HashShard(dists) => { - let new_dists = dists - .iter() - .map(|hash_col| i2o.try_map(*hash_col)) - .collect::>>(); - match new_dists { - Some(new_dists) => Distribution::HashShard(new_dists), - None => Distribution::AnyShard, - } - } - dist => dist.clone(), - }; + let distribution = logical + .i2o_col_mapping() + .rewrite_provided_distribution(logical.input().distribution()); // TODO: Derive order from input let base = PlanBase::new_batch( ctx, @@ -96,24 +85,18 @@ impl ToDistributedBatch for BatchProject { required_order: &Order, required_dist: &Distribution, ) -> PlanRef { - let o2i = - LogicalProject::o2i_col_mapping(self.input().schema().len(), self.logical.exprs()); - let input_dist = match required_dist { - Distribution::HashShard(dists) => { - let input_dists = dists - .iter() - .map(|hash_col| o2i.try_map(*hash_col)) - .collect::>>(); - match input_dists { - Some(input_dists) => Distribution::HashShard(input_dists), - None => Distribution::AnyShard, - } - } - dist => dist.clone(), + let input_required = match required_dist { + Distribution::HashShard(_) => self + .logical + .o2i_col_mapping() + .rewrite_required_distribution(required_dist) + .unwrap_or(Distribution::AnyShard), + Distribution::AnyShard => Distribution::AnyShard, + _ => Distribution::Any, }; let new_input = self .input() - .to_distributed_with_required(required_order, &input_dist); + .to_distributed_with_required(required_order, &input_required); let new_logical = self.logical.clone_with_input(new_input); let batch_plan = BatchProject::new(new_logical); let batch_plan = required_order.enforce_if_not_satisfies(batch_plan.into()); diff --git a/rust/frontend/src/optimizer/plan_node/batch_seq_scan.rs b/rust/frontend/src/optimizer/plan_node/batch_seq_scan.rs index 1ab5d54f0266..6e1f69260230 100644 --- a/rust/frontend/src/optimizer/plan_node/batch_seq_scan.rs +++ b/rust/frontend/src/optimizer/plan_node/batch_seq_scan.rs @@ -36,18 +36,27 @@ impl WithSchema for BatchSeqScan { } impl BatchSeqScan { - pub fn new(logical: LogicalScan) -> Self { + pub fn new_inner(logical: LogicalScan, dist: Distribution) -> Self { let ctx = logical.base.ctx.clone(); // TODO: derive from input - let base = PlanBase::new_batch( - ctx, - logical.schema().clone(), - Distribution::any().clone(), - Order::any().clone(), - ); + let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone()); Self { base, logical } } + + pub fn new(logical: LogicalScan) -> Self { + Self::new_inner(logical, Distribution::Any) + } + + pub fn new_with_dist(logical: LogicalScan) -> Self { + Self::new_inner(logical, Distribution::AnyShard) + } + + /// Get a reference to the batch seq scan's logical. + #[must_use] + pub fn logical(&self) -> &LogicalScan { + &self.logical + } } impl_plan_tree_node_for_leaf! { BatchSeqScan } @@ -65,7 +74,7 @@ impl fmt::Display for BatchSeqScan { impl ToDistributedBatch for BatchSeqScan { fn to_distributed(&self) -> PlanRef { - self.clone().into() + Self::new_with_dist(self.logical.clone()).into() } } diff --git a/rust/frontend/src/optimizer/plan_node/batch_simple_agg.rs b/rust/frontend/src/optimizer/plan_node/batch_simple_agg.rs index e9316edda4c8..a8f9506ec702 100644 --- a/rust/frontend/src/optimizer/plan_node/batch_simple_agg.rs +++ b/rust/frontend/src/optimizer/plan_node/batch_simple_agg.rs @@ -31,12 +31,14 @@ pub struct BatchSimpleAgg { impl BatchSimpleAgg { pub fn new(logical: LogicalAgg) -> Self { let ctx = logical.base.ctx.clone(); - let base = PlanBase::new_batch( - ctx, - logical.schema().clone(), - Distribution::any().clone(), - Order::any().clone(), - ); + let input = logical.input(); + let input_dist = input.distribution(); + let dist = match input_dist { + Distribution::Any => Distribution::Any, + Distribution::Single => Distribution::Single, + _ => panic!(), + }; + let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone()); BatchSimpleAgg { base, logical } } pub fn agg_calls(&self) -> &[PlanAggCall] { diff --git a/rust/frontend/src/optimizer/plan_node/batch_values.rs b/rust/frontend/src/optimizer/plan_node/batch_values.rs index 2b42bd683118..5c2c185bcf57 100644 --- a/rust/frontend/src/optimizer/plan_node/batch_values.rs +++ b/rust/frontend/src/optimizer/plan_node/batch_values.rs @@ -34,15 +34,20 @@ impl_plan_tree_node_for_leaf!(BatchValues); impl BatchValues { pub fn new(logical: LogicalValues) -> Self { + Self::with_dist(logical, Distribution::Any) + } + + pub fn with_dist(logical: LogicalValues, dist: Distribution) -> Self { let ctx = logical.base.ctx.clone(); - let base = PlanBase::new_batch( - ctx, - logical.schema().clone(), - Distribution::Broadcast, - Order::any().clone(), - ); + let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone()); BatchValues { base, logical } } + + /// Get a reference to the batch values's logical. + #[must_use] + pub fn logical(&self) -> &LogicalValues { + &self.logical + } } impl fmt::Display for BatchValues { @@ -61,7 +66,7 @@ impl WithSchema for BatchValues { impl ToDistributedBatch for BatchValues { fn to_distributed(&self) -> PlanRef { - self.clone().into() + Self::with_dist(self.logical().clone(), Distribution::Single).into() } } diff --git a/rust/frontend/src/optimizer/plan_node/eq_join_predicate.rs b/rust/frontend/src/optimizer/plan_node/eq_join_predicate.rs index 872d2b7a5011..8e727a8456fd 100644 --- a/rust/frontend/src/optimizer/plan_node/eq_join_predicate.rs +++ b/rust/frontend/src/optimizer/plan_node/eq_join_predicate.rs @@ -27,6 +27,8 @@ pub struct EqJoinPredicate { /// the first is from the left table and the second is from the right table. /// now all are normal equal(not null-safe-equal), eq_keys: Vec<(InputRef, InputRef)>, + + left_cols_num: usize, } impl fmt::Display for EqJoinPredicate { @@ -48,10 +50,15 @@ impl fmt::Display for EqJoinPredicate { impl EqJoinPredicate { /// The new method for `JoinPredicate` without any analysis, check or rewrite. - pub fn new(other_cond: Condition, eq_keys: Vec<(InputRef, InputRef)>) -> Self { + pub fn new( + other_cond: Condition, + eq_keys: Vec<(InputRef, InputRef)>, + left_cols_num: usize, + ) -> Self { Self { other_cond, eq_keys, + left_cols_num, } } @@ -72,7 +79,7 @@ impl EqJoinPredicate { /// ``` pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self { let (eq_keys, other_cond) = on_clause.split_eq_keys(left_cols_num, right_cols_num); - Self::new(other_cond, eq_keys) + Self::new(other_cond, eq_keys, left_cols_num) } /// Get join predicate's eq conds. @@ -121,17 +128,19 @@ impl EqJoinPredicate { pub fn eq_indexes(&self) -> Vec<(usize, usize)> { self.eq_keys .iter() - .map(|(left, right)| (left.index(), right.index())) + .map(|(left, right)| (left.index(), right.index() - self.left_cols_num)) .collect() } pub fn left_eq_indexes(&self) -> Vec { self.eq_keys.iter().map(|(left, _)| left.index()).collect() } + + /// return the eq keys column index **based on the right input schema** pub fn right_eq_indexes(&self) -> Vec { self.eq_keys .iter() - .map(|(_, right)| right.index()) + .map(|(_, right)| right.index() - self.left_cols_num) .collect() } } diff --git a/rust/frontend/src/optimizer/plan_node/logical_agg.rs b/rust/frontend/src/optimizer/plan_node/logical_agg.rs index b9143f02aff4..6268ab30897f 100644 --- a/rust/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/rust/frontend/src/optimizer/plan_node/logical_agg.rs @@ -225,22 +225,22 @@ impl LogicalAgg { } } - /// get the Mapping of columnIndex from input column index to out column index - pub fn o2i_col_mapping(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping { - let mut map = vec![None; exprs.len()]; - for (i, expr) in exprs.iter().enumerate() { - map[i] = match expr { - ExprImpl::InputRef(input) => Some(input.index()), - _ => None, - } + /// get the Mapping of columnIndex from input column index to output column index,if a input + /// column corresponds more than one out columns, mapping to any one + pub fn o2i_col_mapping(&self) -> ColIndexMapping { + let input_len = self.input.schema().len(); + let agg_cal_num = self.agg_calls().len(); + let group_keys = self.group_keys(); + let mut map = vec![None; agg_cal_num + group_keys.len()]; + for (i, key) in group_keys.iter().enumerate() { + map[i] = Some(*key); } ColIndexMapping::with_target_size(map, input_len) } - /// get the Mapping of columnIndex from input column index to output column index,if a input - /// column corresponds more than one out columns, mapping to any one - pub fn i2o_col_mapping(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping { - Self::o2i_col_mapping(input_len, exprs).inverse() + /// get the Mapping of columnIndex from input column index to out column index + pub fn i2o_col_mapping(&self) -> ColIndexMapping { + self.o2i_col_mapping().inverse() } fn derive_schema( diff --git a/rust/frontend/src/optimizer/plan_node/logical_join.rs b/rust/frontend/src/optimizer/plan_node/logical_join.rs index 5b263f6fd9e5..88ebb7e81062 100644 --- a/rust/frontend/src/optimizer/plan_node/logical_join.rs +++ b/rust/frontend/src/optimizer/plan_node/logical_join.rs @@ -337,8 +337,11 @@ impl ToBatch for LogicalJoin { // For inner joins, pull non-equal conditions to a filter operator on top of it let pull_filter = self.join_type == JoinType::Inner && predicate.has_non_eq(); if pull_filter { - let eq_cond = - EqJoinPredicate::new(Condition::true_cond(), predicate.eq_keys().to_vec()); + let eq_cond = EqJoinPredicate::new( + Condition::true_cond(), + predicate.eq_keys().to_vec(), + self.left.schema().len(), + ); let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond()); let hash_join = BatchHashJoin::new(logical_join, eq_cond).into(); let logical_filter = LogicalFilter::new(hash_join, predicate.non_eq_cond()); @@ -373,8 +376,11 @@ impl ToStream for LogicalJoin { // For inner joins, pull non-equal conditions to a filter operator on top of it let pull_filter = self.join_type == JoinType::Inner && predicate.has_non_eq(); if pull_filter { - let eq_cond = - EqJoinPredicate::new(Condition::true_cond(), predicate.eq_keys().to_vec()); + let eq_cond = EqJoinPredicate::new( + Condition::true_cond(), + predicate.eq_keys().to_vec(), + self.left.schema().len(), + ); let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond()); let hash_join = StreamHashJoin::new(logical_join, eq_cond).into(); let logical_filter = LogicalFilter::new(hash_join, predicate.non_eq_cond()); diff --git a/rust/frontend/src/optimizer/plan_node/logical_project.rs b/rust/frontend/src/optimizer/plan_node/logical_project.rs index 8127b59d7ece..6bc470ea35e6 100644 --- a/rust/frontend/src/optimizer/plan_node/logical_project.rs +++ b/rust/frontend/src/optimizer/plan_node/logical_project.rs @@ -56,7 +56,7 @@ impl LogicalProject { } /// get the Mapping of columnIndex from input column index to out column index - pub fn o2i_col_mapping(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping { + fn o2i_col_mapping_inner(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping { let mut map = vec![None; exprs.len()]; for (i, expr) in exprs.iter().enumerate() { map[i] = match expr { @@ -69,8 +69,16 @@ impl LogicalProject { /// get the Mapping of columnIndex from input column index to output column index,if a input /// column corresponds more than one out columns, mapping to any one - pub fn i2o_col_mapping(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping { - Self::o2i_col_mapping(input_len, exprs).inverse() + fn i2o_col_mapping_inner(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping { + Self::o2i_col_mapping_inner(input_len, exprs).inverse() + } + + pub fn o2i_col_mapping(&self) -> ColIndexMapping { + Self::o2i_col_mapping_inner(self.input.schema().len(), self.exprs()) + } + + pub fn i2o_col_mapping(&self) -> ColIndexMapping { + Self::i2o_col_mapping_inner(self.input.schema().len(), self.exprs()) } pub fn create( @@ -133,7 +141,7 @@ impl LogicalProject { } fn derive_pk(input_schema: &Schema, input_pk: &[usize], exprs: &[ExprImpl]) -> Vec { - let i2o = Self::i2o_col_mapping(input_schema.len(), exprs); + let i2o = Self::i2o_col_mapping_inner(input_schema.len(), exprs); input_pk .iter() .map(|pk_col| i2o.try_map(*pk_col)) @@ -251,21 +259,15 @@ impl ToBatch for LogicalProject { impl ToStream for LogicalProject { fn to_stream_with_dist_required(&self, required_dist: &Distribution) -> PlanRef { - let o2i = LogicalProject::o2i_col_mapping(self.input().schema().len(), self.exprs()); - let input_dist = match required_dist { - Distribution::HashShard(dists) => { - let input_dists = dists - .iter() - .map(|hash_col| o2i.try_map(*hash_col)) - .collect::>>(); - match input_dists { - Some(input_dists) => Distribution::HashShard(input_dists), - None => Distribution::AnyShard, - } - } - dist => dist.clone(), + let input_required = match required_dist { + Distribution::HashShard(_) => self + .o2i_col_mapping() + .rewrite_required_distribution(required_dist) + .unwrap_or(Distribution::AnyShard), + Distribution::AnyShard => Distribution::AnyShard, + _ => Distribution::Any, }; - let new_input = self.input().to_stream_with_dist_required(&input_dist); + let new_input = self.input().to_stream_with_dist_required(&input_required); let new_logical = self.clone_with_input(new_input); let stream_plan = StreamProject::new(new_logical); required_dist.enforce_if_not_satisfies(stream_plan.into(), Order::any()) @@ -278,7 +280,7 @@ impl ToStream for LogicalProject { let (input, input_col_change) = self.input.logical_rewrite_for_stream(); let (proj, out_col_change) = self.rewrite_with_input(input.clone(), input_col_change); let input_pk = input.pk_indices(); - let i2o = Self::i2o_col_mapping(input.schema().len(), proj.exprs()); + let i2o = Self::i2o_col_mapping_inner(input.schema().len(), proj.exprs()); let col_need_to_add = input_pk.iter().cloned().filter(|i| i2o.try_map(*i) == None); let input_schema = input.schema(); let (exprs, expr_alias) = proj diff --git a/rust/frontend/src/optimizer/plan_node/logical_topn.rs b/rust/frontend/src/optimizer/plan_node/logical_topn.rs index 0d8c6ec80758..263394c00a1b 100644 --- a/rust/frontend/src/optimizer/plan_node/logical_topn.rs +++ b/rust/frontend/src/optimizer/plan_node/logical_topn.rs @@ -70,7 +70,9 @@ impl PlanTreeNodeUnary for LogicalTopN { input, self.limit, self.offset, - input_col_change.rewrite_order(self.order.clone()), + input_col_change + .rewrite_required_order(&self.order) + .unwrap(), ), input_col_change, ) diff --git a/rust/frontend/src/optimizer/plan_node/stream_hash_agg.rs b/rust/frontend/src/optimizer/plan_node/stream_hash_agg.rs index aa4b5aa3032e..3803af0cc93c 100644 --- a/rust/frontend/src/optimizer/plan_node/stream_hash_agg.rs +++ b/rust/frontend/src/optimizer/plan_node/stream_hash_agg.rs @@ -33,14 +33,24 @@ impl StreamHashAgg { pub fn new(logical: LogicalAgg) -> Self { let ctx = logical.base.ctx.clone(); let pk_indices = logical.base.pk_indices.to_vec(); + let input = logical.input(); + let input_dist = input.distribution(); + let dist = match input_dist { + Distribution::Any => panic!(), + Distribution::Single => Distribution::Single, + Distribution::Broadcast => panic!(), + Distribution::AnyShard => panic!(), + Distribution::HashShard(_) => { + assert!( + input_dist.satisfies(&Distribution::HashShard(logical.group_keys().to_vec())) + ); + logical + .i2o_col_mapping() + .rewrite_provided_distribution(input_dist) + } + }; // Hash agg executor might change the append-only behavior of the stream. - let base = PlanBase::new_stream( - ctx, - logical.schema().clone(), - pk_indices, - Distribution::HashShard(logical.group_keys().to_vec()), - false, - ); + let base = PlanBase::new_stream(ctx, logical.schema().clone(), pk_indices, dist, false); StreamHashAgg { base, logical } } pub fn agg_calls(&self) -> &[PlanAggCall] { diff --git a/rust/frontend/src/optimizer/plan_node/stream_hash_join.rs b/rust/frontend/src/optimizer/plan_node/stream_hash_join.rs index 13f8ef454764..30ac148d2e01 100644 --- a/rust/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/rust/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -23,6 +23,7 @@ use super::{LogicalJoin, PlanBase, PlanRef, PlanTreeNodeBinary, ToStreamProst}; use crate::expr::Expr; use crate::optimizer::plan_node::EqJoinPredicate; use crate::optimizer::property::{Distribution, WithSchema}; +use crate::utils::ColIndexMapping; /// `BatchHashJoin` implements [`super::LogicalJoin`] with hash table. It builds a hash table /// from inner (right-side) relation and probes with data from outer (left-side) relation to @@ -45,12 +46,18 @@ impl StreamHashJoin { JoinType::Inner => logical.left().append_only() && logical.right().append_only(), _ => false, }; + let dist = Self::derive_dist( + logical.left().distribution(), + logical.right().distribution(), + &eq_join_predicate, + &logical.l2o_col_mapping(), + ); // TODO: derive from input let base = PlanBase::new_stream( ctx, logical.schema().clone(), logical.base.pk_indices.to_vec(), - Distribution::any().clone(), + dist, append_only, ); @@ -65,6 +72,23 @@ impl StreamHashJoin { pub fn eq_join_predicate(&self) -> &EqJoinPredicate { &self.eq_join_predicate } + + fn derive_dist( + left: &Distribution, + right: &Distribution, + predicate: &EqJoinPredicate, + l2o_mapping: &ColIndexMapping, + ) -> Distribution { + match (left, right) { + (Distribution::Single, Distribution::Single) => Distribution::Single, + (Distribution::HashShard(_), Distribution::HashShard(_)) => { + assert!(left.satisfies(&Distribution::HashShard(predicate.left_eq_indexes()))); + assert!(right.satisfies(&Distribution::HashShard(predicate.right_eq_indexes()))); + l2o_mapping.rewrite_provided_distribution(left) + } + (_, _) => panic!(), + } + } } impl fmt::Display for StreamHashJoin { diff --git a/rust/frontend/src/optimizer/plan_node/stream_project.rs b/rust/frontend/src/optimizer/plan_node/stream_project.rs index b929aa07be74..435cd2803c24 100644 --- a/rust/frontend/src/optimizer/plan_node/stream_project.rs +++ b/rust/frontend/src/optimizer/plan_node/stream_project.rs @@ -20,7 +20,7 @@ use risingwave_pb::stream_plan::ProjectNode; use super::{LogicalProject, PlanBase, PlanRef, PlanTreeNodeUnary, ToStreamProst}; use crate::expr::Expr; -use crate::optimizer::property::{Distribution, WithSchema}; +use crate::optimizer::property::WithSchema; /// `StreamProject` implements [`super::LogicalProject`] to evaluate specified expressions on input /// rows. @@ -41,20 +41,9 @@ impl StreamProject { let ctx = logical.base.ctx.clone(); let input = logical.input(); let pk_indices = logical.base.pk_indices.to_vec(); - let i2o = LogicalProject::i2o_col_mapping(logical.input().schema().len(), logical.exprs()); - let distribution = match input.distribution() { - Distribution::HashShard(dists) => { - let new_dists = dists - .iter() - .map(|hash_col| i2o.try_map(*hash_col)) - .collect::>>(); - match new_dists { - Some(new_dists) => Distribution::HashShard(new_dists), - None => Distribution::AnyShard, - } - } - dist => dist.clone(), - }; + let distribution = logical + .i2o_col_mapping() + .rewrite_provided_distribution(input.distribution()); // Project executor won't change the append-only behavior of the stream, so it depends on // input's `append_only`. let base = PlanBase::new_stream( diff --git a/rust/frontend/src/optimizer/plan_node/stream_simple_agg.rs b/rust/frontend/src/optimizer/plan_node/stream_simple_agg.rs index 495d515f7a8d..5924d7480d99 100644 --- a/rust/frontend/src/optimizer/plan_node/stream_simple_agg.rs +++ b/rust/frontend/src/optimizer/plan_node/stream_simple_agg.rs @@ -31,14 +31,16 @@ impl StreamSimpleAgg { pub fn new(logical: LogicalAgg) -> Self { let ctx = logical.base.ctx.clone(); let pk_indices = logical.base.pk_indices.to_vec(); + let input = logical.input(); + let input_dist = input.distribution(); + let dist = match input_dist { + Distribution::Any => Distribution::Any, + Distribution::Single => Distribution::Single, + _ => panic!(), + }; + // Simple agg executor might change the append-only behavior of the stream. - let base = PlanBase::new_stream( - ctx, - logical.schema().clone(), - pk_indices, - Distribution::any().clone(), - false, - ); + let base = PlanBase::new_stream(ctx, logical.schema().clone(), pk_indices, dist, false); StreamSimpleAgg { base, logical } } pub fn agg_calls(&self) -> &[PlanAggCall] { diff --git a/rust/frontend/src/optimizer/property/distribution.rs b/rust/frontend/src/optimizer/property/distribution.rs index 5069b3909373..b91e83b93c21 100644 --- a/rust/frontend/src/optimizer/property/distribution.rs +++ b/rust/frontend/src/optimizer/property/distribution.rs @@ -92,7 +92,7 @@ impl Distribution { // +-------+-------+ // |hash_shard(a,b)| // +---------------+ - fn satisfies(&self, other: &Distribution) -> bool { + pub fn satisfies(&self, other: &Distribution) -> bool { match self { Distribution::Any => matches!(other, Distribution::Any), Distribution::Single => matches!(other, Distribution::Any | Distribution::Single), diff --git a/rust/frontend/src/utils/column_index_mapping.rs b/rust/frontend/src/utils/column_index_mapping.rs index f145566ea97b..5ef6afe14c69 100644 --- a/rust/frontend/src/utils/column_index_mapping.rs +++ b/rust/frontend/src/utils/column_index_mapping.rs @@ -20,7 +20,7 @@ use itertools::Itertools; use log::debug; use crate::expr::{ExprImpl, ExprRewriter, InputRef}; -use crate::optimizer::property::{Distribution, Order}; +use crate::optimizer::property::{Distribution, FieldOrder, Order}; /// `ColIndexMapping` is a partial mapping from usize to usize. /// @@ -226,22 +226,77 @@ impl ColIndexMapping { self.target_size() == 0 } - pub fn rewrite_order(&self, mut order: Order) -> Order { - for field in &mut order.field_order { - field.index = self.map(field.index) + /// Rewrite the provided order's field index. It will try its best to give the most accurate + /// order. Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2) + /// Order(0,1,2) with mapping(0->1,2->0) will be rewritten to Order(1) + pub fn rewrite_provided_order(&self, order: &Order) -> Order { + let mut mapped_field = vec![]; + for field in &order.field_order { + match self.try_map(field.index) { + Some(mapped_index) => mapped_field.push(FieldOrder { + index: mapped_index, + direct: field.direct, + }), + None => break, + } + } + Order { + field_order: mapped_field, } + } + + /// Rewrite the required order's field index. if it can't give a corresponding + /// required order after the column index mapping, it will return None. + /// Order(0,1,2) with mapping(0->1,1->0,2->2) will be rewritten to Order(1,0,2) + /// Order(0,1,2) with mapping(0->1,2->0) will return None + pub fn rewrite_required_order(&self, order: &Order) -> Option { order + .field_order + .iter() + .map(|field| { + self.try_map(field.index).map(|mapped_index| FieldOrder { + index: mapped_index, + direct: field.direct, + }) + }) + .collect::>>() + .map(|mapped_field| Order { + field_order: mapped_field, + }) } - pub fn rewrite_distribution(&mut self, dist: Distribution) -> Distribution { + /// Rewrite the provided distribution's field index. It will try its best to give the most + /// accurate distribution. + /// HashShard(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to HashShard(1,0,2). + /// HashShard(0,1,2), with mapping(0->1,2->0) will be rewritten to `AnyShard`. + pub fn rewrite_provided_distribution(&self, dist: &Distribution) -> Distribution { match dist { - Distribution::HashShard(mut col_idxes) => { - for idx in &mut col_idxes { - *idx = self.map(*idx); + Distribution::HashShard(col_idxes) => { + let mapped_dist = col_idxes + .iter() + .map(|col_idx| self.try_map(*col_idx)) + .collect::>>(); + match mapped_dist { + Some(col_idx) => Distribution::HashShard(col_idx), + None => Distribution::AnyShard, } - Distribution::HashShard(col_idxes) } - _ => dist, + _ => dist.clone(), + } + } + + /// Rewrite the required distribution's field index. if it can't give a corresponding + /// required distribution after the column index mapping, it will return None. + /// HashShard(0,1,2), with mapping(0->1,1->0,2->2) will be rewritten to HashShard(1,0,2). + /// HashShard(0,1,2), with mapping(0->1,2->0) will return None. + pub fn rewrite_required_distribution(&self, dist: &Distribution) -> Option { + match dist { + Distribution::HashShard(col_idxes) => col_idxes + .iter() + .map(|col_idx| self.try_map(*col_idx)) + .collect::>>() + .map(Distribution::HashShard), + _ => Some(dist.clone()), } } diff --git a/rust/frontend/test_runner/tests/testdata/basic_query_1.yaml b/rust/frontend/test_runner/tests/testdata/basic_query_1.yaml index 3129474daade..13a3632e9631 100644 --- a/rust/frontend/test_runner/tests/testdata/basic_query_1.yaml +++ b/rust/frontend/test_runner/tests/testdata/basic_query_1.yaml @@ -1,7 +1,6 @@ - sql: values (11, 22), (33+(1+2), 44); batch_plan: | - BatchExchange { order: [], dist: Single } - BatchValues { rows: [[11:Int32, 22:Int32], [(33:Int32 + (1:Int32 + 2:Int32)), 44:Int32]] } + BatchValues { rows: [[11:Int32, 22:Int32], [(33:Int32 + (1:Int32 + 2:Int32)), 44:Int32]] } - sql: select * from t binder_error: 'Catalog error: table not found: t' - sql: | @@ -43,15 +42,15 @@ create table t (); select (((((false is not true) is true) is not false) is false) is not null) is null from t; batch_plan: | - BatchProject { exprs: [IsNull(IsNotNull(IsFalse(IsNotFalse(IsTrue(IsNotTrue(false:Boolean))))))], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } + BatchExchange { order: [], dist: Single } + BatchProject { exprs: [IsNull(IsNotNull(IsFalse(IsNotFalse(IsTrue(IsNotTrue(false:Boolean))))))], expr_alias: [ ] } BatchScan { table: t, columns: [] } - sql: | create table t (v1 int); select (case when v1=1 then 1 when v1=2 then 2 else 0.0 end) from t; batch_plan: | - BatchProject { exprs: [Case(($0 = 1:Int32), 1:Int32::Decimal, ($0 = 2:Int32), 2:Int32::Decimal, Normalized(0.0):Decimal)], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } + BatchExchange { order: [], dist: Single } + BatchProject { exprs: [Case(($0 = 1:Int32), 1:Int32::Decimal, ($0 = 2:Int32), 2:Int32::Decimal, Normalized(0.0):Decimal)], expr_alias: [ ] } BatchScan { table: t, columns: [v1] } stream_plan: | StreamMaterialize { columns: ["expr#0", "expr#1"(hidden)], pk_columns: ["expr#1"] } @@ -61,15 +60,13 @@ select length(trim(trailing '1' from '12'))+length(trim(leading '2' from '23'))+length(trim(both '3' from '34')); batch_plan: | BatchProject { exprs: [((Length(Rtrim("12":Varchar, "1":Varchar)) + Length(Ltrim("23":Varchar, "2":Varchar))) + Length(Trim("34":Varchar, "3":Varchar)))], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } - BatchValues { rows: [[]] } + BatchValues { rows: [[]] } - sql: | select position(replace('1','1','2'),'123') where '12' like '%1'; batch_plan: | BatchProject { exprs: [Position(Replace("1":Varchar, "1":Varchar, "2":Varchar), "123":Varchar)], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } - BatchFilter { predicate: Like("12":Varchar, "%1":Varchar) } - BatchValues { rows: [[]] } + BatchFilter { predicate: Like("12":Varchar, "%1":Varchar) } + BatchValues { rows: [[]] } - sql: | create table t (v1 int, v2 int); insert into t values (22, 33), (44, 55); diff --git a/rust/frontend/test_runner/tests/testdata/basic_query_2.yaml b/rust/frontend/test_runner/tests/testdata/basic_query_2.yaml index 60922ab5c68c..d68c8904b281 100644 --- a/rust/frontend/test_runner/tests/testdata/basic_query_2.yaml +++ b/rust/frontend/test_runner/tests/testdata/basic_query_2.yaml @@ -16,13 +16,11 @@ - sql: | values(cast(1 as bigint)); batch_plan: | - BatchExchange { order: [], dist: Single } - BatchValues { rows: [[1:Int32::Int64]] } + BatchValues { rows: [[1:Int32::Int64]] } - sql: | values(not true); batch_plan: | - BatchExchange { order: [], dist: Single } - BatchValues { rows: [[Not(true:Boolean)]] } + BatchValues { rows: [[Not(true:Boolean)]] } - sql: | values(must_be_unimplemented_func(1)); binder_error: 'Feature is not yet implemented: unsupported function: Ident { value: "must_be_unimplemented_func", quote_style: None }' @@ -54,9 +52,9 @@ BatchHashJoin { type: Inner, predicate: $1 = $4 } BatchExchange { order: [], dist: HashShard([1]) } BatchScan { table: t1, columns: [_row_id, v1, v2] } - BatchExchange { order: [], dist: HashShard([4]) } + BatchExchange { order: [], dist: HashShard([1]) } BatchScan { table: t2, columns: [_row_id, v1, v2] } - BatchExchange { order: [], dist: HashShard([8]) } + BatchExchange { order: [], dist: HashShard([2]) } BatchScan { table: t3, columns: [_row_id, v1, v2] } stream_plan: | StreamMaterialize { columns: ["_row_id", "v1", "v2", "_row_id", "v1", "v2", "_row_id", "v1", "v2"], pk_columns: ["_row_id", "_row_id", "_row_id"] } @@ -65,21 +63,21 @@ StreamHashJoin { type: Inner, predicate: $1 = $4 } StreamExchange { dist: HashShard([1]) } StreamTableScan { table: t1, columns: [_row_id, v1, v2], pk_indices: [0] } - StreamExchange { dist: HashShard([4]) } + StreamExchange { dist: HashShard([1]) } StreamTableScan { table: t2, columns: [_row_id, v1, v2], pk_indices: [0] } - StreamExchange { dist: HashShard([8]) } + StreamExchange { dist: HashShard([2]) } StreamTableScan { table: t3, columns: [_row_id, v1, v2], pk_indices: [0] } - sql: | create table t1 (v1 int not null, v2 int not null); create table t2 (v1 int not null, v2 int not null); select t1.v2, t2.v2 from t1 join t2 on t1.v1 = t2.v1; batch_plan: | - BatchProject { exprs: [$1, $3], expr_alias: [v2, v2] } - BatchExchange { order: [], dist: Single } + BatchExchange { order: [], dist: Single } + BatchProject { exprs: [$1, $3], expr_alias: [v2, v2] } BatchHashJoin { type: Inner, predicate: $0 = $2 } BatchExchange { order: [], dist: HashShard([0]) } BatchScan { table: t1, columns: [v1, v2] } - BatchExchange { order: [], dist: HashShard([2]) } + BatchExchange { order: [], dist: HashShard([0]) } BatchScan { table: t2, columns: [v1, v2] } stream_plan: | StreamMaterialize { columns: ["v2", "v2", "expr#2"(hidden), "expr#3"(hidden)], pk_columns: ["expr#2", "expr#3"] } @@ -87,19 +85,18 @@ StreamHashJoin { type: Inner, predicate: $0 = $3 } StreamExchange { dist: HashShard([0]) } StreamTableScan { table: t1, columns: [v1, v2, _row_id], pk_indices: [2] } - StreamExchange { dist: HashShard([3]) } + StreamExchange { dist: HashShard([0]) } StreamTableScan { table: t2, columns: [v1, v2, _row_id], pk_indices: [2] } - sql: select 1 batch_plan: | BatchProject { exprs: [1:Int32], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } - BatchValues { rows: [[]] } + BatchValues { rows: [[]] } - sql: | create table t(v1 int, v2 int, v3 int); select v1, min(v2) + max(v3) * count(v1) from t group by v1; batch_plan: | - BatchProject { exprs: [$0, ($1 + ($2 * $3))], expr_alias: [v1, ] } - BatchExchange { order: [], dist: Single } + BatchExchange { order: [], dist: Single } + BatchProject { exprs: [$0, ($1 + ($2 * $3))], expr_alias: [v1, ] } BatchHashAgg { group_keys: [$0], aggs: [min($1), max($2), count($3)] } BatchExchange { order: [], dist: HashShard([0]) } BatchProject { exprs: [$0, $1, $2, $0], expr_alias: [ , , , ] } @@ -118,10 +115,9 @@ select min(v1) + max(v2) * count(v3) from t; batch_plan: | BatchProject { exprs: [($0 + ($1 * $2))], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } - BatchSimpleAgg { aggs: [min($0), max($1), count($2)] } - BatchExchange { order: [], dist: Single } - BatchScan { table: t, columns: [v1, v2, v3] } + BatchSimpleAgg { aggs: [min($0), max($1), count($2)] } + BatchExchange { order: [], dist: Single } + BatchScan { table: t, columns: [v1, v2, v3] } stream_plan: | StreamMaterialize { columns: ["expr#0", "expr#1"(hidden), "expr#2"(hidden), "expr#3"(hidden), "expr#4"(hidden)], pk_columns: ["expr#1", "expr#2", "expr#3", "expr#4"] } StreamProject { exprs: [($1 + ($2 * $3)), $0, $1, $2, $3], expr_alias: [ , , , , ] } @@ -135,8 +131,8 @@ create table t(v1 int, v2 int, v3 int); select v3, min(v1) * avg(v1+v2) from t group by v3; batch_plan: | - BatchProject { exprs: [$0, ($1 * ($2 / $3))], expr_alias: [v3, ] } - BatchExchange { order: [], dist: Single } + BatchExchange { order: [], dist: Single } + BatchProject { exprs: [$0, ($1 * ($2 / $3))], expr_alias: [v3, ] } BatchHashAgg { group_keys: [$0], aggs: [min($1), sum($2), count($2)] } BatchProject { exprs: [$2, $0, ($0 + $1)], expr_alias: [ , , ] } BatchExchange { order: [], dist: HashShard([2]) } diff --git a/rust/frontend/test_runner/tests/testdata/tpch.yaml b/rust/frontend/test_runner/tests/testdata/tpch.yaml index 9c40828d4900..859e992dfa95 100644 --- a/rust/frontend/test_runner/tests/testdata/tpch.yaml +++ b/rust/frontend/test_runner/tests/testdata/tpch.yaml @@ -124,10 +124,10 @@ BatchExchange { order: [], dist: HashShard([0]) } BatchFilter { predicate: ($1 = "FURNITURE":Varchar) AND true:Boolean AND true:Boolean } BatchScan { table: customer, columns: [c_custkey, c_mktsegment] } - BatchExchange { order: [], dist: HashShard([2]) } + BatchExchange { order: [], dist: HashShard([1]) } BatchFilter { predicate: } BatchScan { table: orders, columns: [o_orderkey, o_custkey, o_orderdate, o_shippriority] } - BatchExchange { order: [], dist: HashShard([3]) } + BatchExchange { order: [], dist: HashShard([0]) } BatchFilter { predicate: } BatchScan { table: lineitem, columns: [l_orderkey, l_extendedprice, l_discount] } stream_plan: | @@ -144,10 +144,10 @@ StreamExchange { dist: HashShard([0]) } StreamFilter { predicate: ($1 = "FURNITURE":Varchar) AND true:Boolean AND true:Boolean } StreamTableScan { table: customer, columns: [c_custkey, c_mktsegment, _row_id], pk_indices: [2] } - StreamExchange { dist: HashShard([3]) } + StreamExchange { dist: HashShard([1]) } StreamFilter { predicate: } StreamTableScan { table: orders, columns: [o_orderkey, o_custkey, o_orderdate, o_shippriority, _row_id], pk_indices: [4] } - StreamExchange { dist: HashShard([5]) } + StreamExchange { dist: HashShard([0]) } StreamFilter { predicate: } StreamTableScan { table: lineitem, columns: [l_orderkey, l_extendedprice, l_discount, _row_id], pk_indices: [3] } - id: tpch_q6 @@ -165,12 +165,11 @@ and */ l_quantity < 24; batch_plan: | BatchProject { exprs: [$0], expr_alias: [revenue] } - BatchExchange { order: [], dist: Single } - BatchSimpleAgg { aggs: [sum($0)] } + BatchSimpleAgg { aggs: [sum($0)] } + BatchExchange { order: [], dist: Single } BatchProject { exprs: [($1 * $2)], expr_alias: [ ] } - BatchExchange { order: [], dist: Single } - BatchFilter { predicate: ($0 < 24:Int32) } - BatchScan { table: lineitem, columns: [l_quantity, l_extendedprice, l_discount] } + BatchFilter { predicate: ($0 < 24:Int32) } + BatchScan { table: lineitem, columns: [l_quantity, l_extendedprice, l_discount] } stream_plan: | StreamMaterialize { columns: ["revenue", "expr#1"(hidden)], pk_columns: ["expr#1", "revenue"] } StreamProject { exprs: [$1, $0], expr_alias: [revenue, ] }