Skip to content

Commit

Permalink
fix: Don't vertically parallelize cse contexts (#18177)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 14, 2024
1 parent c6cb8be commit 7b3ab69
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 108 deletions.
18 changes: 2 additions & 16 deletions crates/polars-mem-engine/src/executors/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,8 @@ impl StackExec {
self.has_windows,
self.options.run_parallel,
)?;
if !self.options.should_broadcast {
debug_assert!(
res.iter()
.all(|column| column.name().starts_with("__POLARS_CSER_0x")),
"non-broadcasting hstack should only be used for CSE columns"
);
// Safety: this case only appears as a result
// of CSE optimization, and the usage there
// produces new, unique column names. It is
// immediately followed by a projection which
// pulls out the possibly mismatching column
// lengths.
unsafe { df.get_columns_mut().extend(res) };
} else {
df._add_columns(res, schema)?;
}
// We don't have to do a broadcast check as cse is not allowed to hit this.
df._add_columns(res, schema)?;
Ok(df)
});

Expand Down
6 changes: 4 additions & 2 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ fn create_physical_plan_impl(
state.expr_depth,
);

let streamable = all_streamable(&expr, expr_arena, Context::Default);
let streamable =
options.should_broadcast && all_streamable(&expr, expr_arena, Context::Default);
let phys_expr = create_physical_expressions_from_irs(
&expr,
Context::Default,
Expand Down Expand Up @@ -629,7 +630,8 @@ fn create_physical_plan_impl(
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();
let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?;

let streamable = all_streamable(&exprs, expr_arena, Context::Default);
let streamable =
options.should_broadcast && all_streamable(&exprs, expr_arena, Context::Default);

let mut state = ExpressionConversionState::new(
POOL.current_num_threads() > exprs.len(),
Expand Down
91 changes: 91 additions & 0 deletions crates/polars-ops/src/series/ops/duration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS, SECONDS_IN_DAY};
use polars_core::datatypes::{AnyValue, DataType, TimeUnit};
use polars_core::prelude::Series;
use polars_error::PolarsResult;

pub fn impl_duration(s: &[Series], time_unit: TimeUnit) -> PolarsResult<Series> {
if s.iter().any(|s| s.is_empty()) {
return Ok(Series::new_empty(
s[0].name(),
&DataType::Duration(time_unit),
));
}

// TODO: Handle overflow for UInt64
let weeks = s[0].cast(&DataType::Int64).unwrap();
let days = s[1].cast(&DataType::Int64).unwrap();
let hours = s[2].cast(&DataType::Int64).unwrap();
let minutes = s[3].cast(&DataType::Int64).unwrap();
let seconds = s[4].cast(&DataType::Int64).unwrap();
let mut milliseconds = s[5].cast(&DataType::Int64).unwrap();
let mut microseconds = s[6].cast(&DataType::Int64).unwrap();
let mut nanoseconds = s[7].cast(&DataType::Int64).unwrap();

let is_scalar = |s: &Series| s.len() == 1;
let is_zero_scalar = |s: &Series| is_scalar(s) && s.get(0).unwrap() == AnyValue::Int64(0);

// Process subseconds
let max_len = s.iter().map(|s| s.len()).max().unwrap();
let mut duration = match time_unit {
TimeUnit::Microseconds => {
if is_scalar(&microseconds) {
microseconds = microseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&nanoseconds) {
microseconds = (microseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000)))?;
}
if !is_zero_scalar(&milliseconds) {
microseconds = (microseconds + (milliseconds * 1_000))?;
}
microseconds
},
TimeUnit::Nanoseconds => {
if is_scalar(&nanoseconds) {
nanoseconds = nanoseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&microseconds) {
nanoseconds = (nanoseconds + (microseconds * 1_000))?;
}
if !is_zero_scalar(&milliseconds) {
nanoseconds = (nanoseconds + (milliseconds * 1_000_000))?;
}
nanoseconds
},
TimeUnit::Milliseconds => {
if is_scalar(&milliseconds) {
milliseconds = milliseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&nanoseconds) {
milliseconds = (milliseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000_000)))?;
}
if !is_zero_scalar(&microseconds) {
milliseconds = (milliseconds + (microseconds.wrapping_trunc_div_scalar(1_000)))?;
}
milliseconds
},
};

// Process other duration specifiers
let multiplier = match time_unit {
TimeUnit::Nanoseconds => NANOSECONDS,
TimeUnit::Microseconds => MICROSECONDS,
TimeUnit::Milliseconds => MILLISECONDS,
};
if !is_zero_scalar(&seconds) {
duration = (duration + seconds * multiplier)?;
}
if !is_zero_scalar(&minutes) {
duration = (duration + minutes * (multiplier * 60))?;
}
if !is_zero_scalar(&hours) {
duration = (duration + hours * (multiplier * 60 * 60))?;
}
if !is_zero_scalar(&days) {
duration = (duration + days * (multiplier * SECONDS_IN_DAY))?;
}
if !is_zero_scalar(&weeks) {
duration = (duration + weeks * (multiplier * SECONDS_IN_DAY * 7))?;
}

duration.cast(&DataType::Duration(time_unit))
}
5 changes: 5 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ pub use to_dummies::*;
pub use unique::*;
pub use various::*;
mod not;

#[cfg(feature = "dtype-duration")]
pub(crate) mod duration;
#[cfg(feature = "dtype-duration")]
pub use duration::*;
pub use not::*;

pub trait SeriesSealed {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype-i16 = ["polars-core/dtype-i16"]
dtype-decimal = ["polars-core/dtype-decimal"]
dtype-date = ["polars-time/dtype-date", "temporal"]
dtype-datetime = ["polars-time/dtype-datetime", "temporal"]
dtype-duration = ["polars-core/dtype-duration", "polars-time/dtype-duration", "temporal"]
dtype-duration = ["polars-core/dtype-duration", "polars-time/dtype-duration", "temporal", "polars-ops/dtype-duration"]
dtype-time = ["polars-time/dtype-time", "temporal"]
dtype-array = ["polars-core/dtype-array", "polars-ops/dtype-array"]
dtype-categorical = ["polars-core/dtype-categorical"]
Expand Down
88 changes: 0 additions & 88 deletions crates/polars-plan/src/dsl/function_expr/datetime.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS, SECONDS_IN_DAY};
#[cfg(feature = "timezones")]
use chrono_tz::Tz;
#[cfg(feature = "timezones")]
Expand Down Expand Up @@ -501,90 +500,3 @@ pub(super) fn round(s: &[Series]) -> PolarsResult<Series> {
dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"),
})
}

pub(super) fn duration(s: &[Series], time_unit: TimeUnit) -> PolarsResult<Series> {
if s.iter().any(|s| s.is_empty()) {
return Ok(Series::new_empty(
s[0].name(),
&DataType::Duration(time_unit),
));
}

// TODO: Handle overflow for UInt64
let weeks = s[0].cast(&DataType::Int64).unwrap();
let days = s[1].cast(&DataType::Int64).unwrap();
let hours = s[2].cast(&DataType::Int64).unwrap();
let minutes = s[3].cast(&DataType::Int64).unwrap();
let seconds = s[4].cast(&DataType::Int64).unwrap();
let mut milliseconds = s[5].cast(&DataType::Int64).unwrap();
let mut microseconds = s[6].cast(&DataType::Int64).unwrap();
let mut nanoseconds = s[7].cast(&DataType::Int64).unwrap();

let is_scalar = |s: &Series| s.len() == 1;
let is_zero_scalar = |s: &Series| is_scalar(s) && s.get(0).unwrap() == AnyValue::Int64(0);

// Process subseconds
let max_len = s.iter().map(|s| s.len()).max().unwrap();
let mut duration = match time_unit {
TimeUnit::Microseconds => {
if is_scalar(&microseconds) {
microseconds = microseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&nanoseconds) {
microseconds = (microseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000)))?;
}
if !is_zero_scalar(&milliseconds) {
microseconds = (microseconds + (milliseconds * 1_000))?;
}
microseconds
},
TimeUnit::Nanoseconds => {
if is_scalar(&nanoseconds) {
nanoseconds = nanoseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&microseconds) {
nanoseconds = (nanoseconds + (microseconds * 1_000))?;
}
if !is_zero_scalar(&milliseconds) {
nanoseconds = (nanoseconds + (milliseconds * 1_000_000))?;
}
nanoseconds
},
TimeUnit::Milliseconds => {
if is_scalar(&milliseconds) {
milliseconds = milliseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&nanoseconds) {
milliseconds = (milliseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000_000)))?;
}
if !is_zero_scalar(&microseconds) {
milliseconds = (milliseconds + (microseconds.wrapping_trunc_div_scalar(1_000)))?;
}
milliseconds
},
};

// Process other duration specifiers
let multiplier = match time_unit {
TimeUnit::Nanoseconds => NANOSECONDS,
TimeUnit::Microseconds => MICROSECONDS,
TimeUnit::Milliseconds => MILLISECONDS,
};
if !is_zero_scalar(&seconds) {
duration = (duration + seconds * multiplier)?;
}
if !is_zero_scalar(&minutes) {
duration = (duration + minutes * (multiplier * 60))?;
}
if !is_zero_scalar(&hours) {
duration = (duration + hours * (multiplier * 60 * 60))?;
}
if !is_zero_scalar(&days) {
duration = (duration + days * (multiplier * SECONDS_IN_DAY))?;
}
if !is_zero_scalar(&weeks) {
duration = (duration + weeks * (multiplier * SECONDS_IN_DAY * 7))?;
}

duration.cast(&DataType::Duration(time_unit))
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl From<TemporalFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Quarter => map!(datetime::quarter),
Week => map!(datetime::week),
WeekDay => map!(datetime::weekday),
Duration(tu) => map_as_slice!(datetime::duration, tu),
Duration(tu) => map_as_slice!(impl_duration, tu),
Day => map!(datetime::day),
OrdinalDay => map!(datetime::ordinal_day),
Time => map!(datetime::time),
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,21 @@ def test_cse_non_scalar_length_mismatch_17732() -> None:
)

assert_frame_equal(expect, got)


def test_cse_chunks_18124() -> None:
df = pl.DataFrame(
{
"ts_diff": [timedelta(seconds=60)] * 2,
"ts_diff_after": [timedelta(seconds=120)] * 2,
}
)
df = pl.concat([df, df], rechunk=False)
assert (
df.lazy()
.with_columns(
ts_diff_sign=pl.col("ts_diff") > pl.duration(seconds=0),
ts_diff_after_sign=pl.col("ts_diff_after") > pl.duration(seconds=0),
)
.filter(pl.col("ts_diff") > 1)
).collect().shape == (4, 4)

0 comments on commit 7b3ab69

Please sign in to comment.