From 818efa12d9ac60cff46c7ccd5b78787c2e92803c Mon Sep 17 00:00:00 2001 From: "johnny.kwan" Date: Thu, 6 Jul 2023 08:42:08 +0000 Subject: [PATCH] pytest for plotting components --- neuralprophet/plot_utils.py | 58 +++++++++++--------- tests/test_plotting.py | 105 ++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 26 deletions(-) diff --git a/neuralprophet/plot_utils.py b/neuralprophet/plot_utils.py index 6ad7d10e8..ca93e1e57 100644 --- a/neuralprophet/plot_utils.py +++ b/neuralprophet/plot_utils.py @@ -193,7 +193,7 @@ def check_if_configured(m, components, error_flag=False): # move to utils if "lagged_regressors" in components and m.config_lagged_regressors is None: components.remove("lagged_regressors") invalid_components.append("lagged_regressors") - if "events" in components and (m.config_events and m.config_country_holidays) is None: + if "events" in components and (m.config_events is None and m.config_country_holidays is None): components.remove("events") invalid_components.append("events") if "future_regressors" in components and m.config_regressors is None: @@ -419,31 +419,37 @@ def get_valid_configuration( # move to utils if "events" in components: additive_events_flag = False muliplicative_events_flag = False - for event, configs in m.config_events.items(): - if validator == "plot_components" and configs.mode == "additive": - additive_events_flag = True - elif validator == "plot_components" and configs.mode == "multiplicative": - muliplicative_events_flag = True - elif validator == "plot_parameters": - event_params = m.model.get_event_weights(event) - weight_list = [(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()] - if configs.mode == "additive": - additive_events = additive_events + weight_list - elif configs.mode == "multiplicative": - multiplicative_events = multiplicative_events + weight_list - - for country_holiday in m.config_country_holidays.holiday_names: - if validator == "plot_components" and m.config_country_holidays.mode == "additive": - additive_events_flag = True - elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative": - muliplicative_events_flag = True - elif validator == "plot_parameters": - event_params = m.model.get_event_weights(country_holiday) - weight_list = [(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()] - if m.config_country_holidays.mode == "additive": - additive_events = additive_events + weight_list - elif m.config_country_holidays.mode == "multiplicative": - multiplicative_events = multiplicative_events + weight_list + if m.config_events is not None: + for event, configs in m.config_events.items(): + if validator == "plot_components" and configs.mode == "additive": + additive_events_flag = True + elif validator == "plot_components" and configs.mode == "multiplicative": + muliplicative_events_flag = True + elif validator == "plot_parameters": + event_params = m.model.get_event_weights(event) + weight_list = [ + (key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items() + ] + if configs.mode == "additive": + additive_events = additive_events + weight_list + elif configs.mode == "multiplicative": + multiplicative_events = multiplicative_events + weight_list + + if m.config_country_holidays is not None: + for country_holiday in m.config_country_holidays.holiday_names: + if validator == "plot_components" and m.config_country_holidays.mode == "additive": + additive_events_flag = True + elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative": + muliplicative_events_flag = True + elif validator == "plot_parameters": + event_params = m.model.get_event_weights(country_holiday) + weight_list = [ + (key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items() + ] + if m.config_country_holidays.mode == "additive": + additive_events = additive_events + weight_list + elif m.config_country_holidays.mode == "multiplicative": + multiplicative_events = multiplicative_events + weight_list if additive_events_flag: plot_components.append( diff --git a/tests/test_plotting.py b/tests/test_plotting.py index fc372aeb3..499c9e70e 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -271,6 +271,73 @@ def test_plot_events(plotting_backend): ["superbowl", "playoff"], lower_window=-1, upper_window=1, mode="multiplicative", regularization=0.5 ) # add the country specific holidays + m = m.add_country_holidays("US", mode="multiplicative", regularization=0.5) + m.add_country_holidays("Indonesia") + m.add_country_holidays("Thailand") + m.add_country_holidays("Philippines") + m.add_country_holidays("Pakistan") + m.add_country_holidays("Belarus") + history_df = m.create_df_with_events(df, events_df) + m.fit(history_df, freq="D") + future = m.make_future_dataframe(df=history_df, events_df=events_df, periods=30, n_historic_predictions=90) + forecast = m.predict(df=future) + log.debug(f"Event Parameters:: {m.model.event_params}") + + fig1 = m.plot_components(forecast, plotting_backend=plotting_backend) + fig2 = m.plot(forecast, plotting_backend=plotting_backend) + fig3 = m.plot_parameters(plotting_backend=plotting_backend) + + if PLOT: + fig1.show() + fig2.show() + fig3.show() + + +@pytest.mark.parametrize(*decorator_input) +def test_plot_events_additive(plotting_backend): + log.info(f"testing: Plotting with events with {plotting_backend}") + df = pd.read_csv(PEYTON_FILE)[-NROWS:] + playoffs = pd.DataFrame( + { + "event": "playoff", + "ds": pd.to_datetime( + [ + "2008-01-13", + "2009-01-03", + "2010-01-16", + "2010-01-24", + "2010-02-07", + "2011-01-08", + "2013-01-12", + "2014-01-12", + "2014-01-19", + "2014-02-02", + "2015-01-11", + "2016-01-17", + "2016-01-24", + "2016-02-07", + ] + ), + } + ) + superbowls = pd.DataFrame( + { + "event": "superbowl", + "ds": pd.to_datetime(["2010-02-07", "2014-02-02", "2016-02-07"]), + } + ) + events_df = pd.concat((playoffs, superbowls)) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_lags=2, + n_forecasts=30, + daily_seasonality=False, + ) + # set event windows + m = m.add_events(["superbowl", "playoff"], lower_window=-1, upper_window=1, mode="additive", regularization=0.5) + # add the country specific holidays m = m.add_country_holidays("US", mode="additive", regularization=0.5) m.add_country_holidays("Indonesia") m.add_country_holidays("Thailand") @@ -293,6 +360,44 @@ def test_plot_events(plotting_backend): fig3.show() +@pytest.mark.parametrize(*decorator_input) +def test_plot_events_components(plotting_backend): + log.info(f"testing: Plotting components with events with {plotting_backend}") + df = pd.read_csv(PEYTON_FILE)[-NROWS:] + + events_df = pd.DataFrame( + { + "event": "superbowl", + "ds": pd.to_datetime(["2010-02-07", "2014-02-02", "2016-02-07"]), + } + ) + m = NeuralProphet( + epochs=EPOCHS, + batch_size=BATCH_SIZE, + learning_rate=LR, + n_forecasts=7, + n_lags=14, + yearly_seasonality=True, + weekly_seasonality=True, + ) + # set event windows + m = m.add_events(["superbowl"], lower_window=-1, upper_window=1, mode="additive", regularization=0.5) + history_df = m.create_df_with_events(df, events_df) + m.fit(history_df, freq="D") + future = m.make_future_dataframe(df=history_df, events_df=events_df, periods=30, n_historic_predictions=90) + forecast = m.predict(df=future) + log.debug(f"Event Parameters:: {m.model.event_params}") + + fig1 = m.plot_components(forecast, plotting_backend=plotting_backend) + fig2 = m.plot(forecast, plotting_backend=plotting_backend) + fig3 = m.plot_parameters(plotting_backend=plotting_backend) + + if PLOT: + fig1.show() + fig2.show() + fig3.show() + + @pytest.mark.parametrize(*decorator_input) def test_plot_trend(plotting_backend): log.info(f"testing: Plotting linear trend with {plotting_backend}")