Skip to content

Commit

Permalink
refactor(rust!): Move schema resolving to IR phase. (pola-rs#15714)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 18, 2024
1 parent 2db0ba6 commit 8454d9c
Show file tree
Hide file tree
Showing 34 changed files with 1,027 additions and 1,141 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ either = "1.9"
ethnum = "1.3.2"
fallible-streaming-iterator = "0.1.9"
futures = "0.3.25"
hashbrown = { version = "0.14", features = ["rayon", "ahash"] }
hashbrown = { version = "0.14", features = ["rayon", "ahash", "serde"] }
hex = "0.4.3"
indexmap = { version = "2", features = ["std"] }
itoa = "1.0.6"
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ pub(crate) fn concat_impl<L: AsRef<[LazyFrame]>>(
else {
unreachable!()
};
let mut schema = inputs[0].schema()?.as_ref().as_ref().clone();
let mut schema = inputs[0].compute_schema()?.as_ref().clone();

let mut changed = false;
for input in inputs[1..].iter() {
changed |= schema.to_supertype(input.schema()?.as_ref().as_ref())?;
changed |= schema.to_supertype(input.compute_schema()?.as_ref())?;
}

let mut placeholder = DslPlan::default();
if changed {
let mut exprs = vec![];
for input in &mut inputs {
std::mem::swap(input, &mut placeholder);
let input_schema = placeholder.schema()?;
let input_schema = placeholder.compute_schema()?;

exprs.clear();
let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map(
Expand Down
179 changes: 37 additions & 142 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ mod exitable;
#[cfg(feature = "pivot")]
pub mod pivot;

use std::borrow::Cow;
#[cfg(any(
feature = "parquet",
feature = "ipc",
Expand Down Expand Up @@ -93,7 +92,7 @@ impl LazyFrame {
/// Returns an `Err` if the logical plan has already encountered an error (i.e., if
/// `self.collect()` would fail), `Ok` otherwise.
pub fn schema(&self) -> PolarsResult<SchemaRef> {
self.logical_plan.schema().map(|schema| schema.into_owned())
self.logical_plan.compute_schema()
}

pub(crate) fn get_plan_builder(self) -> DslBuilder {
Expand Down Expand Up @@ -378,34 +377,6 @@ impl LazyFrame {
self.select(vec![col("*").reverse()])
}

/// Check the if the `names` are available in the `schema`, if not
/// return a `LogicalPlan` that raises an `Error`.
fn check_names(&self, names: &[SmartString], schema: Option<&SchemaRef>) -> Option<Self> {
let schema = schema
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(self.schema().unwrap()));

let mut opt_not_found = None;
names.iter().for_each(|name| {
let invalid = schema.get(name).is_none();

if invalid && opt_not_found.is_none() {
opt_not_found = Some(name)
}
});

if let Some(name) = opt_not_found {
let lp = self
.clone()
.get_plan_builder()
.add_err(polars_err!(SchemaFieldNotFound: "{}", name))
.build();
Some(Self::from_logical_plan(lp, self.opt_state))
} else {
None
}
}

/// Rename columns in the DataFrame.
///
/// `existing` and `new` are iterables of the same length containing the old and
Expand Down Expand Up @@ -435,19 +406,10 @@ impl LazyFrame {
}
}

// a column gets swapped
let schema = &self.schema().unwrap();
let swapping = new_vec.iter().any(|name| schema.get(name).is_some());

if let Some(lp) = self.check_names(&existing_vec, Some(schema)) {
lp
} else {
self.map_private(FunctionNode::Rename {
existing: existing_vec.into(),
new: new_vec.into(),
swapping,
})
}
self.map_private(DslFunction::Rename {
existing: existing_vec.into(),
new: new_vec.into(),
})
}

/// Removes columns from the DataFrame.
Expand Down Expand Up @@ -1339,37 +1301,18 @@ impl LazyFrame {
Self::from_logical_plan(lp, opt_state)
}

fn stats_helper<F, E>(self, condition: F, expr: E) -> PolarsResult<LazyFrame>
where
F: Fn(&DataType) -> bool,
E: Fn(&str) -> Expr,
{
let exprs = self
.schema()?
.iter()
.map(|(name, dt)| {
if condition(dt) {
expr(name)
} else {
lit(NULL).cast(dt.clone()).alias(name)
}
})
.collect::<Vec<_>>();
Ok(self.select(exprs))
}

/// Aggregate all the columns as their maximum values.
///
/// Aggregated columns will have the same names as the original columns.
pub fn max(self) -> PolarsResult<LazyFrame> {
self.stats_helper(|dt| dt.is_ord(), |name| col(name).max())
pub fn max(self) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Max))
}

/// Aggregate all the columns as their minimum values.
///
/// Aggregated columns will have the same names as the original columns.
pub fn min(self) -> PolarsResult<LazyFrame> {
self.stats_helper(|dt| dt.is_ord(), |name| col(name).min())
pub fn min(self) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Min))
}

/// Aggregate all the columns as their sum values.
Expand All @@ -1381,68 +1324,33 @@ impl LazyFrame {
/// in `debug` mode, overflows will panic, whereas in `release` mode overflows will
/// silently wrap.
/// - String columns will sum to None.
pub fn sum(self) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| {
dt.is_numeric()
|| dt.is_decimal()
|| matches!(dt, DataType::Boolean | DataType::Duration(_))
},
|name| col(name).sum(),
)
pub fn sum(self) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Sum))
}

/// Aggregate all the columns as their mean values.
///
/// - Boolean and integer columns are converted to `f64` before computing the mean.
/// - String columns will have a mean of None.
pub fn mean(self) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|name| col(name).mean(),
)
pub fn mean(self) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Mean))
}

/// Aggregate all the columns as their median values.
///
/// - Boolean and integer results are converted to `f64`. However, they are still
/// susceptible to overflow before this conversion occurs.
/// - String columns will sum to None.
pub fn median(self) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|name| col(name).median(),
)
pub fn median(self) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Median))
}

/// Aggregate all the columns as their quantile values.
pub fn quantile(
self,
quantile: Expr,
interpol: QuantileInterpolOptions,
) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| dt.is_numeric(),
|name| col(name).quantile(quantile.clone(), interpol),
)
pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Quantile {
quantile,
interpol,
}))
}

/// Aggregate all the columns as their standard deviation values.
Expand All @@ -1457,11 +1365,8 @@ impl LazyFrame {
/// > standard deviation per se.
///
/// Source: [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.std.html#)
pub fn std(self, ddof: u8) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| dt.is_numeric() || dt.is_bool(),
|name| col(name).std(ddof),
)
pub fn std(self, ddof: u8) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Std { ddof }))
}

/// Aggregate all the columns as their variance values.
Expand All @@ -1473,11 +1378,8 @@ impl LazyFrame {
/// > likelihood estimate of the variance for normally distributed variables.
///
/// Source: [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.var.html#)
pub fn var(self, ddof: u8) -> PolarsResult<LazyFrame> {
self.stats_helper(
|dt| dt.is_numeric() || dt.is_bool(),
|name| col(name).var(ddof),
)
pub fn var(self, ddof: u8) -> Self {
self.map_private(DslFunction::Stats(StatsFunction::Var { ddof }))
}

/// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode).
Expand Down Expand Up @@ -1592,7 +1494,7 @@ impl LazyFrame {
/// See [`MeltArgs`] for information on how to melt a DataFrame.
pub fn melt(self, args: MeltArgs) -> LazyFrame {
let opt_state = self.get_opt_state();
let lp = self.get_plan_builder().melt(Arc::new(args)).build();
let lp = self.get_plan_builder().melt(args).build();
Self::from_logical_plan(lp, opt_state)
}

Expand Down Expand Up @@ -1655,7 +1557,7 @@ impl LazyFrame {
Self::from_logical_plan(lp, opt_state)
}

pub(crate) fn map_private(self, function: FunctionNode) -> LazyFrame {
pub(crate) fn map_private(self, function: DslFunction) -> LazyFrame {
let opt_state = self.get_opt_state();
let lp = self.get_plan_builder().map_private(function).build();
Self::from_logical_plan(lp, opt_state)
Expand All @@ -1673,30 +1575,22 @@ impl LazyFrame {
let add_row_index_in_map = match &mut self.logical_plan {
DslPlan::Scan {
file_options: options,
file_info,
scan_type,
..
} if !matches!(scan_type, FileScan::Anonymous { .. }) => {
options.row_index = Some(RowIndex {
name: name.to_string(),
offset: offset.unwrap_or(0),
});
file_info.schema = Arc::new(
file_info
.schema
.new_inserting_at_index(0, name.into(), IDX_DTYPE)
.unwrap(),
);
false
},
_ => true,
};

if add_row_index_in_map {
self.map_private(FunctionNode::RowIndex {
self.map_private(DslFunction::RowIndex {
name: Arc::from(name),
offset,
schema: Default::default(),
})
} else {
self
Expand All @@ -1712,9 +1606,9 @@ impl LazyFrame {
/// inserted as columns.
#[cfg(feature = "dtype-struct")]
pub fn unnest<I: IntoIterator<Item = S>, S: AsRef<str>>(self, cols: I) -> Self {
self.map_private(FunctionNode::Unnest {
self.map_private(DslFunction::FunctionNode(FunctionNode::Unnest {
columns: cols.into_iter().map(|s| Arc::from(s.as_ref())).collect(),
})
}))
}

#[cfg(feature = "merge_sorted")]
Expand All @@ -1723,7 +1617,7 @@ impl LazyFrame {
// this indicates until which chunk the data is from the left df
// this trick allows us to reuse the `Union` architecture to get map over
// two DataFrames
let left = self.map_private(FunctionNode::Rechunk);
let left = self.map_private(DslFunction::FunctionNode(FunctionNode::Rechunk));
let q = concat(
&[left, other],
UnionArgs {
Expand All @@ -1732,9 +1626,11 @@ impl LazyFrame {
..Default::default()
},
)?;
Ok(q.map_private(FunctionNode::MergeSorted {
column: Arc::from(key),
}))
Ok(
q.map_private(DslFunction::FunctionNode(FunctionNode::MergeSorted {
column: Arc::from(key),
})),
)
}
}

Expand Down Expand Up @@ -1847,10 +1743,9 @@ impl LazyGroupBy {

let lp = DslPlan::GroupBy {
input: Arc::new(self.logical_plan),
keys: Arc::new(self.keys),
keys: self.keys,
aggs: vec![],
schema,
apply: Some(Arc::new(f)),
apply: Some((Arc::new(f), schema)),
maintain_order: self.maintain_order,
options: Arc::new(options),
};
Expand Down
Loading

0 comments on commit 8454d9c

Please sign in to comment.