From 162c1e60dd90c4ba9d143875ca269583578d48f1 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 24 Sep 2024 16:48:57 +0200 Subject: [PATCH] fix: AnyValue Series from Categorical/Enum (#18893) --- .../src/chunked_array/logical/enum_/mod.rs | 104 ++++++++++++++++++ .../src/chunked_array/logical/mod.rs | 2 + crates/polars-core/src/series/any_value.rs | 99 +++++++++++++++-- .../constructors/test_any_value_fallbacks.py | 20 +++- 4 files changed, 212 insertions(+), 13 deletions(-) create mode 100644 crates/polars-core/src/chunked_array/logical/enum_/mod.rs diff --git a/crates/polars-core/src/chunked_array/logical/enum_/mod.rs b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs new file mode 100644 index 000000000000..e143a59a7f7b --- /dev/null +++ b/crates/polars-core/src/chunked_array/logical/enum_/mod.rs @@ -0,0 +1,104 @@ +use std::sync::Arc; + +use arrow::array::UInt32Vec; +use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::IdxSize; + +use super::{CategoricalChunked, CategoricalOrdering, DataType, Field, RevMapping, UInt32Chunked}; + +pub struct EnumChunkedBuilder { + name: PlSmallStr, + enum_builder: UInt32Vec, + + rev: Arc, + ordering: CategoricalOrdering, + + // Mapping to amortize the costs of lookups. + mapping: PlHashMap, + strict: bool, +} + +impl EnumChunkedBuilder { + pub fn new( + name: PlSmallStr, + capacity: usize, + rev: Arc, + ordering: CategoricalOrdering, + strict: bool, + ) -> Self { + Self { + name, + enum_builder: UInt32Vec::with_capacity(capacity), + + rev, + ordering, + + mapping: PlHashMap::new(), + strict, + } + } + + pub fn append_str(&mut self, v: &str) -> PolarsResult<&mut Self> { + match self.mapping.get(v) { + Some(v) => self.enum_builder.push(Some(*v)), + None => { + let Some(iv) = self.rev.find(v) else { + if self.strict { + polars_bail!(InvalidOperation: "cannot append '{v}' to enum without that variant"); + } else { + self.enum_builder.push(None); + return Ok(self); + } + }; + self.mapping.insert(v.into(), iv); + self.enum_builder.push(Some(iv)); + }, + } + + Ok(self) + } + + pub fn append_null(&mut self) -> &mut Self { + self.enum_builder.push(None); + self + } + + pub fn append_enum(&mut self, v: u32, rev: &RevMapping) -> PolarsResult<&mut Self> { + if !self.rev.same_src(rev) { + if self.strict { + return Err(polars_err!(ComputeError: "incompatible enum types")); + } else { + self.enum_builder.push(None); + } + } else { + self.enum_builder.push(Some(v)); + } + + Ok(self) + } + + pub fn finish(self) -> CategoricalChunked { + let arr = self.enum_builder.freeze(); + let null_count = arr.validity().map_or(0, |a| a.unset_bits()) as IdxSize; + let length = arr.len() as IdxSize; + let ca = unsafe { + UInt32Chunked::new_with_dims( + Arc::new(Field::new( + self.name, + DataType::Enum(Some(self.rev.clone()), self.ordering), + )), + vec![Box::new(arr)], + length, + null_count, + ) + }; + + // SAFETY: keys and values are in bounds + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering) + } + .with_fast_unique(true) + } +} diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index d4e6d4eb84aa..0baa286e9f1d 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -16,6 +16,8 @@ mod duration; pub use duration::*; #[cfg(feature = "dtype-categorical")] pub mod categorical; +#[cfg(feature = "dtype-categorical")] +pub mod enum_; #[cfg(feature = "dtype-time")] mod time; diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 1a8c1df02b65..1f6562b5397a 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -2,8 +2,6 @@ use std::fmt::Write; use arrow::bitmap::MutableBitmap; -#[cfg(feature = "dtype-categorical")] -use crate::chunked_array::cast::CastOptions; #[cfg(feature = "object")] use crate::chunked_array::object::registry::ObjectRegistry; use crate::prelude::*; @@ -146,9 +144,9 @@ impl Series { #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => any_values_to_duration(values, *tu, strict)?.into_series(), #[cfg(feature = "dtype-categorical")] - dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { - any_values_to_categorical(values, dt, strict)? - }, + dt @ DataType::Categorical(_, _) => any_values_to_categorical(values, dt, strict)?, + #[cfg(feature = "dtype-categorical")] + dt @ DataType::Enum(_, _) => any_values_to_enum(values, dt, strict)?, #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => { any_values_to_decimal(values, *precision, *scale, strict)?.into_series() @@ -445,14 +443,91 @@ fn any_values_to_categorical( dtype: &DataType, strict: bool, ) -> PolarsResult { - // TODO: Handle AnyValues of type Categorical/Enum. - // TODO: Avoid materializing to String before casting to Categorical/Enum. - let ca = any_values_to_string(values, strict)?; - if strict { - ca.into_series().strict_cast(dtype) - } else { - ca.cast_with_options(dtype, CastOptions::NonStrict) + let ordering = match dtype { + DataType::Categorical(_, ordering) => ordering, + _ => panic!("any_values_to_categorical with dtype={dtype:?}"), + }; + + let mut builder = CategoricalChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), *ordering); + + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + + AnyValue::Enum(s, rev, _) => builder.append_value(rev.get(*s)), + AnyValue::EnumOwned(s, rev, _) => builder.append_value(rev.get(*s)), + + AnyValue::Categorical(s, rev, _) => builder.append_value(rev.get(*s)), + AnyValue::CategoricalOwned(s, rev, _) => builder.append_value(rev.get(*s)), + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), + AnyValue::Null => builder.append_null(), + + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } + + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_value(&owned); + }, + } + } + + let ca = builder.finish(); + + Ok(ca.into_series()) +} + +#[cfg(feature = "dtype-categorical")] +fn any_values_to_enum(values: &[AnyValue], dtype: &DataType, strict: bool) -> PolarsResult { + use self::enum_::EnumChunkedBuilder; + + let (rev, ordering) = match dtype { + DataType::Enum(rev, ordering) => (rev.clone(), ordering), + _ => panic!("any_values_to_categorical with dtype={dtype:?}"), + }; + + let Some(rev) = rev else { + polars_bail!(nyi = "Not yet possible to create enum series without a rev-map"); + }; + + let mut builder = + EnumChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), rev, *ordering, strict); + + let mut owned = String::new(); // Amortize allocations. + for av in values { + match av { + AnyValue::String(s) => builder.append_str(s)?, + AnyValue::StringOwned(s) => builder.append_str(s)?, + + AnyValue::Enum(s, rev, _) => builder.append_enum(*s, rev)?, + AnyValue::EnumOwned(s, rev, _) => builder.append_enum(*s, rev)?, + + AnyValue::Categorical(s, rev, _) => builder.append_str(rev.get(*s))?, + AnyValue::CategoricalOwned(s, rev, _) => builder.append_str(rev.get(*s))?, + + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(), + AnyValue::Null => builder.append_null(), + + av => { + if strict { + return Err(invalid_value_error(&DataType::String, av)); + } + + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_str(&owned)? + }, + }; } + + let ca = builder.finish(); + + Ok(ca.into_series()) } #[cfg(feature = "dtype-decimal")] diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py index 0f3c40f21925..8e6f8581072f 100644 --- a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -11,6 +11,7 @@ import polars as pl from polars._utils.wrap import wrap_s from polars.polars import PySeries +from polars.testing import assert_frame_equal if TYPE_CHECKING: from polars._typing import PolarsDataType @@ -379,7 +380,9 @@ def test_fallback_with_dtype_strict_failure_enum_casting() -> None: dtype = pl.Enum(["a", "b"]) values = ["a", "b", "c", None] - with pytest.raises(TypeError, match="conversion from `str` to `enum` failed"): + with pytest.raises( + TypeError, match="cannot append 'c' to enum without that variant" + ): PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) @@ -391,3 +394,18 @@ def test_fallback_with_dtype_strict_failure_decimal_precision() -> None: TypeError, match="decimal precision 3 can't fit values with 5 digits" ): PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) + + +def test_categorical_lit_18874() -> None: + with pl.StringCache(): + assert_frame_equal( + pl.DataFrame( + {"a": [1, 2, 3]}, + ).with_columns(b=pl.lit("foo").cast(pl.Categorical)), + pl.DataFrame( + [ + pl.Series("a", [1, 2, 3]), + pl.Series("b", ["foo"] * 3, pl.Categorical), + ] + ), + )