diff --git a/pyam/core.py b/pyam/core.py index a047000ee..397a2c681 100755 --- a/pyam/core.py +++ b/pyam/core.py @@ -91,6 +91,64 @@ logger = logging.getLogger(__name__) +class IamSlice(pd.Series): + @property + def _constructor(self): + return IamSlice + + _internal_names = pd.Series._internal_names + ["_iamcache"] + _internal_names_set = set(_internal_names) + + def __init__(self, data=None, index=None, **kwargs): + super().__init__(data, index, **kwargs) + self._iamcache = dict() + + def __dir__(self): + return self.dimensions + super().__dir__() + + def __getattr__(self, attr): + ret = object.__getattribute__(self, "_iamcache").get(attr) + if ret is not None: + return ret.tolist() if attr != "time" else ret + + if attr in self.dimensions: + ret = self._iamcache[attr] = self.index[self].unique(level=attr) + return ret.tolist() if attr != "time" else ret + + return super().__getattr__(attr) + + def __len__(self): + return self.sum() + + @property + def dimensions(self): + return self.index.names + + def __repr__(self): + return self.info() + "\n\n" + super().__repr__() + + def info(self, n=80): + """Print a summary of the represented index dimensions + + Parameters + ---------- + n : int + The maximum line length + """ + # concatenate list of index dimensions and levels + info = f"{type(self)}\nIndex dimensions:\n" + c1 = max([len(i) for i in self.dimensions]) + 1 + c2 = n - c1 - 5 + info += "\n".join( + [ + f" * {i:{c1}}: {print_list(getattr(self, i), c2)}" + for i in self.dimensions + ] + ) + + return info + + class IamDataFrame(object): """Scenario timeseries data and meta indicators @@ -255,7 +313,9 @@ def _finalize(self, data, append, **args): def __getitem__(self, key): _key_check = [key] if isstr(key) else key - if key == "value": + if isinstance(key, IamSlice): + return IamDataFrame(self._data.loc[key]) + elif key == "value": return pd.Series(self._data.values, name="value") elif set(_key_check).issubset(self.meta.columns): return self.meta.__getitem__(key) @@ -420,7 +480,9 @@ def time(self): - :class:`pandas.Index` if the time domain is 'mixed' """ if self._time is None: - self._time = pd.Index(get_index_levels(self._data, self.time_col)) + self._time = pd.Index( + self._data.index.unique(level=self.time_col).values, name="time" + ) return self._time @@ -1712,6 +1774,19 @@ def _exclude_on_fail(self, df): ) ) + def slice(self, keep=True, **kwargs): + if not isinstance(keep, bool): + raise ValueError(f"Cannot filter by `keep={keep}`, must be a boolean!") + + _keep = self._apply_filters(**kwargs) + _keep = _keep if keep else ~_keep + + return ( + IamSlice(_keep) + if isinstance(_keep, pd.Series) + else IamSlice(_keep, self._data.index) + ) + def filter(self, keep=True, inplace=False, **kwargs): """Return a (copy of a) filtered (downselected) IamDataFrame @@ -1736,12 +1811,9 @@ def filter(self, keep=True, inplace=False, **kwargs): ('month', 'hour', 'time') - 'regexp=True' disables pseudo-regexp syntax in `pattern_match()` """ - if not isinstance(keep, bool): - raise ValueError(f"Cannot filter by `keep={keep}`, must be a boolean!") - # downselect `data` rows and clean up index - _keep = self._apply_filters(**kwargs) - _keep = _keep if keep else ~_keep + _keep = self.slice(keep=keep, **kwargs) + ret = self.copy() if not inplace else self ret._data = ret._data[_keep] ret._data.index = ret._data.index.remove_unused_levels() diff --git a/tests/conftest.py b/tests/conftest.py index 5428c5d36..f596779ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,7 +38,7 @@ DTS_MAPPING = {2005: TEST_DTS[0], 2010: TEST_DTS[1]} -EXP_DATETIME_INDEX = pd.DatetimeIndex(["2005-06-17T00:00:00"]) +EXP_DATETIME_INDEX = pd.DatetimeIndex(["2005-06-17T00:00:00"], name="time") TEST_DF = pd.DataFrame( diff --git a/tests/test_filter.py b/tests/test_filter.py index 26c667e75..54283039e 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -11,19 +11,22 @@ from .conftest import EXP_DATETIME_INDEX -def test_filter_error_illegal_column(test_df): +@pytest.mark.parametrize("method", ("filter", "slice")) +def test_filter_error_illegal_column(test_df, method): # filtering by column `foo` is not valid - pytest.raises(ValueError, test_df.filter, foo="test") + pytest.raises(ValueError, getattr(test_df, method), foo="test") -def test_filter_error_keep(test_df): +@pytest.mark.parametrize("method", ("filter", "slice")) +def test_filter_error_keep(test_df, method): # string or non-starred dict was mis-interpreted as `keep` kwarg, see #253 - pytest.raises(ValueError, test_df.filter, model="foo", keep=1) - pytest.raises(ValueError, test_df.filter, dict(model="foo")) + pytest.raises(ValueError, getattr(test_df, method), model="foo", keep=1) + pytest.raises(ValueError, getattr(test_df, method), dict(model="foo")) -def test_filter_year(test_df): - obs = test_df.filter(year=2005) +@pytest.mark.parametrize("method", ("filter", "slice")) +def test_filter_year(test_df, method): + obs = getattr(test_df, method)(year=2005) if test_df.time_col == "year": assert obs.year == [2005] else: @@ -45,14 +48,14 @@ def test_filter_mixed_time_domain(test_df_mixed, arg_year, arg_time): # filtering to datetime-only works as expected obs = test_df_mixed.filter(**arg_time) assert obs.time_domain == "datetime" - pdt.assert_index_equal(obs.time, pd.DatetimeIndex(["2010-07-21"])) + pdt.assert_index_equal(obs.time, pd.DatetimeIndex(["2010-07-21"], name="time")) # filtering to year-only works as expected including changing of time domain obs = test_df_mixed.filter(**arg_year) assert obs.time_col == "year" assert obs.time_domain == "year" assert obs.year == [2005] - pdt.assert_index_equal(obs.time, pd.Int64Index([2005])) + pdt.assert_index_equal(obs.time, pd.Int64Index([2005], name="time")) def test_filter_time_domain_raises(test_df_year): diff --git a/tests/test_slice.py b/tests/test_slice.py new file mode 100644 index 000000000..7adbf018d --- /dev/null +++ b/tests/test_slice.py @@ -0,0 +1,4 @@ +def test_slice_len(test_df_year): + """Check the length of a slice""" + + assert len(test_df_year.slice(scenario="scen_a")) == 4 diff --git a/tests/test_time.py b/tests/test_time.py index 1d2f4c896..4aad10ec8 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -38,11 +38,11 @@ def get_subannual_df(date1, date2): @pytest.mark.parametrize( "time, domain, index", [ - (TEST_YEARS, "year", pd.Int64Index([2005, 2010])), - (TEST_DTS, "datetime", pd.DatetimeIndex(TEST_DTS)), - (TEST_TIME_STR, "datetime", pd.DatetimeIndex(TEST_DTS)), - (TEST_TIME_STR_HR, "datetime", pd.DatetimeIndex(TEST_TIME_STR_HR)), - (TEST_TIME_MIXED, "mixed", pd.Index(TEST_TIME_MIXED)), + (TEST_YEARS, "year", pd.Int64Index([2005, 2010], name="time")), + (TEST_DTS, "datetime", pd.DatetimeIndex(TEST_DTS, name="time")), + (TEST_TIME_STR, "datetime", pd.DatetimeIndex(TEST_DTS, name="time")), + (TEST_TIME_STR_HR, "datetime", pd.DatetimeIndex(TEST_TIME_STR_HR, name="time")), + (TEST_TIME_MIXED, "mixed", pd.Index(TEST_TIME_MIXED, name="time")), ], ) def test_time_domain(test_pd_df, time, domain, index): @@ -74,7 +74,7 @@ def test_swap_time_to_year(test_df, inplace): obs = test_df assert_iamframe_equal(obs, exp) - pdt.assert_index_equal(obs.time, pd.Index([2005, 2010])) + pdt.assert_index_equal(obs.time, pd.Index([2005, 2010], name="time")) @pytest.mark.parametrize(