From 28ca88db45967dfe94bb406464df157b93920b59 Mon Sep 17 00:00:00 2001 From: Julien Herzen Date: Mon, 12 Sep 2022 15:45:00 +0200 Subject: [PATCH] adapt prophet calls with vectorized=True (#1208) --- darts/models/forecasting/prophet_model.py | 6 +++--- requirements/core.txt | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index 4dd37ca9d2..5380ccb815 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -155,7 +155,7 @@ def _predict( predict_df = self._generate_predict_df(n=n, future_covariates=future_covariates) if num_samples == 1: - forecast = self.model.predict(predict_df)["yhat"].values + forecast = self.model.predict(predict_df, vectorized=True)["yhat"].values else: forecast = np.expand_dims( self._stochastic_samples(predict_df, n_samples=num_samples), axis=1 @@ -203,7 +203,7 @@ def _stochastic_samples(self, predict_df, n_samples) -> np.ndarray: predict_df["trend"] = self.model.predict_trend(predict_df) - forecast = self.model.sample_posterior_predictive(predict_df) + forecast = self.model.sample_posterior_predictive(predict_df, vectorized=True) # reset default number of uncertainty_samples self.model.uncertainty_samples = n_samples_default @@ -221,7 +221,7 @@ def predict_raw( predict_df = self._generate_predict_df(n=n, future_covariates=future_covariates) - return self.model.predict(predict_df) + return self.model.predict(predict_df, vectorized=True) def add_seasonality( self, diff --git a/requirements/core.txt b/requirements/core.txt index 0f245dcaae..89eedc9e2e 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -8,7 +8,7 @@ nfoursid>=1.0.0 numpy>=1.19.0 pandas>=1.0.5 pmdarima>=1.8.0 -prophet>=1.1 +prophet>=1.1.1 requests>=2.22.0 scikit-learn>=1.0.1 scipy>=1.3.2