diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index a173eb86d3..d8520a334b 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -1358,6 +1358,11 @@ def test_head_overshot_sample_axis(self): result = self.ts.head(20, axis="sample") self.assertEqual(10, result.n_samples) + def test_head_numeric_time_index(self): + s = TimeSeries.from_values(self.ts.values()) + # taking the head should not crash + s.head() + def test_tail_overshot_time_axis(self): result = self.ts.tail(20) self.assertEqual(10, result.n_timesteps) @@ -1371,6 +1376,11 @@ def test_tail_overshot_sample_axis(self): result = self.ts.tail(20, axis="sample") self.assertEqual(10, result.n_samples) + def test_tail_numeric_time_index(self): + s = TimeSeries.from_values(self.ts.values()) + # taking the tail should not crash + s.tail() + class TimeSeriesFromDataFrameTestCase(DartsBaseTestClass): def test_from_dataframe_sunny_day(self): diff --git a/darts/timeseries.py b/darts/timeseries.py index 8f31a8d80f..26d581afe4 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -1261,8 +1261,12 @@ def head( """ axis_str = self._get_dim_name(axis) - display_n = range(min(size, self._xa.sizes[axis_str])) - return self.__class__(self._xa[{axis_str: display_n}]) + display_n = min(size, self._xa.sizes[axis_str]) + + if axis_str == self._time_dim: + return self[:display_n] + else: + return self.__class__(self._xa[{axis_str: range(display_n)}]) def tail( self, size: Optional[int] = 5, axis: Optional[Union[int, str]] = 0 @@ -1284,8 +1288,12 @@ def tail( """ axis_str = self._get_dim_name(axis) - display_n = range(-min(size, self._xa.sizes[axis_str]), 0) - return self.__class__(self._xa[{axis_str: display_n}]) + display_n = min(size, self._xa.sizes[axis_str]) + + if axis_str == self._time_dim: + return self[-display_n:] + else: + return self.__class__(self._xa[{axis_str: range(-display_n, 0)}]) def concatenate( self,