diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index ac51f5b8fabe..28c54b73675d 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -331,6 +331,7 @@ impl OptimizationRule for TypeCoercionRule { op, right: node_right, } => return process_binary(expr_arena, lp_arena, lp_node, node_left, op, node_right), + #[cfg(feature = "is_in")] AExpr::Function { function: FunctionExpr::Boolean(BooleanFunction::IsIn), @@ -356,30 +357,69 @@ impl OptimizationRule for TypeCoercionRule { AExpr::Cast { expr: other_node, data_type: DataType::Categorical(None), - // does not matter strict: false, } }, - (dt, DataType::Utf8) => { - polars_bail!(ComputeError: "cannot compare {:?} to {:?} type in 'is_in' operation", dt, type_other) + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + // can't check for more granular time_unit in less-granular time_unit data, + // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) + (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { + match (lhs_unit, rhs_unit) { + (TimeUnit::Nanoseconds, _) => return Ok(None), + (TimeUnit::Microseconds, TimeUnit::Microseconds | TimeUnit::Milliseconds) => return Ok(None), + (TimeUnit::Milliseconds, TimeUnit::Milliseconds) => return Ok(None), + _ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) + } + }, + (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { + match (lhs_unit, rhs_unit) { + (TimeUnit::Nanoseconds, _) => return Ok(None), + (TimeUnit::Microseconds, TimeUnit::Microseconds | TimeUnit::Milliseconds) => return Ok(None), + (TimeUnit::Milliseconds, TimeUnit::Milliseconds) => return Ok(None), + _ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) + } + }, + // don't attempt to cast between obviously mismatched types; + // we should error early/explicitly on invalid comparisons + ( + _, + | DataType::Datetime(_, _) + | DataType::Duration(_) + | DataType::Date + | DataType::Time + | DataType::Boolean + | DataType::Binary + | DataType::Utf8, + ) + | ( + | DataType::Datetime(_, _) + | DataType::Duration(_) + | DataType::Date + | DataType::Time + | DataType::Boolean + | DataType::Binary + | DataType::Utf8, + _, + ) => { + match type_other { + // all-null can represent anything (and/or empty list), so cast to target dtype + DataType::Null => AExpr::Cast {expr: other_node, data_type: type_left, strict: false}, + _ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) + } }, - (DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None), #[cfg(feature = "dtype-struct")] (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), - // if right is another type, we cast it to left - // we do not use super-type as an `is_in` operation should not - // cast the whole column implicitly. + (DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None), + // if rhs is another type, we cast it to lhs (we do not use supertype + // as `is_in` operation should not implicitly cast the whole column) (a, b) - if a != b - // For integer/ float comparison we let them use supertypes. - && !(a.is_integer() && b.is_float()) => + // for integer/float comparison we let them use supertypes. + if !(a.is_integer() && b.is_float()) => { - AExpr::Cast { - expr: other_node, - data_type: type_left, - // does not matter - strict: false, - } + AExpr::Cast {expr: other_node, data_type: type_left, strict: false } }, // do nothing _ => return Ok(None), diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index d12193685bd8..0648b5c6952e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -27,6 +27,7 @@ FLOAT_DTYPES, INTEGER_DTYPES, Categorical, + Null, Struct, UInt32, Utf8, @@ -4930,8 +4931,8 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Self: if isinstance(other, Collection) and not isinstance(other, str): if isinstance(other, (Set, FrozenSet)): other = list(other) - other = F.lit(pl.Series(other)) - other = other._pyexpr + implied_dtype = Null if len(other) == 0 else None + other = F.lit(pl.Series(other, dtype=implied_dtype))._pyexpr else: other = parse_as_expression(other) return self._from_pyexpr(self._pyexpr.is_in(other)) diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index 24b6d86fdf2f..5b277a60b752 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -80,7 +80,12 @@ def test_is_in_null_prop() -> None: .item() is None ) - assert pl.Series([None], dtype=pl.Boolean).is_in(pl.Series([42])).item() is None + with pytest.raises( + pl.InvalidOperationError, + match="`is_in` cannot check for Int64 values in Boolean data", + ): + _res = pl.Series([None], dtype=pl.Boolean).is_in(pl.Series([42])).item() + assert ( pl.Series([{"a": None}], dtype=pl.Struct({"a": pl.Boolean})) .is_in(pl.Series([{"a": 42}])) @@ -136,7 +141,10 @@ def test_is_in_series() -> None: ] assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height - with pytest.raises(pl.ComputeError, match=r"cannot compare"): + with pytest.raises( + pl.InvalidOperationError, + match=r"`is_in` cannot check for Utf8 values in Int64 data", + ): df.select(pl.col("b").is_in(["x", "x"])) # check we don't shallow-copy and accidentally modify 'a' (see: #10072) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 846d4074495f..1ac1ebef3e06 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -3,13 +3,15 @@ import io import re from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING +from decimal import Decimal +from typing import TYPE_CHECKING, Any import numpy as np import pytest import polars as pl from polars.datatypes.convert import dtype_to_py_type +from polars.exceptions import InvalidOperationError if TYPE_CHECKING: from polars.type_aliases import ConcatMethod @@ -693,6 +695,53 @@ def test_empty_inputs_error() -> None: df.select(pl.sum_horizontal(pl.exclude("col1"))) +@pytest.mark.parametrize( + ("colname", "values", "expected"), + [ + ("a", [2], [False, True, False]), + ("a", [True, False], None), + ("a", ["2", "3", "4"], None), + ("b", [Decimal("3.14")], None), + ("c", [-2, -1, 0, 1, 2], None), + ( + "d", + pl.datetime_range( + datetime.now(), + datetime.now(), + interval="2345ns", + time_unit="ns", + eager=True, + ), + None, + ), + ("d", [time(10, 30)], None), + ("e", [datetime(1999, 12, 31, 10, 30)], None), + ("f", ["xx", "zz"], None), + ], +) +def test_invalid_is_in_dtypes( + colname: str, values: list[Any], expected: list[Any] | None +) -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [-2.5, 0.0, 2.5], + "c": [True, None, False], + "d": [datetime(2001, 10, 30), None, datetime(2009, 7, 5)], + "e": [date(2029, 12, 31), date(1999, 12, 31), None], + "f": [b"xx", b"yy", b"zz"], + } + ) + if expected is None: + with pytest.raises( + InvalidOperationError, + match="`is_in` cannot check for .*? values in .*? data", + ): + df.select(pl.col(colname).is_in(values)) + else: + assert df.select(pl.col(colname).is_in(values))[colname].to_list() == expected + + def test_sort_by_error() -> None: df = pl.DataFrame( {