Skip to content

Commit

Permalink
[ENH]: Expand types allowed in Series.struct.field
Browse files Browse the repository at this point in the history
This expands the set of types allowed by Series.struct.field to allow
those allowed by pyarrow.
  • Loading branch information
TomAugspurger committed Nov 19, 2023
1 parent 4514636 commit ea6e848
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 14 deletions.
102 changes: 89 additions & 13 deletions pandas/core/arrays/arrow/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
pa_version_under11p0,
)

from pandas.core.dtypes.common import is_list_like

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -267,7 +269,16 @@ def dtypes(self) -> Series:
names = [struct.name for struct in pa_type]
return Series(types, index=Index(names))

def field(self, name_or_index: str | int) -> Series:
def field(
self,
name_or_index: list[str]
| list[bytes]
| list[int]
| pc.Expression
| bytes
| str
| int,
) -> Series:
"""
Extract a child field of a struct as a Series.
Expand All @@ -281,6 +292,17 @@ def field(self, name_or_index: str | int) -> Series:
pandas.Series
The data corresponding to the selected child field.
Notes
-----
The name of the resulting Series will be set using the following
rules:
- For string, bytes, or integer `name_or_index` (or a list of these, for
a nested selection), the Series name is set to the selected
field's name.
- For a :class:`pyarrow.compute.Expression`, this is set to
the string form of the expression.
See Also
--------
Series.struct.explode : Return all child fields as a DataFrame.
Expand Down Expand Up @@ -314,27 +336,81 @@ def field(self, name_or_index: str | int) -> Series:
1 2
2 1
Name: version, dtype: int64[pyarrow]
Or an expression
>>> import pyarrow.compute as pc
>>> s.struct.field(pc.field("project"))
0 pandas
1 pandas
2 numpy
Name: project, dtype: string[pyarrow]
For nested struct types, you can
>>> version_type = pa.struct([
... ("major", pa.int64()),
... ("minor", pa.int64()),
... ])
>>> s = pd.Series(
... [
... {"version": {"major": 1, "minor": 5}, "project": "pandas"},
... {"version": {"major": 2, "minor": 1}, "project": "pandas"},
... {"version": {"major": 1, "minor": 26}, "project": "numpy"},
... ],
... dtype=pd.ArrowDtype(pa.struct(
... [("version", version_type), ("project", pa.string())]
... ))
... )
>>> s.struct.field(["version", "minor"])
0 5
1 1
2 26
Name: minor, dtype: int64[pyarrow]
>>> s.struct.field([0, 0])
0 1
1 2
2 1
Name: major, dtype: int64[pyarrow]
"""
from pandas import Series

def get_name(level_name_or_index, data):
if isinstance(level_name_or_index, int):
index = data.type.field(level_name_or_index).name
elif isinstance(level_name_or_index, (str, bytes)):
# index = pa_arr.type.get_field_index(level_name_or_index)
index = level_name_or_index
elif isinstance(level_name_or_index, pc.Expression):
index = str(level_name_or_index)
elif is_list_like(level_name_or_index):
# For nested input like [2, 1, 2]
# iteratively get the struct and field name. The last
# one is used for the name of the index.
level_name_or_index = list(reversed(level_name_or_index))
selected = data
while level_name_or_index:
name_or_index = level_name_or_index.pop()
name = get_name(name_or_index, selected)
selected = selected.type.field(selected.type.get_field_index(name))
index = selected.name
return index
else:
raise ValueError(
"name_or_index must be an int, str, bytes, "
"pyarrow.compute.Expression, or list of those"
)
return index

pa_arr = self._data.array._pa_array
if isinstance(name_or_index, int):
index = name_or_index
elif isinstance(name_or_index, str):
index = pa_arr.type.get_field_index(name_or_index)
else:
raise ValueError(
"name_or_index must be an int or str, "
f"got {type(name_or_index).__name__}"
)
name = get_name(name_or_index, pa_arr)
field_arr = pc.struct_field(pa_arr, name_or_index)

pa_field = pa_arr.type[index]
field_arr = pc.struct_field(pa_arr, [index])
return Series(
field_arr,
dtype=ArrowDtype(field_arr.type),
index=self._data.index,
name=pa_field.name,
name=name,
)

def explode(self) -> DataFrame:
Expand Down
40 changes: 39 additions & 1 deletion pandas/tests/series/accessors/test_struct_accessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re

import pyarrow.compute as pc
import pytest

from pandas import (
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_struct_accessor_field():
def test_struct_accessor_field_with_invalid_name_or_index():
ser = Series([], dtype=ArrowDtype(pa.struct([("field", pa.int64())])))

with pytest.raises(ValueError, match="name_or_index must be an int or str"):
with pytest.raises(ValueError, match="name_or_index must be an int, str,"):
ser.struct.field(1.1)


Expand Down Expand Up @@ -148,3 +149,40 @@ def test_struct_accessor_api_for_invalid(invalid):
),
):
invalid.struct


@pytest.mark.parametrize(
["indices", "name"],
[
(0, "int_col"),
([1, 2], "str_col"),
(pc.field("int_col"), "int_col"),
("int_col", "int_col"),
(b"string_col", b"string_col"),
([b"string_col"], "string_col"),
],
)
def test_struct_accessor_field_expanded(indices, name):
arrow_type = pa.struct(
[
("int_col", pa.int64()),
(
"struct_col",
pa.struct(
[
("int_col", pa.int64()),
("float_col", pa.float64()),
("str_col", pa.string()),
]
),
),
(b"string_col", pa.string()),
]
)

data = pa.array([], type=arrow_type)
ser = Series(data, dtype=ArrowDtype(arrow_type))
expected = pc.struct_field(data, indices)
result = ser.struct.field(indices)
tm.assert_equal(result.array._pa_array.combine_chunks(), expected)
assert result.name == name

0 comments on commit ea6e848

Please sign in to comment.