From 58da15970dc0ec9e3c1c369fe89f6ba38e09d9c9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 5 Jul 2021 00:35:47 -0600 Subject: [PATCH 01/12] Implement metrics for shuffle read and write (#676) --- ballista/rust/core/Cargo.toml | 1 + .../src/execution_plans/shuffle_reader.rs | 26 ++++++++++-- .../src/execution_plans/shuffle_writer.rs | 42 ++++++++++++++++--- ballista/rust/core/src/serde/scheduler/mod.rs | 12 +++++- ballista/rust/core/src/utils.rs | 9 +++- ballista/rust/executor/src/executor.rs | 9 ++++ 6 files changed, 87 insertions(+), 12 deletions(-) diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index bedc0973e6ad9..3a89c75a5cd72 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -33,6 +33,7 @@ simd = ["datafusion/simd"] ahash = "0.7" async-trait = "0.1.36" futures = "0.3" +hashbrown = "0.11" log = "0.4" prost = "0.7" serde = {version = "1", features = ["derive"]} diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index 9ab064115acea..db03d3ddf0800 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -28,13 +28,17 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::error::Result as ArrowResult; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; +use datafusion::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, SQLMetric, +}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::RecordBatchStream, }; use futures::{future, Stream, StreamExt}; +use hashbrown::HashMap; use log::info; +use std::time::Instant; /// ShuffleReaderExec reads partitions that have already been materialized by a ShuffleWriterExec /// being executed by an executor @@ -43,6 +47,8 @@ pub struct ShuffleReaderExec { /// Each partition of a shuffle can read data from multiple locations pub(crate) partition: Vec>, pub(crate) schema: SchemaRef, + /// Time to fetch data from executor + fetch_time: Arc, } impl ShuffleReaderExec { @@ -51,7 +57,11 @@ impl ShuffleReaderExec { partition: Vec>, schema: SchemaRef, ) -> Result { - Ok(Self { partition, schema }) + Ok(Self { + partition, + schema, + fetch_time: SQLMetric::time_nanos(), + }) } } @@ -88,11 +98,13 @@ impl ExecutionPlan for ShuffleReaderExec { ) -> Result>> { info!("ShuffleReaderExec::execute({})", partition); + let start = Instant::now(); let partition_locations = &self.partition[partition]; let result = future::join_all(partition_locations.iter().map(fetch_partition)) .await .into_iter() .collect::>>()?; + self.fetch_time.add_elapsed(start); let result = WrappedStream::new( Box::pin(futures::stream::iter(result).flatten()), @@ -115,7 +127,7 @@ impl ExecutionPlan for ShuffleReaderExec { x.iter() .map(|l| { format!( - "[executor={} part={}:{}:{} stats={:?}]", + "[executor={} part={}:{}:{} stats={}]", l.executor_meta.id, l.partition_id.job_id, l.partition_id.stage_id, @@ -127,11 +139,17 @@ impl ExecutionPlan for ShuffleReaderExec { .join(",") }) .collect::>() - .join("\n"); + .join(", "); write!(f, "ShuffleReaderExec: partition_locations={}", loc_str) } } } + + fn metrics(&self) -> HashMap { + let mut metrics = HashMap::new(); + metrics.insert("fetchTime".to_owned(), (*self.fetch_time).clone()); + metrics + } } async fn fetch_partition( diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 7fffaba13217c..92b4448a69ec6 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -20,6 +20,7 @@ //! partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query //! will use the ShuffleReaderExec to read these results. +use std::fs::File; use std::iter::Iterator; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -43,11 +44,11 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::hash_join::create_hashes; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SQLMetric, }; use futures::StreamExt; +use hashbrown::HashMap; use log::info; -use std::fs::File; use uuid::Uuid; /// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and @@ -66,6 +67,22 @@ pub struct ShuffleWriterExec { work_dir: String, /// Optional shuffle output partitioning shuffle_output_partitioning: Option, + /// Shuffle write metrics + metrics: ShuffleWriteMetrics, +} + +#[derive(Debug, Clone)] +struct ShuffleWriteMetrics { + /// Time spend writing batches to shuffle files + write_time: Arc, +} + +impl ShuffleWriteMetrics { + fn new() -> Self { + Self { + write_time: SQLMetric::time_nanos(), + } + } } impl ShuffleWriterExec { @@ -83,6 +100,7 @@ impl ShuffleWriterExec { plan, work_dir, shuffle_output_partitioning, + metrics: ShuffleWriteMetrics::new(), }) } @@ -150,12 +168,16 @@ impl ExecutionPlan for ShuffleWriterExec { info!("Writing results to {}", path); // stream results to disk - let stats = utils::write_stream_to_disk(&mut stream, path) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + let stats = utils::write_stream_to_disk( + &mut stream, + path, + self.metrics.write_time.clone(), + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; info!( - "Executed partition {} in {} seconds. Statistics: {:?}", + "Executed partition {} in {} seconds. Statistics: {}", partition, now.elapsed().as_secs(), stats @@ -231,6 +253,7 @@ impl ExecutionPlan for ShuffleWriterExec { RecordBatch::try_new(input_batch.schema(), columns)?; // write batch out + let start = Instant::now(); match &mut writers[num_output_partition] { Some(w) => { w.write(&output_batch)?; @@ -251,6 +274,7 @@ impl ExecutionPlan for ShuffleWriterExec { writers[num_output_partition] = Some(writer); } } + self.metrics.write_time.add_elapsed(start); } } @@ -310,6 +334,12 @@ impl ExecutionPlan for ShuffleWriterExec { } } + fn metrics(&self) -> HashMap { + let mut metrics = HashMap::new(); + metrics.insert("writeTime".to_owned(), (*self.metrics.write_time).clone()); + metrics + } + fn fmt_as( &self, t: DisplayFormatType, diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs index f66bb08189d28..cbe1a31227c68 100644 --- a/ballista/rust/core/src/serde/scheduler/mod.rs +++ b/ballista/rust/core/src/serde/scheduler/mod.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, fmt, sync::Arc}; use datafusion::arrow::array::{ ArrayBuilder, ArrayRef, StructArray, StructBuilder, UInt64Array, UInt64Builder, @@ -113,6 +113,16 @@ impl Default for PartitionStats { } } +impl fmt::Display for PartitionStats { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "numBatches={:?}, numRows={:?}, numBytes={:?}", + self.num_batches, self.num_rows, self.num_bytes + ) + } +} + impl PartitionStats { pub fn new( num_rows: Option, diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 8a510f4808760..f7d884d502985 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -53,15 +53,17 @@ use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sort::SortExec; use datafusion::physical_plan::{ - AggregateExpr, ExecutionPlan, PhysicalExpr, RecordBatchStream, + AggregateExpr, ExecutionPlan, PhysicalExpr, RecordBatchStream, SQLMetric, }; use futures::{future, Stream, StreamExt}; +use std::time::Instant; /// Stream data to disk in Arrow IPC format pub async fn write_stream_to_disk( stream: &mut Pin>, path: &str, + disk_write_metric: Arc, ) -> Result { let file = File::create(&path).map_err(|e| { BallistaError::General(format!( @@ -86,9 +88,14 @@ pub async fn write_stream_to_disk( num_batches += 1; num_rows += batch.num_rows(); num_bytes += batch_size_bytes; + + let start = Instant::now(); writer.write(&batch)?; + disk_write_metric.add_elapsed(start); } + let start = Instant::now(); writer.finish()?; + disk_write_metric.add_elapsed(start); Ok(PartitionStats::new( Some(num_rows as u64), Some(num_batches), diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs index 86aaa7e9f4956..4a75448b5f06b 100644 --- a/ballista/rust/executor/src/executor.rs +++ b/ballista/rust/executor/src/executor.rs @@ -23,6 +23,7 @@ use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::utils; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::ExecutionPlan; /// Ballista executor @@ -60,6 +61,14 @@ impl Executor { )?; let mut stream = exec.execute(part).await?; let batches = utils::collect_stream(&mut stream).await?; + + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(&exec) + .indent() + .to_string() + ); + // the output should be a single batch containing metadata (path and statistics) assert!(batches.len() == 1); Ok(batches[0].clone()) From 8cbb750faab3189813e95681bc2af53f20c9f0c7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jul 2021 12:41:28 -0400 Subject: [PATCH 02/12] Add End-to-end test for parquet pruning + metrics for ParquetExec (#657) * End to end tests for parquet pruning * remove unused dep * Make the separation of per-partition and per-exec metrics clearer * Account for statistics once rather than per row group * Fix timestamps to use UTC time --- .../src/physical_optimizer/repartition.rs | 22 +- datafusion/src/physical_plan/mod.rs | 14 + datafusion/src/physical_plan/parquet.rs | 156 ++++++-- datafusion/src/test/mod.rs | 10 +- datafusion/tests/parquet_pruning.rs | 343 ++++++++++++++++++ 5 files changed, 508 insertions(+), 37 deletions(-) create mode 100644 datafusion/tests/parquet_pruning.rs diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index 011db64aaf8a2..4504c81daa06d 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -110,7 +110,9 @@ mod tests { use super::*; use crate::datasource::datasource::Statistics; - use crate::physical_plan::parquet::{ParquetExec, ParquetPartition}; + use crate::physical_plan::parquet::{ + ParquetExec, ParquetExecMetrics, ParquetPartition, + }; use crate::physical_plan::projection::ProjectionExec; #[test] @@ -119,12 +121,13 @@ mod tests { let parquet_project = ProjectionExec::try_new( vec![], Arc::new(ParquetExec::new( - vec![ParquetPartition { - filenames: vec!["x".to_string()], - statistics: Statistics::default(), - }], + vec![ParquetPartition::new( + vec!["x".to_string()], + Statistics::default(), + )], schema, None, + ParquetExecMetrics::new(), None, 2048, None, @@ -156,12 +159,13 @@ mod tests { Arc::new(ProjectionExec::try_new( vec![], Arc::new(ParquetExec::new( - vec![ParquetPartition { - filenames: vec!["x".to_string()], - statistics: Statistics::default(), - }], + vec![ParquetPartition::new( + vec!["x".to_string()], + Statistics::default(), + )], schema, None, + ParquetExecMetrics::new(), None, 2048, None, diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index a940cbe7963a6..d89eb11885041 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -297,6 +297,20 @@ pub fn visit_execution_plan( Ok(()) } +/// Recursively gateher all execution metrics from this plan and all of its input plans +pub fn plan_metrics(plan: Arc) -> HashMap { + fn get_metrics_inner( + plan: &dyn ExecutionPlan, + mut metrics: HashMap, + ) -> HashMap { + metrics.extend(plan.metrics().into_iter()); + plan.children().into_iter().fold(metrics, |metrics, child| { + get_metrics_inner(child.as_ref(), metrics) + }) + } + get_metrics_inner(plan.as_ref(), HashMap::new()) +} + /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect(plan: Arc) -> Result> { match plan.output_partitioning().partition_count() { diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs index 3d20a9bf98c19..f31b921d663b0 100644 --- a/datafusion/src/physical_plan/parquet.rs +++ b/datafusion/src/physical_plan/parquet.rs @@ -40,6 +40,8 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use hashbrown::HashMap; +use log::debug; use parquet::file::{ metadata::RowGroupMetaData, reader::{FileReader, SerializedFileReader}, @@ -59,6 +61,8 @@ use crate::datasource::datasource::{ColumnStatistics, Statistics}; use async_trait::async_trait; use futures::stream::{Stream, StreamExt}; +use super::SQLMetric; + /// Execution plan for scanning one or more Parquet partitions #[derive(Debug, Clone)] pub struct ParquetExec { @@ -72,6 +76,8 @@ pub struct ParquetExec { batch_size: usize, /// Statistics for the data set (sum of statistics for all partitions) statistics: Statistics, + /// metrics for the overall execution + metrics: ParquetExecMetrics, /// Optional predicate builder predicate_builder: Option, /// Optional limit of the number of rows @@ -93,6 +99,24 @@ pub struct ParquetPartition { pub filenames: Vec, /// Statistics for this partition pub statistics: Statistics, + /// Execution metrics + metrics: ParquetPartitionMetrics, +} + +/// Stores metrics about the overall parquet execution +#[derive(Debug, Clone)] +pub struct ParquetExecMetrics { + /// Numer of times the pruning predicate could not be created + pub predicate_creation_errors: Arc, +} + +/// Stores metrics about the parquet execution for a particular ParquetPartition +#[derive(Debug, Clone)] +struct ParquetPartitionMetrics { + /// Numer of times the predicate could not be evaluated + pub predicate_evaluation_errors: Arc, + /// Number of row groups pruned using + pub row_groups_pruned: Arc, } impl ParquetExec { @@ -140,6 +164,8 @@ impl ParquetExec { max_concurrency: usize, limit: Option, ) -> Result { + debug!("Creating ParquetExec, filenames: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", + filenames, projection, predicate, limit); // build a list of Parquet partitions with statistics and gather all unique schemas // used in this data set let mut schemas: Vec = vec![]; @@ -205,10 +231,7 @@ impl ParquetExec { }; // remove files that are not needed in case of limit filenames.truncate(total_files); - partitions.push(ParquetPartition { - filenames, - statistics, - }); + partitions.push(ParquetPartition::new(filenames, statistics)); if limit_exhausted { break; } @@ -225,14 +248,27 @@ impl ParquetExec { ))); } let schema = Arc::new(schemas.pop().unwrap()); + let metrics = ParquetExecMetrics::new(); + let predicate_builder = predicate.and_then(|predicate_expr| { - PruningPredicate::try_new(&predicate_expr, schema.clone()).ok() + match PruningPredicate::try_new(&predicate_expr, schema.clone()) { + Ok(predicate_builder) => Some(predicate_builder), + Err(e) => { + debug!( + "Could not create pruning predicate for {:?}: {}", + predicate_expr, e + ); + metrics.predicate_creation_errors.add(1); + None + } + } }); Ok(Self::new( partitions, schema, projection, + metrics, predicate_builder, batch_size, limit, @@ -244,6 +280,7 @@ impl ParquetExec { partitions: Vec, schema: SchemaRef, projection: Option>, + metrics: ParquetExecMetrics, predicate_builder: Option, batch_size: usize, limit: Option, @@ -307,6 +344,7 @@ impl ParquetExec { partitions, schema: Arc::new(projected_schema), projection, + metrics, predicate_builder, batch_size, statistics, @@ -341,6 +379,7 @@ impl ParquetPartition { Self { filenames, statistics, + metrics: ParquetPartitionMetrics::new(), } } @@ -355,6 +394,25 @@ impl ParquetPartition { } } +impl ParquetExecMetrics { + /// Create new metrics + pub fn new() -> Self { + Self { + predicate_creation_errors: SQLMetric::counter(), + } + } +} + +impl ParquetPartitionMetrics { + /// Create new metrics + pub fn new() -> Self { + Self { + predicate_evaluation_errors: SQLMetric::counter(), + row_groups_pruned: SQLMetric::counter(), + } + } +} + #[async_trait] impl ExecutionPlan for ParquetExec { /// Return a reference to Any that can be used for downcasting @@ -398,7 +456,9 @@ impl ExecutionPlan for ParquetExec { Receiver>, ) = channel(2); - let filenames = self.partitions[partition].filenames.clone(); + let partition = &self.partitions[partition]; + let filenames = partition.filenames.clone(); + let metrics = partition.metrics.clone(); let projection = self.projection.clone(); let predicate_builder = self.predicate_builder.clone(); let batch_size = self.batch_size; @@ -407,6 +467,7 @@ impl ExecutionPlan for ParquetExec { task::spawn_blocking(move || { if let Err(e) = read_files( &filenames, + metrics, &projection, &predicate_builder, batch_size, @@ -448,6 +509,31 @@ impl ExecutionPlan for ParquetExec { } } } + + fn metrics(&self) -> HashMap { + self.partitions + .iter() + .flat_map(|p| { + [ + ( + format!( + "numPredicateEvaluationErrors for {}", + p.filenames.join(",") + ), + p.metrics.predicate_evaluation_errors.as_ref().clone(), + ), + ( + format!("numRowGroupsPruned for {}", p.filenames.join(",")), + p.metrics.row_groups_pruned.as_ref().clone(), + ), + ] + }) + .chain(std::iter::once(( + "numPredicateCreationErrors".to_string(), + self.metrics.predicate_creation_errors.as_ref().clone(), + ))) + .collect() + } } fn send_result( @@ -547,6 +633,7 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn build_row_group_predicate( predicate_builder: &PruningPredicate, + metrics: ParquetPartitionMetrics, row_group_metadata: &[RowGroupMetaData], ) -> Box bool> { let parquet_schema = predicate_builder.schema().as_ref(); @@ -555,21 +642,28 @@ fn build_row_group_predicate( row_group_metadata, parquet_schema, }; - let predicate_values = predicate_builder.prune(&pruning_stats); - let predicate_values = match predicate_values { - Ok(values) => values, + match predicate_values { + Ok(values) => { + // NB: false means don't scan row group + let num_pruned = values.iter().filter(|&v| !v).count(); + metrics.row_groups_pruned.add(num_pruned); + Box::new(move |_, i| values[i]) + } // stats filter array could not be built // return a closure which will not filter out any row groups - _ => return Box::new(|_r, _i| true), - }; - - Box::new(move |_, i| predicate_values[i]) + Err(e) => { + debug!("Error evaluating row group predicate values {}", e); + metrics.predicate_evaluation_errors.add(1); + Box::new(|_r, _i| true) + } + } } fn read_files( filenames: &[String], + metrics: ParquetPartitionMetrics, projection: &[usize], predicate_builder: &Option, batch_size: usize, @@ -583,6 +677,7 @@ fn read_files( if let Some(predicate_builder) = predicate_builder { let row_group_predicate = build_row_group_predicate( predicate_builder, + metrics.clone(), file_reader.metadata().row_groups(), ); file_reader.filter_row_groups(&row_group_predicate); @@ -757,8 +852,11 @@ mod tests { vec![ParquetStatistics::int32(Some(11), Some(20), None, 0, false)], ); let row_group_metadata = vec![rgm1, rgm2]; - let row_group_predicate = - build_row_group_predicate(&predicate_builder, &row_group_metadata); + let row_group_predicate = build_row_group_predicate( + &predicate_builder, + ParquetPartitionMetrics::new(), + &row_group_metadata, + ); let row_group_filter = row_group_metadata .iter() .enumerate() @@ -787,8 +885,11 @@ mod tests { vec![ParquetStatistics::int32(Some(11), Some(20), None, 0, false)], ); let row_group_metadata = vec![rgm1, rgm2]; - let row_group_predicate = - build_row_group_predicate(&predicate_builder, &row_group_metadata); + let row_group_predicate = build_row_group_predicate( + &predicate_builder, + ParquetPartitionMetrics::new(), + &row_group_metadata, + ); let row_group_filter = row_group_metadata .iter() .enumerate() @@ -832,8 +933,11 @@ mod tests { ], ); let row_group_metadata = vec![rgm1, rgm2]; - let row_group_predicate = - build_row_group_predicate(&predicate_builder, &row_group_metadata); + let row_group_predicate = build_row_group_predicate( + &predicate_builder, + ParquetPartitionMetrics::new(), + &row_group_metadata, + ); let row_group_filter = row_group_metadata .iter() .enumerate() @@ -847,8 +951,11 @@ mod tests { // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").modulus(lit(2))); let predicate_builder = PruningPredicate::try_new(&expr, schema)?; - let row_group_predicate = - build_row_group_predicate(&predicate_builder, &row_group_metadata); + let row_group_predicate = build_row_group_predicate( + &predicate_builder, + ParquetPartitionMetrics::new(), + &row_group_metadata, + ); let row_group_filter = row_group_metadata .iter() .enumerate() @@ -891,8 +998,11 @@ mod tests { ], ); let row_group_metadata = vec![rgm1, rgm2]; - let row_group_predicate = - build_row_group_predicate(&predicate_builder, &row_group_metadata); + let row_group_predicate = build_row_group_predicate( + &predicate_builder, + ParquetPartitionMetrics::new(), + &row_group_metadata, + ); let row_group_filter = row_group_metadata .iter() .enumerate() diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index 7ca7cc12d9efb..df3aec4a68502 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -251,11 +251,11 @@ pub fn make_timestamps() -> RecordBatch { let arr_names = StringArray::from(names); let schema = Schema::new(vec![ - Field::new("nanos", arr_nanos.data_type().clone(), false), - Field::new("micros", arr_micros.data_type().clone(), false), - Field::new("millis", arr_millis.data_type().clone(), false), - Field::new("secs", arr_secs.data_type().clone(), false), - Field::new("name", arr_names.data_type().clone(), false), + Field::new("nanos", arr_nanos.data_type().clone(), true), + Field::new("micros", arr_micros.data_type().clone(), true), + Field::new("millis", arr_millis.data_type().clone(), true), + Field::new("secs", arr_secs.data_type().clone(), true), + Field::new("name", arr_names.data_type().clone(), true), ]); let schema = Arc::new(schema); diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs new file mode 100644 index 0000000000000..86b3946e47121 --- /dev/null +++ b/datafusion/tests/parquet_pruning.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This file contains an end to end test of parquet pruning. It writes +// data into a parquet file and then +use std::sync::Arc; + +use arrow::{ + array::{ + Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, + }, + datatypes::{Field, Schema}, + record_batch::RecordBatch, + util::pretty::pretty_format_batches, +}; +use chrono::Duration; +use datafusion::{ + physical_plan::{plan_metrics, SQLMetric}, + prelude::ExecutionContext, +}; +use hashbrown::HashMap; +use parquet::{arrow::ArrowWriter, file::properties::WriterProperties}; +use tempfile::NamedTempFile; + +#[tokio::test] +async fn prune_timestamps_nanos() { + let output = ContextWithParquet::new() + .await + .query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") + .await; + println!("{}", output.description()); + // TODO This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 10, "{}", output.description()); +} + +#[tokio::test] +async fn prune_timestamps_micros() { + let output = ContextWithParquet::new() + .await + .query( + "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", + ) + .await; + println!("{}", output.description()); + // TODO This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 10, "{}", output.description()); +} + +#[tokio::test] +async fn prune_timestamps_millis() { + let output = ContextWithParquet::new() + .await + .query( + "SELECT * FROM t where millis < to_timestamp_millis('2020-01-02 01:01:11Z')", + ) + .await; + println!("{}", output.description()); + // TODO This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 10, "{}", output.description()); +} + +#[tokio::test] +async fn prune_timestamps_seconds() { + let output = ContextWithParquet::new() + .await + .query( + "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", + ) + .await; + println!("{}", output.description()); + // TODO This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(1)); + assert_eq!(output.row_groups_pruned(), Some(0)); + assert_eq!(output.result_rows, 10, "{}", output.description()); +} + +// ---------------------- +// Begin test fixture +// ---------------------- + +/// Test fixture that has an execution context that has an external +/// table "t" registered, pointing at a parquet file made with +/// `make_test_file` +struct ContextWithParquet { + file: NamedTempFile, + ctx: ExecutionContext, +} + +/// The output of running one of the test cases +struct TestOutput { + /// The input string + sql: String, + /// Normalized metrics (filename replaced by a constant) + metrics: HashMap, + /// number of rows in results + result_rows: usize, + /// the contents of the input, as a string + pretty_input: String, + /// the raw results, as a string + pretty_results: String, +} + +impl TestOutput { + /// retrieve the value of the named metric, if any + fn metric_value(&self, metric_name: &str) -> Option { + self.metrics.get(metric_name).map(|m| m.value()) + } + + /// The number of times the pruning predicate evaluation errors + fn predicate_evaluation_errors(&self) -> Option { + self.metric_value("numPredicateEvaluationErrors for PARQUET_FILE") + } + + /// The number of times the pruning predicate evaluation errors + fn row_groups_pruned(&self) -> Option { + self.metric_value("numRowGroupsPruned for PARQUET_FILE") + } + + fn description(&self) -> String { + let metrics = self + .metrics + .iter() + .map(|(name, val)| format!(" {} = {:?}", name, val)) + .collect::>(); + + format!( + "Input:\n{}\nQuery:\n{}\nOutput:\n{}\nMetrics:\n{}", + self.pretty_input, + self.sql, + self.pretty_results, + metrics.join("\n") + ) + } +} + +/// Creates an execution context that has an external table "t" +/// registered pointing at a parquet file made with `make_test_file` +impl ContextWithParquet { + async fn new() -> Self { + let file = make_test_file().await; + + // now, setup a the file as a data source and run a query against it + let mut ctx = ExecutionContext::new(); + let parquet_path = file.path().to_string_lossy(); + ctx.register_parquet("t", &parquet_path) + .expect("registering"); + + Self { file, ctx } + } + + /// Runs the specified SQL query and returns the number of output + /// rows and normalized execution metrics + async fn query(&mut self, sql: &str) -> TestOutput { + println!("Planning sql {}", sql); + + let input = self + .ctx + .sql("SELECT * from t") + .expect("planning") + .collect() + .await + .expect("getting input"); + let pretty_input = pretty_format_batches(&input).unwrap(); + + let logical_plan = self.ctx.sql(sql).expect("planning").to_logical_plan(); + + let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing plan"); + let execution_plan = self + .ctx + .create_physical_plan(&logical_plan) + .expect("creating physical plan"); + + let results = datafusion::physical_plan::collect(execution_plan.clone()) + .await + .expect("Running"); + + // replace the path name, which varies test to test,a with some + // constant for test comparisons + let path = self.file.path(); + let path_name = path.to_string_lossy(); + let metrics = plan_metrics(execution_plan) + .into_iter() + .map(|(name, metric)| { + (name.replace(path_name.as_ref(), "PARQUET_FILE"), metric) + }) + .collect(); + + let result_rows = results.iter().map(|b| b.num_rows()).sum(); + + let pretty_results = pretty_format_batches(&results).unwrap(); + + let sql = sql.to_string(); + TestOutput { + sql, + metrics, + result_rows, + pretty_input, + pretty_results, + } + } +} + +/// Create a test parquet file with varioud data types +async fn make_test_file() -> NamedTempFile { + let output_file = tempfile::Builder::new() + .prefix("parquet_pruning") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); + + let props = WriterProperties::builder() + .set_max_row_group_size(5) + .build(); + + let batches = vec![ + make_batch(Duration::seconds(0)), + make_batch(Duration::seconds(10)), + make_batch(Duration::minutes(10)), + make_batch(Duration::days(10)), + ]; + let schema = batches[0].schema(); + + let mut writer = ArrowWriter::try_new( + output_file + .as_file() + .try_clone() + .expect("cloning file descriptor"), + schema, + Some(props), + ) + .unwrap(); + + for batch in batches { + writer.write(&batch).expect("writing batch"); + } + writer.close().unwrap(); + + output_file +} + +/// Return record batch with a few rows of data for all of the supported timestamp types +/// values with the specified offset +/// +/// Columns are named: +/// "nanos" --> TimestampNanosecondArray +/// "micros" --> TimestampMicrosecondArray +/// "millis" --> TimestampMillisecondArray +/// "seconds" --> TimestampSecondArray +/// "names" --> StringArray +pub fn make_batch(offset: Duration) -> RecordBatch { + let ts_strings = vec![ + Some("2020-01-01T01:01:01.0000000000001"), + Some("2020-01-01T01:02:01.0000000000001"), + Some("2020-01-01T02:01:01.0000000000001"), + None, + Some("2020-01-02T01:01:01.0000000000001"), + ]; + + let offset_nanos = offset.num_nanoseconds().expect("non overflow nanos"); + + let ts_nanos = ts_strings + .into_iter() + .map(|t| { + t.map(|t| { + offset_nanos + + t.parse::() + .unwrap() + .timestamp_nanos() + }) + }) + .collect::>(); + + let ts_micros = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000)) + .collect::>(); + + let ts_millis = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000)) + .collect::>(); + + let ts_seconds = ts_nanos + .iter() + .map(|t| t.as_ref().map(|ts_nanos| ts_nanos / 1000000000)) + .collect::>(); + + let names = ts_nanos + .iter() + .enumerate() + .map(|(i, _)| format!("Row {} + {}", i, offset)) + .collect::>(); + + let arr_nanos = TimestampNanosecondArray::from_opt_vec(ts_nanos, None); + let arr_micros = TimestampMicrosecondArray::from_opt_vec(ts_micros, None); + let arr_millis = TimestampMillisecondArray::from_opt_vec(ts_millis, None); + let arr_seconds = TimestampSecondArray::from_opt_vec(ts_seconds, None); + + let names = names.iter().map(|s| s.as_str()).collect::>(); + let arr_names = StringArray::from(names); + + let schema = Schema::new(vec![ + Field::new("nanos", arr_nanos.data_type().clone(), true), + Field::new("micros", arr_micros.data_type().clone(), true), + Field::new("millis", arr_millis.data_type().clone(), true), + Field::new("seconds", arr_seconds.data_type().clone(), true), + Field::new("name", arr_names.data_type().clone(), true), + ]); + let schema = Arc::new(schema); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(arr_nanos), + Arc::new(arr_micros), + Arc::new(arr_millis), + Arc::new(arr_seconds), + Arc::new(arr_names), + ], + ) + .unwrap() +} From fdf41ad509ddc63e9bbd422768eff9810dea4da4 Mon Sep 17 00:00:00 2001 From: rdettai Date: Tue, 6 Jul 2021 19:45:20 +0200 Subject: [PATCH 03/12] [fix] benchmark run with compose (#666) --- benchmarks/README.md | 2 +- benchmarks/docker-compose.yaml | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 0b5ccfc16e466..a63761b6c2b3d 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -140,7 +140,7 @@ docker-compose up Then you can run the benchmark with: ```bash -docker-compose run ballista-client cargo run benchmark ballista --host ballista-scheduler --port 50050 --query 1 --path /data --format tbl +docker-compose run ballista-client bash -c '/tpch benchmark ballista --host ballista-scheduler --port 50050 --query 1 --path /data --format tbl' ``` ## Expected output diff --git a/benchmarks/docker-compose.yaml b/benchmarks/docker-compose.yaml index 74c6703f30b1c..e025ea360e76c 100644 --- a/benchmarks/docker-compose.yaml +++ b/benchmarks/docker-compose.yaml @@ -41,12 +41,10 @@ services: ballista-client: image: ballista:0.5.0-SNAPSHOT command: "/bin/sh" # do nothing - working_dir: /ballista/benchmarks/tpch environment: - RUST_LOG=info volumes: - ./data:/data - - ../..:/ballista depends_on: - ballista-scheduler - ballista-executor From 0368f59016b943448124f72d1f70b4108c45860e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 7 Jul 2021 13:53:29 +0200 Subject: [PATCH 04/12] Allow non-equijoin filters in join condition (#660) * Allow non-equijoin filters in join condition * Revert change to query * Fix, only do for inner join * Add test * docs update * Update test name Co-authored-by: Andrew Lamb * Add negative test Co-authored-by: Andrew Lamb --- datafusion/src/sql/planner.rs | 83 +++++++++++++++++++++++++---------- datafusion/tests/sql.rs | 22 ++++++++++ 2 files changed, 81 insertions(+), 24 deletions(-) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 213ae890d7d09..e34f0e6c9b674 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -368,15 +368,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // parse ON expression let expr = self.sql_to_rex(sql_expr, &join_schema)?; + // expression that didn't match equi-join pattern + let mut filter = vec![]; + // extract join keys - extract_join_keys(&expr, &mut keys)?; + extract_join_keys(&expr, &mut keys, &mut filter); let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); // return the logical plan representing the join - LogicalPlanBuilder::from(left) - .join(right, join_type, left_keys, right_keys)? + let join = LogicalPlanBuilder::from(left) + .join(right, join_type, left_keys, right_keys)?; + + if filter.is_empty() { + join.build() + } else if join_type == JoinType::Inner { + join.filter( + filter + .iter() + .skip(1) + .fold(filter[0].clone(), |acc, e| acc.and(e.clone())), + )? .build() + } else { + Err(DataFusionError::NotImplemented(format!( + "Unsupported expressions in {:?} JOIN: {:?}", + join_type, filter + ))) + } } JoinConstraint::Using(idents) => { let keys: Vec = idents @@ -1550,39 +1569,41 @@ fn remove_join_expressions( } } -/// Parse equijoin ON condition which could be a single Eq or multiple conjunctive Eqs +/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs +/// Filters matching this pattern are added to `accum` +/// Filters that don't match this pattern are added to `accum_filter` +/// Examples: /// -/// Examples +/// foo = bar => accum=[(foo, bar)] accum_filter=[] +/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] +/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] /// -/// foo = bar -/// foo = bar AND bar = baz AND ... -/// -fn extract_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) -> Result<()> { +fn extract_join_keys( + expr: &Expr, + accum: &mut Vec<(Column, Column)>, + accum_filter: &mut Vec, +) { match expr { Expr::BinaryExpr { left, op, right } => match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => { accum.push((l.clone(), r.clone())); - Ok(()) } - other => Err(DataFusionError::SQL(ParserError(format!( - "Unsupported expression '{:?}' in JOIN condition", - other - )))), + _other => { + accum_filter.push(expr.clone()); + } }, Operator::And => { - extract_join_keys(left, accum)?; - extract_join_keys(right, accum) + extract_join_keys(left, accum, accum_filter); + extract_join_keys(right, accum, accum_filter); + } + _other => { + accum_filter.push(expr.clone()); } - other => Err(DataFusionError::SQL(ParserError(format!( - "Unsupported expression '{:?}' in JOIN condition", - other - )))), }, - other => Err(DataFusionError::SQL(ParserError(format!( - "Unsupported expression '{:?}' in JOIN condition", - other - )))), + _other => { + accum_filter.push(expr.clone()); + } } } @@ -2702,6 +2723,20 @@ mod tests { quick_test(sql, expected); } + #[test] + fn equijoin_unsupported_expression() { + let sql = "SELECT id, order_id \ + FROM person \ + JOIN orders \ + ON id = customer_id AND order_id > 1 "; + let expected = "Projection: #person.id, #orders.order_id\ + \n Filter: #orders.order_id Gt Int64(1)\ + \n Join: #person.id = #orders.customer_id\ + \n TableScan: person projection=None\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn join_with_table_name() { let sql = "SELECT id, order_id \ diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index bd73cb15610a7..f6f8b6f041e6e 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1687,6 +1687,28 @@ async fn equijoin() -> Result<()> { Ok(()) } +#[tokio::test] +async fn equijoin_and_other_condition() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["11", "a", "z"], vec!["22", "b", "y"]]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn equijoin_and_unsupported_condition() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id")?; + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; + let res = ctx.create_logical_plan(sql); + assert!(res.is_err()); + assert_eq!(format!("{}", res.unwrap_err()), "This feature is not implemented: Unsupported expressions in Left JOIN: [#t2.t2_name GtEq Utf8(\"y\")]"); + Ok(()) +} + #[tokio::test] async fn left_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; From 36647662c69b2635cce300b03e5462b39bacd2a4 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 7 Jul 2021 14:02:12 +0200 Subject: [PATCH 05/12] use `Weak` ptr to break catalog list <> info schema cyclic reference (#681) Fixes #680. --- datafusion/src/catalog/information_schema.rs | 16 ++++++----- datafusion/src/execution/context.rs | 28 ++++++++++++++++++-- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/datafusion/src/catalog/information_schema.rs b/datafusion/src/catalog/information_schema.rs index fd7fcb4b901a6..cd1e612245ec0 100644 --- a/datafusion/src/catalog/information_schema.rs +++ b/datafusion/src/catalog/information_schema.rs @@ -19,7 +19,10 @@ //! //! Information Schema](https://en.wikipedia.org/wiki/Information_schema) -use std::{any, sync::Arc}; +use std::{ + any, + sync::{Arc, Weak}, +}; use arrow::{ array::{StringBuilder, UInt64Builder}, @@ -41,14 +44,14 @@ const COLUMNS: &str = "columns"; /// Wraps another [`CatalogProvider`] and adds a "information_schema" /// schema that can introspect on tables in the catalog_list pub(crate) struct CatalogWithInformationSchema { - catalog_list: Arc, + catalog_list: Weak, /// wrapped provider inner: Arc, } impl CatalogWithInformationSchema { pub(crate) fn new( - catalog_list: Arc, + catalog_list: Weak, inner: Arc, ) -> Self { Self { @@ -73,9 +76,10 @@ impl CatalogProvider for CatalogWithInformationSchema { fn schema(&self, name: &str) -> Option> { if name.eq_ignore_ascii_case(INFORMATION_SCHEMA) { - Some(Arc::new(InformationSchemaProvider { - catalog_list: self.catalog_list.clone(), - })) + Weak::upgrade(&self.catalog_list).map(|catalog_list| { + Arc::new(InformationSchemaProvider { catalog_list }) + as Arc + }) } else { self.inner.schema(name) } diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index d5a84869ad94a..6a26e0401bb87 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -144,7 +144,7 @@ impl ExecutionContext { let default_catalog: Arc = if config.information_schema { Arc::new(CatalogWithInformationSchema::new( - catalog_list.clone(), + Arc::downgrade(&catalog_list), Arc::new(default_catalog), )) } else { @@ -346,7 +346,7 @@ impl ExecutionContext { let state = self.state.lock().unwrap(); let catalog = if state.config.information_schema { Arc::new(CatalogWithInformationSchema::new( - state.catalog_list.clone(), + Arc::downgrade(&state.catalog_list), catalog, )) } else { @@ -924,6 +924,7 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use std::fs::File; + use std::sync::Weak; use std::thread::{self, JoinHandle}; use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; @@ -3364,6 +3365,29 @@ mod tests { assert_batches_sorted_eq!(expected, &result); } + #[tokio::test] + async fn catalogs_not_leaked() { + // the information schema used to introduce cyclic Arcs + let ctx = ExecutionContext::with_config( + ExecutionConfig::new().with_information_schema(true), + ); + + // register a single catalog + let catalog = Arc::new(MemoryCatalogProvider::new()); + let catalog_weak = Arc::downgrade(&catalog); + ctx.register_catalog("my_catalog", catalog); + + let catalog_list_weak = { + let state = ctx.state.lock().unwrap(); + Arc::downgrade(&state.catalog_list) + }; + + drop(ctx); + + assert_eq!(Weak::strong_count(&catalog_list_weak), 0); + assert_eq!(Weak::strong_count(&catalog_weak), 0); + } + struct MyPhysicalPlanner {} impl PhysicalPlanner for MyPhysicalPlanner { From 18c581c4dbfbc3b5d135b3bc0d1cdb5c16af9c78 Mon Sep 17 00:00:00 2001 From: QP Hou Date: Wed, 7 Jul 2021 05:06:12 -0700 Subject: [PATCH 06/12] fix join column handling logic for `On` and `Using` constraints (#605) * fix join column handling logic for `On` and `Using` constraints * handling join column expansion during USING JOIN planning get rid of shared field and move column expansion logic into plan builder and optimizer. * add more comments & fix clippy * add more comment * reduce duplicate code in join predicate pushdown --- ballista/rust/core/proto/ballista.proto | 10 +- .../core/src/serde/logical_plan/from_proto.rs | 41 ++-- .../core/src/serde/logical_plan/to_proto.rs | 15 +- ballista/rust/core/src/serde/mod.rs | 46 +++- .../src/serde/physical_plan/from_proto.rs | 16 +- .../rust/core/src/serde/physical_plan/mod.rs | 3 +- .../core/src/serde/physical_plan/to_proto.rs | 13 +- benchmarks/queries/q7.sql | 2 +- datafusion/src/execution/context.rs | 90 +++++++ datafusion/src/execution/dataframe_impl.rs | 23 +- datafusion/src/logical_plan/builder.rs | 94 ++------ datafusion/src/logical_plan/dfschema.rs | 149 ++++++------ datafusion/src/logical_plan/expr.rs | 94 ++++++-- datafusion/src/logical_plan/mod.rs | 8 +- datafusion/src/logical_plan/plan.rs | 54 ++++- datafusion/src/optimizer/filter_push_down.rs | 226 ++++++++++++++---- .../src/optimizer/projection_push_down.rs | 88 ++++++- datafusion/src/optimizer/utils.rs | 9 +- datafusion/src/physical_plan/hash_join.rs | 198 ++++++++------- datafusion/src/physical_plan/hash_utils.rs | 57 +---- datafusion/src/physical_plan/planner.rs | 13 +- datafusion/src/sql/planner.rs | 76 +++--- datafusion/src/test/mod.rs | 11 +- 23 files changed, 836 insertions(+), 500 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index e3788066d33fc..4696d21852fc2 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -378,12 +378,18 @@ enum JoinType { ANTI = 5; } +enum JoinConstraint { + ON = 0; + USING = 1; +} + message JoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; JoinType join_type = 3; - repeated Column left_join_column = 4; - repeated Column right_join_column = 5; + JoinConstraint join_constraint = 4; + repeated Column left_join_column = 5; + repeated Column right_join_column = 6; } message LimitNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index a1136cf4a7d6e..cad0543923081 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -26,8 +26,8 @@ use datafusion::logical_plan::window_frames::{ }; use datafusion::logical_plan::{ abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin, - sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, + sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint, JoinType, + LogicalPlan, LogicalPlanBuilder, Operator, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::csv::CsvReadOptions; @@ -257,23 +257,32 @@ impl TryInto for &protobuf::LogicalPlanNode { join.join_type )) })?; - let join_type = match join_type { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Semi => JoinType::Semi, - protobuf::JoinType::Anti => JoinType::Anti, - }; - LogicalPlanBuilder::from(convert_box_required!(join.left)?) - .join( + let join_constraint = protobuf::JoinConstraint::from_i32( + join.join_constraint, + ) + .ok_or_else(|| { + proto_error(format!( + "Received a JoinNode message with unknown JoinConstraint {}", + join.join_constraint + )) + })?; + + let builder = LogicalPlanBuilder::from(convert_box_required!(join.left)?); + let builder = match join_constraint.into() { + JoinConstraint::On => builder.join( &convert_box_required!(join.right)?, - join_type, + join_type.into(), left_keys, right_keys, - )? - .build() - .map_err(|e| e.into()) + )?, + JoinConstraint::Using => builder.join_using( + &convert_box_required!(join.right)?, + join_type.into(), + left_keys, + )?, + }; + + builder.build().map_err(|e| e.into()) } } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4049622b83dc5..07d7a59c114c6 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn use datafusion::datasource::CsvFile; use datafusion::logical_plan::{ window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, - Column, Expr, JoinType, LogicalPlan, + Column, Expr, JoinConstraint, JoinType, LogicalPlan, }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; @@ -804,26 +804,23 @@ impl TryInto for &LogicalPlan { right, on, join_type, + join_constraint, .. } => { let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?; let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?; - let join_type = match join_type { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::Semi => protobuf::JoinType::Semi, - JoinType::Anti => protobuf::JoinType::Anti, - }; let (left_join_column, right_join_column) = on.iter().map(|(l, r)| (l.into(), r.into())).unzip(); + let join_type: protobuf::JoinType = join_type.to_owned().into(); + let join_constraint: protobuf::JoinConstraint = + join_constraint.to_owned().into(); Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { left: Some(Box::new(left)), right: Some(Box::new(right)), join_type: join_type.into(), + join_constraint: join_constraint.into(), left_join_column, right_join_column, }, diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index af83660baab56..1df0675ecae54 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,7 +20,7 @@ use std::{convert::TryInto, io::Cursor}; -use datafusion::logical_plan::Operator; +use datafusion::logical_plan::{JoinConstraint, JoinType, Operator}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; @@ -291,3 +291,47 @@ impl Into for protobuf::PrimitiveScalarT } } } + +impl From for JoinType { + fn from(t: protobuf::JoinType) -> Self { + match t { + protobuf::JoinType::Inner => JoinType::Inner, + protobuf::JoinType::Left => JoinType::Left, + protobuf::JoinType::Right => JoinType::Right, + protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Semi => JoinType::Semi, + protobuf::JoinType::Anti => JoinType::Anti, + } + } +} + +impl From for protobuf::JoinType { + fn from(t: JoinType) -> Self { + match t { + JoinType::Inner => protobuf::JoinType::Inner, + JoinType::Left => protobuf::JoinType::Left, + JoinType::Right => protobuf::JoinType::Right, + JoinType::Full => protobuf::JoinType::Full, + JoinType::Semi => protobuf::JoinType::Semi, + JoinType::Anti => protobuf::JoinType::Anti, + } + } +} + +impl From for JoinConstraint { + fn from(t: protobuf::JoinConstraint) -> Self { + match t { + protobuf::JoinConstraint::On => JoinConstraint::On, + protobuf::JoinConstraint::Using => JoinConstraint::Using, + } + } +} + +impl From for protobuf::JoinConstraint { + fn from(t: JoinConstraint) -> Self { + match t { + JoinConstraint::On => protobuf::JoinConstraint::On, + JoinConstraint::Using => protobuf::JoinConstraint::Using, + } + } +} diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 717ee209dbe91..12c1743c0747c 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -35,7 +35,9 @@ use datafusion::catalog::catalog::{ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; -use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr}; +use datafusion::logical_plan::{ + window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, +}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; @@ -57,7 +59,6 @@ use datafusion::physical_plan::{ filter::FilterExec, functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, hash_join::HashJoinExec, - hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, parquet::ParquetExec, projection::ProjectionExec, @@ -348,14 +349,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { hashjoin.join_type )) })?; - let join_type = match join_type { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Semi => JoinType::Semi, - protobuf::JoinType::Anti => JoinType::Anti, - }; + let partition_mode = protobuf::PartitionMode::from_i32(hashjoin.partition_mode) .ok_or_else(|| { @@ -372,7 +366,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { left, right, on, - &join_type, + &join_type.into(), partition_mode, )?)) } diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index a393d7fdab1f7..3bf7e9c3063b5 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -27,7 +27,7 @@ mod roundtrip_tests { compute::kernels::sort::SortOptions, datatypes::{DataType, Field, Schema}, }, - logical_plan::Operator, + logical_plan::{JoinType, Operator}, physical_plan::{ empty::EmptyExec, expressions::{binary, col, lit, InListExpr, NotExpr}, @@ -35,7 +35,6 @@ mod roundtrip_tests { filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, hash_join::{HashJoinExec, PartitionMode}, - hash_utils::JoinType, limit::{GlobalLimitExec, LocalLimitExec}, sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 0fc27850074c3..875dbf213441d 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -26,6 +26,7 @@ use std::{ sync::Arc, }; +use datafusion::logical_plan::JoinType; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ @@ -35,7 +36,6 @@ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; -use datafusion::physical_plan::hash_utils::JoinType; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; @@ -135,18 +135,13 @@ impl TryInto for Arc { }), }) .collect(); - let join_type = match exec.join_type() { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::Semi => protobuf::JoinType::Semi, - JoinType::Anti => protobuf::JoinType::Anti, - }; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let partition_mode = match exec.partition_mode() { PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, }; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { diff --git a/benchmarks/queries/q7.sql b/benchmarks/queries/q7.sql index d53877c8dde68..512e5be55a2d9 100644 --- a/benchmarks/queries/q7.sql +++ b/benchmarks/queries/q7.sql @@ -36,4 +36,4 @@ group by order by supp_nation, cust_nation, - l_year; \ No newline at end of file + l_year; diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6a26e0401bb87..d2dcec5f47d73 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1278,6 +1278,96 @@ mod tests { Ok(()) } + #[tokio::test] + async fn left_join_using() -> Result<()> { + let results = execute( + "SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 0 | 1 |", + "| 0 | 2 |", + "| 0 | 3 |", + "| 0 | 4 |", + "| 0 | 5 |", + "| 0 | 6 |", + "| 0 | 7 |", + "| 0 | 8 |", + "| 0 | 9 |", + "| 0 | 10 |", + "+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn left_join_using_join_key_projection() -> Result<()> { + let results = execute( + "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY t2.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+----+", + "| c1 | c2 | c2 |", + "+----+----+----+", + "| 0 | 1 | 1 |", + "| 0 | 2 | 2 |", + "| 0 | 3 | 3 |", + "| 0 | 4 | 4 |", + "| 0 | 5 | 5 |", + "| 0 | 6 | 6 |", + "| 0 | 7 | 7 |", + "| 0 | 8 | 8 |", + "| 0 | 9 | 9 |", + "| 0 | 10 | 10 |", + "+----+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + + #[tokio::test] + async fn left_join() -> Result<()> { + let results = execute( + "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 = t2.c2 ORDER BY t1.c2", + 1, + ) + .await?; + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+----+", + "| c1 | c2 | c2 |", + "+----+----+----+", + "| 0 | 1 | 1 |", + "| 0 | 2 | 2 |", + "| 0 | 3 | 3 |", + "| 0 | 4 | 4 |", + "| 0 | 5 | 5 |", + "| 0 | 6 | 6 |", + "| 0 | 7 | 7 |", + "| 0 | 8 | 8 |", + "| 0 | 9 | 9 |", + "| 0 | 10 | 10 |", + "+----+----+----+", + ]; + + assert_batches_eq!(expected, &results); + Ok(()) + } + #[tokio::test] async fn window() -> Result<()> { let results = execute( diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 7cf779740c473..4edd01c2c0a99 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -264,7 +264,7 @@ mod tests { #[tokio::test] async fn join() -> Result<()> { let left = test_table()?.select_columns(&["c1", "c2"])?; - let right = test_table()?.select_columns(&["c1", "c3"])?; + let right = test_table_with_name("c2")?.select_columns(&["c1", "c3"])?; let left_rows = left.collect().await?; let right_rows = right.collect().await?; let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; @@ -315,7 +315,7 @@ mod tests { #[test] fn registry() -> Result<()> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; + register_aggregate_csv(&mut ctx, "aggregate_test_100")?; // declare the udf let my_fn: ScalarFunctionImplementation = @@ -366,21 +366,28 @@ mod tests { /// Create a logical plan from a SQL query fn create_plan(sql: &str) -> Result { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; + register_aggregate_csv(&mut ctx, "aggregate_test_100")?; ctx.create_logical_plan(sql) } - fn test_table() -> Result> { + fn test_table_with_name(name: &str) -> Result> { let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx)?; - ctx.table("aggregate_test_100") + register_aggregate_csv(&mut ctx, name)?; + ctx.table(name) + } + + fn test_table() -> Result> { + test_table_with_name("aggregate_test_100") } - fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { + fn register_aggregate_csv( + ctx: &mut ExecutionContext, + table_name: &str, + ) -> Result<()> { let schema = test::aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); ctx.register_csv( - "aggregate_test_100", + table_name, &format!("{}/csv/aggregate_test_100.csv", testdata), CsvReadOptions::new().schema(schema.as_ref()), )?; diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 1a53e2185a4bc..41f29c4b99052 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -40,7 +40,6 @@ use crate::logical_plan::{ columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema, DFSchemaRef, Partitioning, }; -use std::collections::HashSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -217,7 +216,6 @@ impl LogicalPlanBuilder { /// * An invalid expression is used (e.g. a `sort` expression) pub fn project(&self, expr: impl IntoIterator) -> Result { let input_schema = self.plan.schema(); - let all_schemas = self.plan.all_schemas(); let mut projected_expr = vec![]; for e in expr { match e { @@ -227,10 +225,8 @@ impl LogicalPlanBuilder { .push(Expr::Column(input_schema.field(i).qualified_column())) }); } - _ => projected_expr.push(columnize_expr( - normalize_col(e, &all_schemas)?, - input_schema, - )), + _ => projected_expr + .push(columnize_expr(normalize_col(e, &self.plan)?, input_schema)), } } @@ -247,7 +243,7 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(&self, expr: Expr) -> Result { - let expr = normalize_col(expr, &self.plan.all_schemas())?; + let expr = normalize_col(expr, &self.plan)?; Ok(Self::from(LogicalPlan::Filter { predicate: expr, input: Arc::new(self.plan.clone()), @@ -264,9 +260,8 @@ impl LogicalPlanBuilder { /// Apply a sort pub fn sort(&self, exprs: impl IntoIterator) -> Result { - let schemas = self.plan.all_schemas(); Ok(Self::from(LogicalPlan::Sort { - expr: normalize_cols(exprs, &schemas)?, + expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), })) } @@ -292,20 +287,15 @@ impl LogicalPlanBuilder { let left_keys: Vec = left_keys .into_iter() - .map(|c| c.into().normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan)) .collect::>()?; let right_keys: Vec = right_keys .into_iter() - .map(|c| c.into().normalize(&right.all_schemas())) + .map(|c| c.into().normalize(right)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - let join_schema = build_join_schema( - self.plan.schema(), - right.schema(), - &on, - &join_type, - &JoinConstraint::On, - )?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join { left: Arc::new(self.plan.clone()), @@ -327,21 +317,16 @@ impl LogicalPlanBuilder { let left_keys: Vec = using_keys .clone() .into_iter() - .map(|c| c.into().normalize(&self.plan.all_schemas())) + .map(|c| c.into().normalize(&self.plan)) .collect::>()?; let right_keys: Vec = using_keys .into_iter() - .map(|c| c.into().normalize(&right.all_schemas())) + .map(|c| c.into().normalize(right)) .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); - let join_schema = build_join_schema( - self.plan.schema(), - right.schema(), - &on, - &join_type, - &JoinConstraint::Using, - )?; + let join_schema = + build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join { left: Arc::new(self.plan.clone()), @@ -394,9 +379,8 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator, aggr_expr: impl IntoIterator, ) -> Result { - let schemas = self.plan.all_schemas(); - let group_expr = normalize_cols(group_expr, &schemas)?; - let aggr_expr = normalize_cols(aggr_expr, &schemas)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; + let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; @@ -440,33 +424,12 @@ impl LogicalPlanBuilder { pub fn build_join_schema( left: &DFSchema, right: &DFSchema, - on: &[(Column, Column)], join_type: &JoinType, - join_constraint: &JoinConstraint, ) -> Result { let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { - let duplicate_keys = match join_constraint { - JoinConstraint::On => on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.clone()) - .collect::>(), - // using join requires unique join columns in the output schema, so we mark all - // right join keys as duplicate - JoinConstraint::Using => { - on.iter().map(|on| on.1.clone()).collect::>() - } - }; - + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let right_fields = right.fields().iter(); let left_fields = left.fields().iter(); - - // remove right-side join keys if they have the same names as the left-side - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(&f.qualified_column())); - // left then right left_fields.chain(right_fields).cloned().collect() } @@ -474,31 +437,6 @@ pub fn build_join_schema( // Only use the left side for the schema left.fields().clone() } - JoinType::Right => { - let duplicate_keys = match join_constraint { - JoinConstraint::On => on - .iter() - .filter(|(l, r)| l == r) - .map(|on| on.1.clone()) - .collect::>(), - // using join requires unique join columns in the output schema, so we mark all - // left join keys as duplicate - JoinConstraint::Using => { - on.iter().map(|on| on.0.clone()).collect::>() - } - }; - - // remove left-side join keys if they have the same names as the right-side - let left_fields = left - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(&f.qualified_column())); - - let right_fields = right.fields().iter(); - - // left then right - left_fields.chain(right_fields).cloned().collect() - } }; DFSchema::new(fields) diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index b4d864f55ebdb..b4bde87f3471f 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -48,6 +48,7 @@ impl DFSchema { pub fn new(fields: Vec) -> Result { let mut qualified_names = HashSet::new(); let mut unqualified_names = HashSet::new(); + for field in &fields { if let Some(qualifier) = field.qualifier() { if !qualified_names.insert((qualifier, field.name())) { @@ -94,10 +95,7 @@ impl DFSchema { schema .fields() .iter() - .map(|f| DFField { - field: f.clone(), - qualifier: Some(qualifier.to_owned()), - }) + .map(|f| DFField::from_qualified(qualifier, f.clone())) .collect(), ) } @@ -149,47 +147,80 @@ impl DFSchema { ))) } - /// Find the index of the column with the given qualifer and name - pub fn index_of_column(&self, col: &Column) -> Result { - for i in 0..self.fields.len() { - let field = &self.fields[i]; - if field.qualifier() == col.relation.as_ref() && field.name() == &col.name { - return Ok(i); - } + fn index_of_column_by_name( + &self, + qualifier: Option<&str>, + name: &str, + ) -> Result { + let matches: Vec = self + .fields + .iter() + .enumerate() + .filter(|(_, field)| match (qualifier, &field.qualifier) { + // field to lookup is qualified. + // current field is qualified and not shared between relations, compare both + // qualifer and name. + (Some(q), Some(field_q)) => q == field_q && field.name() == name, + // field to lookup is qualified but current field is unqualified. + (Some(_), None) => false, + // field to lookup is unqualified, no need to compare qualifier + _ => field.name() == name, + }) + .map(|(idx, _)| idx) + .collect(); + + match matches.len() { + 0 => Err(DataFusionError::Plan(format!( + "No field named '{}.{}'. Valid fields are {}.", + qualifier.unwrap_or(""), + name, + self.get_field_names() + ))), + 1 => Ok(matches[0]), + _ => Err(DataFusionError::Internal(format!( + "Ambiguous reference to qualified field named '{}.{}'", + qualifier.unwrap_or(""), + name + ))), } - Err(DataFusionError::Plan(format!( - "No field matches column '{}'. Available fields: {}", - col, self - ))) + } + + /// Find the index of the column with the given qualifier and name + pub fn index_of_column(&self, col: &Column) -> Result { + self.index_of_column_by_name(col.relation.as_deref(), &col.name) } /// Find the field with the given name pub fn field_with_name( &self, - relation_name: Option<&str>, + qualifier: Option<&str>, name: &str, - ) -> Result { - if let Some(relation_name) = relation_name { - self.field_with_qualified_name(relation_name, name) + ) -> Result<&DFField> { + if let Some(qualifier) = qualifier { + self.field_with_qualified_name(qualifier, name) } else { self.field_with_unqualified_name(name) } } - /// Find the field with the given name - pub fn field_with_unqualified_name(&self, name: &str) -> Result { - let matches: Vec<&DFField> = self - .fields + /// Find all fields match the given name + pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { + self.fields .iter() .filter(|field| field.name() == name) - .collect(); + .collect() + } + + /// Find the field with the given name + pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> { + let matches = self.fields_with_unqualified_name(name); match matches.len() { 0 => Err(DataFusionError::Plan(format!( "No field with unqualified name '{}'. Valid fields are {}.", name, self.get_field_names() ))), - 1 => Ok(matches[0].to_owned()), + 1 => Ok(matches[0]), _ => Err(DataFusionError::Plan(format!( "Ambiguous reference to field named '{}'", name @@ -200,33 +231,15 @@ impl DFSchema { /// Find the field with the given qualified name pub fn field_with_qualified_name( &self, - relation_name: &str, + qualifier: &str, name: &str, - ) -> Result { - let matches: Vec<&DFField> = self - .fields - .iter() - .filter(|field| { - field.qualifier == Some(relation_name.to_string()) && field.name() == name - }) - .collect(); - match matches.len() { - 0 => Err(DataFusionError::Plan(format!( - "No field named '{}.{}'. Valid fields are {}.", - relation_name, - name, - self.get_field_names() - ))), - 1 => Ok(matches[0].to_owned()), - _ => Err(DataFusionError::Internal(format!( - "Ambiguous reference to qualified field named '{}.{}'", - relation_name, name - ))), - } + ) -> Result<&DFField> { + let idx = self.index_of_column_by_name(Some(qualifier), name)?; + Ok(self.field(idx)) } /// Find the field with the given qualified column - pub fn field_from_qualified_column(&self, column: &Column) -> Result { + pub fn field_from_column(&self, column: &Column) -> Result<&DFField> { match &column.relation { Some(r) => self.field_with_qualified_name(r, &column.name), None => self.field_with_unqualified_name(&column.name), @@ -247,31 +260,20 @@ impl DFSchema { fields: self .fields .into_iter() - .map(|f| { - if f.qualifier().is_some() { - DFField::new( - None, - f.name(), - f.data_type().to_owned(), - f.is_nullable(), - ) - } else { - f - } - }) + .map(|f| f.strip_qualifier()) .collect(), } } /// Replace all field qualifier with new value in schema - pub fn replace_qualifier(self, qualifer: &str) -> Self { + pub fn replace_qualifier(self, qualifier: &str) -> Self { DFSchema { fields: self .fields .into_iter() .map(|f| { DFField::new( - Some(qualifer), + Some(qualifier), f.name(), f.data_type().to_owned(), f.is_nullable(), @@ -328,10 +330,7 @@ impl TryFrom for DFSchema { schema .fields() .iter() - .map(|f| DFField { - field: f.clone(), - qualifier: None, - }) + .map(|f| DFField::from(f.clone())) .collect(), ) } @@ -454,8 +453,8 @@ impl DFField { /// Returns a string to the `DFField`'s qualified name pub fn qualified_name(&self) -> String { - if let Some(relation_name) = &self.qualifier { - format!("{}.{}", relation_name, self.field.name()) + if let Some(qualifier) = &self.qualifier { + format!("{}.{}", qualifier, self.field.name()) } else { self.field.name().to_owned() } @@ -469,6 +468,14 @@ impl DFField { } } + /// Builds an unqualified column based on self + pub fn unqualified_column(&self) -> Column { + Column { + relation: None, + name: self.field.name().to_string(), + } + } + /// Get the optional qualifier pub fn qualifier(&self) -> Option<&String> { self.qualifier.as_ref() @@ -478,6 +485,12 @@ impl DFField { pub fn field(&self) -> &Field { &self.field } + + /// Return field with qualifier stripped + pub fn strip_qualifier(mut self) -> Self { + self.qualifier = None; + self + } } #[cfg(test)] diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 1fab9bb875ae9..9454d7593c3f3 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,7 +20,7 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::{window_frames, DFField, DFSchema, DFSchemaRef}; +use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan}; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, @@ -29,7 +29,7 @@ use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::sync::Arc; @@ -89,14 +89,46 @@ impl Column { /// /// For example, `foo` will be normalized to `t.foo` if there is a /// column named `foo` in a relation named `t` found in `schemas` - pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result { + pub fn normalize(self, plan: &LogicalPlan) -> Result { if self.relation.is_some() { return Ok(self); } - for schema in schemas { - if let Ok(field) = schema.field_with_unqualified_name(&self.name) { - return Ok(field.qualified_column()); + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + + for schema in &schemas { + let fields = schema.fields_with_unqualified_name(&self.name); + match fields.len() { + 0 => continue, + 1 => { + return Ok(fields[0].qualified_column()); + } + _ => { + // More than 1 fields in this schema have their names set to self.name. + // + // This should only happen when a JOIN query with USING constraint references + // join columns using unqualified column name. For example: + // + // ```sql + // SELECT id FROM t1 JOIN t2 USING(id) + // ``` + // + // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. + // We will use the relation from the first matched field to normalize self. + + // Compare matched fields with one USING JOIN clause at a time + for using_col in &using_columns { + let all_matched = fields + .iter() + .all(|f| using_col.contains(&f.qualified_column())); + // All matched fields belong to the same using column set, in orther words + // the same join clause. We simply pick the qualifer from the first match. + if all_matched { + return Ok(fields[0].qualified_column()); + } + } + } } } @@ -321,9 +353,7 @@ impl Expr { pub fn get_type(&self, schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(c) => { - Ok(schema.field_from_qualified_column(c)?.data_type().clone()) - } + Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()), Expr::ScalarVariable(_) => Ok(DataType::Utf8), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema), @@ -395,9 +425,7 @@ impl Expr { pub fn nullable(&self, input_schema: &DFSchema) -> Result { match self { Expr::Alias(expr, _) => expr.nullable(input_schema), - Expr::Column(c) => { - Ok(input_schema.field_from_qualified_column(c)?.is_nullable()) - } + Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()), Expr::Literal(value) => Ok(value.is_null()), Expr::ScalarVariable(_) => Ok(true), Expr::Case { @@ -1118,36 +1146,56 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { } } +/// Recursively replace all Column expressions in a given expression tree with Column expressions +/// provided by the hash map argument. +pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { + struct ColumnReplacer<'a> { + replace_map: &'a HashMap<&'a Column, &'a Column>, + } + + impl<'a> ExprRewriter for ColumnReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(c) = &expr { + match self.replace_map.get(c) { + Some(new_c) => Ok(Expr::Column((*new_c).to_owned())), + None => Ok(expr), + } + } else { + Ok(expr) + } + } + } + + e.rewrite(&mut ColumnReplacer { replace_map }) +} + /// Recursively call [`Column::normalize`] on all Column expressions /// in the `expr` expression tree. -pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result { - struct ColumnNormalizer<'a, 'b> { - schemas: &'a [&'b DFSchemaRef], +pub fn normalize_col(e: Expr, plan: &LogicalPlan) -> Result { + struct ColumnNormalizer<'a> { + plan: &'a LogicalPlan, } - impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> { + impl<'a> ExprRewriter for ColumnNormalizer<'a> { fn mutate(&mut self, expr: Expr) -> Result { if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize(self.schemas)?)) + Ok(Expr::Column(c.normalize(self.plan)?)) } else { Ok(expr) } } } - e.rewrite(&mut ColumnNormalizer { schemas }) + e.rewrite(&mut ColumnNormalizer { plan }) } /// Recursively normalize all Column expressions in a list of expression trees #[inline] pub fn normalize_cols( exprs: impl IntoIterator, - schemas: &[&DFSchemaRef], + plan: &LogicalPlan, ) -> Result> { - exprs - .into_iter() - .map(|e| normalize_col(e, schemas)) - .collect() + exprs.into_iter().map(|e| normalize_col(e, plan)).collect() } /// Create an expression to represent the min() aggregate function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 69d03d22bb21a..86a2f567d7de4 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -41,10 +41,10 @@ pub use expr::{ cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, - sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, - to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter, - ExpressionVisitor, Literal, Recursion, + regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, + sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, + substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr, + ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 99f0fa14a2d97..b954b6a97950c 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -21,9 +21,11 @@ use super::display::{GraphvizVisitor, IndentVisitor}; use super::expr::{Column, Expr}; use super::extension::UserDefinedLogicalNode; use crate::datasource::TableProvider; +use crate::error::DataFusionError; use crate::logical_plan::dfschema::DFSchemaRef; use crate::sql::parser::FileType; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::collections::HashSet; use std::{ fmt::{self, Display}, sync::Arc, @@ -354,6 +356,43 @@ impl LogicalPlan { | LogicalPlan::CreateExternalTable { .. } => vec![], } } + + /// returns all `Using` join columns in a logical plan + pub fn using_columns(&self) -> Result>, DataFusionError> { + struct UsingJoinColumnVisitor { + using_columns: Vec>, + } + + impl PlanVisitor for UsingJoinColumnVisitor { + type Error = DataFusionError; + + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + if let LogicalPlan::Join { + join_constraint: JoinConstraint::Using, + on, + .. + } = plan + { + self.using_columns.push( + on.iter() + .map(|entry| { + std::iter::once(entry.0.clone()) + .chain(std::iter::once(entry.1.clone())) + }) + .flatten() + .collect::>(), + ); + } + Ok(true) + } + } + + let mut visitor = UsingJoinColumnVisitor { + using_columns: vec![], + }; + self.accept(&mut visitor)?; + Ok(visitor.using_columns) + } } /// Logical partitioning schemes supported by the repartition operator. @@ -709,10 +748,21 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Join { on: ref keys, .. } => { + LogicalPlan::Join { + on: ref keys, + join_constraint, + .. + } => { let join_expr: Vec = keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); - write!(f, "Join: {}", join_expr.join(", ")) + match join_constraint { + JoinConstraint::On => { + write!(f, "Join: {}", join_expr.join(", ")) + } + JoinConstraint::Using => { + write!(f, "Join: Using {}", join_expr.join(", ")) + } + } } LogicalPlan::CrossJoin { .. } => { write!(f, "CrossJoin:") diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index c1d81fe629345..76d8c05bed4c6 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -16,7 +16,7 @@ use crate::datasource::datasource::TableProviderFilterPushDown; use crate::execution::context::ExecutionProps; -use crate::logical_plan::{and, Column, LogicalPlan}; +use crate::logical_plan::{and, replace_col, Column, LogicalPlan}; use crate::logical_plan::{DFSchema, Expr}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -96,12 +96,21 @@ fn get_join_predicates<'a>( let left_columns = &left .fields() .iter() - .map(|f| f.qualified_column()) + .map(|f| { + std::iter::once(f.qualified_column()) + // we need to push down filter using unqualified column as well + .chain(std::iter::once(f.unqualified_column())) + }) + .flatten() .collect::>(); let right_columns = &right .fields() .iter() - .map(|f| f.qualified_column()) + .map(|f| { + std::iter::once(f.qualified_column()) + .chain(std::iter::once(f.unqualified_column())) + }) + .flatten() .collect::>(); let filters = state @@ -232,6 +241,38 @@ fn split_members<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>) { } } +fn optimize_join( + mut state: State, + plan: &LogicalPlan, + left: &LogicalPlan, + right: &LogicalPlan, +) -> Result { + let (pushable_to_left, pushable_to_right, keep) = + get_join_predicates(&state, left.schema(), right.schema()); + + let mut left_state = state.clone(); + left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); + let left = optimize(left, left_state)?; + + let mut right_state = state.clone(); + right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); + let right = optimize(right, right_state)?; + + // create a new Join with the new `left` and `right` + let expr = plan.expressions(); + let plan = utils::from_plan(plan, &expr, &[left, right])?; + + if keep.0.is_empty() { + Ok(plan) + } else { + // wrap the join on the filter whose predicates must be kept + let plan = add_filter(plan, &keep.0); + state.filters = remove_filters(&state.filters, &keep.1); + + Ok(plan) + } +} + fn optimize(plan: &LogicalPlan, mut state: State) -> Result { match plan { LogicalPlan::Explain { .. } => { @@ -336,32 +377,68 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result { .collect::>(); issue_filters(state, used_columns, plan) } - LogicalPlan::Join { left, right, .. } - | LogicalPlan::CrossJoin { left, right, .. } => { - let (pushable_to_left, pushable_to_right, keep) = - get_join_predicates(&state, left.schema(), right.schema()); - - let mut left_state = state.clone(); - left_state.filters = keep_filters(&left_state.filters, &pushable_to_left); - let left = optimize(left, left_state)?; - - let mut right_state = state.clone(); - right_state.filters = keep_filters(&right_state.filters, &pushable_to_right); - let right = optimize(right, right_state)?; - - // create a new Join with the new `left` and `right` - let expr = plan.expressions(); - let plan = utils::from_plan(plan, &expr, &[left, right])?; + LogicalPlan::CrossJoin { left, right, .. } => { + optimize_join(state, plan, left, right) + } + LogicalPlan::Join { + left, right, on, .. + } => { + // duplicate filters for joined columns so filters can be pushed down to both sides. + // Take the following query as an example: + // + // ```sql + // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 + // ``` + // + // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while + // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. + // + // Join clauses with `Using` constraints also take advantage of this logic to make sure + // predicates reference the shared join columns are pushed to both sides. + let join_side_filters = state + .filters + .iter() + .filter_map(|(predicate, columns)| { + let mut join_cols_to_replace = HashMap::new(); + for col in columns.iter() { + for (l, r) in on { + if col == l { + join_cols_to_replace.insert(col, r); + break; + } else if col == r { + join_cols_to_replace.insert(col, l); + break; + } + } + } - if keep.0.is_empty() { - Ok(plan) - } else { - // wrap the join on the filter whose predicates must be kept - let plan = add_filter(plan, &keep.0); - state.filters = remove_filters(&state.filters, &keep.1); + if join_cols_to_replace.is_empty() { + return None; + } - Ok(plan) - } + let join_side_predicate = + match replace_col(predicate.clone(), &join_cols_to_replace) { + Ok(p) => p, + Err(e) => { + return Some(Err(e)); + } + }; + + let join_side_columns = columns + .clone() + .into_iter() + // replace keys in join_cols_to_replace with values in resulting column + // set + .filter(|c| !join_cols_to_replace.contains_key(c)) + .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone())) + .collect(); + + Some(Ok((join_side_predicate, join_side_columns))) + }) + .collect::>>()?; + state.filters.extend(join_side_filters); + + optimize_join(state, plan, left, right) } LogicalPlan::TableScan { source, @@ -878,12 +955,13 @@ mod tests { Ok(()) } - /// post-join predicates on a column common to both sides is pushed to both sides + /// post-on-join predicates on a column common to both sides is pushed to both sides #[test] - fn filter_join_on_common_independent() -> Result<()> { + fn filter_on_join_on_common_independent() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()).build()?; - let right = LogicalPlanBuilder::from(table_scan) + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a")])? .build()?; let plan = LogicalPlanBuilder::from(left) @@ -901,20 +979,61 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.a LtEq Int64(1)\ - \n Join: #test.a = #test.a\ + \n Join: #test.a = #test2.a\ \n TableScan: test projection=None\ - \n Projection: #test.a\ - \n TableScan: test projection=None" + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" ); // filter sent to side before the join let expected = "\ - Join: #test.a = #test.a\ + Join: #test.a = #test2.a\ \n Filter: #test.a LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #test.a\ - \n Filter: #test.a LtEq Int64(1)\ - \n TableScan: test projection=None"; + \n Projection: #test2.a\ + \n Filter: #test2.a LtEq Int64(1)\ + \n TableScan: test2 projection=None"; + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + /// post-using-join predicates on a column common to both sides is pushed to both sides + #[test] + fn filter_using_join_on_common_independent() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + )? + .filter(col("a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{:?}", plan), + "\ + Filter: #test.a LtEq Int64(1)\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n TableScan: test2 projection=None" + ); + + // filter sent to side before the join + let expected = "\ + Join: Using #test.a = #test2.a\ + \n Filter: #test.a LtEq Int64(1)\ + \n TableScan: test projection=None\ + \n Projection: #test2.a\ + \n Filter: #test2.a LtEq Int64(1)\ + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } @@ -923,10 +1042,11 @@ mod tests { #[test] fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()) + let left = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("c")])? .build()?; - let right = LogicalPlanBuilder::from(table_scan) + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a"), col("b")])? .build()?; let plan = LogicalPlanBuilder::from(left) @@ -944,12 +1064,12 @@ mod tests { assert_eq!( format!("{:?}", plan), "\ - Filter: #test.c LtEq #test.b\ - \n Join: #test.a = #test.a\ + Filter: #test.c LtEq #test2.b\ + \n Join: #test.a = #test2.a\ \n Projection: #test.a, #test.c\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.b\ - \n TableScan: test projection=None" + \n Projection: #test2.a, #test2.b\ + \n TableScan: test2 projection=None" ); // expected is equal: no push-down @@ -962,12 +1082,14 @@ mod tests { #[test] fn filter_join_on_one_side() -> Result<()> { let table_scan = test_table_scan()?; - let left = LogicalPlanBuilder::from(table_scan.clone()) + let left = LogicalPlanBuilder::from(table_scan) .project(vec![col("a"), col("b")])? .build()?; - let right = LogicalPlanBuilder::from(table_scan) + let table_scan_right = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(table_scan_right) .project(vec![col("a"), col("c")])? .build()?; + let plan = LogicalPlanBuilder::from(left) .join( &right, @@ -983,20 +1105,20 @@ mod tests { format!("{:?}", plan), "\ Filter: #test.b LtEq Int64(1)\ - \n Join: #test.a = #test.a\ + \n Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.c\ - \n TableScan: test projection=None" + \n Projection: #test2.a, #test2.c\ + \n TableScan: test2 projection=None" ); let expected = "\ - Join: #test.a = #test.a\ + Join: #test.a = #test2.a\ \n Projection: #test.a, #test.b\ \n Filter: #test.b LtEq Int64(1)\ \n TableScan: test projection=None\ - \n Projection: #test.a, #test.c\ - \n TableScan: test projection=None"; + \n Projection: #test2.a, #test2.c\ + \n TableScan: test2 projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) } diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 3c8f1ee4ceb58..0272b9f7872cf 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -216,9 +216,7 @@ fn optimize_plan( let schema = build_join_schema( optimized_left.schema(), optimized_right.schema(), - on, join_type, - join_constraint, )?; Ok(LogicalPlan::Join { @@ -499,7 +497,7 @@ mod tests { } #[test] - fn join_schema_trim() -> Result<()> { + fn join_schema_trim_full_join_column_projection() -> Result<()> { let table_scan = test_table_scan()?; let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); @@ -511,7 +509,7 @@ mod tests { .project(vec![col("a"), col("b"), col("c1")])? .build()?; - // make sure projections are pushed down to table scan + // make sure projections are pushed down to both table scans let expected = "Projection: #test.a, #test.b, #test2.c1\ \n Join: #test.a = #test2.c1\ \n TableScan: test projection=Some([0, 1])\ @@ -521,7 +519,48 @@ mod tests { let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); - // make sure schema for join node doesn't include c1 column + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "c1", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + + #[test] + fn join_schema_trim_partial_join_column_projection() -> Result<()> { + // test join column push down without explicit column projections + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + // projecting joined column `a` should push the right side column `c1` projection as + // well into test2 table even though `c1` is not referenced in projection. + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to both table scans + let expected = "Projection: #test.a, #test.b\ + \n Join: #test.a = #test2.c1\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; assert_eq!( **optimized_join.schema(), @@ -535,6 +574,45 @@ mod tests { Ok(()) } + #[test] + fn join_schema_trim_using_join() -> Result<()> { + // shared join colums from using join should be pushed to both sides + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + let table2_scan = + LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join_using(&table2_scan, JoinType::Left, vec!["a"])? + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: #test.a, #test.b\ + \n Join: Using #test.a = #test2.a\ + \n TableScan: test projection=Some([0, 1])\ + \n TableScan: test2 projection=Some([0])"; + + let optimized_plan = optimize(&plan)?; + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new(vec![ + DFField::new(Some("test"), "a", DataType::UInt32, false), + DFField::new(Some("test"), "b", DataType::UInt32, false), + DFField::new(Some("test2"), "a", DataType::UInt32, false), + ])?, + ); + + Ok(()) + } + #[test] fn cast() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index ae3e196c22251..1d19f0681b350 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -215,13 +215,8 @@ pub fn from_plan( on, .. } => { - let schema = build_join_schema( - inputs[0].schema(), - inputs[1].schema(), - on, - join_type, - join_constraint, - )?; + let schema = + build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; Ok(LogicalPlan::Join { left: Arc::new(inputs[0].clone()), right: Arc::new(inputs[1].clone()), diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index f426bc9d3c3c2..00ca1539d714f 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -55,9 +55,10 @@ use arrow::array::{ use super::expressions::Column; use super::{ coalesce_partitions::CoalescePartitionsExec, - hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, + hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, }; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; use super::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -165,12 +166,7 @@ impl HashJoinExec { let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &on)?; - let schema = Arc::new(build_join_schema( - &left_schema, - &right_schema, - &on, - join_type, - )); + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); let random_state = RandomState::with_seeds(0, 0, 0, 0); @@ -1437,16 +1433,16 @@ mod tests { join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 5 | 9 | 20 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1478,16 +1474,16 @@ mod tests { ) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 5 | 9 | 20 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1555,18 +1551,18 @@ mod tests { let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+", - "| a1 | b2 | c1 | c2 |", - "+----+----+----+----+", - "| 1 | 1 | 7 | 70 |", - "| 2 | 2 | 8 | 80 |", - "| 2 | 2 | 9 | 80 |", - "+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1607,18 +1603,18 @@ mod tests { let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; - assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+", - "| a1 | b2 | c1 | c2 |", - "+----+----+----+----+", - "| 1 | 1 | 7 | 70 |", - "| 2 | 2 | 8 | 80 |", - "| 2 | 2 | 9 | 80 |", - "+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1655,7 +1651,7 @@ mod tests { let join = join(left, right, on, &JoinType::Inner)?; let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part let stream = join.execute(0).await?; @@ -1663,11 +1659,11 @@ mod tests { assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1676,12 +1672,12 @@ mod tests { let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 2 | 5 | 8 | 30 | 90 |", - "| 3 | 5 | 9 | 30 | 90 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1721,21 +1717,21 @@ mod tests { let join = join(left, right, on, &JoinType::Left).unwrap(); let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let stream = join.execute(0).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1801,19 +1797,19 @@ mod tests { let join = join(left, right, on, &JoinType::Left).unwrap(); let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let stream = join.execute(0).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | | |", - "| 2 | 5 | 8 | | |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | 4 | |", + "| 2 | 5 | 8 | | 5 | |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1874,16 +1870,16 @@ mod tests { let (columns, batches) = join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -1914,16 +1910,16 @@ mod tests { &JoinType::Left, ) .await?; - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | c2 |", - "+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 70 |", - "| 2 | 5 | 8 | 20 | 80 |", - "| 3 | 7 | 9 | | |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2025,16 +2021,16 @@ mod tests { let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; - assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+", - "| | | 30 | 6 | 90 |", - "| 1 | 7 | 10 | 4 | 70 |", - "| 2 | 8 | 20 | 5 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2062,16 +2058,16 @@ mod tests { let (columns, batches) = partitioned_join_collect(left, right, on, &JoinType::Right).await?; - assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ - "+----+----+----+----+----+", - "| a1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+", - "| | | 30 | 6 | 90 |", - "| 1 | 7 | 10 | 4 | 70 |", - "| 2 | 8 | 20 | 5 | 80 |", - "+----+----+----+----+----+", + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 0cf0b9212cd21..9243affe9cfc3 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -21,25 +21,9 @@ use crate::error::{DataFusionError, Result}; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; +use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; -/// All valid types of joins. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum JoinType { - /// Inner Join - Inner, - /// Left Join - Left, - /// Right Join - Right, - /// Full Join - Full, - /// Semi Join - Semi, - /// Anti Join - Anti, -} - /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. @@ -104,46 +88,11 @@ fn check_join_set_is_valid( /// Creates a schema for a join operation. /// The fields from the left side are first -pub fn build_join_schema( - left: &Schema, - right: &Schema, - on: JoinOnRef, - join_type: &JoinType, -) -> Schema { +pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { - // remove right-side join keys if they have the same names as the left-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l.name() == r.name()) - .map(|on| on.1.name()) - .collect::>(); - + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { let left_fields = left.fields().iter(); - - let right_fields = right - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinType::Right => { - // remove left-side join keys if they have the same names as the right-side - let duplicate_keys = &on - .iter() - .filter(|(l, r)| l.name() == r.name()) - .map(|on| on.1.name()) - .collect::>(); - - let left_fields = left - .fields() - .iter() - .filter(|f| !duplicate_keys.contains(f.name().as_str())); - let right_fields = right.fields().iter(); - // left then right left_fields.chain(right_fields).cloned().collect() } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index effdefcfabadc..73b2f362989f6 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -40,7 +40,6 @@ use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{hash_utils, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; -use crate::prelude::JoinType; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; use crate::variable::VarType; @@ -661,14 +660,6 @@ impl DefaultPhysicalPlanner { let physical_left = self.create_initial_plan(left, ctx_state)?; let right_df_schema = right.schema(); let physical_right = self.create_initial_plan(right, ctx_state)?; - let physical_join_type = match join_type { - JoinType::Inner => hash_utils::JoinType::Inner, - JoinType::Left => hash_utils::JoinType::Left, - JoinType::Right => hash_utils::JoinType::Right, - JoinType::Full => hash_utils::JoinType::Full, - JoinType::Semi => hash_utils::JoinType::Semi, - JoinType::Anti => hash_utils::JoinType::Anti, - }; let join_on = keys .iter() .map(|(l, r)| { @@ -702,7 +693,7 @@ impl DefaultPhysicalPlanner { Partitioning::Hash(right_expr, ctx_state.config.concurrency), )?), join_on, - &physical_join_type, + join_type, PartitionMode::Partitioned, )?)) } else { @@ -710,7 +701,7 @@ impl DefaultPhysicalPlanner { physical_left, physical_right, join_on, - &physical_join_type, + join_type, PartitionMode::CollectLeft, )?)) } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index e34f0e6c9b674..f89ba3f659c88 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -27,8 +27,8 @@ use crate::datasource::TableProvider; use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ - and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, PlanType, StringifiedPlan, ToDFSchema, + and, col, lit, normalize_col, union_with_alias, Column, DFSchema, Expr, LogicalPlan, + LogicalPlanBuilder, Operator, PlanType, StringifiedPlan, ToDFSchema, }; use crate::prelude::JoinType; use crate::scalar::ScalarValue; @@ -496,12 +496,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let right_schema = right.schema(); let mut join_keys = vec![]; for (l, r) in &possible_join_keys { - if left_schema.field_from_qualified_column(l).is_ok() - && right_schema.field_from_qualified_column(r).is_ok() + if left_schema.field_from_column(l).is_ok() + && right_schema.field_from_column(r).is_ok() { join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_qualified_column(r).is_ok() - && right_schema.field_from_qualified_column(l).is_ok() + } else if left_schema.field_from_column(r).is_ok() + && right_schema.field_from_column(l).is_ok() { join_keys.push((r.clone(), l.clone())); } @@ -579,7 +579,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // SELECT c1 AS m FROM t HAVING c1 > 10; // SELECT c1, MAX(c2) AS m FROM t GROUP BY c1 HAVING MAX(c2) > 10; // - resolve_aliases_to_exprs(&having_expr, &alias_map) + let having_expr = resolve_aliases_to_exprs(&having_expr, &alias_map)?; + normalize_col(having_expr, &projected_plan) }) .transpose()?; @@ -603,6 +604,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let group_by_expr = resolve_positions_to_exprs(&group_by_expr, &select_exprs) .unwrap_or(group_by_expr); + let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( plan.schema(), &[group_by_expr.clone()], @@ -681,13 +683,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result> { let input_schema = plan.schema(); - Ok(projection + projection .iter() .map(|expr| self.sql_select_to_rex(expr, input_schema)) .collect::>>()? .iter() .flat_map(|expr| expand_wildcard(expr, input_schema)) - .collect::>()) + .map(|expr| normalize_col(expr, plan)) + .collect::>>() } /// Wrap a plan in a projection @@ -835,20 +838,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { find_column_exprs(exprs) .iter() .try_for_each(|col| match col { - Expr::Column(col) => { - match &col.relation { - Some(r) => schema.field_with_qualified_name(r, &col.name), - None => schema.field_with_unqualified_name(&col.name), + Expr::Column(col) => match &col.relation { + Some(r) => { + schema.field_with_qualified_name(r, &col.name)?; + Ok(()) + } + None => { + if !schema.fields_with_unqualified_name(&col.name).is_empty() { + Ok(()) + } else { + Err(DataFusionError::Plan(format!( + "No field with unqualified name '{}'", + &col.name + ))) + } } - .map_err(|_| { - DataFusionError::Plan(format!( - "Invalid identifier '{}' for schema {}", - col, - schema.to_string() - )) - })?; - Ok(()) } + .map_err(|_: DataFusionError| { + DataFusionError::Plan(format!( + "Invalid identifier '{}' for schema {}", + col, + schema.to_string() + )) + }), _ => Err(DataFusionError::Internal("Not a column".to_string())), }) } @@ -926,11 +938,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let var_names = vec![id.value.clone()]; Ok(Expr::ScalarVariable(var_names)) } else { - Ok(Expr::Column( - schema - .field_with_unqualified_name(&id.value)? - .qualified_column(), - )) + // create a column expression based on raw user input, this column will be + // normalized with qualifer later by the SQL planner. + Ok(col(&id.value)) } } @@ -1672,7 +1682,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -1730,7 +1740,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -1740,7 +1750,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'x'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#x' for schema "), )); } @@ -2211,7 +2221,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -2301,7 +2311,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Column #doesnotexist not found in provided schemas"), )); } @@ -2311,7 +2321,7 @@ mod tests { let err = logical_plan(sql).expect_err("query should have failed"); assert!(matches!( err, - DataFusionError::Plan(msg) if msg.contains("No field with unqualified name 'doesnotexist'"), + DataFusionError::Plan(msg) if msg.contains("Invalid identifier '#doesnotexist' for schema "), )); } @@ -2757,7 +2767,7 @@ mod tests { JOIN person as person2 \ USING (id)"; let expected = "Projection: #person.first_name, #person.id\ - \n Join: #person.id = #person2.id\ + \n Join: Using #person.id = #person2.id\ \n TableScan: person projection=None\ \n TableScan: person2 projection=None"; quick_test(sql, expected); diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs index df3aec4a68502..b791551133e7e 100644 --- a/datafusion/src/test/mod.rs +++ b/datafusion/src/test/mod.rs @@ -110,14 +110,19 @@ pub fn aggr_test_schema() -> SchemaRef { ])) } -/// some tests share a common table -pub fn test_table_scan() -> Result { +/// some tests share a common table with different names +pub fn test_table_scan_with_name(name: &str) -> Result { let schema = Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::UInt32, false), Field::new("c", DataType::UInt32, false), ]); - LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build() + LogicalPlanBuilder::scan_empty(Some(name), &schema, None)?.build() +} + +/// some tests share a common table +pub fn test_table_scan() -> Result { + test_table_scan_with_name("test") } pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { From bbc9c6c68a03a19e4f385663b7c7ab795748f16e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Jul 2021 09:00:32 -0400 Subject: [PATCH 07/12] Fix test output due to logical merge conflict (#694) --- datafusion/tests/sql.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index f6f8b6f041e6e..9c7d0795edb91 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1705,7 +1705,7 @@ async fn equijoin_and_unsupported_condition() -> Result<()> { "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; let res = ctx.create_logical_plan(sql); assert!(res.is_err()); - assert_eq!(format!("{}", res.unwrap_err()), "This feature is not implemented: Unsupported expressions in Left JOIN: [#t2.t2_name GtEq Utf8(\"y\")]"); + assert_eq!(format!("{}", res.unwrap_err()), "This feature is not implemented: Unsupported expressions in Left JOIN: [#t2_name GtEq Utf8(\"y\")]"); Ok(()) } From 9f8e265e6df502a3badd8f9eff2f62a47515eb7b Mon Sep 17 00:00:00 2001 From: Cui Wenzheng Date: Wed, 7 Jul 2021 21:01:05 +0800 Subject: [PATCH 08/12] fix typo in DEVELOPERS.md (#692) --- DEVELOPERS.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DEVELOPERS.md b/DEVELOPERS.md index 85384680c02eb..3ee2f7da09c5f 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -37,8 +37,8 @@ Testing setup: - `git submodule init` - `git submodule update` -- `export PARQUET_TEST_DATA=parquet_testing/` -- `export ARROW_TEST_DATA=testing/data/` +- `export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data/` +- `export ARROW_TEST_DATA=$(pwd)/testing/data/` ## How to add a new scalar function From 79d60f9b678e9a2351fc83511399663985e39cf6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Jul 2021 09:43:44 -0400 Subject: [PATCH 09/12] Remove qualifiers on pushed down predicates / Fix parquet pruning (#689) * Remove qualifiers on pushed down predicates * Add test for normalizing and unnormalizing columns * Fix logical conflict --- datafusion/src/logical_plan/expr.rs | 149 ++++++++++++++++++++++-- datafusion/src/logical_plan/mod.rs | 4 +- datafusion/src/physical_plan/planner.rs | 13 ++- datafusion/tests/parquet_pruning.rs | 24 ++-- 4 files changed, 164 insertions(+), 26 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 9454d7593c3f3..59c99797e0cd8 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -90,14 +90,22 @@ impl Column { /// For example, `foo` will be normalized to `t.foo` if there is a /// column named `foo` in a relation named `t` found in `schemas` pub fn normalize(self, plan: &LogicalPlan) -> Result { + let schemas = plan.all_schemas(); + let using_columns = plan.using_columns()?; + self.normalize_with_schemas(&schemas, &using_columns) + } + + // Internal implementation of normalize + fn normalize_with_schemas( + self, + schemas: &[&Arc], + using_columns: &[HashSet], + ) -> Result { if self.relation.is_some() { return Ok(self); } - let schemas = plan.all_schemas(); - let using_columns = plan.using_columns()?; - - for schema in &schemas { + for schema in schemas { let fields = schema.fields_with_unqualified_name(&self.name); match fields.len() { 0 => continue, @@ -118,7 +126,7 @@ impl Column { // We will use the relation from the first matched field to normalize self. // Compare matched fields with one USING JOIN clause at a time - for using_col in &using_columns { + for using_col in using_columns { let all_matched = fields .iter() .all(|f| using_col.contains(&f.qualified_column())); @@ -1171,22 +1179,39 @@ pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result Result { +pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { + normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?) +} + +/// Recursively call [`Column::normalize`] on all Column expressions +/// in the `expr` expression tree. +fn normalize_col_with_schemas( + expr: Expr, + schemas: &[&Arc], + using_columns: &[HashSet], +) -> Result { struct ColumnNormalizer<'a> { - plan: &'a LogicalPlan, + schemas: &'a [&'a Arc], + using_columns: &'a [HashSet], } impl<'a> ExprRewriter for ColumnNormalizer<'a> { fn mutate(&mut self, expr: Expr) -> Result { if let Expr::Column(c) = expr { - Ok(Expr::Column(c.normalize(self.plan)?)) + Ok(Expr::Column(c.normalize_with_schemas( + self.schemas, + self.using_columns, + )?)) } else { Ok(expr) } } } - e.rewrite(&mut ColumnNormalizer { plan }) + expr.rewrite(&mut ColumnNormalizer { + schemas, + using_columns, + }) } /// Recursively normalize all Column expressions in a list of expression trees @@ -1198,6 +1223,38 @@ pub fn normalize_cols( exprs.into_iter().map(|e| normalize_col(e, plan)).collect() } +/// Recursively 'unnormalize' (remove all qualifiers) from an +/// expression tree. +/// +/// For example, if there were expressions like `foo.bar` this would +/// rewrite it to just `bar`. +pub fn unnormalize_col(expr: Expr) -> Expr { + struct RemoveQualifier {} + + impl ExprRewriter for RemoveQualifier { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Column(col) = expr { + //let Column { relation: _, name } = col; + Ok(Expr::Column(Column { + relation: None, + name: col.name, + })) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut RemoveQualifier {}) + .expect("Unnormalize is infallable") +} + +/// Recursively un-normalize all Column expressions in a list of expression trees +#[inline] +pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { + exprs.into_iter().map(unnormalize_col).collect() +} + /// Create an expression to represent the min() aggregate function pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction { @@ -1810,4 +1867,78 @@ mod tests { } } } + + #[test] + fn normalize_cols() { + let expr = col("a") + col("b") + col("c"); + + // Schemas with some matching and some non matching cols + let schema_a = + DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")]) + .unwrap(); + let schema_c = + DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")]) + .unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + // non matching + let schema_f = + DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")]) + .unwrap(); + let schemas = vec![schema_c, schema_f, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!( + normalized_expr, + col("tableA.a") + col("tableB.b") + col("tableC.c") + ); + } + + #[test] + fn normalize_cols_priority() { + let expr = col("a") + col("b"); + // Schemas with multiple matches for column a, first takes priority + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap(); + let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap(); + let schemas = vec![schema_a2, schema_b, schema_a] + .into_iter() + .map(Arc::new) + .collect::>(); + let schemas = schemas.iter().collect::>(); + + let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); + assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); + } + + #[test] + fn normalize_cols_non_exist() { + // test normalizing columns when the name doesn't exist + let expr = col("a") + col("b"); + let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap(); + let schemas = vec![schema_a].into_iter().map(Arc::new).collect::>(); + let schemas = schemas.iter().collect::>(); + + let error = normalize_col_with_schemas(expr, &schemas, &[]) + .unwrap_err() + .to_string(); + assert_eq!( + error, + "Error during planning: Column #b not found in provided schemas" + ); + } + + #[test] + fn unnormalize_cols() { + let expr = col("tableA.a") + col("tableB.b"); + let unnormalized_expr = unnormalize_col(expr); + assert_eq!(unnormalized_expr, col("a") + col("b")); + } + + fn make_field(relation: &str, column: &str) -> DFField { + DFField::new(Some(relation), column, DataType::Int8, false) + } } diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 86a2f567d7de4..2c751abdad349 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -43,8 +43,8 @@ pub use expr::{ min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, - substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, + substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col, unnormalize_cols, + upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 73b2f362989f6..df4168370003a 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -23,8 +23,9 @@ use super::{ }; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ - DFSchema, Expr, LogicalPlan, Operator, Partitioning as LogicalPartitioning, PlanType, - StringifiedPlan, UserDefinedLogicalNode, + unnormalize_cols, DFSchema, Expr, LogicalPlan, Operator, + Partitioning as LogicalPartitioning, PlanType, StringifiedPlan, + UserDefinedLogicalNode, }; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; @@ -311,7 +312,13 @@ impl DefaultPhysicalPlanner { filters, limit, .. - } => source.scan(projection, batch_size, filters, *limit), + } => { + // Remove all qualifiers from the scan as the provider + // doesn't know (nor should care) how the relation was + // referred to in the query + let filters = unnormalize_cols(filters.iter().cloned()); + source.scan(projection, batch_size, &filters, *limit) + } LogicalPlan::Window { input, window_expr, .. } => { diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 86b3946e47121..f5486afc7aa4a 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -44,9 +44,9 @@ async fn prune_timestamps_nanos() { .query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } @@ -59,9 +59,9 @@ async fn prune_timestamps_micros() { ) .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } @@ -74,9 +74,9 @@ async fn prune_timestamps_millis() { ) .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } @@ -89,9 +89,9 @@ async fn prune_timestamps_seconds() { ) .await; println!("{}", output.description()); - // TODO This should prune one metrics without error - assert_eq!(output.predicate_evaluation_errors(), Some(1)); - assert_eq!(output.row_groups_pruned(), Some(0)); + // This should prune one metrics without error + assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_pruned(), Some(1)); assert_eq!(output.result_rows, 10, "{}", output.description()); } From f94f6391845c844980fd4fb3171a743bf5d182b2 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 8 Jul 2021 19:37:49 +0800 Subject: [PATCH 10/12] add more integration tests (#668) --- .../sqls/self_join_with_alias.sql | 22 +++++++++++++++++++ integration-tests/sqls/simple_union_all.sql | 17 ++++++++++++++ integration-tests/test_psql_parity.py | 2 +- 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 integration-tests/sqls/self_join_with_alias.sql create mode 100644 integration-tests/sqls/simple_union_all.sql diff --git a/integration-tests/sqls/self_join_with_alias.sql b/integration-tests/sqls/self_join_with_alias.sql new file mode 100644 index 0000000000000..54c39888dffed --- /dev/null +++ b/integration-tests/sqls/self_join_with_alias.sql @@ -0,0 +1,22 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + t1.c9 result +FROM test t1 +INNER JOIN test t2 +ON t1.c9 = t2.c9 +ORDER BY result; diff --git a/integration-tests/sqls/simple_union_all.sql b/integration-tests/sqls/simple_union_all.sql new file mode 100644 index 0000000000000..65557b8d263fd --- /dev/null +++ b/integration-tests/sqls/simple_union_all.sql @@ -0,0 +1,17 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT 1 num UNION ALL SELECT 2 num ORDER BY num; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 766f403f3e543..39cfdee77fbdd 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 12, msg="tests are missed") + self.assertEqual(len(files), 14, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv( From 024bd89603dea13e63b70c92274116edbe36c4f9 Mon Sep 17 00:00:00 2001 From: Edd Robinson Date: Thu, 8 Jul 2021 12:38:23 +0100 Subject: [PATCH 11/12] perf: Improve materialisation performance of SortPreservingMergeExec (#691) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: add benchmarks for SortPreservingMergeExec * perf: minimise array data extend calls The `SortPreservingMergeStream` operator merges two streams together by creating an output record batch that is build from the contents of the input. Previously each row of input would be pushed into the output sink even if though the API supports pushing batches of rows. This commit implements the logic to push batches of rows from inputs where possible. Performance benchmarks show an improvement of between 3-12%. ``` group master pr ----- ------ -- interleave_batches 1.04 637.5±51.84µs ? ?/sec 1.00 615.5±12.13µs ? ?/sec merge_batches_no_overlap_large 1.12 454.9±2.90µs ? ?/sec 1.00 404.9±10.94µs ? ?/sec merge_batches_no_overlap_small 1.14 485.1±6.67µs ? ?/sec 1.00 425.7±9.33µs ? ?/sec merge_batches_small_into_large 1.14 263.0±8.85µs ? ?/sec 1.00 229.7±5.23µs ? ?/sec merge_batches_some_overlap_large 1.05 532.5±8.33µs ? ?/sec 1.00 508.3±14.24µs ? ?/sec merge_batches_some_overlap_small 1.06 546.9±12.82µs ? ?/sec 1.00 516.9±13.20µs ? ?/sec ``` * test: more test coverage * refactor: update batch size --- datafusion/Cargo.toml | 4 + datafusion/benches/physical_plan.rs | 176 +++++++++++++++ .../physical_plan/sort_preserving_merge.rs | 202 ++++++++++++++---- 3 files changed, 341 insertions(+), 41 deletions(-) create mode 100644 datafusion/benches/physical_plan.rs diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index f1a77741064e4..845de6213f4d3 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -95,3 +95,7 @@ harness = false [[bench]] name = "scalar" harness = false + +[[bench]] +name = "physical_plan" +harness = false \ No newline at end of file diff --git a/datafusion/benches/physical_plan.rs b/datafusion/benches/physical_plan.rs new file mode 100644 index 0000000000000..9222ae131b8ff --- /dev/null +++ b/datafusion/benches/physical_plan.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::{BatchSize, Criterion}; +extern crate arrow; +extern crate datafusion; + +use std::{iter::FromIterator, sync::Arc}; + +use arrow::{ + array::{ArrayRef, Int64Array, StringArray}, + record_batch::RecordBatch, +}; +use tokio::runtime::Runtime; + +use datafusion::physical_plan::{ + collect, + expressions::{col, PhysicalSortExpr}, + memory::MemoryExec, + sort_preserving_merge::SortPreservingMergeExec, +}; + +// Initialise the operator using the provided record batches and the sort key +// as inputs. All record batches must have the same schema. +fn sort_preserving_merge_operator(batches: Vec, sort: &[&str]) { + let schema = batches[0].schema(); + + let sort = sort + .iter() + .map(|name| PhysicalSortExpr { + expr: col(name, &schema).unwrap(), + options: Default::default(), + }) + .collect::>(); + + let exec = MemoryExec::try_new( + &batches.into_iter().map(|rb| vec![rb]).collect::>(), + schema, + None, + ) + .unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 8192)); + + let rt = Runtime::new().unwrap(); + rt.block_on(collect(merge)).unwrap(); +} + +// Produces `n` record batches of row size `m`. Each record batch will have +// identical contents except for if the `batch_offset` is set. In that case the +// values for column "d" in each subsequent record batch will be offset in +// value. +// +// The `rows_per_key` value controls how many rows are generated per "key", +// which is defined as columns a, b and c. +fn batches( + n: usize, + m: usize, + rows_per_sort_key: usize, + batch_offset: usize, +) -> Vec { + let mut rbs = Vec::with_capacity(n); + let mut curr_batch_offset = 0; + + for _ in 0..n { + let mut col_a = Vec::with_capacity(m); + let mut col_b = Vec::with_capacity(m); + let mut col_c = Vec::with_capacity(m); + let mut col_d = Vec::with_capacity(m); + + let mut j = 0; + let mut current_rows_per_sort_key = 0; + + for i in 0..m { + if current_rows_per_sort_key == rows_per_sort_key { + current_rows_per_sort_key = 0; + j = i; + } + + col_a.push(Some(format!("a-{:?}", j))); + col_b.push(Some(format!("b-{:?}", j))); + col_c.push(Some(format!("c-{:?}", j))); + col_d.push(Some((i + curr_batch_offset) as i64)); + + current_rows_per_sort_key += 1; + } + + col_a.sort(); + col_b.sort(); + col_c.sort(); + + let col_a: ArrayRef = Arc::new(StringArray::from_iter(col_a)); + let col_b: ArrayRef = Arc::new(StringArray::from_iter(col_b)); + let col_c: ArrayRef = Arc::new(StringArray::from_iter(col_c)); + let col_d: ArrayRef = Arc::new(Int64Array::from(col_d)); + + let rb = RecordBatch::try_from_iter(vec![ + ("a", col_a), + ("b", col_b), + ("c", col_c), + ("d", col_d), + ]) + .unwrap(); + rbs.push(rb); + + curr_batch_offset += batch_offset; + } + + rbs +} + +fn criterion_benchmark(c: &mut Criterion) { + let small_batch = batches(1, 100, 10, 0).remove(0); + let large_batch = batches(1, 1000, 1, 0).remove(0); + + let benches = vec![ + // Two batches with identical rows. They will need to be merged together + // with one row from each batch being taken until both batches are + // drained. + ("interleave_batches", batches(2, 1000, 10, 1)), + // Two batches with a small overlapping region of rows for each unique + // sort key. + ("merge_batches_some_overlap_small", batches(2, 1000, 10, 5)), + // Two batches with a large overlapping region of rows for each unique + // sort key. + ( + "merge_batches_some_overlap_large", + batches(2, 1000, 250, 125), + ), + // Two batches with no overlapping region of rows for each unique + // sort key. For a given unique sort key all rows are drained from one + // batch, then all the rows for the same key from the second batch. + // This repeats until all rows are drained. There are a small number of + // rows (10) for each unique sort key. + ("merge_batches_no_overlap_small", batches(2, 1000, 10, 12)), + // As above but this time there are a larger number of rows (250) for + // each unique sort key - still no overlaps. + ("merge_batches_no_overlap_large", batches(2, 1000, 250, 252)), + // Merges two batches where one batch is significantly larger than the + // other. + ( + "merge_batches_small_into_large", + vec![large_batch, small_batch], + ), + ]; + + for (name, input) in benches { + c.bench_function(name, move |b| { + b.iter_batched( + || input.clone(), + |input| { + sort_preserving_merge_operator(input, &["a", "b", "c", "d"]); + }, + BatchSize::LargeInput, + ) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs index 316f0509960dd..0949c3c6a8cf6 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sort_preserving_merge.rs @@ -24,22 +24,23 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::{ArrayRef, MutableArrayData}; -use arrow::compute::SortOptions; +use arrow::{ + array::{make_array as make_arrow_array, ArrayRef, MutableArrayData}, + compute::SortOptions, + datatypes::SchemaRef, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, +}; use async_trait::async_trait; use futures::channel::mpsc; use futures::stream::FusedStream; use futures::{Stream, StreamExt}; -use crate::arrow::datatypes::SchemaRef; -use crate::arrow::error::ArrowError; -use crate::arrow::{error::Result as ArrowResult, record_batch::RecordBatch}; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::common::spawn_execution; -use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ - DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, - RecordBatchStream, SendableRecordBatchStream, + common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, + Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, + SendableRecordBatchStream, }; /// Sort preserving merge execution plan @@ -425,19 +426,38 @@ impl SortPreservingMergeStream { self.in_progress.len(), ); - for row_index in &self.in_progress { - let buffer_idx = + if self.in_progress.is_empty() { + return make_arrow_array(array_data.freeze()); + } + + let first = &self.in_progress[0]; + let mut buffer_idx = + stream_to_buffer_idx[first.stream_idx] + first.cursor_idx; + let mut start_row_idx = first.row_idx; + let mut end_row_idx = start_row_idx + 1; + + for row_index in self.in_progress.iter().skip(1) { + let next_buffer_idx = stream_to_buffer_idx[row_index.stream_idx] + row_index.cursor_idx; - // TODO: Coalesce contiguous writes - array_data.extend( - buffer_idx, - row_index.row_idx, - row_index.row_idx + 1, - ); + if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { + // subsequent row in same batch + end_row_idx += 1; + continue; + } + + // emit current batch of rows for current buffer + array_data.extend(buffer_idx, start_row_idx, end_row_idx); + + // start new batch of rows + buffer_idx = next_buffer_idx; + start_row_idx = row_index.row_idx; + end_row_idx = start_row_idx + 1; } - arrow::array::make_array(array_data.freeze()) + // emit final batch of rows + array_data.extend(buffer_idx, start_row_idx, end_row_idx); + make_arrow_array(array_data.freeze()) }) .collect(); @@ -555,7 +575,54 @@ mod tests { use tokio_stream::StreamExt; #[tokio::test] - async fn test_merge() { + async fn test_merge_interleave() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("c"), + Some("e"), + Some("g"), + Some("j"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("b"), + Some("d"), + Some("f"), + Some("h"), + Some("j"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + _test_merge( + b1, + b2, + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 1 | a | 1970-01-01 00:00:00.000000008 |", + "| 10 | b | 1970-01-01 00:00:00.000000004 |", + "| 2 | c | 1970-01-01 00:00:00.000000007 |", + "| 20 | d | 1970-01-01 00:00:00.000000006 |", + "| 7 | e | 1970-01-01 00:00:00.000000006 |", + "| 70 | f | 1970-01-01 00:00:00.000000002 |", + "| 9 | g | 1970-01-01 00:00:00.000000005 |", + "| 90 | h | 1970-01-01 00:00:00.000000002 |", + "| 30 | j | 1970-01-01 00:00:00.000000006 |", // input b2 before b1 + "| 3 | j | 1970-01-01 00:00:00.000000008 |", + "+----+---+-------------------------------+", + ], + ) + .await; + } + + #[tokio::test] + async fn test_merge_some_overlap() { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -564,21 +631,92 @@ mod tests { Some("d"), Some("e"), ])); - let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 4])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("c"), + Some("d"), + Some("e"), + Some("f"), + Some("g"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + _test_merge( + b1, + b2, + &[ + "+-----+---+-------------------------------+", + "| a | b | c |", + "+-----+---+-------------------------------+", + "| 1 | a | 1970-01-01 00:00:00.000000008 |", + "| 2 | b | 1970-01-01 00:00:00.000000007 |", + "| 70 | c | 1970-01-01 00:00:00.000000004 |", + "| 7 | c | 1970-01-01 00:00:00.000000006 |", + "| 9 | d | 1970-01-01 00:00:00.000000005 |", + "| 90 | d | 1970-01-01 00:00:00.000000006 |", + "| 30 | e | 1970-01-01 00:00:00.000000002 |", + "| 3 | e | 1970-01-01 00:00:00.000000008 |", + "| 100 | f | 1970-01-01 00:00:00.000000002 |", + "| 110 | g | 1970-01-01 00:00:00.000000006 |", + "+-----+---+-------------------------------+", + ], + ) + .await; + } + + #[tokio::test] + async fn test_merge_no_overlap() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("a"), + Some("b"), + Some("c"), Some("d"), Some("e"), + ])); + let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ + Some("f"), Some("g"), Some("h"), Some("i"), + Some("j"), ])); let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); - let schema = b1.schema(); + _test_merge( + b1, + b2, + &[ + "+----+---+-------------------------------+", + "| a | b | c |", + "+----+---+-------------------------------+", + "| 1 | a | 1970-01-01 00:00:00.000000008 |", + "| 2 | b | 1970-01-01 00:00:00.000000007 |", + "| 7 | c | 1970-01-01 00:00:00.000000006 |", + "| 9 | d | 1970-01-01 00:00:00.000000005 |", + "| 3 | e | 1970-01-01 00:00:00.000000008 |", + "| 10 | f | 1970-01-01 00:00:00.000000004 |", + "| 20 | g | 1970-01-01 00:00:00.000000006 |", + "| 70 | h | 1970-01-01 00:00:00.000000002 |", + "| 90 | i | 1970-01-01 00:00:00.000000002 |", + "| 30 | j | 1970-01-01 00:00:00.000000006 |", + "+----+---+-------------------------------+", + ], + ) + .await; + } + + async fn _test_merge(b1: RecordBatch, b2: RecordBatch, exp: &[&str]) { + let schema = b1.schema(); let sort = vec![ PhysicalSortExpr { expr: col("b", &schema).unwrap(), @@ -595,25 +733,7 @@ mod tests { let collected = collect(merge).await.unwrap(); assert_eq!(collected.len(), 1); - assert_batches_eq!( - &[ - "+---+---+-------------------------------+", - "| a | b | c |", - "+---+---+-------------------------------+", - "| 1 | a | 1970-01-01 00:00:00.000000008 |", - "| 2 | b | 1970-01-01 00:00:00.000000007 |", - "| 7 | c | 1970-01-01 00:00:00.000000006 |", - "| 1 | d | 1970-01-01 00:00:00.000000004 |", - "| 9 | d | 1970-01-01 00:00:00.000000005 |", - "| 3 | e | 1970-01-01 00:00:00.000000004 |", - "| 2 | e | 1970-01-01 00:00:00.000000006 |", - "| 3 | g | 1970-01-01 00:00:00.000000002 |", - "| 4 | h | 1970-01-01 00:00:00.000000002 |", - "| 5 | i | 1970-01-01 00:00:00.000000006 |", - "+---+---+-------------------------------+", - ], - collected.as_slice() - ); + assert_batches_eq!(exp, collected.as_slice()); } async fn sorted_merge( From 7378bb4de1dbcb008e68a01e1f0f046c6a17cade Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 8 Jul 2021 07:52:33 -0400 Subject: [PATCH 12/12] Fix build with 1.52.1 (#696) --- datafusion/src/physical_plan/parquet.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs index f31b921d663b0..63e11d5106bac 100644 --- a/datafusion/src/physical_plan/parquet.rs +++ b/datafusion/src/physical_plan/parquet.rs @@ -514,7 +514,7 @@ impl ExecutionPlan for ParquetExec { self.partitions .iter() .flat_map(|p| { - [ + vec![ ( format!( "numPredicateEvaluationErrors for {}",