Skip to content

Commit

Permalink
parquet_derive: Match fields by name, support reading selected fields…
Browse files Browse the repository at this point in the history
… rather than all (#6269)

* support reading pruned parquet

* add pruned parquet reading test

* better unit test

* update comments

* deref instead of clone

* do not panic

* copy integer

* restore struct name

* update comments

---------

Co-authored-by: Ye Yuan <yuanye_ptr@qq.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
3 people committed Aug 31, 2024
1 parent 0c15191 commit 3a1f67f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
22 changes: 17 additions & 5 deletions parquet_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke
/// Derive flat, simple RecordReader implementations. Works by parsing
/// a struct tagged with `#[derive(ParquetRecordReader)]` and emitting
/// the correct writing code for each field of the struct. Column readers
/// are generated in the order they are defined.
/// are generated by matching names in the schema to the names in the struct.
///
/// It is up to the programmer to keep the order of the struct
/// fields lined up with the schema.
/// It is up to the programmer to ensure the names in the struct
/// fields line up with the schema.
///
/// Example:
///
Expand Down Expand Up @@ -189,7 +189,6 @@ pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::Toke
let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
let reader_snippets: Vec<proc_macro2::TokenStream> =
field_infos.iter().map(|x| x.reader_snippet()).collect();
let i: Vec<_> = (0..reader_snippets.len()).collect();

let derived_for = input.ident;
let generics = input.generics;
Expand All @@ -206,6 +205,12 @@ pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::Toke

let mut row_group_reader = row_group_reader;

// key: parquet file column name, value: column index
let mut name_to_index = std::collections::HashMap::new();
for (idx, col) in row_group_reader.metadata().schema_descr().columns().iter().enumerate() {
name_to_index.insert(col.name().to_string(), idx);
}

for _ in 0..num_records {
self.push(#derived_for {
#(
Expand All @@ -218,7 +223,14 @@ pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::Toke

#(
{
if let Ok(mut column_reader) = row_group_reader.get_column_reader(#i) {
let idx: usize = match name_to_index.get(stringify!(#field_names)) {
Some(&col_idx) => col_idx,
None => {
let error_msg = format!("column name '{}' is not found in parquet file!", stringify!(#field_names));
return Err(::parquet::errors::ParquetError::General(error_msg));
}
};
if let Ok(mut column_reader) = row_group_reader.get_column_reader(idx) {
#reader_snippets
} else {
return Err(::parquet::errors::ParquetError::General("Failed to get next column".into()))
Expand Down
84 changes: 74 additions & 10 deletions parquet_derive_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ struct APartiallyCompleteRecord {
struct APartiallyOptionalRecord {
pub bool: bool,
pub string: String,
pub maybe_i16: Option<i16>,
pub maybe_i32: Option<i32>,
pub maybe_u64: Option<u64>,
pub i16: Option<i16>,
pub i32: Option<i32>,
pub u64: Option<u64>,
pub isize: isize,
pub float: f32,
pub double: f64,
Expand All @@ -85,6 +85,22 @@ struct APartiallyOptionalRecord {
pub byte_vec: Vec<u8>,
}

// This struct removes several fields from the "APartiallyCompleteRecord",
// and it shuffles the fields.
// we should still be able to load it from APartiallyCompleteRecord parquet file
#[derive(PartialEq, ParquetRecordReader, Debug)]
struct APrunedRecord {
pub bool: bool,
pub string: String,
pub byte_vec: Vec<u8>,
pub float: f32,
pub double: f64,
pub i16: i16,
pub i32: i32,
pub u64: u64,
pub isize: isize,
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -240,12 +256,12 @@ mod tests {
#[test]
fn test_parquet_derive_read_optional_but_valid_column() {
let file = get_temp_file("test_parquet_derive_read_optional", &[]);
let drs: Vec<APartiallyOptionalRecord> = vec![APartiallyOptionalRecord {
let drs = vec![APartiallyOptionalRecord {
bool: true,
string: "a string".into(),
maybe_i16: Some(-45),
maybe_i32: Some(456),
maybe_u64: Some(4563424),
i16: Some(-45),
i32: Some(456),
u64: Some(4563424),
isize: -365,
float: 3.5,
double: f64::NAN,
Expand Down Expand Up @@ -273,9 +289,57 @@ mod tests {
let mut row_group = reader.get_row_group(0).unwrap();
out.read_from_row_group(&mut *row_group, 1).unwrap();

assert_eq!(drs[0].maybe_i16.unwrap(), out[0].i16);
assert_eq!(drs[0].maybe_i32.unwrap(), out[0].i32);
assert_eq!(drs[0].maybe_u64.unwrap(), out[0].u64);
assert_eq!(drs[0].i16.unwrap(), out[0].i16);
assert_eq!(drs[0].i32.unwrap(), out[0].i32);
assert_eq!(drs[0].u64.unwrap(), out[0].u64);
}

#[test]
fn test_parquet_derive_read_pruned_and_shuffled_columns() {
let file = get_temp_file("test_parquet_derive_read_pruned", &[]);
let drs = vec![APartiallyCompleteRecord {
bool: true,
string: "a string".into(),
i16: -45,
i32: 456,
u64: 4563424,
isize: -365,
float: 3.5,
double: f64::NAN,
now: chrono::Utc::now().naive_local(),
date: chrono::naive::NaiveDate::from_ymd_opt(2015, 3, 14).unwrap(),
uuid: uuid::Uuid::new_v4(),
byte_vec: vec![0x65, 0x66, 0x67],
}];

let generated_schema = drs.as_slice().schema().unwrap();

let props = Default::default();
let mut writer =
SerializedFileWriter::new(file.try_clone().unwrap(), generated_schema, props).unwrap();

let mut row_group = writer.next_row_group().unwrap();
drs.as_slice().write_to_row_group(&mut row_group).unwrap();
row_group.close().unwrap();
writer.close().unwrap();

use parquet::file::{reader::FileReader, serialized_reader::SerializedFileReader};
let reader = SerializedFileReader::new(file).unwrap();
let mut out: Vec<APrunedRecord> = Vec::new();

let mut row_group = reader.get_row_group(0).unwrap();
out.read_from_row_group(&mut *row_group, 1).unwrap();

assert_eq!(drs[0].bool, out[0].bool);
assert_eq!(drs[0].string, out[0].string);
assert_eq!(drs[0].byte_vec, out[0].byte_vec);
assert_eq!(drs[0].float, out[0].float);
assert!(drs[0].double.is_nan());
assert!(out[0].double.is_nan());
assert_eq!(drs[0].i16, out[0].i16);
assert_eq!(drs[0].i32, out[0].i32);
assert_eq!(drs[0].u64, out[0].u64);
assert_eq!(drs[0].isize, out[0].isize);
}

/// Returns file handle for a temp file in 'target' directory with a provided content
Expand Down

0 comments on commit 3a1f67f

Please sign in to comment.