Skip to content

Commit

Permalink
[MINOR] Remove duplicate test utility and move one utility function f…
Browse files Browse the repository at this point in the history
…or better organization (#8652)

* Code rearrange

* Update stream_join_utils.rs
  • Loading branch information
metesynnada authored Dec 25, 2023
1 parent 3698693 commit 18c7566
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 153 deletions.
156 changes: 97 additions & 59 deletions datafusion/physical-plan/src/joins/stream_join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,25 @@ use std::usize;

use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder};
use crate::{handle_async_state, handle_state, metrics};
use crate::{handle_async_state, handle_state, metrics, ExecutionPlan};

use arrow::compute::concat_batches;
use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch};
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder};
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{
arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue,
arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result,
ScalarValue,
};
use datafusion_execution::SendableRecordBatchStream;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};

use async_trait::async_trait;
use futures::{ready, FutureExt, StreamExt};
use hashbrown::raw::RawTable;
use hashbrown::HashSet;
Expand Down Expand Up @@ -175,7 +177,7 @@ impl PruningJoinHashMap {
prune_length: usize,
deleting_offset: u64,
shrink_factor: usize,
) -> Result<()> {
) {
// Remove elements from the list based on the pruning length.
self.next.drain(0..prune_length);

Expand All @@ -198,11 +200,10 @@ impl PruningJoinHashMap {

// Shrink the map if necessary.
self.shrink_if_necessary(shrink_factor);
Ok(())
}
}

pub fn check_filter_expr_contains_sort_information(
fn check_filter_expr_contains_sort_information(
expr: &Arc<dyn PhysicalExpr>,
reference: &Arc<dyn PhysicalExpr>,
) -> bool {
Expand All @@ -227,7 +228,7 @@ pub fn map_origin_col_to_filter_col(
side: &JoinSide,
) -> Result<HashMap<Column, Column>> {
let filter_schema = filter.schema();
let mut col_to_col_map: HashMap<Column, Column> = HashMap::new();
let mut col_to_col_map = HashMap::<Column, Column>::new();
for (filter_schema_index, index) in filter.column_indices().iter().enumerate() {
if index.side.eq(side) {
// Get the main field from column index:
Expand Down Expand Up @@ -581,7 +582,7 @@ where
// get the semi index
(0..prune_length)
.filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
.collect::<PrimitiveArray<T>>()
.collect()
}

pub fn combine_two_batches(
Expand Down Expand Up @@ -763,7 +764,6 @@ pub trait EagerJoinStream {
if batch.num_rows() == 0 {
return Ok(StatefulStreamResult::Continue);
}

self.set_state(EagerJoinStreamState::PullLeft);
self.process_batch_from_right(batch)
}
Expand Down Expand Up @@ -1032,6 +1032,91 @@ impl StreamJoinMetrics {
}
}

/// Updates sorted filter expressions with corresponding node indices from the
/// expression interval graph.
///
/// This function iterates through the provided sorted filter expressions,
/// gathers the corresponding node indices from the expression interval graph,
/// and then updates the sorted expressions with these indices. It ensures
/// that these sorted expressions are aligned with the structure of the graph.
fn update_sorted_exprs_with_node_indices(
graph: &mut ExprIntervalGraph,
sorted_exprs: &mut [SortedFilterExpr],
) {
// Extract filter expressions from the sorted expressions:
let filter_exprs = sorted_exprs
.iter()
.map(|expr| expr.filter_expr().clone())
.collect::<Vec<_>>();

// Gather corresponding node indices for the extracted filter expressions from the graph:
let child_node_indices = graph.gather_node_indices(&filter_exprs);

// Iterate through the sorted expressions and the gathered node indices:
for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) {
// Update each sorted expression with the corresponding node index:
sorted_expr.set_node_index(index);
}
}

/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions.
///
/// # Arguments
///
/// * `filter` - The join filter to base the sorting on.
/// * `left` - The left execution plan.
/// * `right` - The right execution plan.
/// * `left_sort_exprs` - The expressions to sort on the left side.
/// * `right_sort_exprs` - The expressions to sort on the right side.
///
/// # Returns
///
/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph.
pub fn prepare_sorted_exprs(
filter: &JoinFilter,
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
left_sort_exprs: &[PhysicalSortExpr],
right_sort_exprs: &[PhysicalSortExpr],
) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
// Build the filter order for the left side
let err = || plan_datafusion_err!("Filter does not include the child order");

let left_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Left,
filter,
&left.schema(),
&left_sort_exprs[0],
)?
.ok_or_else(err)?;

// Build the filter order for the right side
let right_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Right,
filter,
&right.schema(),
&right_sort_exprs[0],
)?
.ok_or_else(err)?;

// Collect the sorted expressions
let mut sorted_exprs =
vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];

// Build the expression interval graph
let mut graph =
ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?;

// Update sorted expressions with node indices
update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);

// Swap and remove to get the final sorted filter expressions
let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
let left_sorted_filter_expr = sorted_exprs.swap_remove(0);

Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
}

#[cfg(test)]
pub mod tests {
use std::sync::Arc;
Expand All @@ -1043,62 +1128,15 @@ pub mod tests {
};
use crate::{
expressions::{Column, PhysicalSortExpr},
joins::test_utils::complicated_filter,
joins::utils::{ColumnIndex, JoinFilter},
};

use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{JoinSide, ScalarValue};
use datafusion_common::JoinSide;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, cast, col, lit};

/// Filter expr for a + b > c + 10 AND a + b < c + 100
pub(crate) fn complicated_filter(
filter_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
let left_expr = binary(
cast(
binary(
col("0", filter_schema)?,
Operator::Plus,
col("1", filter_schema)?,
filter_schema,
)?,
filter_schema,
DataType::Int64,
)?,
Operator::Gt,
binary(
cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
Operator::Plus,
lit(ScalarValue::Int64(Some(10))),
filter_schema,
)?,
filter_schema,
)?;

let right_expr = binary(
cast(
binary(
col("0", filter_schema)?,
Operator::Plus,
col("1", filter_schema)?,
filter_schema,
)?,
filter_schema,
DataType::Int64,
)?,
Operator::Lt,
binary(
cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
Operator::Plus,
lit(ScalarValue::Int64(Some(100))),
filter_schema,
)?,
filter_schema,
)?;
binary(left_expr, Operator::And, right_expr, filter_schema)
}
use datafusion_physical_expr::expressions::{binary, cast, col};

#[test]
fn test_column_exchange() -> Result<()> {
Expand Down
11 changes: 6 additions & 5 deletions datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
use crate::joins::stream_join_utils::{
calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
get_pruning_semi_indices, record_visited_indices, EagerJoinStream,
EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr,
StreamJoinMetrics,
};
use crate::joins::utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter,
JoinOn, StatefulStreamResult,
partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn,
StatefulStreamResult,
};
use crate::{
expressions::{Column, PhysicalSortExpr},
Expand Down Expand Up @@ -936,7 +937,7 @@ impl OneSideHashJoiner {
prune_length,
self.deleted_offset as u64,
HASHMAP_SHRINK_SCALE_FACTOR,
)?;
);
// Remove pruned rows from the visited rows set:
for row in self.deleted_offset..(self.deleted_offset + prune_length) {
self.visited_rows.remove(&row);
Expand Down
90 changes: 1 addition & 89 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::usize;

use crate::joins::stream_join_utils::{build_filter_input_order, SortedFilterExpr};
use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics};

Expand All @@ -39,13 +38,11 @@ use arrow::record_batch::{RecordBatch, RecordBatchOptions};
use datafusion_common::cast::as_boolean_array;
use datafusion_common::stats::Precision;
use datafusion_common::{
plan_datafusion_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
SharedResult,
plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult,
};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::equivalence::add_offset_to_expr;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use datafusion_physical_expr::utils::merge_vectors;
use datafusion_physical_expr::{
LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr,
Expand Down Expand Up @@ -1208,91 +1205,6 @@ impl BuildProbeJoinMetrics {
}
}

/// Updates sorted filter expressions with corresponding node indices from the
/// expression interval graph.
///
/// This function iterates through the provided sorted filter expressions,
/// gathers the corresponding node indices from the expression interval graph,
/// and then updates the sorted expressions with these indices. It ensures
/// that these sorted expressions are aligned with the structure of the graph.
fn update_sorted_exprs_with_node_indices(
graph: &mut ExprIntervalGraph,
sorted_exprs: &mut [SortedFilterExpr],
) {
// Extract filter expressions from the sorted expressions:
let filter_exprs = sorted_exprs
.iter()
.map(|expr| expr.filter_expr().clone())
.collect::<Vec<_>>();

// Gather corresponding node indices for the extracted filter expressions from the graph:
let child_node_indices = graph.gather_node_indices(&filter_exprs);

// Iterate through the sorted expressions and the gathered node indices:
for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) {
// Update each sorted expression with the corresponding node index:
sorted_expr.set_node_index(index);
}
}

/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions.
///
/// # Arguments
///
/// * `filter` - The join filter to base the sorting on.
/// * `left` - The left execution plan.
/// * `right` - The right execution plan.
/// * `left_sort_exprs` - The expressions to sort on the left side.
/// * `right_sort_exprs` - The expressions to sort on the right side.
///
/// # Returns
///
/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph.
pub fn prepare_sorted_exprs(
filter: &JoinFilter,
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
left_sort_exprs: &[PhysicalSortExpr],
right_sort_exprs: &[PhysicalSortExpr],
) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
// Build the filter order for the left side
let err = || plan_datafusion_err!("Filter does not include the child order");

let left_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Left,
filter,
&left.schema(),
&left_sort_exprs[0],
)?
.ok_or_else(err)?;

// Build the filter order for the right side
let right_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Right,
filter,
&right.schema(),
&right_sort_exprs[0],
)?
.ok_or_else(err)?;

// Collect the sorted expressions
let mut sorted_exprs =
vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];

// Build the expression interval graph
let mut graph =
ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?;

// Update sorted expressions with node indices
update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);

// Swap and remove to get the final sorted filter expressions
let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
let left_sorted_filter_expr = sorted_exprs.swap_remove(0);

Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
}

/// The `handle_state` macro is designed to process the result of a state-changing
/// operation, encountered e.g. in implementations of `EagerJoinStream`. It
/// operates on a `StatefulStreamResult` by matching its variants and executing
Expand Down

0 comments on commit 18c7566

Please sign in to comment.