Skip to content

Commit

Permalink
EHN: multi-column explode (pandas-dev#39240)
Browse files Browse the repository at this point in the history
  • Loading branch information
iynehz committed Jun 20, 2021
1 parent 3659eda commit 5f2bff9
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 24 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ Other enhancements
- Add keyword ``dropna`` to :meth:`DataFrame.value_counts` to allow counting rows that include ``NA`` values (:issue:`41325`)
- :meth:`Series.replace` will now cast results to ``PeriodDtype`` where possible instead of ``object`` dtype (:issue:`41526`)
- Improved error message in ``corr`` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`)
- :meth:`DataFrame.explode` now supports exploding multiple columns. Its ``column`` argument now also accepts a list of str or tuples for exploding on multiple columns at the same time (:issue:`39240`)

.. ---------------------------------------------------------------------------
Expand Down
97 changes: 74 additions & 23 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8151,16 +8151,27 @@ def stack(self, level: Level = -1, dropna: bool = True):

return result.__finalize__(self, method="stack")

def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
def explode(
self,
column: str | tuple | list[str | tuple],
ignore_index: bool = False,
) -> DataFrame:
"""
Transform each element of a list-like to a row, replicating index values.
.. versionadded:: 0.25.0
Parameters
----------
column : str or tuple
Column to explode.
column : str or tuple or list thereof
Column(s) to explode.
For multiple columns, specify a non-empty list with each element
be str or tuple, and all specified columns their list-like data
on same row of the frame must have matching length.
.. versionadded:: 1.3.0
Multi-column explode
ignore_index : bool, default False
If True, the resulting index will be labeled 0, 1, …, n - 1.
Expand All @@ -8175,7 +8186,10 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
Raises
------
ValueError :
if columns of the frame are not unique.
* If columns of the frame are not unique.
* If specified columns to explode is empty list.
* If specified columns to explode have not matching count of
elements rowwise in the frame.
See Also
--------
Expand All @@ -8194,32 +8208,69 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame:
Examples
--------
>>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1})
>>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]],
... 'B': 1,
... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]})
>>> df
A B
0 [1, 2, 3] 1
1 foo 1
2 [] 1
3 [3, 4] 1
A B C
0 [0, 1, 2] 1 [a, b, c]
1 foo 1 NaN
2 [] 1 []
3 [3, 4] 1 [d, e]
Single-column explode.
>>> df.explode('A')
A B
0 1 1
0 2 1
0 3 1
1 foo 1
2 NaN 1
3 3 1
3 4 1
"""
if not (is_scalar(column) or isinstance(column, tuple)):
raise ValueError("column must be a scalar")
A B C
0 0 1 [a, b, c]
0 1 1 [a, b, c]
0 2 1 [a, b, c]
1 foo 1 NaN
2 NaN 1 []
3 3 1 [d, e]
3 4 1 [d, e]
Multi-column explode.
>>> df.explode(list('AC'))
A B C
0 0 1 a
0 1 1 b
0 2 1 c
1 foo 1 NaN
2 NaN 1 NaN
3 3 1 d
3 4 1 e
"""
if not self.columns.is_unique:
raise ValueError("columns must be unique")

columns: list[str | tuple]
if is_scalar(column) or isinstance(column, tuple):
assert isinstance(column, (str, tuple))
columns = [column]
elif isinstance(column, list) and all(
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
):
if not column:
raise ValueError("column must be nonempty")
if len(column) > len(set(column)):
raise ValueError("column must be unique")
columns = column
else:
raise ValueError("column must be a scalar, tuple, or list thereof")

df = self.reset_index(drop=True)
result = df[column].explode()
result = df.drop([column], axis=1).join(result)
if len(columns) == 1:
result = df[columns[0]].explode()
else:
mylen = lambda x: len(x) if is_list_like(x) else -1
counts0 = self[columns[0]].apply(mylen)
for c in columns[1:]:
if not all(counts0 == self[c].apply(mylen)):
raise ValueError("columns must have matching element counts")
result = DataFrame({c: df[c].explode() for c in columns})
result = df.drop(columns, axis=1).join(result)
if ignore_index:
result.index = ibase.default_index(len(result))
else:
Expand Down
93 changes: 92 additions & 1 deletion pandas/tests/frame/methods/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,50 @@ def test_error():
df = pd.DataFrame(
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
)
with pytest.raises(ValueError, match="column must be a scalar"):
with pytest.raises(
ValueError, match="column must be a scalar, tuple, or list thereof"
):
df.explode([list("AA")])

with pytest.raises(ValueError, match="column must be unique"):
df.explode(list("AA"))

df.columns = list("AA")
with pytest.raises(ValueError, match="columns must be unique"):
df.explode("A")


@pytest.mark.parametrize(
"input_subset, error_message",
[
(
list("AC"),
"columns must have matching element counts",
),
(
[],
"column must be nonempty",
),
(
list("AC"),
"columns must have matching element counts",
),
],
)
def test_error_multi_columns(input_subset, error_message):
# GH 39240
df = pd.DataFrame(
{
"A": [[0, 1, 2], np.nan, [], (3, 4)],
"B": 1,
"C": [["a", "b", "c"], "foo", [], ["d", "e", "f"]],
},
index=list("abcd"),
)
with pytest.raises(ValueError, match=error_message):
df.explode(input_subset)


def test_basic():
df = pd.DataFrame(
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
Expand Down Expand Up @@ -180,3 +216,58 @@ def test_explode_sets():
result = df.explode(column="a").sort_values(by="a")
expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1])
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"input_subset, expected_dict, expected_index",
[
(
list("AC"),
{
"A": pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4, np.nan],
index=list("aaabcdde"),
dtype=object,
),
"B": 1,
"C": ["a", "b", "c", "foo", np.nan, "d", "e", np.nan],
},
list("aaabcdde"),
),
(
list("A"),
{
"A": pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4, np.nan],
index=list("aaabcdde"),
dtype=object,
),
"B": 1,
"C": [
["a", "b", "c"],
["a", "b", "c"],
["a", "b", "c"],
"foo",
[],
["d", "e"],
["d", "e"],
np.nan,
],
},
list("aaabcdde"),
),
],
)
def test_multi_columns(input_subset, expected_dict, expected_index):
# GH 39240
df = pd.DataFrame(
{
"A": [[0, 1, 2], np.nan, [], (3, 4), np.nan],
"B": 1,
"C": [["a", "b", "c"], "foo", [], ["d", "e"], np.nan],
},
index=list("abcde"),
)
result = df.explode(input_subset)
expected = pd.DataFrame(expected_dict, expected_index)
tm.assert_frame_equal(result, expected)

0 comments on commit 5f2bff9

Please sign in to comment.