Skip to content

Commit

Permalink
use attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Oct 21, 2019
1 parent 930aa9d commit 05e238d
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 37 deletions.
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5316,8 +5316,9 @@ def _arith_op(left, right):
with np.errstate(all="ignore"):
res_values = _arith_op(this.values, other.values)
new_data = dispatch_fill_zeros(func, this.values, other.values, res_values)
# XXX: pass them here.
return this._construct_result(new_data)
return this._construct_result(new_data).__finalize__(
(self, other), method="combine_frame"
)

def _combine_match_index(self, other, func):
# at this point we have `self.index.equals(other.index)`
Expand Down
60 changes: 33 additions & 27 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class NDFrame(PandasObject, SelectionMixin):
"ix",
]
) # type: FrozenSet[str]
_metadata = ["allows_duplicate_labels"] # type: List[str]
_metadata = [] # type: List[str]
_is_copy = None
_data = None # type: BlockManager

Expand Down Expand Up @@ -224,12 +224,16 @@ def __init__(
object.__setattr__(self, "_is_copy", None)
object.__setattr__(self, "_data", data)
object.__setattr__(self, "_item_cache", {})
object.__setattr__(self, "allows_duplicate_labels", allow_duplicate_labels)
if attrs is None:
attrs = {}
else:
attrs = dict(attrs)
# need to add it to the dict here, since NDFrame.__setattr__
# also calls NDFrame.__getattr__...
# attrs['allows_duplicate_labels'] = allow_duplicate_labels
object.__setattr__(self, "_attrs", attrs)
object.__setattr__(self, "allows_duplicate_labels", allow_duplicate_labels)
# self.allows_duplicate_labels = allow_duplicate_labels

def _init_mgr(self, mgr, axes=None, dtype=None, copy=False):
""" passed a manager and a axes dict """
Expand All @@ -251,21 +255,6 @@ def _init_mgr(self, mgr, axes=None, dtype=None, copy=False):
# ----------------------------------------------------------------------

@property
def allows_duplicate_labels(self):
"""
Whether this object allows duplicate labels.
"""
return self._allows_duplicate_labels

@allows_duplicate_labels.setter
def allows_duplicate_labels(self, value: bool):
value = bool(value)
if not value:
for ax in self.axes:
ax._maybe_check_unique()

self._allows_duplicate_labels = value

def attrs(self) -> Dict[Hashable, Any]:
"""
Dictionary of global attributes on this object.
Expand All @@ -278,6 +267,22 @@ def attrs(self) -> Dict[Hashable, Any]:
def attrs(self, value: Mapping[Hashable, Any]) -> None:
self._attrs = dict(value)

@property
def allows_duplicate_labels(self) -> bool:
"""
Whether this object allows duplicate labels.
"""
return self.attrs["allows_duplicate_labels"]

@allows_duplicate_labels.setter
def allows_duplicate_labels(self, value: bool):
value = bool(value)
if not value:
for ax in self.axes:
ax._maybe_check_unique()

self.attrs["allows_duplicate_labels"] = value

@property
def is_copy(self):
"""
Expand Down Expand Up @@ -5249,6 +5254,7 @@ def pipe(self, func, *args, **kwargs):
>>> s = pd.Series(range(3))
>>> s
0 0
1 1
2 2
dtype: int64
Expand Down Expand Up @@ -5287,18 +5293,18 @@ def finalize_name(objs):

duplicate_labels = "allows_duplicate_labels"

# import pdb; pdb.set_trace()
if isinstance(other, NDFrame):
for name in other.attrs:
self.attrs[name] = other.attrs[name]
for name, value in other.attrs.items():
# Need to think about this...
if name == "allows_duplicate_labels":
self.allows_duplicate_labels = value
elif name in self.attrs:
self.attrs[name] = other.attrs[name]

# For subclasses using _metadata.
for name in self._metadata:
if name == "name" and getattr(other, "ndim", None) == 1:
# Calling hasattr(other, 'name') is bad for DataFrames with
# a name column.
object.__setattr__(self, name, getattr(other, name, None))
elif name != "name":
object.__setattr__(self, name, getattr(other, name, None))
object.__setattr__(self, name, getattr(other, name, None))

elif method == "concat":
assert isinstance(other, _Concatenator)
self.allows_duplicate_labels = merge_all(other.objs, duplicate_labels)
Expand All @@ -5307,7 +5313,7 @@ def finalize_name(objs):
self.allows_duplicate_labels = merge_all(
(other.left, other.right), duplicate_labels
)
elif method in {"combine_const", "combine_frame"}:
elif method in {"combine_const", "combine_frame", "combine_series_frame"}:
assert isinstance(other, tuple)
self.allows_duplicate_labels = merge_all(other, duplicate_labels)
elif method == "align_series":
Expand Down
12 changes: 8 additions & 4 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def aggregate(self, func=None, *args, **kwargs):

if isinstance(func, str):
return getattr(self, func)(*args, **kwargs).__finalize__(
method="groupby-aggregate"
self, method="groupby-aggregate"
)

if isinstance(func, abc.Iterable):
Expand Down Expand Up @@ -275,12 +275,16 @@ def aggregate(self, func=None, *args, **kwargs):
print("Warning, ignoring as_index=True")

# _level handled at higher
if not _level and isinstance(ret, dict):
if not _level and isinstance(ret, (dict, OrderedDict)):
from pandas import concat

ret = concat(ret, axis=1)

return ret.__finalize__(self, method="groupby-aggregate")
if isinstance(ret, NDFrame):
# TODO: when is this *not* an NDFrame?
# pandas/tests/resample/test_resample_api.py::test_agg_nested_dicts
ret = ret.__finalize__(self, method="groupby-aggregate")
return ret

agg = aggregate

Expand Down Expand Up @@ -876,7 +880,7 @@ def aggregate(self, func=None, *args, **kwargs):

result, how = self._aggregate(func, _level=_level, *args, **kwargs)
if how is None:
return result
return result.__finalize__(self, method="groupby-aggregate")

if result is None:

Expand Down
6 changes: 4 additions & 2 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,9 @@ def _combine_series_frame(self, other, func, fill_value=None, axis=None, level=N
else:
new_data = dispatch_to_series(left, right, func, axis="columns")

return left._construct_result(new_data)
return left._construct_result(new_data).__finalize__(
(self, other), method="combine_series_frame"
)


def _align_method_FRAME(left, right, axis):
Expand Down Expand Up @@ -724,7 +726,7 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
self = self.fillna(fill_value)

new_data = dispatch_to_series(self, other, op)
return self._construct_result(new_data)
return self._construct_result(new_data).__finalize__(self)

f.__name__ = op_name

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/test_duplicate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class TestRaises:
)
def test_construction_with_duplicates(self, cls, axes):
result = cls(**axes)
assert result._allows_duplicate_labels is True
assert result.allows_duplicate_labels is True

with pytest.raises(pandas.errors.DuplicateLabelError):
cls(**axes, allow_duplicate_labels=False)
Expand Down
2 changes: 1 addition & 1 deletion pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2800,7 +2800,7 @@ def inner(*args, **kwargs):


class SubclassedSeries(Series):
_metadata = ["testattr", "name"]
_metadata = Series._metadata + ["testattr"]

@property
def _constructor(self):
Expand Down

0 comments on commit 05e238d

Please sign in to comment.