diff --git a/rust/frontend/src/optimizer/mod.rs b/rust/frontend/src/optimizer/mod.rs
index c1dd698edca3..ac995cbcd7c6 100644
--- a/rust/frontend/src/optimizer/mod.rs
+++ b/rust/frontend/src/optimizer/mod.rs
@@ -174,10 +174,13 @@ impl PlanRoot {
let stream_plan = match self.plan.convention() {
Convention::Logical => {
let plan = self.gen_optimized_logical_plan();
- let (plan, mut out_col_change) = plan.logical_rewrite_for_stream();
- self.required_dist =
- out_col_change.rewrite_distribution(self.required_dist.clone());
- self.required_order = out_col_change.rewrite_order(self.required_order.clone());
+ let (plan, out_col_change) = plan.logical_rewrite_for_stream();
+ self.required_dist = out_col_change
+ .rewrite_required_distribution(&self.required_dist)
+ .unwrap();
+ self.required_order = out_col_change
+ .rewrite_required_order(&self.required_order)
+ .unwrap();
self.out_fields = out_col_change.rewrite_bitset(&self.out_fields);
self.schema = plan.schema().clone();
plan.to_stream_with_dist_required(&self.required_dist)
diff --git a/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs b/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
index a9083745437a..9d8c736825d9 100644
--- a/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
+++ b/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
@@ -33,12 +33,18 @@ pub struct BatchHashAgg {
impl BatchHashAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
- let base = PlanBase::new_batch(
- ctx,
- logical.schema().clone(),
- Distribution::any().clone(),
- Order::any().clone(),
- );
+ let input = logical.input();
+ let input_dist = input.distribution();
+ let dist = match input_dist {
+ Distribution::Any => Distribution::Any,
+ Distribution::Single => Distribution::Single,
+ Distribution::Broadcast => panic!(),
+ Distribution::AnyShard => Distribution::AnyShard,
+ Distribution::HashShard(_) => logical
+ .i2o_col_mapping()
+ .rewrite_provided_distribution(input_dist),
+ };
+ let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
BatchHashAgg { base, logical }
}
pub fn agg_calls(&self) -> &[PlanAggCall] {
diff --git a/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs b/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
index 384698d0fc50..b12c1bb7de13 100644
--- a/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
+++ b/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
@@ -23,6 +23,7 @@ use super::{
ToDistributedBatch,
};
use crate::optimizer::property::{Distribution, Order, WithSchema};
+use crate::utils::ColIndexMapping;
/// `BatchHashJoin` implements [`super::LogicalJoin`] with hash table. It builds a hash table
/// from inner (right-side) relation and then probes with data from outer (left-side) relation to
@@ -40,13 +41,13 @@ pub struct BatchHashJoin {
impl BatchHashJoin {
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
let ctx = logical.base.ctx.clone();
- // TODO: derive from input
- let base = PlanBase::new_batch(
- ctx,
- logical.schema().clone(),
- Distribution::any().clone(),
- Order::any().clone(),
+ let dist = Self::derive_dist(
+ logical.left().distribution(),
+ logical.right().distribution(),
+ &eq_join_predicate,
+ &logical.l2o_col_mapping(),
);
+ let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
Self {
base,
@@ -55,6 +56,24 @@ impl BatchHashJoin {
}
}
+ fn derive_dist(
+ left: &Distribution,
+ right: &Distribution,
+ predicate: &EqJoinPredicate,
+ l2o_mapping: &ColIndexMapping,
+ ) -> Distribution {
+ match (left, right) {
+ (Distribution::Any, Distribution::Any) => Distribution::Any,
+ (Distribution::Single, Distribution::Single) => Distribution::Single,
+ (Distribution::HashShard(_), Distribution::HashShard(_)) => {
+ assert!(left.satisfies(&Distribution::HashShard(predicate.left_eq_indexes())));
+ assert!(right.satisfies(&Distribution::HashShard(predicate.right_eq_indexes())));
+ l2o_mapping.rewrite_provided_distribution(left)
+ }
+ (_, _) => panic!(),
+ }
+ }
+
/// Get a reference to the batch hash join's eq join predicate.
pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
&self.eq_join_predicate
diff --git a/rust/frontend/src/optimizer/plan_node/batch_project.rs b/rust/frontend/src/optimizer/plan_node/batch_project.rs
index df6e4ba0e8ef..7dd4fcc34647 100644
--- a/rust/frontend/src/optimizer/plan_node/batch_project.rs
+++ b/rust/frontend/src/optimizer/plan_node/batch_project.rs
@@ -36,20 +36,9 @@ pub struct BatchProject {
impl BatchProject {
pub fn new(logical: LogicalProject) -> Self {
let ctx = logical.base.ctx.clone();
- let i2o = LogicalProject::i2o_col_mapping(logical.input().schema().len(), logical.exprs());
- let distribution = match logical.input().distribution() {
- Distribution::HashShard(dists) => {
- let new_dists = dists
- .iter()
- .map(|hash_col| i2o.try_map(*hash_col))
- .collect::