Skip to content

Commit

Permalink
fix: AnyValue Series from Categorical/Enum (#18893)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Sep 24, 2024
1 parent fa84194 commit 162c1e6
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 13 deletions.
104 changes: 104 additions & 0 deletions crates/polars-core/src/chunked_array/logical/enum_/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use std::sync::Arc;

use arrow::array::UInt32Vec;
use polars_error::{polars_bail, polars_err, PolarsResult};
use polars_utils::aliases::{InitHashMaps, PlHashMap};
use polars_utils::pl_str::PlSmallStr;
use polars_utils::IdxSize;

use super::{CategoricalChunked, CategoricalOrdering, DataType, Field, RevMapping, UInt32Chunked};

pub struct EnumChunkedBuilder {
name: PlSmallStr,
enum_builder: UInt32Vec,

rev: Arc<RevMapping>,
ordering: CategoricalOrdering,

// Mapping to amortize the costs of lookups.
mapping: PlHashMap<PlSmallStr, u32>,
strict: bool,
}

impl EnumChunkedBuilder {
pub fn new(
name: PlSmallStr,
capacity: usize,
rev: Arc<RevMapping>,
ordering: CategoricalOrdering,
strict: bool,
) -> Self {
Self {
name,
enum_builder: UInt32Vec::with_capacity(capacity),

rev,
ordering,

mapping: PlHashMap::new(),
strict,
}
}

pub fn append_str(&mut self, v: &str) -> PolarsResult<&mut Self> {
match self.mapping.get(v) {
Some(v) => self.enum_builder.push(Some(*v)),
None => {
let Some(iv) = self.rev.find(v) else {
if self.strict {
polars_bail!(InvalidOperation: "cannot append '{v}' to enum without that variant");
} else {
self.enum_builder.push(None);
return Ok(self);
}
};
self.mapping.insert(v.into(), iv);
self.enum_builder.push(Some(iv));
},
}

Ok(self)
}

pub fn append_null(&mut self) -> &mut Self {
self.enum_builder.push(None);
self
}

pub fn append_enum(&mut self, v: u32, rev: &RevMapping) -> PolarsResult<&mut Self> {
if !self.rev.same_src(rev) {
if self.strict {
return Err(polars_err!(ComputeError: "incompatible enum types"));
} else {
self.enum_builder.push(None);
}
} else {
self.enum_builder.push(Some(v));
}

Ok(self)
}

pub fn finish(self) -> CategoricalChunked {
let arr = self.enum_builder.freeze();
let null_count = arr.validity().map_or(0, |a| a.unset_bits()) as IdxSize;
let length = arr.len() as IdxSize;
let ca = unsafe {
UInt32Chunked::new_with_dims(
Arc::new(Field::new(
self.name,
DataType::Enum(Some(self.rev.clone()), self.ordering),
)),
vec![Box::new(arr)],
length,
null_count,
)
};

// SAFETY: keys and values are in bounds
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(ca, self.rev, true, self.ordering)
}
.with_fast_unique(true)
}
}
2 changes: 2 additions & 0 deletions crates/polars-core/src/chunked_array/logical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod duration;
pub use duration::*;
#[cfg(feature = "dtype-categorical")]
pub mod categorical;
#[cfg(feature = "dtype-categorical")]
pub mod enum_;
#[cfg(feature = "dtype-time")]
mod time;

Expand Down
99 changes: 87 additions & 12 deletions crates/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use std::fmt::Write;

use arrow::bitmap::MutableBitmap;

#[cfg(feature = "dtype-categorical")]
use crate::chunked_array::cast::CastOptions;
#[cfg(feature = "object")]
use crate::chunked_array::object::registry::ObjectRegistry;
use crate::prelude::*;
Expand Down Expand Up @@ -146,9 +144,9 @@ impl Series {
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => any_values_to_duration(values, *tu, strict)?.into_series(),
#[cfg(feature = "dtype-categorical")]
dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => {
any_values_to_categorical(values, dt, strict)?
},
dt @ DataType::Categorical(_, _) => any_values_to_categorical(values, dt, strict)?,
#[cfg(feature = "dtype-categorical")]
dt @ DataType::Enum(_, _) => any_values_to_enum(values, dt, strict)?,
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(precision, scale) => {
any_values_to_decimal(values, *precision, *scale, strict)?.into_series()
Expand Down Expand Up @@ -445,14 +443,91 @@ fn any_values_to_categorical(
dtype: &DataType,
strict: bool,
) -> PolarsResult<Series> {
// TODO: Handle AnyValues of type Categorical/Enum.
// TODO: Avoid materializing to String before casting to Categorical/Enum.
let ca = any_values_to_string(values, strict)?;
if strict {
ca.into_series().strict_cast(dtype)
} else {
ca.cast_with_options(dtype, CastOptions::NonStrict)
let ordering = match dtype {
DataType::Categorical(_, ordering) => ordering,
_ => panic!("any_values_to_categorical with dtype={dtype:?}"),
};

let mut builder = CategoricalChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), *ordering);

let mut owned = String::new(); // Amortize allocations.
for av in values {
match av {
AnyValue::String(s) => builder.append_value(s),
AnyValue::StringOwned(s) => builder.append_value(s),

AnyValue::Enum(s, rev, _) => builder.append_value(rev.get(*s)),
AnyValue::EnumOwned(s, rev, _) => builder.append_value(rev.get(*s)),

AnyValue::Categorical(s, rev, _) => builder.append_value(rev.get(*s)),
AnyValue::CategoricalOwned(s, rev, _) => builder.append_value(rev.get(*s)),

AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(),
AnyValue::Null => builder.append_null(),

av => {
if strict {
return Err(invalid_value_error(&DataType::String, av));
}

owned.clear();
write!(owned, "{av}").unwrap();
builder.append_value(&owned);
},
}
}

let ca = builder.finish();

Ok(ca.into_series())
}

#[cfg(feature = "dtype-categorical")]
fn any_values_to_enum(values: &[AnyValue], dtype: &DataType, strict: bool) -> PolarsResult<Series> {
use self::enum_::EnumChunkedBuilder;

let (rev, ordering) = match dtype {
DataType::Enum(rev, ordering) => (rev.clone(), ordering),
_ => panic!("any_values_to_categorical with dtype={dtype:?}"),
};

let Some(rev) = rev else {
polars_bail!(nyi = "Not yet possible to create enum series without a rev-map");
};

let mut builder =
EnumChunkedBuilder::new(PlSmallStr::EMPTY, values.len(), rev, *ordering, strict);

let mut owned = String::new(); // Amortize allocations.
for av in values {
match av {
AnyValue::String(s) => builder.append_str(s)?,
AnyValue::StringOwned(s) => builder.append_str(s)?,

AnyValue::Enum(s, rev, _) => builder.append_enum(*s, rev)?,
AnyValue::EnumOwned(s, rev, _) => builder.append_enum(*s, rev)?,

AnyValue::Categorical(s, rev, _) => builder.append_str(rev.get(*s))?,
AnyValue::CategoricalOwned(s, rev, _) => builder.append_str(rev.get(*s))?,

AnyValue::Binary(_) | AnyValue::BinaryOwned(_) if !strict => builder.append_null(),
AnyValue::Null => builder.append_null(),

av => {
if strict {
return Err(invalid_value_error(&DataType::String, av));
}

owned.clear();
write!(owned, "{av}").unwrap();
builder.append_str(&owned)?
},
};
}

let ca = builder.finish();

Ok(ca.into_series())
}

#[cfg(feature = "dtype-decimal")]
Expand Down
20 changes: 19 additions & 1 deletion py-polars/tests/unit/constructors/test_any_value_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import polars as pl
from polars._utils.wrap import wrap_s
from polars.polars import PySeries
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType
Expand Down Expand Up @@ -379,7 +380,9 @@ def test_fallback_with_dtype_strict_failure_enum_casting() -> None:
dtype = pl.Enum(["a", "b"])
values = ["a", "b", "c", None]

with pytest.raises(TypeError, match="conversion from `str` to `enum` failed"):
with pytest.raises(
TypeError, match="cannot append 'c' to enum without that variant"
):
PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True)


Expand All @@ -391,3 +394,18 @@ def test_fallback_with_dtype_strict_failure_decimal_precision() -> None:
TypeError, match="decimal precision 3 can't fit values with 5 digits"
):
PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True)


def test_categorical_lit_18874() -> None:
with pl.StringCache():
assert_frame_equal(
pl.DataFrame(
{"a": [1, 2, 3]},
).with_columns(b=pl.lit("foo").cast(pl.Categorical)),
pl.DataFrame(
[
pl.Series("a", [1, 2, 3]),
pl.Series("b", ["foo"] * 3, pl.Categorical),
]
),
)

0 comments on commit 162c1e6

Please sign in to comment.