Skip to content

Commit

Permalink
BUG: DataFrame.where does not handle Series slice correctly (fixes #1…
Browse files Browse the repository at this point in the history
  • Loading branch information
mortada committed Aug 23, 2015
1 parent 42ca8cd commit a25a664
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.17.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,4 @@ Bug Fixes
- Bug in ``read_msgpack`` where encoding is not respected (:issue:`10580`)
- Bug preventing access to the first index when using ``iloc`` with a list containing the appropriate negative integer (:issue:`10547`, :issue:`10779`)
- Bug in ``TimedeltaIndex`` formatter causing error while trying to save ``DataFrame`` with ``TimedeltaIndex`` using ``to_csv`` (:issue:`10833`)
- BUG in ``DataFrame.where`` when handling Series slicing (:issue:`10218`, :issue:`9558`)
8 changes: 8 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,14 @@ def _reindex_multi(self, axes, copy, fill_value):
copy=copy,
fill_value=fill_value)

@Appender(_shared_docs['align'] % _shared_doc_kwargs)
def align(self, other, join='outer', axis=None, level=None, copy=True,
fill_value=None, method=None, limit=None, fill_axis=0,
broadcast_axis=None):
return super(DataFrame, self).align(other, join=join, axis=axis, level=level, copy=copy,
fill_value=fill_value, method=method, limit=limit,
fill_axis=fill_axis, broadcast_axis=broadcast_axis)

@Appender(_shared_docs['reindex'] % _shared_doc_kwargs)
def reindex(self, index=None, columns=None, **kwargs):
return super(DataFrame, self).reindex(index=index, columns=columns,
Expand Down
42 changes: 35 additions & 7 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3447,8 +3447,7 @@ def last(self, offset):
start = self.index.searchsorted(start_date, side='right')
return self.ix[start:]

def align(self, other, join='outer', axis=None, level=None, copy=True,
fill_value=None, method=None, limit=None, fill_axis=0):
_shared_docs['align'] = (
"""
Align two object on their axes with the
specified join method for each axis Index
Expand All @@ -3470,17 +3469,46 @@ def align(self, other, join='outer', axis=None, level=None, copy=True,
"compatible" value
method : str, default None
limit : int, default None
fill_axis : {0, 1}, default 0
fill_axis : %(axes_single_arg)s, default 0
Filling axis, method and limit
broadcast_axis : %(axes_single_arg)s, default None
Broadcast values along this axis, if aligning two objects of
different dimensions
.. versionadded:: 0.17.0
Returns
-------
(left, right) : (type of input, type of other)
(left, right) : (%(klass)s, type of other)
Aligned objects
"""
)

@Appender(_shared_docs['align'] % _shared_doc_kwargs)
def align(self, other, join='outer', axis=None, level=None, copy=True,
fill_value=None, method=None, limit=None, fill_axis=0,
broadcast_axis=None):
from pandas import DataFrame, Series
method = com._clean_fill_method(method)

if broadcast_axis == 1 and self.ndim != other.ndim:
if isinstance(self, Series):
# this means other is a DataFrame, and we need to broadcast self
df = DataFrame(dict((c, self) for c in other.columns),
**other._construct_axes_dict())
return df._align_frame(other, join=join, axis=axis, level=level,
copy=copy, fill_value=fill_value,
method=method, limit=limit,
fill_axis=fill_axis)
elif isinstance(other, Series):
# this means self is a DataFrame, and we need to broadcast other
df = DataFrame(dict((c, other) for c in self.columns),
**self._construct_axes_dict())
return self._align_frame(df, join=join, axis=axis, level=level,
copy=copy, fill_value=fill_value,
method=method, limit=limit,
fill_axis=fill_axis)

if axis is not None:
axis = self._get_axis_number(axis)
if isinstance(other, DataFrame):
Expand Down Expand Up @@ -3516,11 +3544,11 @@ def _align_frame(self, other, join='outer', axis=None, level=None,
self.columns.join(other.columns, how=join, level=level,
return_indexers=True)

left = self._reindex_with_indexers({0: [join_index, ilidx],
left = self._reindex_with_indexers({0: [join_index, ilidx],
1: [join_columns, clidx]},
copy=copy, fill_value=fill_value,
allow_dups=True)
right = other._reindex_with_indexers({0: [join_index, iridx],
right = other._reindex_with_indexers({0: [join_index, iridx],
1: [join_columns, cridx]},
copy=copy, fill_value=fill_value,
allow_dups=True)
Expand Down Expand Up @@ -3624,7 +3652,7 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
try_cast=False, raise_on_error=True):

if isinstance(cond, NDFrame):
cond = cond.reindex(**self._construct_axes_dict())
cond, _ = cond.align(self, join='right', broadcast_axis=1)
else:
if not hasattr(cond, 'shape'):
raise ValueError('where requires an ndarray like object for '
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,9 @@ def _needs_reindex_multi(self, axes, method, level):
""" don't allow a multi reindex on Panel or above ndim """
return False

def align(self, other, **kwargs):
raise NotImplementedError

def dropna(self, axis=0, how='any', inplace=False):
"""
Drop 2D from panel, holding passed axis constant
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,14 @@ def _needs_reindex_multi(self, axes, method, level):
"""
return False

@Appender(generic._shared_docs['align'] % _shared_doc_kwargs)
def align(self, other, join='outer', axis=None, level=None, copy=True,
fill_value=None, method=None, limit=None, fill_axis=0,
broadcast_axis=None):
return super(Series, self).align(other, join=join, axis=axis, level=level, copy=copy,
fill_value=fill_value, method=method, limit=limit,
fill_axis=fill_axis, broadcast_axis=broadcast_axis)

@Appender(generic._shared_docs['rename'] % _shared_doc_kwargs)
def rename(self, index=None, **kwargs):
return super(Series, self).rename(index=index, **kwargs)
Expand Down
35 changes: 35 additions & 0 deletions pandas/tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10065,6 +10065,34 @@ def test_align(self):
self.assertRaises(ValueError, self.frame.align, af.ix[0, :3],
join='inner', axis=2)

# align dataframe to series with broadcast or not
idx = self.frame.index
s = Series(range(len(idx)), index=idx)

left, right = self.frame.align(s, axis=0)
tm.assert_index_equal(left.index, self.frame.index)
tm.assert_index_equal(right.index, self.frame.index)
self.assertTrue(isinstance(right, Series))

left, right = self.frame.align(s, broadcast_axis=1)
tm.assert_index_equal(left.index, self.frame.index)
expected = {}
for c in self.frame.columns:
expected[c] = s
expected = DataFrame(expected, index=self.frame.index,
columns=self.frame.columns)
assert_frame_equal(right, expected)

# GH 9558
df = DataFrame({'a':[1,2,3], 'b':[4,5,6]})
result = df[df['a'] == 2]
expected = DataFrame([[2, 5]], index=[1], columns=['a', 'b'])
assert_frame_equal(result, expected)

result = df.where(df['a'] == 2, 0)
expected = DataFrame({'a':[0, 2, 0], 'b':[0, 5, 0]})
assert_frame_equal(result, expected)

def _check_align(self, a, b, axis, fill_axis, how, method, limit=None):
aa, ab = a.align(b, axis=axis, join=how, method=method, limit=limit,
fill_axis=fill_axis)
Expand Down Expand Up @@ -10310,6 +10338,13 @@ def _check_set(df, cond, check_dtypes = True):
cond = (df >= 0)[1:]
_check_set(df, cond)

# GH 10218
# test DataFrame.where with Series slicing
df = DataFrame({'a': range(3), 'b': range(4, 7)})
result = df.where(df['a'] == 1)
expected = df[df['a'] == 1].reindex(df.index)
assert_frame_equal(result, expected)

def test_where_bug(self):

# GH 2793
Expand Down

0 comments on commit a25a664

Please sign in to comment.