diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d523c39ee01e..9c500ec07293 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -849,6 +849,23 @@ impl SessionContext { let table_paths = table_paths.to_urls()?; let session_config = self.copied_config(); let listing_options = options.to_listing_options(&session_config); + + let option_extension = listing_options.file_extension.clone(); + + if table_paths.is_empty() { + return exec_err!("No table paths were provided"); + } + + // check if the file extension matches the expected extension + for path in &table_paths { + let file_name = path.prefix().filename().unwrap_or_default(); + if !path.as_str().ends_with(&option_extension) && file_name.contains('.') { + return exec_err!( + "File '{file_name}' does not match the expected extension '{option_extension}'" + ); + } + } + let resolved_schema = options .get_resolved_schema(&session_config, self.state(), table_paths[0].clone()) .await?; diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index dc202b9903f5..ef1f0143543d 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -74,6 +74,11 @@ impl SessionContext { mod tests { use async_trait::async_trait; + use crate::arrow::array::{Float32Array, Int32Array}; + use crate::arrow::datatypes::{DataType, Field, Schema}; + use crate::arrow::record_batch::RecordBatch; + use crate::dataframe::DataFrameWriteOptions; + use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; use super::*; @@ -132,6 +137,124 @@ mod tests { Ok(()) } + #[tokio::test] + async fn read_from_different_file_extension() -> Result<()> { + let ctx = SessionContext::new(); + + // Make up a new dataframe. + let write_df = ctx.read_batch(RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("purchase_id", DataType::Int32, false), + Field::new("price", DataType::Float32, false), + Field::new("quantity", DataType::Int32, false), + ])), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 6.66])), + Arc::new(Int32Array::from(vec![1, 3, 2, 4, 3])), + ], + )?)?; + + // Write the dataframe to a parquet file named 'output1.parquet' + write_df + .clone() + .write_parquet( + "output1.parquet", + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'output2.parquet.snappy' + write_df + .clone() + .write_parquet( + "output2.parquet.snappy", + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet' + write_df + .write_parquet( + "output3.parquet.snappy.parquet", + DataFrameWriteOptions::new().with_single_file_output(true), + Some( + WriterProperties::builder() + .set_compression(Compression::SNAPPY) + .build(), + ), + ) + .await?; + + // Read the dataframe from 'output1.parquet' with the default file extension. + let read_df = ctx + .read_parquet( + "output1.parquet", + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output2.parquet.snappy' with the correct file extension. + let read_df = ctx + .read_parquet( + "output2.parquet.snappy", + ParquetReadOptions { + file_extension: "snappy", + ..Default::default() + }, + ) + .await?; + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + + // Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension. + let read_df = ctx + .read_parquet( + "output2.parquet.snappy", + ParquetReadOptions { + ..Default::default() + }, + ) + .await; + + assert_eq!( + read_df.unwrap_err().strip_backtrace(), + "Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'" + ); + + // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. + let read_df = ctx + .read_parquet( + "output3.parquet.snappy.parquet", + ParquetReadOptions { + ..Default::default() + }, + ) + .await?; + + let results = read_df.collect().await?; + let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum(); + assert_eq!(total_rows, 5); + Ok(()) + } + // Test for compilation error when calling read_* functions from an #[async_trait] function. // See https://github.com/apache/arrow-datafusion/issues/1154 #[async_trait]