Skip to content

Commit

Permalink
refactor(plan_node): simplify Expand and Filter (risingwavelabs#8932
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ice1000 authored Apr 3, 2023
1 parent 626ff72 commit 30dca71
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 109 deletions.
25 changes: 14 additions & 11 deletions src/frontend/src/optimizer/plan_node/batch_expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,36 @@ use risingwave_pb::batch_plan::expand_node::Subset;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::ExpandNode;

use super::ExprRewritable;
use super::{generic, ExprRewritable};
use crate::optimizer::plan_node::{
LogicalExpand, PlanBase, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch, ToLocalBatch,
PlanBase, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch, ToLocalBatch,
};
use crate::optimizer::property::{Distribution, Order};
use crate::optimizer::PlanRef;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchExpand {
pub base: PlanBase,
logical: LogicalExpand,
logical: generic::Expand<PlanRef>,
}

impl BatchExpand {
pub fn new(logical: LogicalExpand) -> Self {
let ctx = logical.base.ctx.clone();
let dist = match logical.input().distribution() {
pub fn new(logical: generic::Expand<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
let dist = match logical.input.distribution() {
Distribution::Single => Distribution::Single,
Distribution::SomeShard
| Distribution::HashShard(_)
| Distribution::UpstreamHashShard(_, _) => Distribution::SomeShard,
Distribution::Broadcast => unreachable!(),
};
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any());
let base = PlanBase::new_batch(ctx, base.schema, dist, Order::any());
BatchExpand { base, logical }
}

pub fn column_subsets(&self) -> &Vec<Vec<usize>> {
self.logical.column_subsets()
pub fn column_subsets(&self) -> &[Vec<usize>] {
&self.logical.column_subsets
}
}

Expand All @@ -60,11 +61,13 @@ impl fmt::Display for BatchExpand {

impl PlanTreeNodeUnary for BatchExpand {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.logical.clone_with_input(input))
let mut logical = self.logical.clone();
logical.input = input;
Self::new(logical)
}
}

Expand Down
40 changes: 17 additions & 23 deletions src/frontend/src/optimizer/plan_node/batch_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ use risingwave_common::error::Result;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::FilterNode;

use super::{
ExprRewritable, LogicalFilter, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch,
};
use super::{generic, ExprRewritable, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch};
use crate::expr::{Expr, ExprImpl, ExprRewriter};
use crate::optimizer::plan_node::{PlanBase, ToLocalBatch};
use crate::utils::Condition;
Expand All @@ -29,24 +27,25 @@ use crate::utils::Condition;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchFilter {
pub base: PlanBase,
logical: LogicalFilter,
logical: generic::Filter<PlanRef>,
}

impl BatchFilter {
pub fn new(logical: LogicalFilter) -> Self {
let ctx = logical.base.ctx.clone();
pub fn new(logical: generic::Filter<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
// TODO: derive from input
let base = PlanBase::new_batch(
ctx,
logical.schema().clone(),
logical.input().distribution().clone(),
logical.input().order().clone(),
base.schema,
logical.input.distribution().clone(),
logical.input.order().clone(),
);
BatchFilter { base, logical }
}

pub fn predicate(&self) -> &Condition {
self.logical.predicate()
&self.logical.predicate
}
}

Expand All @@ -58,11 +57,13 @@ impl fmt::Display for BatchFilter {

impl PlanTreeNodeUnary for BatchFilter {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.logical.clone_with_input(input))
let mut logical = self.logical.clone();
logical.input = input;
Self::new(logical)
}
}

Expand All @@ -78,9 +79,7 @@ impl ToDistributedBatch for BatchFilter {
impl ToBatchPb for BatchFilter {
fn to_batch_prost_body(&self) -> NodeBody {
NodeBody::Filter(FilterNode {
search_condition: Some(
ExprImpl::from(self.logical.predicate().clone()).to_expr_proto(),
),
search_condition: Some(ExprImpl::from(self.logical.predicate.clone()).to_expr_proto()),
})
}
}
Expand All @@ -98,13 +97,8 @@ impl ExprRewritable for BatchFilter {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_filter()
.unwrap()
.clone(),
)
.into()
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(logical).into()
}
}
15 changes: 13 additions & 2 deletions src/frontend/src/optimizer/plan_node/generic/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;

use itertools::Itertools;
use risingwave_common::catalog::{Field, FieldDisplay, Schema};
use risingwave_common::types::DataType;
Expand Down Expand Up @@ -106,9 +108,18 @@ impl<PlanRef: GenericPlanRef> Expand<PlanRef> {
subset
.iter()
.map(|&i| FieldDisplay(self.input.schema().fields.get(i).unwrap()))
.collect_vec()
.collect()
})
.collect_vec()
.collect()
}

pub(crate) fn fmt_with_name(&self, f: &mut fmt::Formatter<'_>, name: &str) -> fmt::Result {
write!(
f,
"{} {{ column_subsets: {:?} }}",
name,
self.column_subsets_display()
)
}

pub fn i2o_col_mapping(&self) -> ColIndexMapping {
Expand Down
23 changes: 21 additions & 2 deletions src/frontend/src/optimizer/plan_node/generic/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;

use risingwave_common::catalog::Schema;

use super::{GenericPlanNode, GenericPlanRef};
use crate::expr::ExprRewriter;
use crate::optimizer::optimizer_context::OptimizerContextRef;
use crate::optimizer::property::FunctionalDependencySet;
use crate::utils::Condition;
use crate::utils::{Condition, ConditionDisplay};

/// [`Filter`] iterates over its input and returns elements for which `predicate` evaluates to
/// true, filtering out the others.
Expand All @@ -30,7 +32,24 @@ pub struct Filter<PlanRef> {
pub input: PlanRef,
}

impl<PlanRef: GenericPlanRef> Filter<PlanRef> {}
impl<PlanRef: GenericPlanRef> Filter<PlanRef> {
pub(crate) fn fmt_with_name(&self, f: &mut fmt::Formatter<'_>, name: &str) -> fmt::Result {
let input_schema = self.input.schema();
write!(
f,
"{} {{ predicate: {} }}",
name,
ConditionDisplay {
condition: &self.predicate,
input_schema
}
)
}

pub fn new(predicate: Condition, input: PlanRef) -> Self {
Filter { predicate, input }
}
}
impl<PlanRef: GenericPlanRef> GenericPlanNode for Filter<PlanRef> {
fn schema(&self) -> Schema {
self.input.schema().clone()
Expand Down
18 changes: 5 additions & 13 deletions src/frontend/src/optimizer/plan_node/logical_expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::fmt;

use itertools::Itertools;
use risingwave_common::catalog::FieldDisplay;
use risingwave_common::error::Result;

use super::{
Expand Down Expand Up @@ -64,17 +63,8 @@ impl LogicalExpand {
&self.core.column_subsets
}

pub fn column_subsets_display(&self) -> Vec<Vec<FieldDisplay<'_>>> {
self.core.column_subsets_display()
}

pub(super) fn fmt_with_name(&self, f: &mut fmt::Formatter<'_>, name: &str) -> fmt::Result {
write!(
f,
"{} {{ column_subsets: {:?} }}",
name,
self.column_subsets_display()
)
self.core.fmt_with_name(f, name)
}
}

Expand Down Expand Up @@ -153,7 +143,8 @@ impl PredicatePushdown for LogicalExpand {
impl ToBatch for LogicalExpand {
fn to_batch(&self) -> Result<PlanRef> {
let new_input = self.input().to_batch()?;
let new_logical = self.clone_with_input(new_input);
let mut new_logical = self.core.clone();
new_logical.input = new_input;
Ok(BatchExpand::new(new_logical).into())
}
}
Expand All @@ -170,7 +161,8 @@ impl ToStream for LogicalExpand {

fn to_stream(&self, ctx: &mut ToStreamContext) -> Result<PlanRef> {
let new_input = self.input().to_stream(ctx)?;
let new_logical = self.clone_with_input(new_input);
let mut new_logical = self.core.clone();
new_logical.input = new_input;
Ok(StreamExpand::new(new_logical).into())
}
}
Expand Down
25 changes: 6 additions & 19 deletions src/frontend/src/optimizer/plan_node/logical_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::optimizer::plan_node::{
BatchFilter, ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext,
StreamFilter, ToStreamContext,
};
use crate::utils::{ColIndexMapping, Condition, ConditionDisplay};
use crate::utils::{ColIndexMapping, Condition};

/// `LogicalFilter` iterates over its input and returns elements for which `predicate` evaluates to
/// true, filtering out the others.
Expand Down Expand Up @@ -94,20 +94,6 @@ impl LogicalFilter {
pub fn predicate(&self) -> &Condition {
&self.core.predicate
}

pub(super) fn fmt_with_name(&self, f: &mut fmt::Formatter<'_>, name: &str) -> fmt::Result {
let input = self.input();
let input_schema = input.schema();
write!(
f,
"{} {{ predicate: {} }}",
name,
ConditionDisplay {
condition: self.predicate(),
input_schema
}
)
}
}

impl PlanTreeNodeUnary for LogicalFilter {
Expand All @@ -134,7 +120,7 @@ impl_plan_tree_node_for_unary! {LogicalFilter}

impl fmt::Display for LogicalFilter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with_name(f, "LogicalFilter")
self.core.fmt_with_name(f, "LogicalFilter")
}
}

Expand Down Expand Up @@ -210,7 +196,8 @@ impl PredicatePushdown for LogicalFilter {
impl ToBatch for LogicalFilter {
fn to_batch(&self) -> Result<PlanRef> {
let new_input = self.input().to_batch()?;
let new_logical = self.clone_with_input(new_input);
let mut new_logical = self.core.clone();
new_logical.input = new_input;
Ok(BatchFilter::new(new_logical).into())
}
}
Expand Down Expand Up @@ -240,7 +227,8 @@ impl ToStream for LogicalFilter {
"All `now()` exprs were valid, but the condition must have at least one now expr as a lower bound."
);
}
let new_logical = self.clone_with_input(new_input);
let mut new_logical = self.core.clone();
new_logical.input = new_input;
Ok(StreamFilter::new(new_logical).into())
}

Expand All @@ -256,7 +244,6 @@ impl ToStream for LogicalFilter {

#[cfg(test)]
mod tests {

use std::collections::HashSet;

use risingwave_common::catalog::{Field, Schema};
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ impl LogicalJoin {
);
let logical_join = logical_join.clone_with_cond(eq_cond.eq_cond());
let hash_join = StreamHashJoin::new(logical_join.core, eq_cond).into();
let logical_filter = LogicalFilter::new(hash_join, predicate.non_eq_cond());
let logical_filter = generic::Filter::new(predicate.non_eq_cond(), hash_join);
let plan = StreamFilter::new(logical_filter).into();
if self.output_indices() != &default_indices {
let logical_project = LogicalProject::with_mapping(
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/logical_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ impl LogicalScan {

let mut plan: PlanRef = BatchSeqScan::new(scan, scan_ranges).into();
if !predicate.always_true() {
plan = BatchFilter::new(LogicalFilter::new(plan, predicate)).into();
plan = BatchFilter::new(generic::Filter::new(predicate, plan)).into();
}
if let Some(exprs) = project_expr {
plan = BatchProject::new(LogicalProject::new(plan, exprs)).into()
Expand Down
Loading

0 comments on commit 30dca71

Please sign in to comment.