Skip to content

Commit

Permalink
fix: Scalar checks (#18627)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 11, 2024
1 parent 1ee85a3 commit 12cce97
Show file tree
Hide file tree
Showing 32 changed files with 496 additions and 294 deletions.
8 changes: 8 additions & 0 deletions crates/polars-expr/src/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,10 @@ impl PhysicalExpr for AggregationExpr {
}
}

fn is_scalar(&self) -> bool {
true
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down Expand Up @@ -742,6 +746,10 @@ impl PhysicalExpr for AggQuantileExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
true
}
}

/// Simple wrapper to parallelize functions that can be divided over threads aggregated and
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ impl PhysicalExpr for AliasExpr {
))
}

fn is_scalar(&self) -> bool {
self.physical_expr.is_scalar()
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down
18 changes: 13 additions & 5 deletions crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pub struct ApplyExpr {
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
returns_scalar: bool,
function_returns_scalar: bool,
function_operates_on_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
Expand All @@ -29,6 +30,7 @@ pub struct ApplyExpr {
}

impl ApplyExpr {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
Expand All @@ -37,6 +39,7 @@ impl ApplyExpr {
allow_threading: bool,
input_schema: Option<SchemaRef>,
output_dtype: Option<DataType>,
returns_scalar: bool,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise)
Expand All @@ -50,7 +53,8 @@ impl ApplyExpr {
function,
expr,
collect_groups: options.collect_groups,
returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
function_returns_scalar: options.flags.contains(FunctionFlags::RETURNS_SCALAR),
function_operates_on_scalar: returns_scalar,
allow_rename: options.flags.contains(FunctionFlags::ALLOW_RENAME),
pass_name_to_apply: options.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY),
input_schema,
Expand All @@ -72,7 +76,8 @@ impl ApplyExpr {
function,
expr,
collect_groups,
returns_scalar: false,
function_returns_scalar: false,
function_operates_on_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
Expand Down Expand Up @@ -104,7 +109,7 @@ impl ApplyExpr {
ca: ListChunked,
) -> PolarsResult<AggregationContext<'a>> {
let all_unit_len = all_unit_length(&ca);
if all_unit_len && self.returns_scalar {
if all_unit_len && self.function_returns_scalar {
ac.with_agg_state(AggState::AggregatedScalar(
ca.explode().unwrap().into_series(),
));
Expand Down Expand Up @@ -253,7 +258,7 @@ impl ApplyExpr {
let mut ac = acs.swap_remove(0);
ac.with_update_groups(UpdateGroups::No);

let agg_state = if self.returns_scalar {
let agg_state = if self.function_returns_scalar {
AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype))
} else {
match self.collect_groups {
Expand Down Expand Up @@ -426,6 +431,9 @@ impl PhysicalExpr for ApplyExpr {
None
}
}
fn is_scalar(&self) -> bool {
self.function_returns_scalar || self.function_operates_on_scalar
}
}

fn apply_multiple_elementwise<'a>(
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct BinaryExpr {
expr: Expr,
has_literal: bool,
allow_threading: bool,
is_scalar: bool,
}

impl BinaryExpr {
Expand All @@ -25,6 +26,7 @@ impl BinaryExpr {
expr: Expr,
has_literal: bool,
allow_threading: bool,
is_scalar: bool,
) -> Self {
Self {
left,
Expand All @@ -33,6 +35,7 @@ impl BinaryExpr {
expr,
has_literal,
allow_threading,
is_scalar,
}
}
}
Expand Down Expand Up @@ -254,6 +257,10 @@ impl PhysicalExpr for BinaryExpr {
self.expr.to_field(input_schema, Context::Default)
}

fn is_scalar(&self) -> bool {
self.is_scalar
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ impl PhysicalExpr for CastExpr {
})
}

fn is_scalar(&self) -> bool {
self.input.is_scalar()
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-expr/src/expressions/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ impl PhysicalExpr for ColumnExpr {
)
})
}
fn is_scalar(&self) -> bool {
false
}
}

impl PartitionedAggregation for ColumnExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl PhysicalExpr for CountExpr {
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

fn is_scalar(&self) -> bool {
true
}
}

impl PartitionedAggregation for CountExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,8 @@ impl PhysicalExpr for FilterExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl PhysicalExpr for GatherExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.phys_expr.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
self.returns_scalar
}
}

impl GatherExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl PhysicalExpr for LiteralExpr {
fn is_literal(&self) -> bool {
true
}

fn is_scalar(&self) -> bool {
self.0.is_scalar()
}
}

impl PartitionedAggregation for LiteralExpr {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ pub trait PhysicalExpr: Send + Sync {
fn is_literal(&self) -> bool {
false
}
fn is_scalar(&self) -> bool;
}

impl Display for &dyn PhysicalExpr {
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/rolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,8 @@ impl PhysicalExpr for RollingExpr {
fn as_expression(&self) -> Option<&Expr> {
Some(&self.expr)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,8 @@ impl PhysicalExpr for SliceExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,8 @@ impl PhysicalExpr for SortExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.physical_expr.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,8 @@ impl PhysicalExpr for SortByExpr {
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
self.input.to_field(input_schema)
}

fn is_scalar(&self) -> bool {
false
}
}
7 changes: 7 additions & 0 deletions crates/polars-expr/src/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct TernaryExpr {
expr: Expr,
// Can be expensive on small data to run literals in parallel.
run_par: bool,
returns_scalar: bool,
}

impl TernaryExpr {
Expand All @@ -21,13 +22,15 @@ impl TernaryExpr {
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
run_par: bool,
returns_scalar: bool,
) -> Self {
Self {
predicate,
truthy,
falsy,
expr,
run_par,
returns_scalar,
}
}
}
Expand Down Expand Up @@ -322,6 +325,10 @@ impl PhysicalExpr for TernaryExpr {
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

fn is_scalar(&self) -> bool {
self.returns_scalar
}
}

impl PartitionedAggregation for TernaryExpr {
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-expr/src/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,10 @@ impl PhysicalExpr for WindowExpr {
match self.determine_map_strategy(ac.agg_state(), sorted_keys, &gb)? {
Nothing => {
let mut out = ac.flat_naive().into_owned();

if ac.is_literal() {
out = out.new_from_index(0, df.height())
}
cache_gb(gb, state, &cache_key);
if let Some(name) = &self.out_name {
out.rename(name.clone());
Expand Down Expand Up @@ -630,6 +634,10 @@ impl PhysicalExpr for WindowExpr {
self.function.to_field(input_schema, Context::Default)
}

fn is_scalar(&self) -> bool {
false
}

#[allow(clippy::ptr_arg)]
fn evaluate_on_groups<'a>(
&self,
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ fn create_physical_expr_inner(
)))
},
BinaryExpr { left, op, right } => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let lhs = create_physical_expr_inner(*left, ctxt, expr_arena, schema, state)?;
let rhs = create_physical_expr_inner(*right, ctxt, expr_arena, schema, state)?;
Ok(Arc::new(phys_expr::BinaryExpr::new(
Expand All @@ -302,6 +303,7 @@ fn create_physical_expr_inner(
node_to_expr(expression, expr_arena),
state.local.has_lit,
state.allow_threading,
is_scalar,
)))
},
Column(column) => Ok(Arc::new(ColumnExpr::new(
Expand Down Expand Up @@ -444,6 +446,7 @@ fn create_physical_expr_inner(
truthy,
falsy,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let mut lit_count = 0u8;
state.reset();
let predicate =
Expand All @@ -461,6 +464,7 @@ fn create_physical_expr_inner(
falsy,
node_to_expr(expression, expr_arena),
lit_count < 2,
is_scalar,
)))
},
AnonymousFunction {
Expand All @@ -469,6 +473,7 @@ fn create_physical_expr_inner(
output_type: _,
options,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
expr_arena
.get(expression)
Expand Down Expand Up @@ -500,6 +505,7 @@ fn create_physical_expr_inner(
state.allow_threading,
schema.cloned(),
output_dtype,
is_scalar,
)))
},
Function {
Expand All @@ -508,6 +514,7 @@ fn create_physical_expr_inner(
options,
..
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype = schema.and_then(|schema| {
expr_arena
.get(expression)
Expand Down Expand Up @@ -538,6 +545,7 @@ fn create_physical_expr_inner(
state.allow_threading,
schema.cloned(),
output_dtype,
is_scalar,
)))
},
Slice {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-mem-engine/src/executors/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl ProjectionExec {
self.has_windows,
self.options.run_parallel,
)?;
check_expand_literals(selected_cols, df.is_empty(), self.options)
check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)
});

let df = POOL.install(|| iter.collect::<PolarsResult<Vec<_>>>())?;
Expand All @@ -53,7 +53,7 @@ impl ProjectionExec {
self.has_windows,
self.options.run_parallel,
)?;
check_expand_literals(selected_cols, df.is_empty(), self.options)?
check_expand_literals(&df, &self.expr, selected_cols, df.is_empty(), self.options)?
};

// this only runs during testing and check if the runtime type matches the predicted schema
Expand Down
Loading

0 comments on commit 12cce97

Please sign in to comment.