Skip to content

Commit

Permalink
feat(optimizer): distribution derive and pass through (risingwavelabs…
Browse files Browse the repository at this point in the history
…#1313)

* fix project dist passThough

* fix project dist passThough

* fix agg mapping

* add dist derive for batch hash agg

* fix col mapping rewrite for distribution and order

* fix agg dist derive

* fix agg dist derive

* hash join dist derive

* fix comments

* rewrite provided && required with mapping

* fix some

* fix clippy

* fix some

* change values as single and fix ut

* refactor project pass through

* refactor project mapping

* fix comments

* fix simple agg

* make some fn private

* fix clippy

* remove useless inner mapping function for agg

* fix hash join distribution

* fix ut

* chang batch values new
  • Loading branch information
st1page authored Mar 28, 2022
1 parent 6c818e6 commit b620f15
Show file tree
Hide file tree
Showing 21 changed files with 307 additions and 189 deletions.
11 changes: 7 additions & 4 deletions rust/frontend/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,13 @@ impl PlanRoot {
let stream_plan = match self.plan.convention() {
Convention::Logical => {
let plan = self.gen_optimized_logical_plan();
let (plan, mut out_col_change) = plan.logical_rewrite_for_stream();
self.required_dist =
out_col_change.rewrite_distribution(self.required_dist.clone());
self.required_order = out_col_change.rewrite_order(self.required_order.clone());
let (plan, out_col_change) = plan.logical_rewrite_for_stream();
self.required_dist = out_col_change
.rewrite_required_distribution(&self.required_dist)
.unwrap();
self.required_order = out_col_change
.rewrite_required_order(&self.required_order)
.unwrap();
self.out_fields = out_col_change.rewrite_bitset(&self.out_fields);
self.schema = plan.schema().clone();
plan.to_stream_with_dist_required(&self.required_dist)
Expand Down
18 changes: 12 additions & 6 deletions rust/frontend/src/optimizer/plan_node/batch_hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ pub struct BatchHashAgg {
impl BatchHashAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
Distribution::any().clone(),
Order::any().clone(),
);
let input = logical.input();
let input_dist = input.distribution();
let dist = match input_dist {
Distribution::Any => Distribution::Any,
Distribution::Single => Distribution::Single,
Distribution::Broadcast => panic!(),
Distribution::AnyShard => Distribution::AnyShard,
Distribution::HashShard(_) => logical
.i2o_col_mapping()
.rewrite_provided_distribution(input_dist),
};
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
BatchHashAgg { base, logical }
}
pub fn agg_calls(&self) -> &[PlanAggCall] {
Expand Down
31 changes: 25 additions & 6 deletions rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use super::{
ToDistributedBatch,
};
use crate::optimizer::property::{Distribution, Order, WithSchema};
use crate::utils::ColIndexMapping;

/// `BatchHashJoin` implements [`super::LogicalJoin`] with hash table. It builds a hash table
/// from inner (right-side) relation and then probes with data from outer (left-side) relation to
Expand All @@ -40,13 +41,13 @@ pub struct BatchHashJoin {
impl BatchHashJoin {
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
let ctx = logical.base.ctx.clone();
// TODO: derive from input
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
Distribution::any().clone(),
Order::any().clone(),
let dist = Self::derive_dist(
logical.left().distribution(),
logical.right().distribution(),
&eq_join_predicate,
&logical.l2o_col_mapping(),
);
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());

Self {
base,
Expand All @@ -55,6 +56,24 @@ impl BatchHashJoin {
}
}

fn derive_dist(
left: &Distribution,
right: &Distribution,
predicate: &EqJoinPredicate,
l2o_mapping: &ColIndexMapping,
) -> Distribution {
match (left, right) {
(Distribution::Any, Distribution::Any) => Distribution::Any,
(Distribution::Single, Distribution::Single) => Distribution::Single,
(Distribution::HashShard(_), Distribution::HashShard(_)) => {
assert!(left.satisfies(&Distribution::HashShard(predicate.left_eq_indexes())));
assert!(right.satisfies(&Distribution::HashShard(predicate.right_eq_indexes())));
l2o_mapping.rewrite_provided_distribution(left)
}
(_, _) => panic!(),
}
}

/// Get a reference to the batch hash join's eq join predicate.
pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
&self.eq_join_predicate
Expand Down
41 changes: 12 additions & 29 deletions rust/frontend/src/optimizer/plan_node/batch_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,9 @@ pub struct BatchProject {
impl BatchProject {
pub fn new(logical: LogicalProject) -> Self {
let ctx = logical.base.ctx.clone();
let i2o = LogicalProject::i2o_col_mapping(logical.input().schema().len(), logical.exprs());
let distribution = match logical.input().distribution() {
Distribution::HashShard(dists) => {
let new_dists = dists
.iter()
.map(|hash_col| i2o.try_map(*hash_col))
.collect::<Option<Vec<_>>>();
match new_dists {
Some(new_dists) => Distribution::HashShard(new_dists),
None => Distribution::AnyShard,
}
}
dist => dist.clone(),
};
let distribution = logical
.i2o_col_mapping()
.rewrite_provided_distribution(logical.input().distribution());
// TODO: Derive order from input
let base = PlanBase::new_batch(
ctx,
Expand Down Expand Up @@ -96,24 +85,18 @@ impl ToDistributedBatch for BatchProject {
required_order: &Order,
required_dist: &Distribution,
) -> PlanRef {
let o2i =
LogicalProject::o2i_col_mapping(self.input().schema().len(), self.logical.exprs());
let input_dist = match required_dist {
Distribution::HashShard(dists) => {
let input_dists = dists
.iter()
.map(|hash_col| o2i.try_map(*hash_col))
.collect::<Option<Vec<_>>>();
match input_dists {
Some(input_dists) => Distribution::HashShard(input_dists),
None => Distribution::AnyShard,
}
}
dist => dist.clone(),
let input_required = match required_dist {
Distribution::HashShard(_) => self
.logical
.o2i_col_mapping()
.rewrite_required_distribution(required_dist)
.unwrap_or(Distribution::AnyShard),
Distribution::AnyShard => Distribution::AnyShard,
_ => Distribution::Any,
};
let new_input = self
.input()
.to_distributed_with_required(required_order, &input_dist);
.to_distributed_with_required(required_order, &input_required);
let new_logical = self.logical.clone_with_input(new_input);
let batch_plan = BatchProject::new(new_logical);
let batch_plan = required_order.enforce_if_not_satisfies(batch_plan.into());
Expand Down
25 changes: 17 additions & 8 deletions rust/frontend/src/optimizer/plan_node/batch_seq_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,27 @@ impl WithSchema for BatchSeqScan {
}

impl BatchSeqScan {
pub fn new(logical: LogicalScan) -> Self {
pub fn new_inner(logical: LogicalScan, dist: Distribution) -> Self {
let ctx = logical.base.ctx.clone();
// TODO: derive from input
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
Distribution::any().clone(),
Order::any().clone(),
);
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());

Self { base, logical }
}

pub fn new(logical: LogicalScan) -> Self {
Self::new_inner(logical, Distribution::Any)
}

pub fn new_with_dist(logical: LogicalScan) -> Self {
Self::new_inner(logical, Distribution::AnyShard)
}

/// Get a reference to the batch seq scan's logical.
#[must_use]
pub fn logical(&self) -> &LogicalScan {
&self.logical
}
}

impl_plan_tree_node_for_leaf! { BatchSeqScan }
Expand All @@ -65,7 +74,7 @@ impl fmt::Display for BatchSeqScan {

impl ToDistributedBatch for BatchSeqScan {
fn to_distributed(&self) -> PlanRef {
self.clone().into()
Self::new_with_dist(self.logical.clone()).into()
}
}

Expand Down
14 changes: 8 additions & 6 deletions rust/frontend/src/optimizer/plan_node/batch_simple_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ pub struct BatchSimpleAgg {
impl BatchSimpleAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
Distribution::any().clone(),
Order::any().clone(),
);
let input = logical.input();
let input_dist = input.distribution();
let dist = match input_dist {
Distribution::Any => Distribution::Any,
Distribution::Single => Distribution::Single,
_ => panic!(),
};
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
BatchSimpleAgg { base, logical }
}
pub fn agg_calls(&self) -> &[PlanAggCall] {
Expand Down
19 changes: 12 additions & 7 deletions rust/frontend/src/optimizer/plan_node/batch_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@ impl_plan_tree_node_for_leaf!(BatchValues);

impl BatchValues {
pub fn new(logical: LogicalValues) -> Self {
Self::with_dist(logical, Distribution::Any)
}

pub fn with_dist(logical: LogicalValues, dist: Distribution) -> Self {
let ctx = logical.base.ctx.clone();
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
Distribution::Broadcast,
Order::any().clone(),
);
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any().clone());
BatchValues { base, logical }
}

/// Get a reference to the batch values's logical.
#[must_use]
pub fn logical(&self) -> &LogicalValues {
&self.logical
}
}

impl fmt::Display for BatchValues {
Expand All @@ -61,7 +66,7 @@ impl WithSchema for BatchValues {

impl ToDistributedBatch for BatchValues {
fn to_distributed(&self) -> PlanRef {
self.clone().into()
Self::with_dist(self.logical().clone(), Distribution::Single).into()
}
}

Expand Down
17 changes: 13 additions & 4 deletions rust/frontend/src/optimizer/plan_node/eq_join_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub struct EqJoinPredicate {
/// the first is from the left table and the second is from the right table.
/// now all are normal equal(not null-safe-equal),
eq_keys: Vec<(InputRef, InputRef)>,

left_cols_num: usize,
}

impl fmt::Display for EqJoinPredicate {
Expand All @@ -48,10 +50,15 @@ impl fmt::Display for EqJoinPredicate {

impl EqJoinPredicate {
/// The new method for `JoinPredicate` without any analysis, check or rewrite.
pub fn new(other_cond: Condition, eq_keys: Vec<(InputRef, InputRef)>) -> Self {
pub fn new(
other_cond: Condition,
eq_keys: Vec<(InputRef, InputRef)>,
left_cols_num: usize,
) -> Self {
Self {
other_cond,
eq_keys,
left_cols_num,
}
}

Expand All @@ -72,7 +79,7 @@ impl EqJoinPredicate {
/// ```
pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self {
let (eq_keys, other_cond) = on_clause.split_eq_keys(left_cols_num, right_cols_num);
Self::new(other_cond, eq_keys)
Self::new(other_cond, eq_keys, left_cols_num)
}

/// Get join predicate's eq conds.
Expand Down Expand Up @@ -121,17 +128,19 @@ impl EqJoinPredicate {
pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
self.eq_keys
.iter()
.map(|(left, right)| (left.index(), right.index()))
.map(|(left, right)| (left.index(), right.index() - self.left_cols_num))
.collect()
}

pub fn left_eq_indexes(&self) -> Vec<usize> {
self.eq_keys.iter().map(|(left, _)| left.index()).collect()
}

/// return the eq keys column index **based on the right input schema**
pub fn right_eq_indexes(&self) -> Vec<usize> {
self.eq_keys
.iter()
.map(|(_, right)| right.index())
.map(|(_, right)| right.index() - self.left_cols_num)
.collect()
}
}
24 changes: 12 additions & 12 deletions rust/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,22 @@ impl LogicalAgg {
}
}

/// get the Mapping of columnIndex from input column index to out column index
pub fn o2i_col_mapping(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping {
let mut map = vec![None; exprs.len()];
for (i, expr) in exprs.iter().enumerate() {
map[i] = match expr {
ExprImpl::InputRef(input) => Some(input.index()),
_ => None,
}
/// get the Mapping of columnIndex from input column index to output column index,if a input
/// column corresponds more than one out columns, mapping to any one
pub fn o2i_col_mapping(&self) -> ColIndexMapping {
let input_len = self.input.schema().len();
let agg_cal_num = self.agg_calls().len();
let group_keys = self.group_keys();
let mut map = vec![None; agg_cal_num + group_keys.len()];
for (i, key) in group_keys.iter().enumerate() {
map[i] = Some(*key);
}
ColIndexMapping::with_target_size(map, input_len)
}

/// get the Mapping of columnIndex from input column index to output column index,if a input
/// column corresponds more than one out columns, mapping to any one
pub fn i2o_col_mapping(input_len: usize, exprs: &[ExprImpl]) -> ColIndexMapping {
Self::o2i_col_mapping(input_len, exprs).inverse()
/// get the Mapping of columnIndex from input column index to out column index
pub fn i2o_col_mapping(&self) -> ColIndexMapping {
self.o2i_col_mapping().inverse()
}

fn derive_schema(
Expand Down
14 changes: 10 additions & 4 deletions rust/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,11 @@ impl ToBatch for LogicalJoin {
// For inner joins, pull non-equal conditions to a filter operator on top of it
let pull_filter = self.join_type == JoinType::Inner && predicate.has_non_eq();
if pull_filter {
let eq_cond =
EqJoinPredicate::new(Condition::true_cond(), predicate.eq_keys().to_vec());
let eq_cond = EqJoinPredicate::new(
Condition::true_cond(),
predicate.eq_keys().to_vec(),
self.left.schema().len(),
);
let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
let hash_join = BatchHashJoin::new(logical_join, eq_cond).into();
let logical_filter = LogicalFilter::new(hash_join, predicate.non_eq_cond());
Expand Down Expand Up @@ -373,8 +376,11 @@ impl ToStream for LogicalJoin {
// For inner joins, pull non-equal conditions to a filter operator on top of it
let pull_filter = self.join_type == JoinType::Inner && predicate.has_non_eq();
if pull_filter {
let eq_cond =
EqJoinPredicate::new(Condition::true_cond(), predicate.eq_keys().to_vec());
let eq_cond = EqJoinPredicate::new(
Condition::true_cond(),
predicate.eq_keys().to_vec(),
self.left.schema().len(),
);
let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
let hash_join = StreamHashJoin::new(logical_join, eq_cond).into();
let logical_filter = LogicalFilter::new(hash_join, predicate.non_eq_cond());
Expand Down
Loading

0 comments on commit b620f15

Please sign in to comment.