From af008218b67c6768005a73ecdaa7c0326529430a Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 14 Dec 2022 16:47:38 +0000 Subject: [PATCH 1/4] Make DataFrame API consuming (#4621) --- .../examples/custom_datasource.rs | 2 +- datafusion-examples/examples/dataframe.rs | 5 +- .../examples/deserialize_to_struct.rs | 2 +- datafusion-examples/examples/flight_server.rs | 4 +- datafusion/core/src/dataframe.rs | 361 +++++++++--------- datafusion/core/src/datasource/view.rs | 4 +- datafusion/core/src/execution/context.rs | 77 ++-- datafusion/core/src/scheduler/mod.rs | 2 +- datafusion/core/tests/dataframe.rs | 4 +- datafusion/core/tests/dataframe_functions.rs | 2 +- datafusion/core/tests/sql/select.rs | 2 +- 11 files changed, 227 insertions(+), 238 deletions(-) diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 3be17d950070..db4fed494ca1 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -69,7 +69,7 @@ async fn search_accounts( )? .build()?; - let mut dataframe = DataFrame::new(ctx.state, &logical_plan) + let mut dataframe = DataFrame::new(ctx.state, logical_plan) .select_columns(&["id", "bank_account"])?; if let Some(f) = filter { diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index a212387e2162..9ec03c594bfa 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -19,7 +19,6 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::Result; use datafusion::prelude::*; use std::fs; -use std::sync::Arc; /// This example demonstrates executing a simple query against an Arrow data source (Parquet) and /// fetching results, using the DataFrame trait @@ -64,7 +63,7 @@ a2,"08 9, 2013",2,1376006400,4.5"#; } // Example to read data from a csv file with inferred schema -async fn example_read_csv_file_with_inferred_schema() -> Arc { +async fn example_read_csv_file_with_inferred_schema() -> DataFrame { let path = "example.csv"; // Create a csv file using the predefined function create_csv_file(path.to_string()); @@ -75,7 +74,7 @@ async fn example_read_csv_file_with_inferred_schema() -> Arc { } // Example to read csv file with a defined schema for the csv file -async fn example_read_csv_file_with_schema() -> Arc { +async fn example_read_csv_file_with_schema() -> DataFrame { let path = "example.csv"; // Create a csv file using the predefined function create_csv_file(path.to_string()); diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs index ce89b64d2796..dba2e04b1bf6 100644 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ b/datafusion-examples/examples/deserialize_to_struct.rs @@ -57,7 +57,7 @@ impl Data { .sql("SELECT int_col, double_col FROM alltypes_plain") .await?; - df.show().await?; + df.clone().show().await?; df.collect().await? }; diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs index d10f11a43d22..66dcd4583ed2 100644 --- a/datafusion-examples/examples/flight_server.rs +++ b/datafusion-examples/examples/flight_server.rs @@ -114,6 +114,7 @@ impl FlightService for FlightServiceImpl { let df = ctx.sql(sql).await.map_err(to_tonic_err)?; // execute the query + let schema = df.schema().clone().into(); let results = df.collect().await.map_err(to_tonic_err)?; if results.is_empty() { return Err(Status::internal("There were no results from ticket")); @@ -121,8 +122,7 @@ impl FlightService for FlightServiceImpl { // add an initial FlightData message that sends schema let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default(); - let schema_flight_data = - SchemaAsIpc::new(&df.schema().clone().into(), &options).into(); + let schema_flight_data = SchemaAsIpc::new(&schema, &options).into(); let mut flights: Vec> = vec![Ok(schema_flight_data)]; diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index d7dd04f886ce..ac77dc1c62b0 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -17,6 +17,16 @@ //! DataFrame API for building and executing query plans. +use std::any::Any; +use std::sync::Arc; + +use async_trait::async_trait; +use parking_lot::RwLock; +use parquet::file::properties::WriterProperties; + +use datafusion_common::{Column, DFSchema}; +use datafusion_expr::TableProviderFilterPushDown; + use crate::arrow::datatypes::Schema; use crate::arrow::datatypes::SchemaRef; use crate::arrow::record_batch::RecordBatch; @@ -36,13 +46,6 @@ use crate::physical_plan::SendableRecordBatchStream; use crate::physical_plan::{collect, collect_partitioned}; use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; use crate::prelude::SessionContext; -use async_trait::async_trait; -use datafusion_common::{Column, DFSchema}; -use datafusion_expr::TableProviderFilterPushDown; -use parking_lot::RwLock; -use parquet::file::properties::WriterProperties; -use std::any::Any; -use std::sync::Arc; /// DataFrame represents a logical set of rows with the same named columns. /// Similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or @@ -69,7 +72,7 @@ use std::sync::Arc; /// # Ok(()) /// # } /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct DataFrame { session_state: Arc>, plan: LogicalPlan, @@ -77,15 +80,15 @@ pub struct DataFrame { impl DataFrame { /// Create a new Table based on an existing logical plan - pub fn new(session_state: Arc>, plan: &LogicalPlan) -> Self { + pub fn new(session_state: Arc>, plan: LogicalPlan) -> Self { Self { session_state, - plan: plan.clone(), + plan, } } /// Create a physical plan - pub async fn create_physical_plan(&self) -> Result> { + pub async fn create_physical_plan(self) -> Result> { // this function is copied from SessionContext function of the // same name let state_cloned = { @@ -121,7 +124,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn select_columns(&self, columns: &[&str]) -> Result> { + pub fn select_columns(self, columns: &[&str]) -> Result { let fields = columns .iter() .map(|name| self.plan.schema().field_with_unqualified_name(name)) @@ -146,7 +149,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn select(&self, expr_list: Vec) -> Result> { + pub fn select(&self, expr_list: Vec) -> Result { let window_func_exprs = find_window_exprs(&expr_list); let plan = if window_func_exprs.is_empty() { self.plan.clone() @@ -155,10 +158,7 @@ impl DataFrame { }; let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &project_plan, - ))) + Ok(DataFrame::new(self.session_state.clone(), project_plan)) } /// Filter a DataFrame to only include rows that match the specified filter expression. @@ -174,11 +174,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn filter(&self, predicate: Expr) -> Result> { + pub fn filter(&self, predicate: Expr) -> Result { let plan = LogicalPlanBuilder::from(self.plan.clone()) .filter(predicate)? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state.clone(), plan)) } /// Perform an aggregate query with optional grouping expressions. @@ -203,11 +203,11 @@ impl DataFrame { &self, group_expr: Vec, aggr_expr: Vec, - ) -> Result> { + ) -> Result { let plan = LogicalPlanBuilder::from(self.plan.clone()) .aggregate(group_expr, aggr_expr)? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state.clone(), plan)) } /// Limit the number of rows returned from this DataFrame. @@ -226,11 +226,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn limit(&self, skip: usize, fetch: Option) -> Result> { - let plan = LogicalPlanBuilder::from(self.plan.clone()) + pub fn limit(self, skip: usize, fetch: Option) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) .limit(skip, fetch)? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state, plan)) } /// Calculate the union of two [`DataFrame`]s, preserving duplicate rows.The @@ -247,11 +247,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn union(&self, dataframe: Arc) -> Result> { - let plan = LogicalPlanBuilder::from(self.plan.clone()) - .union(dataframe.plan.clone())? + pub fn union(self, dataframe: DataFrame) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .union(dataframe.plan)? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state, plan)) } /// Calculate the distinct union of two [`DataFrame`]s. The @@ -268,13 +268,13 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn union_distinct(&self, dataframe: Arc) -> Result> { - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &LogicalPlanBuilder::from(self.plan.clone()) - .union_distinct(dataframe.plan.clone())? + pub fn union_distinct(self, dataframe: DataFrame) -> Result { + Ok(DataFrame::new( + self.session_state, + LogicalPlanBuilder::from(self.plan) + .union_distinct(dataframe.plan)? .build()?, - ))) + )) } /// Filter out duplicate rows @@ -290,13 +290,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn distinct(&self) -> Result> { - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &LogicalPlanBuilder::from(self.plan.clone()) - .distinct()? - .build()?, - ))) + pub fn distinct(self) -> Result { + Ok(DataFrame::new( + self.session_state, + LogicalPlanBuilder::from(self.plan).distinct()?.build()?, + )) } /// Sort the DataFrame by the specified sorting expressions. Any expression can be turned into @@ -313,11 +311,9 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn sort(&self, expr: Vec) -> Result> { - let plan = LogicalPlanBuilder::from(self.plan.clone()) - .sort(expr)? - .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + pub fn sort(self, expr: Vec) -> Result { + let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?; + Ok(DataFrame::new(self.session_state, plan)) } /// Join this DataFrame with another DataFrame using the specified columns as join keys. @@ -344,22 +340,22 @@ impl DataFrame { /// # } /// ``` pub fn join( - &self, - right: Arc, + self, + right: DataFrame, join_type: JoinType, left_cols: &[&str], right_cols: &[&str], filter: Option, - ) -> Result> { - let plan = LogicalPlanBuilder::from(self.plan.clone()) + ) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) .join( - &right.plan.clone(), + &right.plan, join_type, (left_cols.to_vec(), right_cols.to_vec()), filter, )? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state, plan)) } /// Repartition a DataFrame based on a logical partitioning scheme. @@ -375,14 +371,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn repartition( - &self, - partitioning_scheme: Partitioning, - ) -> Result> { + pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { let plan = LogicalPlanBuilder::from(self.plan.clone()) .repartition(partitioning_scheme)? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state, plan)) } /// Convert the logical plan represented by this DataFrame into a physical plan and @@ -399,9 +392,9 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub async fn collect(&self) -> Result> { + pub async fn collect(self) -> Result> { + let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; - let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone())); collect(plan, task_ctx).await } @@ -418,7 +411,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub async fn show(&self) -> Result<()> { + pub async fn show(self) -> Result<()> { let results = self.collect().await?; Ok(pretty::print_batches(&results)?) } @@ -436,11 +429,16 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub async fn show_limit(&self, num: usize) -> Result<()> { + pub async fn show_limit(self, num: usize) -> Result<()> { let results = self.limit(0, Some(num))?.collect().await?; Ok(pretty::print_batches(&results)?) } + fn task_ctx(&self) -> TaskContext { + let lock = self.session_state.read(); + TaskContext::from(&*lock) + } + /// Executes this DataFrame and returns a stream over a single partition /// /// ``` @@ -454,9 +452,9 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub async fn execute_stream(&self) -> Result { + pub async fn execute_stream(self) -> Result { + let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; - let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone())); execute_stream(plan, task_ctx).await } @@ -474,9 +472,9 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub async fn collect_partitioned(&self) -> Result>> { + pub async fn collect_partitioned(self) -> Result>> { + let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; - let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone())); collect_partitioned(plan, task_ctx).await } @@ -494,10 +492,10 @@ impl DataFrame { /// # } /// ``` pub async fn execute_stream_partitioned( - &self, + self, ) -> Result> { + let task_ctx = Arc::new(self.task_ctx()); let plan = self.create_physical_plan().await?; - let task_ctx = Arc::new(TaskContext::from(&self.session_state.read().clone())); execute_stream_partitioned(plan, task_ctx).await } @@ -520,12 +518,12 @@ impl DataFrame { } /// Return the unoptimized logical plan represented by this DataFrame. - pub fn to_unoptimized_plan(&self) -> LogicalPlan { - self.plan.clone() + pub fn to_unoptimized_plan(self) -> LogicalPlan { + self.plan } /// Return the optimized logical plan represented by this DataFrame. - pub fn to_logical_plan(&self) -> Result { + pub fn to_logical_plan(self) -> Result { // Optimize the plan first for better UX let state = self.session_state.read().clone(); state.optimize(&self.plan) @@ -546,11 +544,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn explain(&self, verbose: bool, analyze: bool) -> Result> { - let plan = LogicalPlanBuilder::from(self.plan.clone()) + pub fn explain(self, verbose: bool, analyze: bool) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) .explain(verbose, analyze)? .build()?; - Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) + Ok(DataFrame::new(self.session_state, plan)) } /// Return a `FunctionRegistry` used to plan udf's calls @@ -585,13 +583,13 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn intersect(&self, dataframe: Arc) -> Result> { - let left_plan = self.plan.clone(); - let right_plan = dataframe.plan.clone(); - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &LogicalPlanBuilder::intersect(left_plan, right_plan, true)?, - ))) + pub fn intersect(self, dataframe: DataFrame) -> Result { + let left_plan = self.plan; + let right_plan = dataframe.plan; + Ok(DataFrame::new( + self.session_state, + LogicalPlanBuilder::intersect(left_plan, right_plan, true)?, + )) } /// Calculate the exception of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema @@ -607,38 +605,38 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn except(&self, dataframe: Arc) -> Result> { - let left_plan = self.plan.clone(); - let right_plan = dataframe.plan.clone(); + pub fn except(self, dataframe: DataFrame) -> Result { + let left_plan = self.plan; + let right_plan = dataframe.plan; - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &LogicalPlanBuilder::except(left_plan, right_plan, true)?, - ))) + Ok(DataFrame::new( + self.session_state, + LogicalPlanBuilder::except(left_plan, right_plan, true)?, + )) } /// Write a `DataFrame` to a CSV file. - pub async fn write_csv(&self, path: &str) -> Result<()> { - let plan = self.create_physical_plan().await?; + pub async fn write_csv(self, path: &str) -> Result<()> { let state = self.session_state.read().clone(); + let plan = self.create_physical_plan().await?; plan_to_csv(&state, plan, path).await } /// Write a `DataFrame` to a Parquet file. pub async fn write_parquet( - &self, + self, path: &str, writer_properties: Option, ) -> Result<()> { - let plan = self.create_physical_plan().await?; let state = self.session_state.read().clone(); + let plan = self.create_physical_plan().await?; plan_to_parquet(&state, plan, path, writer_properties).await } /// Executes a query and writes the results to a partitioned JSON file. - pub async fn write_json(&self, path: impl AsRef) -> Result<()> { - let plan = self.create_physical_plan().await?; + pub async fn write_json(self, path: impl AsRef) -> Result<()> { let state = self.session_state.read().clone(); + let plan = self.create_physical_plan().await?; plan_to_json(&state, plan, path).await } @@ -655,12 +653,12 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn with_column(&self, name: &str, expr: Expr) -> Result> { + pub fn with_column(self, name: &str, expr: Expr) -> Result { let window_func_exprs = find_window_exprs(&[expr.clone()]); let plan = if window_func_exprs.is_empty() { - self.plan.clone() + self.plan } else { - LogicalPlanBuilder::window_plan(self.plan.clone(), window_func_exprs)? + LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; let new_column = expr.alias(name); @@ -688,10 +686,7 @@ impl DataFrame { let project_plan = LogicalPlanBuilder::from(plan).project(fields)?.build()?; - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &project_plan, - ))) + Ok(DataFrame::new(self.session_state, project_plan)) } /// Rename one column by applying a new projection. This is a no-op if the column to be @@ -709,10 +704,10 @@ impl DataFrame { /// # } /// ``` pub fn with_column_renamed( - &self, + self, old_name: &str, new_name: &str, - ) -> Result> { + ) -> Result { let mut projection = vec![]; let mut rename_applied = false; for field in self.plan.schema().fields() { @@ -728,15 +723,9 @@ impl DataFrame { let project_plan = LogicalPlanBuilder::from(self.plan.clone()) .project(projection)? .build()?; - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &project_plan, - ))) + Ok(DataFrame::new(self.session_state, project_plan)) } else { - Ok(Arc::new(DataFrame::new( - self.session_state.clone(), - &self.plan, - ))) + Ok(DataFrame::new(self.session_state, self.plan)) } } @@ -753,14 +742,14 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub async fn cache(&self) -> Result> { + pub async fn cache(self) -> Result { + let context = SessionContext::with_state(self.session_state.read().clone()); let mem_table = MemTable::try_new( SchemaRef::from(self.schema().clone()), self.collect_partitioned().await?, )?; - SessionContext::with_state(self.session_state.read().clone()) - .read_table(Arc::new(mem_table)) + context.read_table(Arc::new(mem_table)) } } @@ -799,38 +788,29 @@ impl TableProvider for DataFrame { filters: &[Expr], limit: Option, ) -> Result> { - let mut expr = projection - // construct projections - .map_or_else( - || { - Ok(Arc::new(Self::new(self.session_state.clone(), &self.plan)) - as Arc<_>) - }, - |projection| { - let schema = TableProvider::schema(self).project(projection)?; - let names = schema - .fields() - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - self.select_columns(names.as_slice()) - }, - )?; + let mut expr = self.clone(); + if let Some(p) = projection { + let schema = TableProvider::schema(&expr).project(p)?; + let names = schema + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + expr = expr.select_columns(names.as_slice())?; + } + // Add filter when given let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new)); if let Some(filter) = filter { expr = expr.filter(filter)? } + if let Some(l) = limit { + expr = expr.limit(0, Some(l))? + } // add a limit if given - Self::new( - self.session_state.clone(), - &limit - .map_or_else(|| Ok(expr.clone()), |n| expr.limit(0, Some(n)))? - .plan - .clone(), - ) - .create_physical_plan() - .await + Self::new(self.session_state.clone(), expr.plan) + .create_physical_plan() + .await } } @@ -838,17 +818,9 @@ impl TableProvider for DataFrame { mod tests { use std::vec; - use super::*; - use crate::execution::context::SessionConfig; - use crate::execution::options::{CsvReadOptions, ParquetReadOptions}; - use crate::physical_plan::ColumnarValue; - use crate::physical_plan::Partitioning; - use crate::physical_plan::PhysicalExpr; - use crate::test_util; - use crate::test_util::parquet_test_data; - use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; use arrow::array::Int32Array; use arrow::datatypes::DataType; + use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, @@ -856,6 +828,17 @@ mod tests { }; use datafusion_physical_expr::expressions::Column; + use crate::execution::context::SessionConfig; + use crate::execution::options::{CsvReadOptions, ParquetReadOptions}; + use crate::physical_plan::ColumnarValue; + use crate::physical_plan::Partitioning; + use crate::physical_plan::PhysicalExpr; + use crate::test_util; + use crate::test_util::parquet_test_data; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; + + use super::*; + #[tokio::test] async fn select_columns() -> Result<()> { // build plan using Table API @@ -973,8 +956,8 @@ mod tests { let right = test_table_with_name("c2") .await? .select_columns(&["c1", "c3"])?; - let left_rows = left.collect().await?; - let right_rows = right.collect().await?; + let left_rows = left.clone().collect().await?; + let right_rows = right.clone().collect().await?; let join = left.join(right, JoinType::Inner, &["c1"], &["c1"], None)?; let join_rows = join.collect().await?; assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); @@ -1045,14 +1028,13 @@ mod tests { let f = df.registry(); let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?; - let plan = df.plan.clone(); // build query using SQL let sql_plan = ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?; // the two plans should be identical - assert_same_plan(&plan, &sql_plan); + assert_same_plan(&df.plan, &sql_plan); Ok(()) } @@ -1071,7 +1053,8 @@ mod tests { #[tokio::test] async fn intersect() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c3"])?; - let plan = df.intersect(df.clone())?; + let d2 = df.clone(); + let plan = df.intersect(d2)?; let result = plan.plan.clone(); let expected = create_plan( "SELECT c1, c3 FROM aggregate_test_100 @@ -1085,7 +1068,8 @@ mod tests { #[tokio::test] async fn except() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c3"])?; - let plan = df.except(df.clone())?; + let d2 = df.clone(); + let plan = df.except(d2)?; let result = plan.plan.clone(); let expected = create_plan( "SELECT c1, c3 FROM aggregate_test_100 @@ -1100,7 +1084,7 @@ mod tests { async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; let ctx = SessionContext::new(); - let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.plan.clone())); + let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), df.plan.clone())); // register a dataframe as a table ctx.register_table("test_table", df_impl.clone())?; @@ -1163,13 +1147,13 @@ mod tests { ctx.create_logical_plan(sql) } - async fn test_table_with_name(name: &str) -> Result> { + async fn test_table_with_name(name: &str) -> Result { let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx, name).await?; ctx.table(name) } - async fn test_table() -> Result> { + async fn test_table() -> Result { test_table_with_name("aggregate_test_100").await } @@ -1192,14 +1176,14 @@ mod tests { async fn with_column() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; let ctx = SessionContext::new(); - let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.plan.clone())); + let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone()); - let df = &df_impl + let df = df_impl .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? .with_column("sum", col("c2") + col("c3"))?; // check that new column added - let df_results = df.collect().await?; + let df_results = df.clone().collect().await?; assert_batches_sorted_eq!( vec![ @@ -1219,6 +1203,7 @@ mod tests { // check that col with the same name ovwewritten let df_results_overwrite = df + .clone() .with_column("c1", col("c2") + col("c3"))? .collect() .await?; @@ -1240,8 +1225,11 @@ mod tests { ); // check that col with the same name ovwewritten using same name as reference - let df_results_overwrite_self = - df.with_column("c2", col("c2") + lit(1))?.collect().await?; + let df_results_overwrite_self = df + .clone() + .with_column("c2", col("c2") + lit(1))? + .collect() + .await?; assert_batches_sorted_eq!( vec![ @@ -1297,8 +1285,10 @@ mod tests { async fn with_column_renamed_join() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; let ctx = SessionContext::new(); - ctx.register_table("t1", df.clone())?; - ctx.register_table("t2", df)?; + + let table = Arc::new(df); + ctx.register_table("t1", table.clone())?; + ctx.register_table("t2", table)?; let df = ctx .table("t1")? .join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)? @@ -1313,7 +1303,7 @@ mod tests { ])? .limit(0, Some(1))?; - let df_results = df.collect().await?; + let df_results = df.clone().collect().await?; assert_batches_sorted_eq!( vec![ "+----+----+-----+----+----+-----+", @@ -1325,7 +1315,7 @@ mod tests { &df_results ); - let df_renamed = df.with_column_renamed("t1.c1", "AAA")?; + let df_renamed = df.clone().with_column_renamed("t1.c1", "AAA")?; assert_eq!("\ Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3\ @@ -1334,7 +1324,7 @@ mod tests { \n Inner Join: t1.c1 = t2.c1\ \n TableScan: t1\ \n TableScan: t2", - format!("{:?}", df_renamed.to_unoptimized_plan()) + format!("{:?}", df_renamed.clone().to_unoptimized_plan()) ); assert_eq!("\ @@ -1348,7 +1338,7 @@ mod tests { \n SubqueryAlias: t2\ \n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3\ \n TableScan: aggregate_test_100 projection=[c1, c2, c3]", - format!("{:?}", df_renamed.to_logical_plan()?) + format!("{:?}", df_renamed.clone().to_logical_plan()?) ); let df_results = df_renamed.collect().await?; @@ -1378,7 +1368,7 @@ mod tests { ) .await?; - ctx.register_table("t1", ctx.table("test")?)?; + ctx.register_table("t1", Arc::new(ctx.table("test")?))?; let df = ctx .table("t1")? @@ -1401,8 +1391,8 @@ mod tests { .limit(0, Some(1))? .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; - let df_results = df.collect().await?; - df.show().await?; + let df_results = df.clone().collect().await?; + df.clone().show().await?; assert_batches_sorted_eq!( vec![ "+----+----+-----+", @@ -1491,11 +1481,11 @@ mod tests { .limit(0, Some(1))? .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; - let cached_df = df.cache().await?; + let cached_df = df.clone().cache().await?; assert_eq!( "TableScan: ?table? projection=[c2, c3, sum]", - format!("{:?}", cached_df.to_logical_plan()?) + format!("{:?}", cached_df.clone().to_logical_plan()?) ); let df_results = df.collect().await?; @@ -1524,15 +1514,20 @@ mod tests { .select_columns(&["c1", "c3"])? .with_column_renamed("c2.c1", "c2_c1")?; - let left_rows = left.collect().await?; - let right_rows = right.collect().await?; - let join1 = - left.join(right.clone(), JoinType::Inner, &["c1"], &["c2_c1"], None)?; + let left_rows = left.clone().collect().await?; + let right_rows = right.clone().collect().await?; + let join1 = left.clone().join( + right.clone(), + JoinType::Inner, + &["c1"], + &["c2_c1"], + None, + )?; let join2 = left.join(right, JoinType::Inner, &["c1"], &["c2_c1"], None)?; let union = join1.union(join2)?; - let union_rows = union.collect().await?; + let union_rows = union.clone().collect().await?; assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::()); @@ -1566,9 +1561,9 @@ mod tests { .with_column_renamed("c2.c1", "c2_c1")? .with_column_renamed("c2.c2", "c2_c2")?; - let left_rows = left.collect().await?; - let right_rows = right.collect().await?; - let join1 = left.join( + let left_rows = left.clone().collect().await?; + let right_rows = right.clone().collect().await?; + let join1 = left.clone().join( right.clone(), JoinType::Inner, &["c1", "c2"], @@ -1587,7 +1582,7 @@ mod tests { let union = join1.union(join2)?; - let union_rows = union.collect().await?; + let union_rows = union.clone().collect().await?; assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::()); assert_eq!(100, right_rows.iter().map(|x| x.num_rows()).sum::()); @@ -1626,7 +1621,7 @@ mod tests { let default_partition_count = SessionConfig::new().target_partitions(); for join_type in all_join_types { - let join = left.join( + let join = left.clone().join( right.clone(), join_type, &["c1", "c2"], diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index ce262f55ddfc..c50bc61783ee 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -431,7 +431,7 @@ mod tests { ) .await?; - ctx.register_table("t1", ctx.table("test")?)?; + ctx.register_table("t1", Arc::new(ctx.table("test")?))?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; @@ -460,7 +460,7 @@ mod tests { ) .await?; - ctx.register_table("t1", ctx.table("test")?)?; + ctx.register_table("t1", Arc::new(ctx.table("test")?))?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 759cc79a8bf8..99539a06cdbb 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -252,7 +252,7 @@ impl SessionContext { /// /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` /// might require the schema to be inferred. - pub async fn sql(&self, sql: &str) -> Result> { + pub async fn sql(&self, sql: &str) -> Result { let plan = self.create_logical_plan(sql)?; match plan { LogicalPlan::CreateExternalTable(cmd) => { @@ -265,20 +265,18 @@ impl SessionContext { if_not_exists, or_replace, }) => { + let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); let table = self.table(&name); match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { self.deregister_table(&name)?; - let physical = - Arc::new(DataFrame::new(self.state.clone(), &input)); + let schema = Arc::new(input.schema().as_ref().into()); + let physical = DataFrame::new(self.state.clone(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new( - Arc::new(input.schema().as_ref().into()), - batches, - )?); + let table = Arc::new(MemTable::try_new(schema, batches)?); self.register_table(&name, table)?; self.return_empty_dataframe() @@ -287,14 +285,11 @@ impl SessionContext { "'IF NOT EXISTS' cannot coexist with 'REPLACE'".to_string(), )), (_, _, Err(_)) => { - let physical = - Arc::new(DataFrame::new(self.state.clone(), &input)); + let schema = Arc::new(input.schema().as_ref().into()); + let physical = DataFrame::new(self.state.clone(), input); let batches: Vec<_> = physical.collect_partitioned().await?; - let table = Arc::new(MemTable::try_new( - Arc::new(input.schema().as_ref().into()), - batches, - )?); + let table = Arc::new(MemTable::try_new(schema, batches)?); self.register_table(&name, table)?; self.return_empty_dataframe() @@ -480,20 +475,20 @@ impl SessionContext { } } - plan => Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))), + plan => Ok(DataFrame::new(self.state.clone(), plan)), } } // return an empty dataframe - fn return_empty_dataframe(&self) -> Result> { + fn return_empty_dataframe(&self) -> Result { let plan = LogicalPlanBuilder::empty(false).build()?; - Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))) + Ok(DataFrame::new(self.state.clone(), plan)) } async fn create_external_table( &self, cmd: &CreateExternalTable, - ) -> Result> { + ) -> Result { let table_provider: Arc = self.create_custom_table(cmd).await?; @@ -614,7 +609,7 @@ impl SessionContext { &self, table_path: impl AsRef, options: AvroReadOptions<'_>, - ) -> Result> { + ) -> Result { let table_path = ListingTableUrl::parse(table_path)?; let target_partitions = self.copied_config().target_partitions(); @@ -641,7 +636,7 @@ impl SessionContext { &mut self, table_path: impl AsRef, options: NdJsonReadOptions<'_>, - ) -> Result> { + ) -> Result { let table_path = ListingTableUrl::parse(table_path)?; let target_partitions = self.copied_config().target_partitions(); @@ -664,11 +659,11 @@ impl SessionContext { } /// Creates an empty DataFrame. - pub fn read_empty(&self) -> Result> { - Ok(Arc::new(DataFrame::new( + pub fn read_empty(&self) -> Result { + Ok(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::empty(true).build()?, - ))) + LogicalPlanBuilder::empty(true).build()?, + )) } /// Creates a [`DataFrame`] for reading a CSV data source. @@ -676,7 +671,7 @@ impl SessionContext { &self, table_path: impl AsRef, options: CsvReadOptions<'_>, - ) -> Result> { + ) -> Result { let table_path = ListingTableUrl::parse(table_path)?; let target_partitions = self.copied_config().target_partitions(); let listing_options = options.to_listing_options(target_partitions); @@ -701,7 +696,7 @@ impl SessionContext { &self, table_path: impl AsRef, options: ParquetReadOptions<'_>, - ) -> Result> { + ) -> Result { let table_path = ListingTableUrl::parse(table_path)?; let listing_options = options.to_listing_options(&self.state.read().config); @@ -719,26 +714,26 @@ impl SessionContext { } /// Creates a [`DataFrame`] for reading a custom [`TableProvider`]. - pub fn read_table(&self, provider: Arc) -> Result> { - Ok(Arc::new(DataFrame::new( + pub fn read_table(&self, provider: Arc) -> Result { + Ok(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? + LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? .build()?, - ))) + )) } /// Creates a [`DataFrame`] for reading a [`RecordBatch`] - pub fn read_batch(&self, batch: RecordBatch) -> Result> { + pub fn read_batch(&self, batch: RecordBatch) -> Result { let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - Ok(Arc::new(DataFrame::new( + Ok(DataFrame::new( self.state.clone(), - &LogicalPlanBuilder::scan( + LogicalPlanBuilder::scan( UNNAMED_TABLE, provider_as_source(Arc::new(provider)), None, )? .build()?, - ))) + )) } /// Registers a [`ListingTable]` that can assemble multiple files @@ -942,7 +937,7 @@ impl SessionContext { pub fn table<'a>( &self, table_ref: impl Into>, - ) -> Result> { + ) -> Result { let table_ref = table_ref.into(); let provider = self.table_provider(table_ref)?; let plan = LogicalPlanBuilder::scan( @@ -951,7 +946,7 @@ impl SessionContext { None, )? .build()?; - Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))) + Ok(DataFrame::new(self.state.clone(), plan)) } /// Return a [`TabelProvider`] for the specified table. @@ -2679,28 +2674,28 @@ mod tests { // See https://github.com/apache/arrow-datafusion/issues/1154 #[async_trait] trait CallReadTrait { - async fn call_read_csv(&self) -> Arc; - async fn call_read_avro(&self) -> Arc; - async fn call_read_parquet(&self) -> Arc; + async fn call_read_csv(&self) -> DataFrame; + async fn call_read_avro(&self) -> DataFrame; + async fn call_read_parquet(&self) -> DataFrame; } struct CallRead {} #[async_trait] impl CallReadTrait for CallRead { - async fn call_read_csv(&self) -> Arc { + async fn call_read_csv(&self) -> DataFrame { let ctx = SessionContext::new(); ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() } - async fn call_read_avro(&self) -> Arc { + async fn call_read_avro(&self) -> DataFrame { let ctx = SessionContext::new(); ctx.read_avro("dummy", AvroReadOptions::default()) .await .unwrap() } - async fn call_read_parquet(&self) -> Arc { + async fn call_read_parquet(&self) -> DataFrame { let ctx = SessionContext::new(); ctx.read_parquet("dummy", ParquetReadOptions::default()) .await diff --git a/datafusion/core/src/scheduler/mod.rs b/datafusion/core/src/scheduler/mod.rs index 24c30dc3e495..e9e8aa7755ae 100644 --- a/datafusion/core/src/scheduler/mod.rs +++ b/datafusion/core/src/scheduler/mod.rs @@ -354,7 +354,7 @@ mod tests { let query = context.sql(sql).await.unwrap(); - let plan = query.create_physical_plan().await.unwrap(); + let plan = query.clone().create_physical_plan().await.unwrap(); info!("Plan: {}", displayable(plan.as_ref()).indent()); diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 57df7e32afdb..b2d599727e6b 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -380,7 +380,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { Ok(()) } -fn create_test_table() -> Result> { +fn create_test_table() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), @@ -407,7 +407,7 @@ fn create_test_table() -> Result> { ctx.table("test") } -async fn aggregates_table(ctx: &SessionContext) -> Result> { +async fn aggregates_table(ctx: &SessionContext) -> Result { let testdata = datafusion::test_util::arrow_test_data(); ctx.read_csv( diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe_functions.rs index 5b643e12892b..624291a952df 100644 --- a/datafusion/core/tests/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe_functions.rs @@ -34,7 +34,7 @@ use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; use datafusion_expr::{approx_median, cast}; -fn create_test_table() -> Result> { +fn create_test_table() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 35d8dc6d6e08..d756fad5a381 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -1390,7 +1390,7 @@ async fn unprojected_filter() { .select(vec![col("i") + col("i")]) .unwrap(); - let plan = df.to_logical_plan().unwrap(); + let plan = df.clone().to_logical_plan().unwrap(); println!("{}", plan.display_indent()); let results = df.collect().await.unwrap(); From 7ec70885225f7f1f24dcedca3251b934ebcab43f Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 14 Dec 2022 17:34:08 +0000 Subject: [PATCH 2/4] Fix doc --- datafusion/core/src/dataframe.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index ac77dc1c62b0..6296da180dee 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -243,7 +243,8 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; - /// let df = df.union(df.clone())?; + /// let d2 = df.clone(); + /// let df = df.union(d2)?; /// # Ok(()) /// # } /// ``` @@ -264,7 +265,8 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; - /// let df = df.union_distinct(df.clone())?; + /// let d2 = df.clone(); + /// let df = df.union_distinct(d2)?; /// # Ok(()) /// # } /// ``` @@ -579,7 +581,8 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; - /// let df = df.intersect(df.clone())?; + /// let d2 = df.clone(); + /// let df = df.intersect(d2)?; /// # Ok(()) /// # } /// ``` @@ -601,7 +604,8 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; - /// let df = df.except(df.clone())?; + /// let d2 = df.clone(); + /// let df = df.except(d2)?; /// # Ok(()) /// # } /// ``` From eed0dd7358cbe3c24768bc55e29bfb1380917b84 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 14 Dec 2022 20:32:27 +0000 Subject: [PATCH 3/4] More methods --- datafusion/core/src/dataframe.rs | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 6296da180dee..302a3070d84a 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -149,16 +149,16 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn select(&self, expr_list: Vec) -> Result { + pub fn select(self, expr_list: Vec) -> Result { let window_func_exprs = find_window_exprs(&expr_list); let plan = if window_func_exprs.is_empty() { - self.plan.clone() + self.plan } else { - LogicalPlanBuilder::window_plan(self.plan.clone(), window_func_exprs)? + LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; - Ok(DataFrame::new(self.session_state.clone(), project_plan)) + Ok(DataFrame::new(self.session_state, project_plan)) } /// Filter a DataFrame to only include rows that match the specified filter expression. @@ -174,11 +174,11 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn filter(&self, predicate: Expr) -> Result { - let plan = LogicalPlanBuilder::from(self.plan.clone()) + pub fn filter(self, predicate: Expr) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) .filter(predicate)? .build()?; - Ok(DataFrame::new(self.session_state.clone(), plan)) + Ok(DataFrame::new(self.session_state, plan)) } /// Perform an aggregate query with optional grouping expressions. @@ -200,14 +200,14 @@ impl DataFrame { /// # } /// ``` pub fn aggregate( - &self, + self, group_expr: Vec, aggr_expr: Vec, ) -> Result { - let plan = LogicalPlanBuilder::from(self.plan.clone()) + let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? .build()?; - Ok(DataFrame::new(self.session_state.clone(), plan)) + Ok(DataFrame::new(self.session_state, plan)) } /// Limit the number of rows returned from this DataFrame. @@ -374,7 +374,7 @@ impl DataFrame { /// # } /// ``` pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { - let plan = LogicalPlanBuilder::from(self.plan.clone()) + let plan = LogicalPlanBuilder::from(self.plan) .repartition(partitioning_scheme)? .build()?; Ok(DataFrame::new(self.session_state, plan)) @@ -724,7 +724,7 @@ impl DataFrame { } } if rename_applied { - let project_plan = LogicalPlanBuilder::from(self.plan.clone()) + let project_plan = LogicalPlanBuilder::from(self.plan) .project(projection)? .build()?; Ok(DataFrame::new(self.session_state, project_plan)) @@ -1088,10 +1088,10 @@ mod tests { async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; let ctx = SessionContext::new(); - let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), df.plan.clone())); + let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone()); // register a dataframe as a table - ctx.register_table("test_table", df_impl.clone())?; + ctx.register_table("test_table", Arc::new(df_impl.clone()))?; // pull the table out let table = ctx.table("test_table")?; @@ -1100,7 +1100,7 @@ mod tests { let aggr_expr = vec![sum(col("c12"))]; // check that we correctly read from the table - let df_results = &df_impl + let df_results = df_impl .aggregate(group_expr.clone(), aggr_expr.clone())? .collect() .await?; @@ -1118,7 +1118,7 @@ mod tests { "| e | 10.206140546981722 |", "+----+-----------------------------+", ], - df_results + &df_results ); // the results are the same as the results from the view, modulo the leaf table name From e38e76d4c163bc860fcabe329238621c501955e9 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 14 Dec 2022 22:26:35 +0000 Subject: [PATCH 4/4] Fix doc --- datafusion/core/src/dataframe.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 3902f904cccf..5e615be607c3 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -192,7 +192,7 @@ impl DataFrame { /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" - /// let _ = df.aggregate(vec![col("a")], vec![min(col("b"))])?; + /// let _ = df.clone().aggregate(vec![col("a")], vec![min(col("b"))])?; /// /// // The following use is the equivalent of "SELECT MIN(b)" /// let _ = df.aggregate(vec![], vec![min(col("b"))])?;