Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Adds str.length_bytes() function #2775

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ class PyExpr:
def utf8_extract_all(self, pattern: PyExpr, index: int) -> PyExpr: ...
def utf8_replace(self, pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ...
def utf8_length(self) -> PyExpr: ...
def utf8_length_bytes(self) -> PyExpr: ...
def utf8_lower(self) -> PyExpr: ...
def utf8_upper(self) -> PyExpr: ...
def utf8_lstrip(self) -> PyExpr: ...
Expand Down Expand Up @@ -1331,6 +1332,7 @@ class PySeries:
def utf8_extract_all(self, pattern: PySeries, index: int) -> PySeries: ...
def utf8_replace(self, pattern: PySeries, replacement: PySeries, regex: bool) -> PySeries: ...
def utf8_length(self) -> PySeries: ...
def utf8_length_bytes(self) -> PySeries: ...
def utf8_lower(self) -> PySeries: ...
def utf8_upper(self) -> PySeries: ...
def utf8_lstrip(self) -> PySeries: ...
Expand Down
27 changes: 27 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,33 @@ def length(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_length())

def length_bytes(self) -> Expression:
"""Retrieves the length for a UTF-8 string column in bytes.

Example:
>>> import daft
>>> df = daft.from_pydict({"x": ["😉test", "hey̆", "baz"]})
>>> df = df.select(df["x"].str.length_bytes())
>>> df.show()
╭────────╮
│ x │
│ --- │
│ UInt64 │
╞════════╡
│ 8 │
├╌╌╌╌╌╌╌╌┤
│ 5 │
├╌╌╌╌╌╌╌╌┤
│ 3 │
╰────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)

Returns:
Expression: an UInt64 expression with the length of each string
"""
return Expression._from_pyexpr(self._expr.utf8_length_bytes())

def lower(self) -> Expression:
"""Convert UTF-8 string to all lowercase

Expand Down
4 changes: 4 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,10 @@ def length(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_length())

def length_bytes(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_length_bytes())

def lower(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_lower())
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.extract_all
Expression.str.replace
Expression.str.length
Expression.str.length_bytes
Expression.str.lower
Expression.str.upper
Expression.str.lstrip
Expand Down
13 changes: 13 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,19 @@ impl Utf8Array {
Ok(UInt64Array::from((self.name(), Box::new(arrow_result))))
}

pub fn length_bytes(&self) -> DaftResult<UInt64Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.len() as u64)
})
.collect::<arrow2::array::UInt64Array>()
.with_validity(self_arrow.validity().cloned());
Ok(UInt64Array::from((self.name(), Box::new(arrow_result))))
}

pub fn lower(&self) -> DaftResult<Utf8Array> {
self.unary_broadcasted_op(|val| val.to_lowercase().into())
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ impl PySeries {
Ok(self.series.utf8_length()?.into())
}

pub fn utf8_length_bytes(&self) -> PyResult<Self> {
Ok(self.series.utf8_length_bytes()?.into())
}

pub fn utf8_lower(&self) -> PyResult<Self> {
Ok(self.series.utf8_lower()?.into())
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ impl Series {
self.with_utf8_array(|arr| Ok(arr.length()?.into_series()))
}

pub fn utf8_length_bytes(&self) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.length_bytes()?.into_series()))
}

pub fn utf8_lower(&self) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.lower()?.into_series()))
}
Expand Down
47 changes: 47 additions & 0 deletions src/daft-dsl/src/functions/utf8/length_bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::functions::FunctionExpr;
use crate::ExprRef;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct LengthBytesEvaluator {}

impl FunctionEvaluator for LengthBytesEvaluator {
fn fn_name(&self) -> &'static str {
"length_bytes"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)),
_ => Err(DaftError::TypeError(format!(
"Expects input to length_bytes to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_length_bytes(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod find;
mod ilike;
mod left;
mod length;
mod length_bytes;
mod like;
mod lower;
mod lpad;
Expand Down Expand Up @@ -36,6 +37,7 @@ use find::FindEvaluator;
use ilike::IlikeEvaluator;
use left::LeftEvaluator;
use length::LengthEvaluator;
use length_bytes::LengthBytesEvaluator;
use like::LikeEvaluator;
use lower::LowerEvaluator;
use lpad::LpadEvaluator;
Expand Down Expand Up @@ -70,6 +72,7 @@ pub enum Utf8Expr {
ExtractAll(usize),
Replace(bool),
Length,
LengthBytes,
Lower,
Upper,
Lstrip,
Expand Down Expand Up @@ -104,6 +107,7 @@ impl Utf8Expr {
ExtractAll(_) => &ExtractAllEvaluator {},
Replace(_) => &ReplaceEvaluator {},
Length => &LengthEvaluator {},
LengthBytes => &LengthBytesEvaluator {},
Lower => &LowerEvaluator {},
Upper => &UpperEvaluator {},
Lstrip => &LstripEvaluator {},
Expand Down Expand Up @@ -198,6 +202,14 @@ pub fn length(data: ExprRef) -> ExprRef {
.into()
}

pub fn length_bytes(data: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::LengthBytes),
inputs: vec![data],
}
.into()
}

pub fn lower(data: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Lower),
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,11 @@ impl PyExpr {
Ok(length(self.into()).into())
}

pub fn utf8_length_bytes(&self) -> PyResult<Self> {
use crate::functions::utf8::length_bytes;
Ok(length_bytes(self.into()).into())
}

pub fn utf8_lower(&self) -> PyResult<Self> {
use crate::functions::utf8::lower;
Ok(lower(self.into()).into())
Expand Down
4 changes: 4 additions & 0 deletions src/daft-sql/src/modules/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ pub(crate) fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult<Exp
ensure!(args.len() == 1, "length takes exactly one argument");
Ok(length(args[0].clone()))
}
LengthBytes => {
ensure!(args.len() == 1, "length_bytes takes exactly one argument");
Ok(length_bytes(args[0].clone()))
}
Lower => {
ensure!(args.len() == 1, "lower takes exactly one argument");
Ok(lower(args[0].clone()))
Expand Down
6 changes: 6 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ def test_series_utf8_length_unicode() -> None:
assert result.to_pylist() == [5, 4]


def test_series_utf8_length_bytes() -> None:
s = Series.from_arrow(pa.array(["😉test", "hey̆", "baz", None, ""]))
result = s.str.length_bytes()
assert result.to_pylist() == [8, 5, 3, None, 0]


@pytest.mark.parametrize(
["data", "expected"],
[
Expand Down
Loading