Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Pytest for fix ValueError when plotting events components #1434

Merged
merged 2 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions neuralprophet/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
if "lagged_regressors" in components and m.config_lagged_regressors is None:
components.remove("lagged_regressors")
invalid_components.append("lagged_regressors")
if "events" in components and (m.config_events and m.config_country_holidays) is None:
if "events" in components and (m.config_events is None and m.config_country_holidays is None):
components.remove("events")
invalid_components.append("events")
if "future_regressors" in components and m.config_regressors is None:
Expand Down Expand Up @@ -419,31 +419,37 @@
if "events" in components:
additive_events_flag = False
muliplicative_events_flag = False
for event, configs in m.config_events.items():
if validator == "plot_components" and configs.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and configs.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(event)
weight_list = [(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()]
if configs.mode == "additive":
additive_events = additive_events + weight_list
elif configs.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

for country_holiday in m.config_country_holidays.holiday_names:
if validator == "plot_components" and m.config_country_holidays.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(country_holiday)
weight_list = [(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()]
if m.config_country_holidays.mode == "additive":
additive_events = additive_events + weight_list
elif m.config_country_holidays.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list
if m.config_events is not None:
for event, configs in m.config_events.items():
if validator == "plot_components" and configs.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and configs.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(event)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if configs.mode == "additive":
additive_events = additive_events + weight_list
elif configs.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

if m.config_country_holidays is not None:
for country_holiday in m.config_country_holidays.holiday_names:
if validator == "plot_components" and m.config_country_holidays.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative":
muliplicative_events_flag = True

Check warning on line 443 in neuralprophet/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/plot_utils.py#L443

Added line #L443 was not covered by tests
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(country_holiday)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if m.config_country_holidays.mode == "additive":
additive_events = additive_events + weight_list
elif m.config_country_holidays.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

Check warning on line 452 in neuralprophet/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/plot_utils.py#L451-L452

Added lines #L451 - L452 were not covered by tests

if additive_events_flag:
plot_components.append(
Expand Down
105 changes: 105 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,73 @@ def test_plot_events(plotting_backend):
["superbowl", "playoff"], lower_window=-1, upper_window=1, mode="multiplicative", regularization=0.5
)
# add the country specific holidays
m = m.add_country_holidays("US", mode="multiplicative", regularization=0.5)
m.add_country_holidays("Indonesia")
m.add_country_holidays("Thailand")
m.add_country_holidays("Philippines")
m.add_country_holidays("Pakistan")
m.add_country_holidays("Belarus")
history_df = m.create_df_with_events(df, events_df)
m.fit(history_df, freq="D")
future = m.make_future_dataframe(df=history_df, events_df=events_df, periods=30, n_historic_predictions=90)
forecast = m.predict(df=future)
log.debug(f"Event Parameters:: {m.model.event_params}")

fig1 = m.plot_components(forecast, plotting_backend=plotting_backend)
fig2 = m.plot(forecast, plotting_backend=plotting_backend)
fig3 = m.plot_parameters(plotting_backend=plotting_backend)

if PLOT:
fig1.show()
fig2.show()
fig3.show()


@pytest.mark.parametrize(*decorator_input)
def test_plot_events_additive(plotting_backend):
log.info(f"testing: Plotting with events with {plotting_backend}")
df = pd.read_csv(PEYTON_FILE)[-NROWS:]
playoffs = pd.DataFrame(
{
"event": "playoff",
"ds": pd.to_datetime(
[
"2008-01-13",
"2009-01-03",
"2010-01-16",
"2010-01-24",
"2010-02-07",
"2011-01-08",
"2013-01-12",
"2014-01-12",
"2014-01-19",
"2014-02-02",
"2015-01-11",
"2016-01-17",
"2016-01-24",
"2016-02-07",
]
),
}
)
superbowls = pd.DataFrame(
{
"event": "superbowl",
"ds": pd.to_datetime(["2010-02-07", "2014-02-02", "2016-02-07"]),
}
)
events_df = pd.concat((playoffs, superbowls))
m = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
n_lags=2,
n_forecasts=30,
daily_seasonality=False,
)
# set event windows
m = m.add_events(["superbowl", "playoff"], lower_window=-1, upper_window=1, mode="additive", regularization=0.5)
# add the country specific holidays
m = m.add_country_holidays("US", mode="additive", regularization=0.5)
m.add_country_holidays("Indonesia")
m.add_country_holidays("Thailand")
Expand All @@ -293,6 +360,44 @@ def test_plot_events(plotting_backend):
fig3.show()


@pytest.mark.parametrize(*decorator_input)
def test_plot_events_components(plotting_backend):
log.info(f"testing: Plotting components with events with {plotting_backend}")
df = pd.read_csv(PEYTON_FILE)[-NROWS:]

events_df = pd.DataFrame(
{
"event": "superbowl",
"ds": pd.to_datetime(["2010-02-07", "2014-02-02", "2016-02-07"]),
}
)
m = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
n_forecasts=7,
n_lags=14,
yearly_seasonality=True,
weekly_seasonality=True,
)
# set event windows
m = m.add_events(["superbowl"], lower_window=-1, upper_window=1, mode="additive", regularization=0.5)
history_df = m.create_df_with_events(df, events_df)
m.fit(history_df, freq="D")
future = m.make_future_dataframe(df=history_df, events_df=events_df, periods=30, n_historic_predictions=90)
forecast = m.predict(df=future)
log.debug(f"Event Parameters:: {m.model.event_params}")

fig1 = m.plot_components(forecast, plotting_backend=plotting_backend)
fig2 = m.plot(forecast, plotting_backend=plotting_backend)
fig3 = m.plot_parameters(plotting_backend=plotting_backend)

if PLOT:
fig1.show()
fig2.show()
fig3.show()


@pytest.mark.parametrize(*decorator_input)
def test_plot_trend(plotting_backend):
log.info(f"testing: Plotting linear trend with {plotting_backend}")
Expand Down
Loading