Skip to content

Commit

Permalink
[FEAT] Adds str.length_bytes() function (#2775)
Browse files Browse the repository at this point in the history
Resolves #2770
  • Loading branch information
thomasjpfan authored Sep 5, 2024
1 parent fa6a482 commit 6fe408c
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 0 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ class PyExpr:
def utf8_extract_all(self, pattern: PyExpr, index: int) -> PyExpr: ...
def utf8_replace(self, pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ...
def utf8_length(self) -> PyExpr: ...
def utf8_length_bytes(self) -> PyExpr: ...
def utf8_lower(self) -> PyExpr: ...
def utf8_upper(self) -> PyExpr: ...
def utf8_lstrip(self) -> PyExpr: ...
Expand Down Expand Up @@ -1339,6 +1340,7 @@ class PySeries:
def utf8_extract_all(self, pattern: PySeries, index: int) -> PySeries: ...
def utf8_replace(self, pattern: PySeries, replacement: PySeries, regex: bool) -> PySeries: ...
def utf8_length(self) -> PySeries: ...
def utf8_length_bytes(self) -> PySeries: ...
def utf8_lower(self) -> PySeries: ...
def utf8_upper(self) -> PySeries: ...
def utf8_lstrip(self) -> PySeries: ...
Expand Down
27 changes: 27 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,33 @@ def length(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_length())

def length_bytes(self) -> Expression:
"""Retrieves the length for a UTF-8 string column in bytes.
Example:
>>> import daft
>>> df = daft.from_pydict({"x": ["😉test", "hey̆", "baz"]})
>>> df = df.select(df["x"].str.length_bytes())
>>> df.show()
╭────────╮
│ x │
│ --- │
│ UInt64 │
╞════════╡
│ 8 │
├╌╌╌╌╌╌╌╌┤
│ 5 │
├╌╌╌╌╌╌╌╌┤
│ 3 │
╰────────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
Returns:
Expression: an UInt64 expression with the length of each string
"""
return Expression._from_pyexpr(self._expr.utf8_length_bytes())

def lower(self) -> Expression:
"""Convert UTF-8 string to all lowercase
Expand Down
4 changes: 4 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,10 @@ def length(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_length())

def length_bytes(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_length_bytes())

def lower(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.utf8_lower())
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.extract_all
Expression.str.replace
Expression.str.length
Expression.str.length_bytes
Expression.str.lower
Expression.str.upper
Expression.str.lstrip
Expand Down
13 changes: 13 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,19 @@ impl Utf8Array {
Ok(UInt64Array::from((self.name(), Box::new(arrow_result))))
}

pub fn length_bytes(&self) -> DaftResult<UInt64Array> {
let self_arrow = self.as_arrow();
let arrow_result = self_arrow
.iter()
.map(|val| {
let v = val?;
Some(v.len() as u64)
})
.collect::<arrow2::array::UInt64Array>()
.with_validity(self_arrow.validity().cloned());
Ok(UInt64Array::from((self.name(), Box::new(arrow_result))))
}

pub fn lower(&self) -> DaftResult<Utf8Array> {
self.unary_broadcasted_op(|val| val.to_lowercase().into())
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ impl PySeries {
Ok(self.series.utf8_length()?.into())
}

pub fn utf8_length_bytes(&self) -> PyResult<Self> {
Ok(self.series.utf8_length_bytes()?.into())
}

pub fn utf8_lower(&self) -> PyResult<Self> {
Ok(self.series.utf8_lower()?.into())
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ impl Series {
self.with_utf8_array(|arr| Ok(arr.length()?.into_series()))
}

pub fn utf8_length_bytes(&self) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.length_bytes()?.into_series()))
}

pub fn utf8_lower(&self) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.lower()?.into_series()))
}
Expand Down
47 changes: 47 additions & 0 deletions src/daft-dsl/src/functions/utf8/length_bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::functions::FunctionExpr;
use crate::ExprRef;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct LengthBytesEvaluator {}

impl FunctionEvaluator for LengthBytesEvaluator {
fn fn_name(&self) -> &'static str {
"length_bytes"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)),
_ => Err(DaftError::TypeError(format!(
"Expects input to length_bytes to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data] => data.utf8_length_bytes(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod find;
mod ilike;
mod left;
mod length;
mod length_bytes;
mod like;
mod lower;
mod lpad;
Expand Down Expand Up @@ -36,6 +37,7 @@ use find::FindEvaluator;
use ilike::IlikeEvaluator;
use left::LeftEvaluator;
use length::LengthEvaluator;
use length_bytes::LengthBytesEvaluator;
use like::LikeEvaluator;
use lower::LowerEvaluator;
use lpad::LpadEvaluator;
Expand Down Expand Up @@ -70,6 +72,7 @@ pub enum Utf8Expr {
ExtractAll(usize),
Replace(bool),
Length,
LengthBytes,
Lower,
Upper,
Lstrip,
Expand Down Expand Up @@ -104,6 +107,7 @@ impl Utf8Expr {
ExtractAll(_) => &ExtractAllEvaluator {},
Replace(_) => &ReplaceEvaluator {},
Length => &LengthEvaluator {},
LengthBytes => &LengthBytesEvaluator {},
Lower => &LowerEvaluator {},
Upper => &UpperEvaluator {},
Lstrip => &LstripEvaluator {},
Expand Down Expand Up @@ -198,6 +202,14 @@ pub fn length(data: ExprRef) -> ExprRef {
.into()
}

pub fn length_bytes(data: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::LengthBytes),
inputs: vec![data],
}
.into()
}

pub fn lower(data: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Lower),
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,11 @@ impl PyExpr {
Ok(length(self.into()).into())
}

pub fn utf8_length_bytes(&self) -> PyResult<Self> {
use crate::functions::utf8::length_bytes;
Ok(length_bytes(self.into()).into())
}

pub fn utf8_lower(&self) -> PyResult<Self> {
use crate::functions::utf8::lower;
Ok(lower(self.into()).into())
Expand Down
4 changes: 4 additions & 0 deletions src/daft-sql/src/modules/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult<ExprRef> {
ensure!(args.len() == 1, "length takes exactly one argument");
Ok(length(args[0].clone()))
}
LengthBytes => {
ensure!(args.len() == 1, "length_bytes takes exactly one argument");
Ok(length_bytes(args[0].clone()))
}
Lower => {
ensure!(args.len() == 1, "lower takes exactly one argument");
Ok(lower(args[0].clone()))
Expand Down
6 changes: 6 additions & 0 deletions tests/series/test_utf8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ def test_series_utf8_length_unicode() -> None:
assert result.to_pylist() == [5, 4]


def test_series_utf8_length_bytes() -> None:
s = Series.from_arrow(pa.array(["😉test", "hey̆", "baz", None, ""]))
result = s.str.length_bytes()
assert result.to_pylist() == [8, 5, 3, None, 0]


@pytest.mark.parametrize(
["data", "expected"],
[
Expand Down

0 comments on commit 6fe408c

Please sign in to comment.