From 15e250ed24fdcc69ac9cf46c9db08691653d1959 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Thu, 12 Jan 2023 11:25:30 +0100 Subject: [PATCH] fix issue with slice_n_points on integer indexes --- darts/tests/test_timeseries.py | 8 ++++++++ darts/timeseries.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index 23b9c80e8c..3adf416728 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -336,6 +336,14 @@ def helper_test_slice(test_case, test_series: TimeSeries): test_case.assertEqual(seriesG.start_time(), pd.Timestamp("20130101")) test_case.assertEqual(seriesG.end_time(), pd.Timestamp("20130107")) + # test slice_n_points_after and slice_n_points_before with integer-indexed series + s = TimeSeries.from_times_and_values(pd.RangeIndex(6, 10), np.arange(16, 20)) + sliced_idx = s.slice_n_points_after(7, 2).time_index + test_case.assertTrue(all(sliced_idx == pd.RangeIndex(7, 9))) + + sliced_idx = s.slice_n_points_before(8, 2).time_index + test_case.assertTrue(all(sliced_idx == pd.RangeIndex(7, 9))) + # integer indexed series, step = 1, timestamps not in series values = np.random.rand(30) idx = pd.RangeIndex(start=0, stop=30, step=1) diff --git a/darts/timeseries.py b/darts/timeseries.py index 7eca55eb94..2be06b8960 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -2272,7 +2272,7 @@ def slice_n_points_after( self._raise_if_not_within(start_ts) if isinstance(start_ts, (int, np.int64)): - return self[start_ts : start_ts + n] + return self[pd.RangeIndex(start=start_ts, stop=start_ts + n)] elif isinstance(start_ts, pd.Timestamp): # get first timestamp greater or equal to start_ts tss = self._get_first_timestamp_after(start_ts) @@ -2308,7 +2308,7 @@ def slice_n_points_before( self._raise_if_not_within(end_ts) if isinstance(end_ts, (int, np.int64)): - return self[end_ts - n + 1 : end_ts + 1] + return self[pd.RangeIndex(start=end_ts - n + 1, stop=end_ts + 1)] elif isinstance(end_ts, pd.Timestamp): # get last timestamp smaller or equal to start_ts tss = self._get_last_timestamp_before(end_ts)