From b620f158af5e508d2bc38f4c25eead41785f6558 Mon Sep 17 00:00:00 2001
From: stonepage <40830455+st1page@users.noreply.github.com>
Date: Mon, 28 Mar 2022 14:31:17 +0800
Subject: [PATCH] feat(optimizer): distribution derive and pass through (#1313)
* fix project dist passThough
* fix project dist passThough
* fix agg mapping
* add dist derive for batch hash agg
* fix col mapping rewrite for distribution and order
* fix agg dist derive
* fix agg dist derive
* hash join dist derive
* fix comments
* rewrite provided && required with mapping
* fix some
* fix clippy
* fix some
* change values as single and fix ut
* refactor project pass through
* refactor project mapping
* fix comments
* fix simple agg
* make some fn private
* fix clippy
* remove useless inner mapping function for agg
* fix hash join distribution
* fix ut
* chang batch values new
---
rust/frontend/src/optimizer/mod.rs | 11 ++-
.../src/optimizer/plan_node/batch_hash_agg.rs | 18 +++--
.../optimizer/plan_node/batch_hash_join.rs | 31 ++++++--
.../src/optimizer/plan_node/batch_project.rs | 41 +++-------
.../src/optimizer/plan_node/batch_seq_scan.rs | 25 +++++--
.../optimizer/plan_node/batch_simple_agg.rs | 14 ++--
.../src/optimizer/plan_node/batch_values.rs | 19 +++--
.../optimizer/plan_node/eq_join_predicate.rs | 17 ++++-
.../src/optimizer/plan_node/logical_agg.rs | 24 +++---
.../src/optimizer/plan_node/logical_join.rs | 14 +++-
.../optimizer/plan_node/logical_project.rs | 40 +++++-----
.../src/optimizer/plan_node/logical_topn.rs | 4 +-
.../optimizer/plan_node/stream_hash_agg.rs | 24 ++++--
.../optimizer/plan_node/stream_hash_join.rs | 26 ++++++-
.../src/optimizer/plan_node/stream_project.rs | 19 +----
.../optimizer/plan_node/stream_simple_agg.rs | 16 ++--
.../src/optimizer/property/distribution.rs | 2 +-
.../src/utils/column_index_mapping.rs | 75 ++++++++++++++++---
.../tests/testdata/basic_query_1.yaml | 19 ++---
.../tests/testdata/basic_query_2.yaml | 40 +++++-----
.../test_runner/tests/testdata/tpch.yaml | 17 ++---
21 files changed, 307 insertions(+), 189 deletions(-)
diff --git a/rust/frontend/src/optimizer/mod.rs b/rust/frontend/src/optimizer/mod.rs
index c1dd698edca3..ac995cbcd7c6 100644
--- a/rust/frontend/src/optimizer/mod.rs
+++ b/rust/frontend/src/optimizer/mod.rs
@@ -174,10 +174,13 @@ impl PlanRoot {
let stream_plan = match self.plan.convention() {
Convention::Logical => {
let plan = self.gen_optimized_logical_plan();
- let (plan, mut out_col_change) = plan.logical_rewrite_for_stream();
- self.required_dist =
- out_col_change.rewrite_distribution(self.required_dist.clone());
- self.required_order = out_col_change.rewrite_order(self.required_order.clone());
+ let (plan, out_col_change) = plan.logical_rewrite_for_stream();
+ self.required_dist = out_col_change
+ .rewrite_required_distribution(&self.required_dist)
+ .unwrap();
+ self.required_order = out_col_change
+ .rewrite_required_order(&self.required_order)
+ .unwrap();
self.out_fields = out_col_change.rewrite_bitset(&self.out_fields);
self.schema = plan.schema().clone();
plan.to_stream_with_dist_required(&self.required_dist)
diff --git a/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs b/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
index a9083745437a..9d8c736825d9 100644
--- a/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
+++ b/rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
@@ -33,12 +33,18 @@ pub struct BatchHashAgg {
impl BatchHashAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
- let base = PlanBase::new_batch(
- ctx,
- logical.schema().clone(),
- Distribution::any().clone(),
- Order::any().clone(),
- );
+ let input = logical.input();
+ let input_dist = input.distribution();
+ let dist = match input_dist {
+ Distribution::Any => Distribution::Any,
+ Distribution::Single => Distribution::Single,
+ Distribution::Broadcast => panic!(),
+ Distribution::AnyShard => Distribution::AnyShard,
+ Distribution::HashShard(_) => logical
+ .i2o_col_mapping()
+ .rewrite_provided_distribution(input_dist),
+ };
+ let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
BatchHashAgg { base, logical }
}
pub fn agg_calls(&self) -> &[PlanAggCall] {
diff --git a/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs b/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
index 384698d0fc50..b12c1bb7de13 100644
--- a/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
+++ b/rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
@@ -23,6 +23,7 @@ use super::{
ToDistributedBatch,
};
use crate::optimizer::property::{Distribution, Order, WithSchema};
+use crate::utils::ColIndexMapping;
/// `BatchHashJoin` implements [`super::LogicalJoin`] with hash table. It builds a hash table
/// from inner (right-side) relation and then probes with data from outer (left-side) relation to
@@ -40,13 +41,13 @@ pub struct BatchHashJoin {
impl BatchHashJoin {
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
let ctx = logical.base.ctx.clone();
- // TODO: derive from input
- let base = PlanBase::new_batch(
- ctx,
- logical.schema().clone(),
- Distribution::any().clone(),
- Order::any().clone(),
+ let dist = Self::derive_dist(
+ logical.left().distribution(),
+ logical.right().distribution(),
+ &eq_join_predicate,
+ &logical.l2o_col_mapping(),
);
+ let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
Self {
base,
@@ -55,6 +56,24 @@ impl BatchHashJoin {
}
}
+ fn derive_dist(
+ left: &Distribution,
+ right: &Distribution,
+ predicate: &EqJoinPredicate,
+ l2o_mapping: &ColIndexMapping,
+ ) -> Distribution {
+ match (left, right) {
+ (Distribution::Any, Distribution::Any) => Distribution::Any,
+ (Distribution::Single, Distribution::Single) => Distribution::Single,
+ (Distribution::HashShard(_), Distribution::HashShard(_)) => {
+ assert!(left.satisfies(&Distribution::HashShard(predicate.left_eq_indexes())));
+ assert!(right.satisfies(&Distribution::HashShard(predicate.right_eq_indexes())));
+ l2o_mapping.rewrite_provided_distribution(left)
+ }
+ (_, _) => panic!(),
+ }
+ }
+
/// Get a reference to the batch hash join's eq join predicate.
pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
&self.eq_join_predicate
diff --git a/rust/frontend/src/optimizer/plan_node/batch_project.rs b/rust/frontend/src/optimizer/plan_node/batch_project.rs
index df6e4ba0e8ef..7dd4fcc34647 100644
--- a/rust/frontend/src/optimizer/plan_node/batch_project.rs
+++ b/rust/frontend/src/optimizer/plan_node/batch_project.rs
@@ -36,20 +36,9 @@ pub struct BatchProject {
impl BatchProject {
pub fn new(logical: LogicalProject) -> Self {
let ctx = logical.base.ctx.clone();
- let i2o = LogicalProject::i2o_col_mapping(logical.input().schema().len(), logical.exprs());
- let distribution = match logical.input().distribution() {
- Distribution::HashShard(dists) => {
- let new_dists = dists
- .iter()
- .map(|hash_col| i2o.try_map(*hash_col))
- .collect::