Skip to content

Commit

Permalink
feat: Support determining extensions from names like `foo.parquet.sna…
Browse files Browse the repository at this point in the history
…ppy` as well as `foo.parquet` (#7972)

* feat: read files based on the file extention

* fix: some the file extension might be started with . and some not

* fix: rename extention to extension

* chore: use exec_err

* chore: rename extention to extension

* chore: rename extention to extension

* chore: simplify the code

* fix: check table is empty

* ci: fix test

* fix: add err info

* refactor: extract the logic to infer_types

* fix: add tests for different extensions

* fix: ci clippy

* fix: add more tests

* fix: simplify the logic

* fix: ci
  • Loading branch information
Weijun-H authored Nov 7, 2023
1 parent 06fd26b commit 56f6437
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 0 deletions.
17 changes: 17 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,23 @@ impl SessionContext {
let table_paths = table_paths.to_urls()?;
let session_config = self.copied_config();
let listing_options = options.to_listing_options(&session_config);

let option_extension = listing_options.file_extension.clone();

if table_paths.is_empty() {
return exec_err!("No table paths were provided");
}

// check if the file extension matches the expected extension
for path in &table_paths {
let file_name = path.prefix().filename().unwrap_or_default();
if !path.as_str().ends_with(&option_extension) && file_name.contains('.') {
return exec_err!(
"File '{file_name}' does not match the expected extension '{option_extension}'"
);
}
}

let resolved_schema = options
.get_resolved_schema(&session_config, self.state(), table_paths[0].clone())
.await?;
Expand Down
123 changes: 123 additions & 0 deletions datafusion/core/src/execution/context/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ impl SessionContext {
mod tests {
use async_trait::async_trait;

use crate::arrow::array::{Float32Array, Int32Array};
use crate::arrow::datatypes::{DataType, Field, Schema};
use crate::arrow::record_batch::RecordBatch;
use crate::dataframe::DataFrameWriteOptions;
use crate::parquet::basic::Compression;
use crate::test_util::parquet_test_data;

use super::*;
Expand Down Expand Up @@ -132,6 +137,124 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn read_from_different_file_extension() -> Result<()> {
let ctx = SessionContext::new();

// Make up a new dataframe.
let write_df = ctx.read_batch(RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("purchase_id", DataType::Int32, false),
Field::new("price", DataType::Float32, false),
Field::new("quantity", DataType::Int32, false),
])),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 6.66])),
Arc::new(Int32Array::from(vec![1, 3, 2, 4, 3])),
],
)?)?;

// Write the dataframe to a parquet file named 'output1.parquet'
write_df
.clone()
.write_parquet(
"output1.parquet",
DataFrameWriteOptions::new().with_single_file_output(true),
Some(
WriterProperties::builder()
.set_compression(Compression::SNAPPY)
.build(),
),
)
.await?;

// Write the dataframe to a parquet file named 'output2.parquet.snappy'
write_df
.clone()
.write_parquet(
"output2.parquet.snappy",
DataFrameWriteOptions::new().with_single_file_output(true),
Some(
WriterProperties::builder()
.set_compression(Compression::SNAPPY)
.build(),
),
)
.await?;

// Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet'
write_df
.write_parquet(
"output3.parquet.snappy.parquet",
DataFrameWriteOptions::new().with_single_file_output(true),
Some(
WriterProperties::builder()
.set_compression(Compression::SNAPPY)
.build(),
),
)
.await?;

// Read the dataframe from 'output1.parquet' with the default file extension.
let read_df = ctx
.read_parquet(
"output1.parquet",
ParquetReadOptions {
..Default::default()
},
)
.await?;

let results = read_df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 5);

// Read the dataframe from 'output2.parquet.snappy' with the correct file extension.
let read_df = ctx
.read_parquet(
"output2.parquet.snappy",
ParquetReadOptions {
file_extension: "snappy",
..Default::default()
},
)
.await?;
let results = read_df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 5);

// Read the dataframe from 'output3.parquet.snappy.parquet' with the wrong file extension.
let read_df = ctx
.read_parquet(
"output2.parquet.snappy",
ParquetReadOptions {
..Default::default()
},
)
.await;

assert_eq!(
read_df.unwrap_err().strip_backtrace(),
"Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'"
);

// Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension.
let read_df = ctx
.read_parquet(
"output3.parquet.snappy.parquet",
ParquetReadOptions {
..Default::default()
},
)
.await?;

let results = read_df.collect().await?;
let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
assert_eq!(total_rows, 5);
Ok(())
}

// Test for compilation error when calling read_* functions from an #[async_trait] function.
// See https://github.com/apache/arrow-datafusion/issues/1154
#[async_trait]
Expand Down

0 comments on commit 56f6437

Please sign in to comment.