Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] fix ValueError when plotting events components #1368

58 changes: 32 additions & 26 deletions neuralprophet/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
if "lagged_regressors" in components and m.config_lagged_regressors is None:
components.remove("lagged_regressors")
invalid_components.append("lagged_regressors")
if "events" in components and (m.config_events and m.config_country_holidays) is None:
if "events" in components and (m.config_events is None and m.config_country_holidays is None):
components.remove("events")
invalid_components.append("events")
if "future_regressors" in components and m.config_regressors is None:
Expand Down Expand Up @@ -423,31 +423,37 @@
if "events" in components:
additive_events_flag = False
muliplicative_events_flag = False
for event, configs in m.config_events.items():
if validator == "plot_components" and configs.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and configs.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(event)
weight_list = [(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()]
if configs.mode == "additive":
additive_events = additive_events + weight_list
elif configs.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

for country_holiday in m.config_country_holidays.holiday_names:
if validator == "plot_components" and m.config_country_holidays.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(country_holiday)
weight_list = [(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()]
if m.config_country_holidays.mode == "additive":
additive_events = additive_events + weight_list
elif m.config_country_holidays.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list
if m.config_events is not None:
for event, configs in m.config_events.items():
if validator == "plot_components" and configs.mode == "additive":
additive_events_flag = True

Check warning on line 429 in neuralprophet/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/plot_utils.py#L429

Added line #L429 was not covered by tests
elif validator == "plot_components" and configs.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(event)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if configs.mode == "additive":
additive_events = additive_events + weight_list

Check warning on line 438 in neuralprophet/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/plot_utils.py#L438

Added line #L438 was not covered by tests
elif configs.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

if m.config_country_holidays is not None:
for country_holiday in m.config_country_holidays.holiday_names:
if validator == "plot_components" and m.config_country_holidays.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative":
muliplicative_events_flag = True

Check warning on line 447 in neuralprophet/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/plot_utils.py#L447

Added line #L447 was not covered by tests
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(country_holiday)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if m.config_country_holidays.mode == "additive":
additive_events = additive_events + weight_list
elif m.config_country_holidays.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

Check warning on line 456 in neuralprophet/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

neuralprophet/plot_utils.py#L455-L456

Added lines #L455 - L456 were not covered by tests

if additive_events_flag:
plot_components.append(
Expand Down
Loading