Skip to content

Commit

Permalink
fix issue with slice_n_points on integer indexes (#1482)
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn committed Jan 12, 2023
1 parent 16f3a9f commit ff9aa90
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 8 additions & 0 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ def helper_test_slice(test_case, test_series: TimeSeries):
test_case.assertEqual(seriesG.start_time(), pd.Timestamp("20130101"))
test_case.assertEqual(seriesG.end_time(), pd.Timestamp("20130107"))

# test slice_n_points_after and slice_n_points_before with integer-indexed series
s = TimeSeries.from_times_and_values(pd.RangeIndex(6, 10), np.arange(16, 20))
sliced_idx = s.slice_n_points_after(7, 2).time_index
test_case.assertTrue(all(sliced_idx == pd.RangeIndex(7, 9)))

sliced_idx = s.slice_n_points_before(8, 2).time_index
test_case.assertTrue(all(sliced_idx == pd.RangeIndex(7, 9)))

# integer indexed series, step = 1, timestamps not in series
values = np.random.rand(30)
idx = pd.RangeIndex(start=0, stop=30, step=1)
Expand Down
4 changes: 2 additions & 2 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,7 +2272,7 @@ def slice_n_points_after(
self._raise_if_not_within(start_ts)

if isinstance(start_ts, (int, np.int64)):
return self[start_ts : start_ts + n]
return self[pd.RangeIndex(start=start_ts, stop=start_ts + n)]
elif isinstance(start_ts, pd.Timestamp):
# get first timestamp greater or equal to start_ts
tss = self._get_first_timestamp_after(start_ts)
Expand Down Expand Up @@ -2308,7 +2308,7 @@ def slice_n_points_before(
self._raise_if_not_within(end_ts)

if isinstance(end_ts, (int, np.int64)):
return self[end_ts - n + 1 : end_ts + 1]
return self[pd.RangeIndex(start=end_ts - n + 1, stop=end_ts + 1)]
elif isinstance(end_ts, pd.Timestamp):
# get last timestamp smaller or equal to start_ts
tss = self._get_last_timestamp_before(end_ts)
Expand Down

0 comments on commit ff9aa90

Please sign in to comment.