Skip to content

Commit

Permalink
fix(rust,python): address multiple issues caused by casting is_in v…
Browse files Browse the repository at this point in the history
…alues to the column type being searched
  • Loading branch information
alexander-beedie committed Sep 30, 2023
1 parent dcd0229 commit 3136ed0
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 21 deletions.
72 changes: 56 additions & 16 deletions crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ impl OptimizationRule for TypeCoercionRule {
op,
right: node_right,
} => return process_binary(expr_arena, lp_arena, lp_node, node_left, op, node_right),

#[cfg(feature = "is_in")]
AExpr::Function {
function: FunctionExpr::Boolean(BooleanFunction::IsIn),
Expand All @@ -356,30 +357,69 @@ impl OptimizationRule for TypeCoercionRule {
AExpr::Cast {
expr: other_node,
data_type: DataType::Categorical(None),
// does not matter
strict: false,
}
},
(dt, DataType::Utf8) => {
polars_bail!(ComputeError: "cannot compare {:?} to {:?} type in 'is_in' operation", dt, type_other)
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => {
polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left)
},
// can't check for more granular time_unit in less-granular time_unit data,
// or we'll cast away valid/necessary precision (eg: nanosecs to millisecs)
(DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => {
match (lhs_unit, rhs_unit) {
(TimeUnit::Nanoseconds, _) => return Ok(None),
(TimeUnit::Microseconds, TimeUnit::Microseconds | TimeUnit::Milliseconds) => return Ok(None),
(TimeUnit::Milliseconds, TimeUnit::Milliseconds) => return Ok(None),
_ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit)
}
},
(DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => {
match (lhs_unit, rhs_unit) {
(TimeUnit::Nanoseconds, _) => return Ok(None),
(TimeUnit::Microseconds, TimeUnit::Microseconds | TimeUnit::Milliseconds) => return Ok(None),
(TimeUnit::Milliseconds, TimeUnit::Milliseconds) => return Ok(None),
_ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit)
}
},
// don't attempt to cast between obviously mismatched types;
// we should error early/explicitly on invalid comparisons
(
_,
| DataType::Datetime(_, _)
| DataType::Duration(_)
| DataType::Date
| DataType::Time
| DataType::Boolean
| DataType::Binary
| DataType::Utf8,
)
| (
| DataType::Datetime(_, _)
| DataType::Duration(_)
| DataType::Date
| DataType::Time
| DataType::Boolean
| DataType::Binary
| DataType::Utf8,
_,
) => {
match type_other {
// all-null can represent anything (and/or empty list), so cast to target dtype
DataType::Null => AExpr::Cast {expr: other_node, data_type: type_left, strict: false},
_ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left)
}
},
(DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None),
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None),
// if right is another type, we cast it to left
// we do not use super-type as an `is_in` operation should not
// cast the whole column implicitly.
(DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None),
// if rhs is another type, we cast it to lhs (we do not use supertype
// as `is_in` operation should not implicitly cast the whole column)
(a, b)
if a != b
// For integer/ float comparison we let them use supertypes.
&& !(a.is_integer() && b.is_float()) =>
// for integer/float comparison we let them use supertypes.
if !(a.is_integer() && b.is_float()) =>
{
AExpr::Cast {
expr: other_node,
data_type: type_left,
// does not matter
strict: false,
}
AExpr::Cast {expr: other_node, data_type: type_left, strict: false }
},
// do nothing
_ => return Ok(None),
Expand Down
5 changes: 3 additions & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FLOAT_DTYPES,
INTEGER_DTYPES,
Categorical,
Null,
Struct,
UInt32,
Utf8,
Expand Down Expand Up @@ -4930,8 +4931,8 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Self:
if isinstance(other, Collection) and not isinstance(other, str):
if isinstance(other, (Set, FrozenSet)):
other = list(other)
other = F.lit(pl.Series(other))
other = other._pyexpr
implied_dtype = Null if len(other) == 0 else None
other = F.lit(pl.Series(other, dtype=implied_dtype))._pyexpr
else:
other = parse_as_expression(other)
return self._from_pyexpr(self._pyexpr.is_in(other))
Expand Down
12 changes: 10 additions & 2 deletions py-polars/tests/unit/operations/test_is_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ def test_is_in_null_prop() -> None:
.item()
is None
)
assert pl.Series([None], dtype=pl.Boolean).is_in(pl.Series([42])).item() is None
with pytest.raises(
pl.InvalidOperationError,
match="`is_in` cannot check for Int64 values in Boolean data",
):
_res = pl.Series([None], dtype=pl.Boolean).is_in(pl.Series([42])).item()

assert (
pl.Series([{"a": None}], dtype=pl.Struct({"a": pl.Boolean}))
.is_in(pl.Series([{"a": 42}]))
Expand Down Expand Up @@ -136,7 +141,10 @@ def test_is_in_series() -> None:
]
assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height

with pytest.raises(pl.ComputeError, match=r"cannot compare"):
with pytest.raises(
pl.InvalidOperationError,
match=r"`is_in` cannot check for Utf8 values in Int64 data",
):
df.select(pl.col("b").is_in(["x", "x"]))

# check we don't shallow-copy and accidentally modify 'a' (see: #10072)
Expand Down
51 changes: 50 additions & 1 deletion py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import io
import re
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING
from decimal import Decimal
from typing import TYPE_CHECKING, Any

import numpy as np
import pytest

import polars as pl
from polars.datatypes.convert import dtype_to_py_type
from polars.exceptions import InvalidOperationError

if TYPE_CHECKING:
from polars.type_aliases import ConcatMethod
Expand Down Expand Up @@ -693,6 +695,53 @@ def test_empty_inputs_error() -> None:
df.select(pl.sum_horizontal(pl.exclude("col1")))


@pytest.mark.parametrize(
("colname", "values", "expected"),
[
("a", [2], [False, True, False]),
("a", [True, False], None),
("a", ["2", "3", "4"], None),
("b", [Decimal("3.14")], None),
("c", [-2, -1, 0, 1, 2], None),
(
"d",
pl.datetime_range(
datetime.now(),
datetime.now(),
interval="2345ns",
time_unit="ns",
eager=True,
),
None,
),
("d", [time(10, 30)], None),
("e", [datetime(1999, 12, 31, 10, 30)], None),
("f", ["xx", "zz"], None),
],
)
def test_invalid_is_in_dtypes(
colname: str, values: list[Any], expected: list[Any] | None
) -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": [-2.5, 0.0, 2.5],
"c": [True, None, False],
"d": [datetime(2001, 10, 30), None, datetime(2009, 7, 5)],
"e": [date(2029, 12, 31), date(1999, 12, 31), None],
"f": [b"xx", b"yy", b"zz"],
}
)
if expected is None:
with pytest.raises(
InvalidOperationError,
match="`is_in` cannot check for .*? values in .*? data",
):
df.select(pl.col(colname).is_in(values))
else:
assert df.select(pl.col(colname).is_in(values))[colname].to_list() == expected


def test_sort_by_error() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 3136ed0

Please sign in to comment.