From 205355c5a6be13df1158a5f9943b22eb9bc7a864 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 5 Oct 2023 16:58:13 +0400 Subject: [PATCH] fix(rust,python): streamline `is_in` handling of mismatched dtypes and fix a minor regression (#11533) --- .../optimizer/type_coercion/mod.rs | 85 ++++++++----------- py-polars/tests/unit/operations/test_is_in.py | 45 +++++++++- 2 files changed, 79 insertions(+), 51 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index be80b4b68f44..a69d5e1a9afb 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -350,77 +350,62 @@ impl OptimizationRule for TypeCoercionRule { let casted_expr = match (&type_left, &type_other) { // types are equal, do nothing (a, b) if a == b => return Ok(None), + // all-null can represent anything (and/or empty list), so cast to target dtype + (_, DataType::Null) => AExpr::Cast { + expr: other_node, + data_type: type_left, + strict: false, + }, // cast both local and global string cache // note that there might not yet be a rev #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_), DataType::Utf8) => { - AExpr::Cast { - expr: other_node, - data_type: DataType::Categorical(None), - strict: false, - } + (DataType::Categorical(_), DataType::Utf8) => AExpr::Cast { + expr: other_node, + data_type: DataType::Categorical(None), + strict: false, }, #[cfg(feature = "dtype-decimal")] (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) }, - // can't check for more granular time_unit in less-granular time_unit data, + // can't check for more granular time_unit in less-granular time_unit data, // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { - if lhs_unit <= rhs_unit { return Ok(None) } - else { - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) } }, (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { - if lhs_unit <= rhs_unit { return Ok(None) } - else { - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) } }, - // don't attempt to cast between obviously mismatched types; - // we should error early/explicitly on invalid comparisons - ( - _, - | DataType::Datetime(_, _) - | DataType::Duration(_) - | DataType::Date - | DataType::Time - | DataType::Boolean - | DataType::Binary - | DataType::Utf8, - ) - | ( - | DataType::Datetime(_, _) - | DataType::Duration(_) - | DataType::Date - | DataType::Time - | DataType::Boolean - | DataType::Binary - | DataType::Utf8, - _, - ) => { - match type_other { - // all-null can represent anything (and/or empty list), so cast to target dtype - DataType::Null => AExpr::Cast {expr: other_node, data_type: type_left, strict: false}, - _ => polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) + (_, DataType::List(other_inner)) => { + if other_inner.as_ref() == &type_left + || (type_left == DataType::Null) + || (other_inner.as_ref() == &DataType::Null) + || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) + { + return Ok(None); } + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_left, &type_other) }, #[cfg(feature = "dtype-struct")] (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), - (DataType::List(_), _) | (_, DataType::List(_)) => return Ok(None), - // if rhs is another type, we cast it to lhs (we do not use supertype - // as `is_in` operation should not implicitly cast the whole column) - (a, b) - // for integer/float comparison we let them use supertypes. - if !(a.is_integer() && b.is_float()) => - { - AExpr::Cast {expr: other_node, data_type: type_left, strict: false } + + // don't attempt to cast between obviously mismatched types, but + // allow integer/float comparison (will use their supertypes). + (a, b) => { + if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) { + return Ok(None); + } + polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) }, - // do nothing - _ => return Ok(None), }; - let mut input = input.clone(); let other_input = expr_arena.add(casted_expr); input[1] = other_input; diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index 5b277a60b752..7c4a90c104fe 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -105,7 +105,6 @@ def test_is_in_float_list_10764() -> None: "n": [3.0, 2.0], } ) - assert df.select(pl.col("n").is_in("lst").alias("is_in")).to_dict(False) == { "is_in": [True, False] } @@ -165,3 +164,47 @@ def test_is_in_null() -> None: def test_is_in_invalid_shape() -> None: with pytest.raises(pl.ComputeError): pl.Series("a", [1, 2, 3]).is_in([[]]) + + +@pytest.mark.parametrize( + ("df", "matches", "expected_error"), + [ + ( + pl.DataFrame({"a": [1, 2], "b": [[1.0, 2.5], [3.0, 4.0]]}), + [True, False], + None, + ), + ( + pl.DataFrame({"a": [2.5, 3.0], "b": [[1, 2], [3, 4]]}), + [False, True], + None, + ), + ( + pl.DataFrame( + {"a": [None, None], "b": [[1, 2], [3, 4]]}, + schema_overrides={"a": pl.Null}, + ), + [None, None], + None, + ), + ( + pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}), + None, + r"`is_in` cannot check for Utf8 values in List\(Int64\) data", + ), + ( + pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}), + None, + r"`is_in` cannot check for Date values in List\(Int64\) data", + ), + ], +) +def test_is_in_expr_list_series( + df: pl.DataFrame, matches: list[bool] | None, expected_error: str | None +) -> None: + expr_is_in = pl.col("a").is_in(pl.col("b")) + if matches: + assert df.select(expr_is_in).to_series().to_list() == matches + else: + with pytest.raises(pl.InvalidOperationError, match=expected_error): + df.select(expr_is_in)