Skip to content

Commit

Permalink
fix: Raise proper error for mismatching parquet schema instead of pan…
Browse files Browse the repository at this point in the history
…icking (pola-rs#17321)
  • Loading branch information
nameexhaustion authored Jul 2, 2024
1 parent f73937a commit 8eef76e
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 137 deletions.
19 changes: 19 additions & 0 deletions crates/polars-arrow/src/datatypes/schema.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::sync::Arc;

use polars_error::{polars_bail, PolarsResult};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -62,6 +63,24 @@ impl ArrowSchema {
metadata: self.metadata,
}
}

pub fn try_project(&self, indices: &[usize]) -> PolarsResult<Self> {
let fields = indices.iter().map(|&i| {
let Some(out) = self.fields.get(i) else {
polars_bail!(
SchemaFieldNotFound: "projection index {} is out of bounds for schema of length {}",
i, self.fields.len()
);
};

Ok(out.clone())
}).collect::<PolarsResult<Vec<_>>>()?;

Ok(ArrowSchema {
fields,
metadata: self.metadata.clone(),
})
}
}

impl From<Vec<Field>> for ArrowSchema {
Expand Down
97 changes: 97 additions & 0 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ pub trait IndexOfSchema: Debug {
/// Get a vector of all column names.
fn get_names(&self) -> Vec<&str>;

/// Get a vector of (name, dtype) pairs
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, DataType)>;

fn try_index_of(&self, name: &str) -> PolarsResult<usize> {
self.index_of(name).ok_or_else(|| {
polars_err!(
Expand All @@ -464,6 +467,13 @@ impl IndexOfSchema for Schema {
fn get_names(&self) -> Vec<&str> {
self.iter_names().map(|name| name.as_str()).collect()
}

fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, DataType)> {
self.inner
.iter()
.map(|(name, dtype)| (name.as_str(), dtype.clone()))
.collect()
}
}

impl IndexOfSchema for ArrowSchema {
Expand All @@ -474,6 +484,45 @@ impl IndexOfSchema for ArrowSchema {
fn get_names(&self) -> Vec<&str> {
self.fields.iter().map(|f| f.name.as_str()).collect()
}

fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, DataType)> {
self.fields
.iter()
.map(|x| (x.name.as_str(), DataType::from_arrow(&x.data_type, true)))
.collect()
}
}

pub trait SchemaNamesAndDtypes {
const IS_ARROW: bool;
type DataType: Debug + PartialEq;

/// Get a vector of (name, dtype) pairs
fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)>;
}

impl SchemaNamesAndDtypes for Schema {
const IS_ARROW: bool = false;
type DataType = DataType;

fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> {
self.inner
.iter()
.map(|(name, dtype)| (name.as_str(), dtype.clone()))
.collect()
}
}

impl SchemaNamesAndDtypes for ArrowSchema {
const IS_ARROW: bool = true;
type DataType = ArrowDataType;

fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> {
self.fields
.iter()
.map(|x| (x.name.as_str(), x.data_type.clone()))
.collect()
}
}

impl From<&ArrowSchema> for Schema {
Expand All @@ -498,3 +547,51 @@ impl From<&ArrowSchemaRef> for Schema {
Self::from(value.as_ref())
}
}

pub fn ensure_matching_schema<S: SchemaNamesAndDtypes>(lhs: &S, rhs: &S) -> PolarsResult<()> {
let lhs = lhs.get_names_and_dtypes();
let rhs = rhs.get_names_and_dtypes();

if lhs.len() != rhs.len() {
polars_bail!(
SchemaMismatch:
"schemas contained differing number of columns: {} != {}",
lhs.len(), rhs.len(),
);
}

for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.iter().zip(&rhs).enumerate() {
if l_name != r_name {
polars_bail!(
SchemaMismatch:
"schema names differ at index {}: {} != {}",
i, l_name, r_name
)
}
if l_dtype != r_dtype
&& (!S::IS_ARROW
|| unsafe {
// For timezone normalization. Easier than writing out the entire PartialEq.
DataType::from_arrow(
std::mem::transmute::<&<S as SchemaNamesAndDtypes>::DataType, &ArrowDataType>(
l_dtype,
),
true,
) != DataType::from_arrow(
std::mem::transmute::<&<S as SchemaNamesAndDtypes>::DataType, &ArrowDataType>(
r_dtype,
),
true,
)
})
{
polars_bail!(
SchemaMismatch:
"schema dtypes differ at index {} for column {}: {:?} != {:?}",
i, l_name, l_dtype, r_dtype
)
}
}

Ok(())
}
61 changes: 48 additions & 13 deletions crates/polars-io/src/parquet/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,38 @@ impl<R: MmapBytesReader> ParquetReader<R> {
self
}

/// Set the [`Schema`] if already known. This must be exactly the same as
/// the schema in the file itself.
pub fn with_schema(mut self, schema: Option<ArrowSchemaRef>) -> Self {
self.schema = schema;
self
/// Ensure the schema of the file matches the given schema. Calling this
/// after setting the projection will ensure only the projected indices
/// are checked.
pub fn check_schema(mut self, schema: &ArrowSchema) -> PolarsResult<Self> {
let self_schema = self.schema()?;
let self_schema = self_schema.as_ref();

if let Some(ref projection) = self.projection {
let projection = projection.as_slice();

ensure_matching_schema(
&schema.try_project(projection)?,
&self_schema.try_project(projection)?,
)?;
} else {
ensure_matching_schema(schema, self_schema)?;
}

Ok(self)
}

/// [`Schema`] of the file.
pub fn schema(&mut self) -> PolarsResult<ArrowSchemaRef> {
match &self.schema {
Some(schema) => Ok(schema.clone()),
self.schema = Some(match &self.schema {
Some(schema) => schema.clone(),
None => {
let metadata = self.get_metadata()?;
Ok(Arc::new(read::infer_schema(metadata)?))
Arc::new(read::infer_schema(metadata)?)
},
}
});

Ok(self.schema.clone().unwrap())
}

/// Use statistics in the parquet to determine if pages
Expand Down Expand Up @@ -226,7 +242,6 @@ impl ParquetAsyncReader {
pub async fn from_uri(
uri: &str,
cloud_options: Option<&CloudOptions>,
schema: Option<ArrowSchemaRef>,
metadata: Option<FileMetaDataRef>,
) -> PolarsResult<ParquetAsyncReader> {
Ok(ParquetAsyncReader {
Expand All @@ -238,20 +253,40 @@ impl ParquetAsyncReader {
predicate: None,
use_statistics: true,
hive_partition_columns: None,
schema,
schema: None,
parallel: Default::default(),
})
}

pub async fn check_schema(mut self, schema: &ArrowSchema) -> PolarsResult<Self> {
let self_schema = self.schema().await?;
let self_schema = self_schema.as_ref();

if let Some(ref projection) = self.projection {
let projection = projection.as_slice();

ensure_matching_schema(
&schema.try_project(projection)?,
&self_schema.try_project(projection)?,
)?;
} else {
ensure_matching_schema(schema, self_schema)?;
}

Ok(self)
}

pub async fn schema(&mut self) -> PolarsResult<ArrowSchemaRef> {
Ok(match self.schema.as_ref() {
self.schema = Some(match self.schema.as_ref() {
Some(schema) => Arc::clone(schema),
None => {
let metadata = self.reader.get_metadata().await?;
let arrow_schema = polars_parquet::arrow::read::infer_schema(metadata)?;
Arc::new(arrow_schema)
},
})
});

Ok(self.schema.clone().unwrap())
}

pub async fn num_rows(&mut self) -> PolarsResult<usize> {
Expand Down
47 changes: 0 additions & 47 deletions crates/polars-io/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,53 +270,6 @@ pub fn materialize_projection(
}
}

pub fn check_projected_schema_impl(
a: &Schema,
b: &Schema,
projected_names: Option<&[String]>,
msg: &str,
) -> PolarsResult<()> {
if !projected_names
.map(|projected_names| {
projected_names
.iter()
.all(|name| a.get(name) == b.get(name))
})
.unwrap_or_else(|| a == b)
{
polars_bail!(ComputeError: "{msg}\n\n\
Expected: {:?}\n\n\
Got: {:?}", a, b)
}
Ok(())
}

/// Checks if the projected columns are equal
pub fn check_projected_arrow_schema(
a: &ArrowSchema,
b: &ArrowSchema,
projected_names: Option<&[String]>,
msg: &str,
) -> PolarsResult<()> {
if a != b {
let a = Schema::from(a);
let b = Schema::from(b);
check_projected_schema_impl(&a, &b, projected_names, msg)
} else {
Ok(())
}
}

/// Checks if the projected columns are equal
pub fn check_projected_schema(
a: &Schema,
b: &Schema,
projected_names: Option<&[String]>,
msg: &str,
) -> PolarsResult<()> {
check_projected_schema_impl(a, b, projected_names, msg)
}

/// Split DataFrame into chunks in preparation for writing. The chunks have a
/// maximum number of rows per chunk to ensure reasonable memory efficiency when
/// reading the resulting file, and a minimum size per chunk to ensure
Expand Down
Loading

0 comments on commit 8eef76e

Please sign in to comment.