diff --git a/experiments/mfles/README.md b/experiments/mfles/README.md new file mode 100644 index 00000000..6f50b48b --- /dev/null +++ b/experiments/mfles/README.md @@ -0,0 +1,22 @@ +# MFLES +A method to forecast time series based on Gradient Boosted Time Series Decomposition which treats traditional decomposition as the base estimator in the boosting process. Unlike normal gradient boosting, slight learning rates are applied at the component level (trend/seasonality/exogenous). + +The method derives its name from some of the underlying estimators that can enter into the boosting procedure, specifically: a simple Median, Fourier functions for seasonality, a simple/piecewise Linear trend, and Exponential Smoothing. + +## Gradient Boosted Time Series Decomposition Theory +The idea is pretty simple, take a process like decomposition and view it as +a type of 'psuedo' gradient boosting since we are passing residuals around +simlar to standard gradient boosting. Then apply gradient boosting approaches +such as iterating with a global mechanism to control the process and introduce +learning rates for each of the components in the process such as trend or +seasonality or exogenous. By doing this we graduate from this 'psuedo' approach +to full blown gradient boosting. + +## Some Benchmarks +Average SMAPE from a few M4 datasets +| Dataset | AutoMFLES | AutoETS | +| -------- | ------- | ------- | +| Monthly | 12.91 | 13.59* | +| Hourly | 11.73 | 17.19 | +| Weekly | 8.18 | 8.64 | +| Quarterly | 10.72 | 10.25 | diff --git a/experiments/mfles/statsforecast_auto_mfles_benchmark.ipynb b/experiments/mfles/statsforecast_auto_mfles_benchmark.ipynb new file mode 100644 index 00000000..e876c98a --- /dev/null +++ b/experiments/mfles/statsforecast_auto_mfles_benchmark.ipynb @@ -0,0 +1,556 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sPVGa3t5gaM-", + "outputId": "786da4cb-43c6-4c92-ffe2-4005452a12b3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting git+https://github.com/tblume1992/statsforecast.git@feature/mfles\n", + " Cloning https://github.com/tblume1992/statsforecast.git (to revision feature/mfles) to /tmp/pip-req-build-y37borof\n", + " Running command git clone --filter=blob:none --quiet https://github.com/tblume1992/statsforecast.git /tmp/pip-req-build-y37borof\n", + " Running command git checkout -b feature/mfles --track origin/feature/mfles\n", + " Switched to a new branch 'feature/mfles'\n", + " Branch 'feature/mfles' set up to track remote branch 'feature/mfles' from 'origin'.\n", + " Resolved https://github.com/tblume1992/statsforecast.git to commit 5fc9e76fcd177256e6dcc617f231d3c3f02cdae2\n", + " Running command git submodule update --init --recursive -q\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (2.2.1)\n", + "Collecting coreforecast>=0.0.9 (from statsforecast==1.7.5)\n", + " Downloading coreforecast-0.0.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (223 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m223.4/223.4 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numba>=0.55.0 in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (0.58.1)\n", + "Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (1.25.2)\n", + "Requirement already satisfied: pandas>=1.3.5 in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (2.0.3)\n", + "Requirement already satisfied: scipy>=1.7.3 in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (1.11.4)\n", + "Requirement already satisfied: statsmodels>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (0.14.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (4.66.4)\n", + "Collecting fugue>=0.8.1 (from statsforecast==1.7.5)\n", + " Downloading fugue-0.9.0-py3-none-any.whl (278 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m278.2/278.2 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting utilsforecast>=0.1.4 (from statsforecast==1.7.5)\n", + " Downloading utilsforecast-0.1.10-py3-none-any.whl (40 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: threadpoolctl in /usr/local/lib/python3.10/dist-packages (from statsforecast==1.7.5) (3.5.0)\n", + "Collecting triad>=0.9.6 (from fugue>=0.8.1->statsforecast==1.7.5)\n", + " Downloading triad-0.9.6-py3-none-any.whl (62 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.1/62.1 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting adagio>=0.2.4 (from fugue>=0.8.1->statsforecast==1.7.5)\n", + " Downloading adagio-0.2.4-py3-none-any.whl (26 kB)\n", + "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba>=0.55.0->statsforecast==1.7.5) (0.41.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->statsforecast==1.7.5) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->statsforecast==1.7.5) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.3.5->statsforecast==1.7.5) (2024.1)\n", + "Requirement already satisfied: patsy>=0.5.6 in /usr/local/lib/python3.10/dist-packages (from statsmodels>=0.13.2->statsforecast==1.7.5) (0.5.6)\n", + "Requirement already satisfied: packaging>=21.3 in /usr/local/lib/python3.10/dist-packages (from statsmodels>=0.13.2->statsforecast==1.7.5) (24.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from patsy>=0.5.6->statsmodels>=0.13.2->statsforecast==1.7.5) (1.16.0)\n", + "Requirement already satisfied: pyarrow>=6.0.1 in /usr/local/lib/python3.10/dist-packages (from triad>=0.9.6->fugue>=0.8.1->statsforecast==1.7.5) (14.0.2)\n", + "Requirement already satisfied: fsspec>=2022.5.0 in /usr/local/lib/python3.10/dist-packages (from triad>=0.9.6->fugue>=0.8.1->statsforecast==1.7.5) (2023.6.0)\n", + "Collecting fs (from triad>=0.9.6->fugue>=0.8.1->statsforecast==1.7.5)\n", + " Downloading fs-2.4.16-py2.py3-none-any.whl (135 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.3/135.3 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting appdirs~=1.4.3 (from fs->triad>=0.9.6->fugue>=0.8.1->statsforecast==1.7.5)\n", + " Downloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from fs->triad>=0.9.6->fugue>=0.8.1->statsforecast==1.7.5) (67.7.2)\n", + "Building wheels for collected packages: statsforecast\n", + " Building wheel for statsforecast (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for statsforecast: filename=statsforecast-1.7.5-py3-none-any.whl size=133406 sha256=decc0ffe89ec2e1127dd6f22fabde985460a87aad8639e5ccae07b64a95a3284\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-9xisvg_t/wheels/2d/9b/1b/8efa519ae3dd0763c2ae88bcb5220b40c0c701aa006eeed286\n", + "Successfully built statsforecast\n", + "Installing collected packages: appdirs, fs, coreforecast, utilsforecast, triad, adagio, fugue, statsforecast\n", + "Successfully installed adagio-0.2.4 appdirs-1.4.4 coreforecast-0.0.9 fs-2.4.16 fugue-0.9.0 statsforecast-1.7.5 triad-0.9.6 utilsforecast-0.1.10\n" + ] + } + ], + "source": [ + "pip install git+https://github.com/tblume1992/statsforecast.git@feature/mfles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "J-dBpWnVgtCV", + "outputId": "289b838b-a828-4903-ba83-6f841251595f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting mfles\n", + " Downloading MFLES-0.2.4-py3-none-any.whl (14 kB)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from mfles) (1.25.2)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from mfles) (2.0.3)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from mfles) (4.66.4)\n", + "Requirement already satisfied: numba in /usr/local/lib/python3.10/dist-packages (from mfles) (0.58.1)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from mfles) (3.7.1)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (4.51.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (1.4.5)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (24.0)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (9.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->mfles) (2.8.2)\n", + "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba->mfles) (0.41.1)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->mfles) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->mfles) (2024.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->mfles) (1.16.0)\n", + "Installing collected packages: mfles\n", + "Successfully installed mfles-0.2.4\n" + ] + } + ], + "source": [ + "pip install mfles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rtpQx2tthr7L", + "outputId": "e7d70891-4af7-4cd5-b961-d91aac7e33b4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/statsforecast/core.py:27: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/Model.py:41: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def lasso_nb(X, y, alpha, tol=0.001, maxiter=10000):\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/Model.py:164: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def median(y, seasonal_period):\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/Model.py:180: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def ols(X, y):\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/Model.py:185: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def wls(X, y, weights):\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/Model.py:191: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def _ols(X, y):\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/Model.py:196: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def ridge(X, y, lam):\n", + "/usr/local/lib/python3.10/dist-packages/MFLES/FeatureEngineering.py:81: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\n", + " def get_future_basis(basis_functions, forecast_horizon):\n" + ] + } + ], + "source": [ + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import numpy as np\n", + "from statsforecast.models import MSTL, AutoETS, AutoTBATS, AutoMFLES\n", + "from MFLES.Forecaster import MFLES as old_MFLES\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oLSZ-MniqfvX" + }, + "outputs": [], + "source": [ + "train_df = pd.read_csv('https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/Train/Hourly-train.csv')\n", + "test_df = pd.read_csv('https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/Test/Hourly-test.csv')\n", + "\n", + "train_df.index = train_df['V1']\n", + "train_df = train_df.drop('V1', axis = 1)\n", + "test_df.index = test_df['V1']\n", + "test_df = test_df.drop('V1', axis = 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "background_save": true, + "base_uri": "https://localhost:8080/" + }, + "id": "oCUNMkc7dWnn", + "outputId": "f4e55987-ba8b-4ba2-c0b7-879389aaaa3e" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/414 [00:00 Dict[str, Any]:\n", + " model = _MFLES(verbose=self.verbose, robust=self.robust)\n", + " fitted = model.fit(\n", + " y=y,\n", + " X=X,\n", + " seasonal_period=self.season_length,\n", + " fourier_order=self.fourier_order,\n", + " ma=self.ma,\n", + " alpha=self.alpha,\n", + " decay=self.decay,\n", + " n_changepoints=self.n_changepoints,\n", + " seasonal_lr=self.seasonal_lr,\n", + " linear_lr=self.trend_lr, \n", + " exogenous_lr=self.exogenous_lr, \n", + " rs_lr=self.residuals_lr,\n", + " cov_threshold=self.cov_threshold,\n", + " moving_medians=self.moving_medians,\n", + " max_rounds=self.max_rounds,\n", + " min_alpha=self.min_alpha,\n", + " max_alpha=self.max_alpha,\n", + " trend_penalty=self.trend_penalty,\n", + " multiplicative=self.multiplicative,\n", + " changepoints=self.changepoints,\n", + " smoother=self.smoother,\n", + " )\n", + " return {'model': model, 'fitted': fitted}\n", + " \n", + " def fit(self, y: np.ndarray, X: Optional[np.ndarray] = None) -> 'MFLES':\n", + " \"\"\"Fit the model\n", + "\n", + " Parameters\n", + " ----------\n", + " y : numpy.array\n", + " Clean time series of shape (t, ).\n", + " X : array-like, optional (default=None)\n", + " Exogenous of shape (t, n_x).\n", + "\n", + " Returns\n", + " -------\n", + " self : MFLES\n", + " Fitted MFLES object.\n", + " \"\"\"\n", + " self.model_ = self._fit(y=y, X=X)\n", + " self._store_cs(y=y, X=X)\n", + " residuals = y - self.model_['fitted']\n", + " self.model_['sigma'] = _calculate_sigma(residuals, y.size)\n", + " return self\n", + "\n", + " def predict(\n", + " self,\n", + " h: int,\n", + " X: Optional[np.ndarray] = None,\n", + " level: Optional[List[int]] = None,\n", + " ) -> Dict[str, Any]:\n", + " \"\"\"Predict with fitted MFLES.\n", + "\n", + " Parameters\n", + " ----------\n", + " h : int\n", + " Forecast horizon.\n", + " X : array-like, optional (default=None)\n", + " Exogenous of shape (h, n_x).\n", + " level: List[int]\n", + " Confidence levels (0-100) for prediction intervals.\n", + "\n", + " Returns\n", + " -------\n", + " forecasts : dict\n", + " Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n", + " \"\"\"\n", + " res = {\"mean\": self.model_[\"model\"].predict(forecast_horizon=h, X=X)}\n", + " if level is None:\n", + " return res\n", + " level = sorted(level)\n", + " if self.prediction_intervals is not None:\n", + " res = self._add_predict_conformal_intervals(res, level)\n", + " else:\n", + " raise Exception(\"You must pass `prediction_intervals` to compute them.\")\n", + " return res\n", + "\n", + " def predict_in_sample(self, level: Optional[List[int]] = None) -> Dict[str, Any]:\n", + " \"\"\"Access fitted SklearnModel insample predictions.\n", + "\n", + " Parameters\n", + " ----------\n", + " level : List[int]\n", + " Confidence levels (0-100) for prediction intervals.\n", + "\n", + " Returns\n", + " -------\n", + " forecasts : dict\n", + " Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions.\n", + " \"\"\"\n", + " res = {'fitted': self.model_['fitted']}\n", + " if level is not None:\n", + " level = sorted(level)\n", + " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " return res\n", + "\n", + " def forecast(\n", + " self,\n", + " y: np.ndarray,\n", + " h: int,\n", + " X: Optional[np.ndarray] = None,\n", + " X_future: Optional[np.ndarray] = None,\n", + " level: Optional[List[int]] = None,\n", + " fitted: bool = False,\n", + " ) -> Dict[str, Any]:\n", + " \"\"\"Memory Efficient MFLES predictions.\n", + "\n", + " This method avoids memory burden due from object storage.\n", + " It is analogous to `fit_predict` without storing information.\n", + " It assumes you know the forecast horizon in advance.\n", + "\n", + " Parameters\n", + " ----------\n", + " y : numpy.array\n", + " Clean time series of shape (t, ).\n", + " h : int\n", + " Forecast horizon.\n", + " X : array-like\n", + " Insample exogenous of shape (t, n_x).\n", + " X_future : array-like\n", + " Exogenous of shape (h, n_x).\n", + " level : List[int]\n", + " Confidence levels (0-100) for prediction intervals.\n", + " fitted : bool\n", + " Whether or not to return insample predictions.\n", + "\n", + " Returns\n", + " -------\n", + " forecasts : dict\n", + " Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n", + " \"\"\"\n", + " model = self._fit(y=y, X=X)\n", + " res = {\"mean\": model['model'].predict(forecast_horizon=h, X=X_future)}\n", + " if fitted:\n", + " res[\"fitted\"] = model['fitted']\n", + " if level is not None:\n", + " level = sorted(level)\n", + " if self.prediction_intervals is not None:\n", + " res = self._add_conformal_intervals(fcst=res, y=y, X=X, level=level)\n", + " else:\n", + " raise Exception(\"You must pass `prediction_intervals` to compute them.\")\n", + " if fitted:\n", + " residuals = y - res[\"fitted\"]\n", + " sigma = _calculate_sigma(residuals, y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(MFLES)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(MFLES.fit)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(MFLES.predict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(MFLES.predict_in_sample)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(MFLES.forecast)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "h = 12\n", + "X = np.random.rand(ap.size, 2)\n", + "X_future = np.random.rand(h, 2)\n", + "\n", + "mfles = MFLES()\n", + "test_class(mfles, x=deg_ts, X=X, X_future=X_future, h=h, skip_insample=False, test_forward=False)\n", + "\n", + "mfles = MFLES(prediction_intervals=ConformalIntervals(h=h, n_windows=2))\n", + "test_class(mfles, x=ap, X=X, X_future=X_future, h=h, skip_insample=False, level=[80, 95], test_forward=False)\n", + "fcst_mfles = mfles.forecast(ap, h, X=X, X_future=X_future, fitted=True, level=[80, 95])\n", + "_plot_insample_pi(fcst_mfles)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## AutoMFLES" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class AutoMFLES(_TS):\n", + " \"\"\"AutoMFLES\n", + " \n", + " Parameters\n", + " ----------\n", + " test_size : int\n", + " Forecast horizon used during cross validation.\n", + " season_length : int or list of int, optional (default=None)\n", + " Number of observations per unit of time. Ex: 24 Hourly data.\n", + " n_windows : int (default=2)\n", + " Number of windows used for cross validation.\n", + " config : dict, optional (default=None)\n", + " Mapping from parameter name (from the init arguments of MFLES) to a list of values to try.\n", + " If `None`, will use defaults.\n", + " step_size : int, optional (default=None)\n", + " Step size between each cross validation window. If `None` will be set to test_size.\n", + " metric : str (default='smape')\n", + " Metric used to select the best model. Possible options are: 'smape', 'mape', 'mse' and 'mae'.\n", + " verbose : bool (default=False)\n", + " Print debugging information.\n", + " prediction_intervals : Optional[ConformalIntervals]\n", + " Information to compute conformal prediction intervals.\n", + " This is required for generating future prediction intervals.\n", + " alias : str (default='AutoMFLES')\n", + " Custom name of the model.\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " test_size: int,\n", + " season_length: Optional[Union[int, List[int]]] = None,\n", + " n_windows: int = 2,\n", + " config: Optional[Dict[str, Any]] = None,\n", + " step_size: Optional[int] = None,\n", + " metric: str = 'smape',\n", + " verbose: bool = False,\n", + " prediction_intervals: Optional[ConformalIntervals] = None,\n", + " alias: str = 'AutoMFLES',\n", + " ):\n", + " try:\n", + " import sklearn # noqa: F401\n", + " except ImportError:\n", + " raise ImportError(\"MFLES requires scikit-learn.\") from None\n", + " self.season_length = season_length\n", + " self.n_windows = n_windows\n", + " self.test_size = test_size\n", + " self.config = config\n", + " self.step_size = step_size if step_size is not None else test_size\n", + " self.metric = metric\n", + " self.verbose = verbose\n", + " self.prediction_intervals = prediction_intervals\n", + " self.alias = alias\n", + "\n", + " def _fit(self, y: np.ndarray, X: Optional[np.ndarray] = None) -> Dict[str, Any]:\n", + " model = _MFLES(verbose=self.verbose)\n", + " optim_params = model.optimize(\n", + " y=y,\n", + " X=X,\n", + " test_size=self.test_size,\n", + " n_steps=self.n_windows,\n", + " step_size=self.step_size,\n", + " seasonal_period=self.season_length,\n", + " metric=self.metric,\n", + " params=self.config,\n", + " )\n", + " # the seasonal_period may've been found during the optimization\n", + " seasonal_period = optim_params.pop('seasonal_period', self.season_length)\n", + " fitted = model.fit(\n", + " y=y,\n", + " X=X,\n", + " seasonal_period=seasonal_period,\n", + " **optim_params,\n", + " )\n", + " return {'model': model, 'fitted': fitted}\n", + "\n", + " def fit(self, y: np.ndarray, X: Optional[np.ndarray] = None) -> 'AutoMFLES':\n", + " \"\"\"Fit the model\n", + "\n", + " Parameters\n", + " ----------\n", + " y : numpy.array\n", + " Clean time series of shape (t, ).\n", + " X : array-like, optional (default=None)\n", + " Exogenous of shape (t, n_x).\n", + "\n", + " Returns\n", + " -------\n", + " self : AutoMFLES\n", + " Fitted AutoMFLES object.\n", + " \"\"\" \n", + " self.model_ = self._fit(y=y, X=X)\n", + " self._store_cs(y=y, X=X)\n", + " residuals = y - self.model_[\"fitted\"]\n", + " self.model_[\"sigma\"] = _calculate_sigma(residuals, y.size)\n", + " return self\n", + "\n", + " def predict(\n", + " self,\n", + " h: int,\n", + " X: Optional[np.ndarray] = None,\n", + " level: Optional[List[int]] = None,\n", + " ) -> Dict[str, Any]:\n", + " \"\"\"Predict with fitted AutoMFLES.\n", + "\n", + " Parameters\n", + " ----------\n", + " h : int\n", + " Forecast horizon.\n", + " X : array-like, optional (default=None)\n", + " Exogenous of shape (h, n_x).\n", + " level: List[int]\n", + " Confidence levels (0-100) for prediction intervals.\n", + "\n", + " Returns\n", + " -------\n", + " forecasts : dict\n", + " Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n", + " \"\"\" \n", + " res = {\"mean\": self.model_[\"model\"].predict(forecast_horizon=h, X=X)}\n", + " if level is None:\n", + " return res\n", + " level = sorted(level)\n", + " if self.prediction_intervals is not None:\n", + " res = self._add_predict_conformal_intervals(res, level)\n", + " else:\n", + " raise Exception(\"You must pass `prediction_intervals` to compute them.\")\n", + " return res\n", + "\n", + " def predict_in_sample(self, level: Optional[List[int]] = None) -> Dict[str, Any]:\n", + " \"\"\"Access fitted AutoMFLES insample predictions.\n", + "\n", + " Parameters\n", + " ----------\n", + " level : List[int]\n", + " Confidence levels (0-100) for prediction intervals.\n", + "\n", + " Returns\n", + " -------\n", + " forecasts : dict\n", + " Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions.\n", + " \"\"\"\n", + " res = {\"fitted\": self.model_[\"fitted\"]}\n", + " if level is not None:\n", + " level = sorted(level)\n", + " res = _add_fitted_pi(res=res, se=self.model_[\"sigma\"], level=level)\n", + " return res\n", + "\n", + " def forecast(\n", + " self,\n", + " y: np.ndarray,\n", + " h: int,\n", + " X: Optional[np.ndarray] = None,\n", + " X_future: Optional[np.ndarray] = None,\n", + " level: Optional[List[int]] = None,\n", + " fitted: bool = False,\n", + " ) -> Dict[str, Any]:\n", + " \"\"\"Memory Efficient AutoMFLES predictions.\n", + "\n", + " This method avoids memory burden due from object storage.\n", + " It is analogous to `fit_predict` without storing information.\n", + " It assumes you know the forecast horizon in advance.\n", + "\n", + " Parameters\n", + " ----------\n", + " y : numpy.array\n", + " Clean time series of shape (t, ).\n", + " h : int\n", + " Forecast horizon.\n", + " X : array-like\n", + " Insample exogenous of shape (t, n_x).\n", + " X_future : array-like\n", + " Exogenous of shape (h, n_x).\n", + " level : List[int]\n", + " Confidence levels (0-100) for prediction intervals.\n", + " fitted : bool\n", + " Whether or not to return insample predictions.\n", + "\n", + " Returns\n", + " -------\n", + " forecasts : dict\n", + " Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n", + " \"\"\"\n", + " model = self._fit(y=y, X=X)\n", + " res = {\"mean\": model[\"model\"].predict(forecast_horizon=h, X=X_future)}\n", + " if fitted:\n", + " res[\"fitted\"] = model[\"fitted\"]\n", + " if level is not None:\n", + " level = sorted(level)\n", + " if self.prediction_intervals is not None:\n", + " res = self._add_conformal_intervals(fcst=res, y=y, X=X, level=level)\n", + " else:\n", + " raise Exception(\"You must pass `prediction_intervals` to compute them.\")\n", + " if fitted:\n", + " residuals = y - res[\"fitted\"]\n", + " sigma = _calculate_sigma(residuals, y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "h = 12\n", + "X = np.random.rand(ap.size, 2)\n", + "X_future = np.random.rand(h, 2)\n", + "\n", + "auto_mfles = AutoMFLES(test_size=h, season_length=12)\n", + "test_class(auto_mfles, x=deg_ts, X=X, X_future=X_future, h=h, skip_insample=False, test_forward=False)\n", + "\n", + "auto_mfles = AutoMFLES(test_size=h, season_length=12, prediction_intervals=ConformalIntervals(h=h, n_windows=2))\n", + "test_class(auto_mfles, x=ap, X=X, X_future=X_future, h=h, skip_insample=False, level=[80, 95], test_forward=False)\n", + "fcst_auto_mfles = auto_mfles.forecast(ap, h, X=X, X_future=X_future, fitted=True, level=[80, 95])\n", + "_plot_insample_pi(fcst_auto_mfles)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -11889,9 +12382,6 @@ " self.constant = constant\n", " self.alias = alias\n", " \n", - " def __repr__(self):\n", - " return self.alias\n", - " \n", " def fit(\n", " self, \n", " y: np.ndarray,\n", diff --git a/nbs/src/mfles.ipynb b/nbs/src/mfles.ipynb new file mode 100644 index 00000000..d1453c82 --- /dev/null +++ b/nbs/src/mfles.ipynb @@ -0,0 +1,836 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b714247f-d5f5-4937-a0ce-ea07da9915a8", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a433f56c-136e-404c-a488-ddb05ab947c9", + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp mfles" + ] + }, + { + "cell_type": "markdown", + "id": "eb487f24-09bd-480d-b672-f8242f6f3850", + "metadata": {}, + "source": [ + "# MFLES model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "577a1635-ea8a-454e-a77e-a133e894baa8", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "import itertools\n", + "import warnings\n", + "\n", + "import numpy as np\n", + "from coreforecast.exponentially_weighted import exponentially_weighted_mean\n", + "from coreforecast.rolling import rolling_mean\n", + "from numba import njit\n", + "\n", + "from statsforecast.utils import _ensure_float" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba9d2a56-4a7d-4062-9cf9-dcaebb218ccf", + "metadata": {}, + "outputs": [], + "source": [ + "#| exporti\n", + "# utility functions\n", + "def calc_mse(y_true, y_pred):\n", + " sq_err = (y_true - y_pred) ** 2\n", + " return np.mean(sq_err)\n", + "\n", + "def calc_mae(y_true, y_pred):\n", + " abs_err = np.abs(y_true - y_pred)\n", + " return np.mean(abs_err)\n", + "\n", + "def calc_mape(y_true, y_pred):\n", + " pct_err = np.abs((y_true - y_pred) / (y_pred + 1e-6))\n", + " return np.mean(pct_err)\n", + "\n", + "def calc_smape(y_true, y_pred):\n", + " pct_err = 2 * np.abs(y_true - y_pred) / np.abs(y_true + y_pred + 1e-6)\n", + " return np.mean(pct_err)\n", + "\n", + "_metric2fn = {\n", + " \"mse\": calc_mse,\n", + " \"mae\": calc_mae,\n", + " \"mape\": calc_mape,\n", + " \"smape\": calc_smape,\n", + "}\n", + "\n", + "def cross_validation(y, X, test_size, n_splits, model_obj, metric, step_size=1, **kwargs):\n", + " metrics = []\n", + " metric_fn = _metric2fn[metric]\n", + " residuals = []\n", + " if X is None:\n", + " exogenous = None\n", + " else:\n", + " exogenous = X.copy()\n", + " for split in range(n_splits):\n", + " train_y = y[:-(split*step_size + test_size)]\n", + " test_y = y[len(train_y): len(train_y) + test_size]\n", + " if exogenous is not None:\n", + " train_X = exogenous[:-(split*step_size + test_size), :]\n", + " test_X = exogenous[len(train_y): len(train_y) + test_size, :]\n", + " else:\n", + " train_X = None\n", + " test_X = None\n", + " model_obj.fit(train_y, X=train_X, **kwargs)\n", + " prediction = model_obj.predict(test_size, X=test_X)\n", + " metrics.append(metric_fn(test_y, prediction))\n", + " residuals.append(test_y - prediction)\n", + " return {'metric': np.mean(metrics), 'residuals': residuals}\n", + "\n", + "def logic_check(keys_to_check, keys):\n", + " return set(keys_to_check).issubset(keys)\n", + "\n", + "def logic_layer(param_dict):\n", + " keys = param_dict.keys()\n", + " # if param_dict['n_changepoints'] is None:\n", + " # if param_dict['decay'] != -1:\n", + " # return False\n", + " if logic_check(['seasonal_period', 'max_rounds'], keys):\n", + " if param_dict['seasonal_period'] is None:\n", + " if param_dict['max_rounds'] < 4:\n", + " return False\n", + " if logic_check(['smoother', 'ma'], keys):\n", + " if param_dict['smoother']:\n", + " if param_dict['ma'] is not None:\n", + " return False\n", + " if logic_check(['seasonal_period', 'seasonality_weights'], keys):\n", + " if param_dict['seasonality_weights']:\n", + " if param_dict['seasonal_period'] is None:\n", + " return False\n", + " return True\n", + "\n", + "def default_configs(seasonal_period, configs=None):\n", + " if configs is None:\n", + " if seasonal_period is not None:\n", + " if not isinstance(seasonal_period, list):\n", + " seasonal_period = [seasonal_period]\n", + " configs = {\n", + " 'seasonality_weights': [True, False],\n", + " 'smoother': [True, False],\n", + " 'ma': [int(min(seasonal_period)), int(min(seasonal_period)/2),None],\n", + " 'seasonal_period': [None, seasonal_period],\n", + " }\n", + " else:\n", + " configs = {\n", + " 'smoother': [True, False],\n", + " 'cov_threshold': [.5, -1],\n", + " 'max_rounds': [5, 20],\n", + " 'seasonal_period': [None],\n", + " }\n", + " keys = configs.keys()\n", + " combinations = itertools.product(*configs.values())\n", + " ds = [dict(zip(keys,cc)) for cc in combinations]\n", + " ds = [i for i in ds if logic_layer(i)]\n", + " return ds\n", + "\n", + "def cap_outliers(series, outlier_cap=3):\n", + " mean = np.mean(series)\n", + " std = np.std(series)\n", + " return series.clip(\n", + " min=mean - outlier_cap * std,\n", + " max=mean + outlier_cap * std\n", + " )\n", + "\n", + "def set_fourier(period):\n", + " if period < 10:\n", + " fourier = 5\n", + " elif period < 70:\n", + " fourier = 10\n", + " else:\n", + " fourier = 15\n", + " return fourier\n", + "\n", + "def calc_trend_strength(resids, deseasonalized):\n", + " return max(0, 1 - (np.var(resids) / np.var(deseasonalized)))\n", + "\n", + "def calc_seas_strength(resids, detrended):\n", + " return max(0, 1 - (np.var(resids) / np.var(detrended)))\n", + "\n", + "def calc_rsq(y, fitted):\n", + " try:\n", + " mean_y = np.mean(y)\n", + " ssres = np.sum((y - fitted) ** 2)\n", + " sstot = np.sum((y - mean_y) ** 2)\n", + " return 1 - (ssres / sstot)\n", + " except:\n", + " return 0\n", + "\n", + "def calc_cov(y, mult=1):\n", + " if mult:\n", + " # source http://medcraveonline.com/MOJPB/MOJPB-06-00200.pdf\n", + " res = np.sqrt(np.exp(np.log(10)*(np.std(y)**2) - 1))\n", + " else:\n", + " res = np.std(y)\n", + " mean = np.mean(y)\n", + " if mean != 0:\n", + " res = res / mean\n", + " return res\n", + "\n", + "def get_seasonality_weights(y, seasonal_period):\n", + " return 1 + np.arange(y.size) // seasonal_period\n", + "\n", + "# feature engineering functions\n", + "def get_fourier_series(length, seasonal_period, fourier_order):\n", + " x = 2 * np.pi * np.arange(1, fourier_order + 1) / seasonal_period\n", + " t = np.arange(1, length + 1).reshape(-1, 1)\n", + " x = x * t\n", + " return np.hstack([np.cos(x), np.sin(x)])\n", + "\n", + "@njit\n", + "def get_basis(y, n_changepoints, decay=-1, gradient_strategy=0):\n", + " if n_changepoints < 1:\n", + " return np.arange(y.size, dtype=np.float64).reshape(-1, 1)\n", + " y = y.copy()\n", + " y -= y[0]\n", + " n = len(y)\n", + " if gradient_strategy:\n", + " gradients = np.abs(y[:-1] - y[1:])\n", + " initial_point = y[0]\n", + " final_point = y[-1]\n", + " mean_y = np.mean(y)\n", + " changepoints = np.empty(shape=(len(y), n_changepoints + 1))\n", + " array_splits = []\n", + " for i in range(1, n_changepoints + 1):\n", + " i = n_changepoints - i + 1\n", + " if gradient_strategy:\n", + " cps = np.argsort(-gradients)\n", + " cps = cps[cps > 0.1 * len(gradients)]\n", + " cps = cps[cps < 0.9 * len(gradients)]\n", + " split_point = cps[i-1]\n", + " array_splits.append(y[:split_point])\n", + " else:\n", + " split_point = len(y)//i\n", + " array_splits.append(y[:split_point])\n", + " y = y[split_point:]\n", + " len_splits = 0\n", + " for i in range(n_changepoints):\n", + " if gradient_strategy:\n", + " len_splits = len(array_splits[i])\n", + " else:\n", + " len_splits += len(array_splits[i])\n", + " moving_point = array_splits[i][-1]\n", + " left_basis = np.linspace(initial_point, moving_point, len_splits)\n", + " if decay is None:\n", + " end_point = final_point\n", + " else:\n", + " if decay == -1:\n", + " dd = moving_point**2\n", + " if mean_y != 0:\n", + " dd /= mean_y**2\n", + " if dd > 0.99:\n", + " dd = 0.99\n", + " if dd < 0.001:\n", + " dd = 0.001\n", + " end_point = moving_point - ((moving_point - final_point) * (1 - dd))\n", + " else:\n", + " end_point = moving_point - ((moving_point - final_point) * (1 - decay))\n", + " right_basis = np.linspace(moving_point, end_point, n - len_splits + 1)\n", + " changepoints[:, i] = np.append(left_basis, right_basis[1:])\n", + " changepoints[:, i+1] = np.ones(n)\n", + " return changepoints\n", + "\n", + "\n", + "def get_future_basis(basis_functions, forecast_horizon):\n", + " n_components = np.shape(basis_functions)[1]\n", + " slopes = np.gradient(basis_functions)[0][-1, :]\n", + " future_basis = np.arange(0, forecast_horizon + 1)\n", + " future_basis += len(basis_functions)\n", + " future_basis = np.transpose([future_basis] * n_components)\n", + " future_basis = future_basis * slopes\n", + " future_basis = future_basis + (basis_functions[-1, :] - future_basis[0, :])\n", + " return future_basis[1:, :]\n", + "\n", + "def lasso_nb(X, y, alpha, tol=0.001, maxiter=10000):\n", + " from sklearn.linear_model import Lasso\n", + " from sklearn.exceptions import ConvergenceWarning\n", + "\n", + " with warnings.catch_warnings(record=False):\n", + " warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)\n", + " lasso = Lasso(alpha=alpha, fit_intercept=False, tol=tol, max_iter=maxiter)\n", + " lasso.fit(X, y)\n", + " return lasso.coef_\n", + "\n", + "# different models\n", + "@njit\n", + "def siegel_repeated_medians(x, y):\n", + " # Siegel repeated medians regression\n", + " n = y.size\n", + " slopes = np.empty_like(y)\n", + " slopes_sub = np.empty(shape=n - 1, dtype=y.dtype)\n", + " for i in range(n):\n", + " k = 0\n", + " for j in range(n):\n", + " if i == j:\n", + " continue\n", + " xd = x[j] - x[i]\n", + " if xd == 0:\n", + " slope = 0\n", + " else:\n", + " slope = (y[j] - y[i]) / xd\n", + " slopes_sub[k] = slope\n", + " k += 1\n", + " slopes[i] = np.median(slopes_sub)\n", + " ints = y - slopes * x\n", + " return x * np.median(slopes) + np.median(ints)\n", + "\n", + "def ses_ensemble(y, min_alpha=0.05, max_alpha=1.0, smooth=False, order=1):\n", + " #bad name but does either a ses ensemble or simple moving average\n", + " if smooth:\n", + " results = np.zeros_like(y)\n", + " alphas = np.arange(min_alpha, max_alpha, 0.05)\n", + " for alpha in alphas:\n", + " results += exponentially_weighted_mean(y, alpha)\n", + " results = results / len(alphas)\n", + " else:\n", + " results = rolling_mean(y, order + 1)\n", + " results[:order + 1] = y[:order + 1]\n", + " return results\n", + "\n", + "def fast_ols(x, y):\n", + " \"\"\"Simple OLS for two data sets.\"\"\"\n", + " M = x.size\n", + " x_sum = x.sum()\n", + " y_sum = y.sum()\n", + " x_sq_sum = x @ x\n", + " x_y_sum = x @ y\n", + " slope = (M * x_y_sum - x_sum * y_sum) / (M * x_sq_sum - x_sum**2)\n", + " intercept = (y_sum - slope * x_sum) / M\n", + " return slope * x + intercept\n", + "\n", + "def median(y, seasonal_period):\n", + " if seasonal_period is None:\n", + " return np.full_like(y, np.median(y))\n", + " full_periods, resid = divmod(len(y), seasonal_period)\n", + " period_medians = np.median(\n", + " y[:full_periods * seasonal_period].reshape(full_periods, seasonal_period),\n", + " axis=1,\n", + " )\n", + " medians = np.repeat(period_medians, seasonal_period)\n", + " if resid:\n", + " remainder_median = np.median(y[-seasonal_period:])\n", + " medians = np.append(medians, np.repeat(remainder_median, resid))\n", + " return medians\n", + "\n", + "def ols(X, y):\n", + " coefs = np.linalg.pinv(X.T.dot(X)).dot(X.T.dot(y))\n", + " return X @ coefs\n", + "\n", + "def wls(X, y, weights):\n", + " weighted_X_T = X.T @ np.diag(weights)\n", + " coefs = np.linalg.pinv(weighted_X_T.dot(X)).dot(weighted_X_T.dot(y))\n", + " return X @ coefs\n", + "\n", + "def _ols(X, y):\n", + " return np.linalg.pinv(X.T.dot(X)).dot(X.T.dot(y))\n", + "\n", + "class OLS:\n", + " def fit(self, X, y):\n", + " self.coefs = _ols(X, y)\n", + " def predict(self, X):\n", + " return X @ self.coefs\n", + "\n", + "class Zeros:\n", + " def predict(self, X):\n", + " return np.zeros(X.shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b32c4213-8932-4f89-9f70-9375de577ca5", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class MFLES:\n", + " def __init__(self, verbose=1, robust=None):\n", + " self.penalty = None\n", + " self.trend = None\n", + " self.seasonality = None\n", + " self.robust = robust\n", + " self.const = None\n", + " self.aic = None\n", + " self.upper = None\n", + " self.lower= None\n", + " self.exogenous_models = None\n", + " self.verbose = verbose\n", + " self.predicted = None\n", + "\n", + " def fit(self,\n", + " y,\n", + " seasonal_period=None,\n", + " X=None,\n", + " fourier_order=None,\n", + " ma=None,\n", + " alpha=1.0,\n", + " decay=-1,\n", + " n_changepoints=.25,\n", + " seasonal_lr=.9,\n", + " rs_lr=1,\n", + " exogenous_lr=1,\n", + " exogenous_estimator=OLS,\n", + " exogenous_params={},\n", + " linear_lr=.9,\n", + " cov_threshold=.7,\n", + " moving_medians=False,\n", + " max_rounds=50,\n", + " min_alpha=.05,\n", + " max_alpha=1.0,\n", + " round_penalty=0.0001,\n", + " trend_penalty=True,\n", + " multiplicative=None,\n", + " changepoints=True,\n", + " smoother=False,\n", + " seasonality_weights=False):\n", + " \"\"\"\n", + "\n", + "\n", + " Parameters\n", + " ----------\n", + " y : np.array\n", + " the time series as a numpy array.\n", + " seasonal_period : int, optional\n", + " DESCRIPTION. The default is None.\n", + " fourier_order : int, optional\n", + " How many fourier sin/cos pairs to create, the larger the number the more complex of a seasonal pattern can be fitted. A lower number leads to smoother results. This is auto-set based on seasonal_period. The default is None.\n", + " ma : int, optional\n", + " The moving average order to use, this is auto-set based on internal logic. Passing 4 would fit a 4 period moving average on the residual component. The default is None.\n", + " alpha : TYPE, optional\n", + " The alpha which is used in fitting the underlying LASSO when using piecewise functions. The default is 1.0.\n", + " decay : float, optional\n", + " Effects the slopes of the piecewise-linear basis function. The default is -1.\n", + " n_changepoints : float, optional\n", + " The number of changepoint knots to place, a default of .25 with place .25 * series length number of knots. The default is .25.\n", + " seasonal_lr : float, optional\n", + " A shrinkage parameter (0 the more smooth your fit. The default is 10.\n", + " min_alpha : float, optional\n", + " The min alpha in the SES ensemble. The default is .05.\n", + " max_alpha : float, optional\n", + " The max alpha used in the SES ensemble. The default is 1.0.\n", + " trend_penalty : boolean, optional\n", + " Whether to apply a simple penalty to the lienar trend component, very useful for dealing with the potentially dangerous piecewise trend. The default is True.\n", + " multiplicative : boolean, optional\n", + " Auto-set based on internal logic, but if given True it will simply take the log of the time series. The default is None.\n", + " changepoints : boolean, optional\n", + " Whether to fit for changepoints if all other logic allows for it, by setting False then MFLES will not ever fit a piecewise trend. The default is True.\n", + " smoother : boolean, optional\n", + " If True then a simple exponential ensemble will be used rather than auto settings. The default is False.\n", + "\n", + " Returns\n", + " -------\n", + " None.\n", + "\n", + " \"\"\"\n", + " if cov_threshold == -1:\n", + " cov_threshold = 10000\n", + " n = len(y)\n", + " y = _ensure_float(y)\n", + " self.exogenous_lr = exogenous_lr\n", + " if multiplicative is None:\n", + " if seasonal_period is None:\n", + " multiplicative = False\n", + " else:\n", + " multiplicative = True\n", + " if y.min() <= 0:\n", + " multiplicative = False\n", + " if multiplicative:\n", + " self.const = y.min()\n", + " y = np.log(y)\n", + " else:\n", + " self.const = None\n", + " self.std = np.std(y)\n", + " self.mean = np.mean(y)\n", + " y = y - self.mean\n", + " if self.std > 0:\n", + " y = y / self.std\n", + " if seasonal_period is not None:\n", + " if not isinstance(seasonal_period, list):\n", + " seasonal_period = [seasonal_period] \n", + " if n < 4 or np.all(y == np.mean(y)):\n", + " if self.verbose:\n", + " if n < 4:\n", + " print('series is too short (<4), defaulting to naive')\n", + " else:\n", + " print(f'input is constant with value {y[0]}, defaulting to naive')\n", + " self.trend = np.append(y[-1], y[-1])\n", + " self.seasonality = np.zeros(len(y))\n", + " self.trend_penalty = False\n", + " self.mean = y[-1]\n", + " self.std = 0\n", + " self.exo_model = [Zeros()]\n", + " return np.tile(y[-1], len(y))\n", + " og_y = y\n", + " self.og_y = og_y\n", + " y = y.copy()\n", + " if n_changepoints is None:\n", + " changepoints = False\n", + " if isinstance(n_changepoints, float) and n_changepoints < 1:\n", + " n_changepoints = int(n_changepoints * n)\n", + " self.linear_component = np.zeros(n)\n", + " self.seasonal_component = np.zeros(n)\n", + " self.ses_component = np.zeros(n)\n", + " self.median_component = np.zeros(n)\n", + " self.exogenous_component = np.zeros(n)\n", + " self.exo_model = []\n", + " self.round_cost = []\n", + " self.trend_penalty = trend_penalty\n", + " if moving_medians and seasonal_period is not None:\n", + " fitted = median(y, max(seasonal_period))\n", + " else:\n", + " fitted = median(y, None)\n", + " self.median_component += fitted\n", + " self.trend = np.append(fitted.copy()[-1:], fitted.copy()[-1:])\n", + " mse = None\n", + " equal = 0\n", + " if ma is None:\n", + " ma_cycle = itertools.cycle([1])\n", + " else:\n", + " if not isinstance(ma, list):\n", + " ma = [ma]\n", + " ma_cycle = itertools.cycle(ma)\n", + " if seasonal_period is not None:\n", + " seasons_cycle = itertools.cycle(list(range(len(seasonal_period))))\n", + " self.seasonality = np.zeros(max(seasonal_period))\n", + " fourier_series = []\n", + " for period in seasonal_period:\n", + " if fourier_order is None:\n", + " fourier = set_fourier(period)\n", + " else:\n", + " fourier = fourier_order\n", + " fourier_series.append(get_fourier_series(n,\n", + " period,\n", + " fourier))\n", + " if seasonality_weights:\n", + " cycle_weights = []\n", + " for period in seasonal_period:\n", + " cycle_weights.append(get_seasonality_weights(y, period))\n", + " else:\n", + " self.seasonality = None\n", + " for i in range(max_rounds):\n", + " resids = y - fitted\n", + " if mse is None:\n", + " mse = calc_mse(y, fitted)\n", + " else:\n", + " if mse <= calc_mse(y, fitted):\n", + " if equal == 6:\n", + " break\n", + " equal += 1\n", + " else:\n", + " mse = calc_mse(y, fitted)\n", + " self.round_cost.append(mse)\n", + " if seasonal_period is not None:\n", + " seasonal_period_cycle = next(seasons_cycle)\n", + " if seasonality_weights:\n", + " seas = wls(fourier_series[seasonal_period_cycle],\n", + " resids,\n", + " cycle_weights[seasonal_period_cycle])\n", + " else:\n", + " seas = ols(fourier_series[seasonal_period_cycle],\n", + " resids)\n", + " seas = seas * seasonal_lr\n", + " component_mse = calc_mse(y, fitted + seas)\n", + " if mse > component_mse:\n", + " mse = component_mse\n", + " fitted += seas\n", + " resids = y - fitted\n", + " self.seasonality += np.resize(seas[-seasonal_period[seasonal_period_cycle]:],\n", + " len(self.seasonality))\n", + " self.seasonal_component += seas\n", + " if X is not None and i > 0:\n", + " model_obj = exogenous_estimator(**exogenous_params)\n", + " model_obj.fit(X, resids)\n", + " self.exo_model.append(model_obj)\n", + " _fitted_values = model_obj.predict(X) * exogenous_lr\n", + " self.exogenous_component += _fitted_values\n", + " fitted += _fitted_values\n", + " resids = y - fitted\n", + " if i % 2: #if even get linear piece, allows for multiple seasonality fitting a bit more\n", + " if self.robust:\n", + " tren = siegel_repeated_medians(x=np.arange(n, dtype=resids.dtype), y=resids)\n", + " else:\n", + " if i==1 or not changepoints:\n", + " tren = fast_ols(x=np.arange(n),\n", + " y=resids)\n", + " else:\n", + " cps = min(n_changepoints, int(.1*n))\n", + " lbf = get_basis(y=resids,\n", + " n_changepoints=cps,\n", + " decay=decay)\n", + " tren = np.dot(lbf, lasso_nb(lbf, resids, alpha=alpha))\n", + " tren = tren * linear_lr\n", + " component_mse = calc_mse(y, fitted + tren)\n", + " if mse > component_mse:\n", + " mse = component_mse\n", + " fitted += tren\n", + " self.linear_component += tren\n", + " self.trend += tren[-2:]\n", + " if i == 1:\n", + " self.penalty = calc_rsq(resids, tren)\n", + " elif i > 4 and not i % 2:\n", + " if smoother is None:\n", + " if seasonal_period is not None:\n", + " len_check = int(max(seasonal_period))\n", + " else:\n", + " len_check = 12\n", + " if resids[-1] > np.mean(resids[-len_check:-1]) + 3 * np.std(resids[-len_check:-1]):\n", + " smoother = 0\n", + " if resids[-1] < np.mean(resids[-len_check:-1]) - 3 * np.std(resids[-len_check:-1]):\n", + " smoother = 0\n", + " if resids[-2] > np.mean(resids[-len_check:-2]) + 3 * np.std(resids[-len_check:-2]):\n", + " smoother = 0\n", + " if resids[-2] < np.mean(resids[-len_check:-2]) - 3 * np.std(resids[-len_check:-2]):\n", + " smoother = 0\n", + " if smoother is None:\n", + " smoother = 1\n", + " else:\n", + " resids[-2:] = cap_outliers(resids, 3)[-2:]\n", + " tren = ses_ensemble(resids,\n", + " min_alpha=min_alpha,\n", + " max_alpha=max_alpha,\n", + " smooth=smoother*1,\n", + " order=next(ma_cycle)\n", + " )\n", + " tren = tren * rs_lr\n", + " component_mse = calc_mse(y, fitted + tren)\n", + " if mse > component_mse + round_penalty * mse:\n", + " mse = component_mse\n", + " fitted += tren\n", + " self.ses_component += tren\n", + " self.trend += tren[-1]\n", + " if i == 0: #get deasonalized cov for some heuristic logic\n", + " if self.robust is None:\n", + " try:\n", + " if calc_cov(resids, multiplicative) > cov_threshold:\n", + " self.robust = True\n", + " else:\n", + " self.robust = False\n", + " except:\n", + " self.robust = True\n", + "\n", + " if i == 1:\n", + " resids = cap_outliers(resids, 5) #cap extreme outliers after initial rounds\n", + " if multiplicative:\n", + " fitted = np.exp(fitted)\n", + " else:\n", + " fitted = self.mean + (fitted * self.std)\n", + " self.multiplicative = multiplicative\n", + " return fitted\n", + "\n", + " def predict(self, forecast_horizon, X=None):\n", + " last_point = self.trend[1]\n", + " slope = last_point - self.trend[0]\n", + " if self.trend_penalty and self.penalty is not None:\n", + " slope = slope * max(0, self.penalty)\n", + " self.predicted_trend = slope * np.arange(1, forecast_horizon + 1) + last_point\n", + " if self.seasonality is not None:\n", + " predicted = self.predicted_trend + np.resize(self.seasonality, forecast_horizon)\n", + " else:\n", + " predicted = self.predicted_trend\n", + " if X is not None:\n", + " for model in self.exo_model:\n", + " predicted += model.predict(X) * self.exogenous_lr\n", + " if self.const is not None:\n", + " predicted = np.exp(predicted)\n", + " else:\n", + " predicted = self.mean + (predicted * self.std)\n", + " return predicted\n", + "\n", + " def optimize(self, y, test_size, n_steps, step_size=1, seasonal_period=None, metric='smape', X=None, params=None):\n", + " \"\"\"\n", + " Optimization method for MFLES\n", + "\n", + " Parameters\n", + " ----------\n", + " y : np.array\n", + " Your time series as a numpy array.\n", + " test_size : int\n", + " length of the test set to hold out to calculate test error.\n", + " n_steps : int\n", + " number of train and test sets to create.\n", + " step_size : 1, optional\n", + " how many periods to move after each step. The default is 1.\n", + " seasonal_period : int or list, optional\n", + " the seasonal period to calculate for. The default is None.\n", + " metric : TYPE, optional\n", + " supported metrics are smape, mape, mse, mae. The default is 'smape'.\n", + " params : dict, optional\n", + " A user provided dictionary of params to try. The default is None.\n", + "\n", + " Returns\n", + " -------\n", + " opt_param : TYPE\n", + " DESCRIPTION.\n", + "\n", + " \"\"\"\n", + " configs = default_configs(seasonal_period, params)\n", + " # the 4 here is because with less than 4 samples the model defaults to naive\n", + " max_steps = (len(y) - test_size - 4) // step_size + 1\n", + " if max_steps < 1:\n", + " if self.verbose:\n", + " print(\n", + " 'Series does not have enough samples for a single cross validation step '\n", + " f'({test_size + 4}). Choosing the first configuration.'\n", + " )\n", + " return configs[0]\n", + " if max_steps < n_steps:\n", + " n_steps = max_steps\n", + " if self.verbose:\n", + " print(f'Series length too small, setting n_steps to {n_steps}')\n", + "\n", + " self.metrics = []\n", + " for param in configs:\n", + " cv_results = cross_validation(y,\n", + " X,\n", + " test_size,\n", + " n_steps,\n", + " MFLES(verbose=self.verbose),\n", + " step_size=step_size,\n", + " metric=metric,\n", + " **param)\n", + " self.metrics.append(cv_results['metric'])\n", + " return configs[np.argmin(self.metrics)]\n", + "\n", + " def seasonal_decompose(self, y, **kwargs):\n", + " fitted = self.fit(y, **kwargs)\n", + " trend = self.linear_component\n", + " exogenous = self.median_component + self.exogenous_component\n", + " level = self.median_component + self.ses_component\n", + " seasonality = self.seasonal_component\n", + " if self.multiplicative:\n", + " trend = np.exp(trend)\n", + " level = np.exp(level)\n", + " exogenous = np.exp(exogenous) - np.exp(self.median_component)\n", + " if kwargs['seasonal_period'] is not None:\n", + " seasonality = np.exp(seasonality)\n", + " trend = trend * level\n", + " else:\n", + " trend = self.mean + (trend * self.std)\n", + " level = self.mean + (level * self.std)\n", + " exogenous = self.mean + (exogenous * self.std)\n", + " if kwargs['seasonal_period'] is not None:\n", + " seasonality = (seasonality * self.std)\n", + " trend = trend + level - self.mean\n", + " residuals = y - fitted\n", + " self.decomposition = {'y': y,\n", + " 'trend': trend,\n", + " 'seasonality': seasonality,\n", + " 'exogenous': exogenous,\n", + " 'residuals': residuals\n", + " }\n", + " return self.decomposition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34cfab63-842b-47b0-8a7b-942bc31923d9", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01c3cae8-f111-4163-812d-6d6e35782f38", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "url = \"https://raw.githubusercontent.com/tidyverts/tsibbledata/master/data-raw/vic_elec/VIC2015/demand.csv\"\n", + "df = pd.read_csv(url)\n", + "df[\"Date\"] = df[\"Date\"].apply(\n", + " lambda x: pd.Timestamp(\"1899-12-30\") + pd.Timedelta(x, unit=\"days\")\n", + ")\n", + "df[\"ds\"] = df[\"Date\"] + pd.to_timedelta((df[\"Period\"] - 1) * 30, unit=\"m\")\n", + "timeseries = df[[\"ds\", \"OperationalLessIndustrial\"]]\n", + "timeseries.columns = [\n", + " \"ds\",\n", + " \"y\",\n", + "] # Rename to OperationalLessIndustrial to y for simplicity.\n", + "\n", + "# Filter for first 149 days of 2012.\n", + "start_date = pd.to_datetime(\"2012-01-01\")\n", + "end_date = start_date + pd.Timedelta(\"149D\")\n", + "mask = (timeseries[\"ds\"] >= start_date) & (timeseries[\"ds\"] < end_date)\n", + "timeseries = timeseries[mask]\n", + "\n", + "# Resample to hourly\n", + "timeseries = timeseries.set_index(\"ds\").resample(\"H\").sum()\n", + "timeseries.head()\n", + "\n", + "# decomposition\n", + "mfles = MFLES()\n", + "fitted = mfles.fit(y=timeseries.y.values, seasonal_period=[24, 24 * 7])\n", + "predicted = mfles.predict(forecast_horizon=24)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "T62OIfir7bpn", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "mfles = MFLES()\n", + "opt_params = mfles.optimize(timeseries.y.values,\n", + " seasonal_period=[24, 24 * 7],\n", + " n_steps=3,\n", + " test_size=24,\n", + " step_size=24)\n", + "fitted = mfles.fit(y=timeseries.y.values, seasonal_period=[24, 24 * 7])\n", + "predicted = mfles.predict(forecast_horizon=24)\n", + "print(opt_params)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/settings.ini b/settings.ini index 75683f2f..83c1373e 100644 --- a/settings.ini +++ b/settings.ini @@ -8,20 +8,20 @@ author = Nixtla author_email = business@nixtla.io copyright = Nixtla Inc. branch = main -version = 1.7.4 +version = 1.7.5 min_python = 3.8 audience = Developers language = English custom_sidebar = True license = apache2 status = 2 -requirements = cloudpickle coreforecast>=0.0.7 numba>=0.55.0 numpy>=1.21.6 pandas>=1.3.5 scipy>=1.7.3 statsmodels>=0.13.2 tqdm fugue>=0.8.1 utilsforecast>=0.1.4 threadpoolctl +requirements = cloudpickle coreforecast>=0.0.9 numba>=0.55.0 numpy>=1.21.6 pandas>=1.3.5 scipy>=1.7.3 statsmodels>=0.13.2 tqdm fugue>=0.8.1 utilsforecast>=0.1.4 threadpoolctl polars_requirements = polars ray_requirements = fugue[ray]>=0.8.1 protobuf>=3.15.3,<4.0.0 ray<2.8 dask_requirements = fugue[dask]>=0.8.1 spark_requirements = fugue[spark]>=0.8.1 plotly_requirements = plotly plotly-resampler -dev_requirements = nbdev black mypy pandas[plot] pmdarima prophet pyarrow ruff scikit-learn datasetsforecast supersmoother nbdev_plotly pre-commit +dev_requirements = nbdev black mypy pandas[plot] pmdarima prophet pyarrow ruff scikit-learn setuptools<70 datasetsforecast supersmoother nbdev_plotly pre-commit nbs_path = nbs doc_path = _docs recursive = True diff --git a/statsforecast/__init__.py b/statsforecast/__init__.py index 36b48bec..4195e404 100644 --- a/statsforecast/__init__.py +++ b/statsforecast/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.7.4" +__version__ = "1.7.5" __all__ = ["StatsForecast"] from .core import StatsForecast from .distributed import fugue # noqa diff --git a/statsforecast/_modidx.py b/statsforecast/_modidx.py index 647f519d..1e378fe5 100644 --- a/statsforecast/_modidx.py +++ b/statsforecast/_modidx.py @@ -280,11 +280,52 @@ 'statsforecast.garch.garch_sigma2': ('src/garch.html#garch_sigma2', 'statsforecast/garch.py'), 'statsforecast.garch.generate_garch_data': ( 'src/garch.html#generate_garch_data', 'statsforecast/garch.py')}, + 'statsforecast.mfles': { 'statsforecast.mfles.MFLES': ('src/mfles.html#mfles', 'statsforecast/mfles.py'), + 'statsforecast.mfles.MFLES.__init__': ('src/mfles.html#mfles.__init__', 'statsforecast/mfles.py'), + 'statsforecast.mfles.MFLES.fit': ('src/mfles.html#mfles.fit', 'statsforecast/mfles.py'), + 'statsforecast.mfles.MFLES.optimize': ('src/mfles.html#mfles.optimize', 'statsforecast/mfles.py'), + 'statsforecast.mfles.MFLES.predict': ('src/mfles.html#mfles.predict', 'statsforecast/mfles.py'), + 'statsforecast.mfles.MFLES.seasonal_decompose': ( 'src/mfles.html#mfles.seasonal_decompose', + 'statsforecast/mfles.py'), + 'statsforecast.mfles.OLS': ('src/mfles.html#ols', 'statsforecast/mfles.py'), + 'statsforecast.mfles.OLS.fit': ('src/mfles.html#ols.fit', 'statsforecast/mfles.py'), + 'statsforecast.mfles.OLS.predict': ('src/mfles.html#ols.predict', 'statsforecast/mfles.py'), + 'statsforecast.mfles.Zeros': ('src/mfles.html#zeros', 'statsforecast/mfles.py'), + 'statsforecast.mfles.Zeros.predict': ('src/mfles.html#zeros.predict', 'statsforecast/mfles.py'), + 'statsforecast.mfles._ols': ('src/mfles.html#_ols', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_cov': ('src/mfles.html#calc_cov', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_mae': ('src/mfles.html#calc_mae', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_mape': ('src/mfles.html#calc_mape', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_mse': ('src/mfles.html#calc_mse', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_rsq': ('src/mfles.html#calc_rsq', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_seas_strength': ( 'src/mfles.html#calc_seas_strength', + 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_smape': ('src/mfles.html#calc_smape', 'statsforecast/mfles.py'), + 'statsforecast.mfles.calc_trend_strength': ( 'src/mfles.html#calc_trend_strength', + 'statsforecast/mfles.py'), + 'statsforecast.mfles.cap_outliers': ('src/mfles.html#cap_outliers', 'statsforecast/mfles.py'), + 'statsforecast.mfles.cross_validation': ('src/mfles.html#cross_validation', 'statsforecast/mfles.py'), + 'statsforecast.mfles.default_configs': ('src/mfles.html#default_configs', 'statsforecast/mfles.py'), + 'statsforecast.mfles.fast_ols': ('src/mfles.html#fast_ols', 'statsforecast/mfles.py'), + 'statsforecast.mfles.get_basis': ('src/mfles.html#get_basis', 'statsforecast/mfles.py'), + 'statsforecast.mfles.get_fourier_series': ( 'src/mfles.html#get_fourier_series', + 'statsforecast/mfles.py'), + 'statsforecast.mfles.get_future_basis': ('src/mfles.html#get_future_basis', 'statsforecast/mfles.py'), + 'statsforecast.mfles.get_seasonality_weights': ( 'src/mfles.html#get_seasonality_weights', + 'statsforecast/mfles.py'), + 'statsforecast.mfles.lasso_nb': ('src/mfles.html#lasso_nb', 'statsforecast/mfles.py'), + 'statsforecast.mfles.logic_check': ('src/mfles.html#logic_check', 'statsforecast/mfles.py'), + 'statsforecast.mfles.logic_layer': ('src/mfles.html#logic_layer', 'statsforecast/mfles.py'), + 'statsforecast.mfles.median': ('src/mfles.html#median', 'statsforecast/mfles.py'), + 'statsforecast.mfles.ols': ('src/mfles.html#ols', 'statsforecast/mfles.py'), + 'statsforecast.mfles.ses_ensemble': ('src/mfles.html#ses_ensemble', 'statsforecast/mfles.py'), + 'statsforecast.mfles.set_fourier': ('src/mfles.html#set_fourier', 'statsforecast/mfles.py'), + 'statsforecast.mfles.siegel_repeated_medians': ( 'src/mfles.html#siegel_repeated_medians', + 'statsforecast/mfles.py'), + 'statsforecast.mfles.wls': ('src/mfles.html#wls', 'statsforecast/mfles.py')}, 'statsforecast.models': { 'statsforecast.models.ADIDA': ('src/core/models.html#adida', 'statsforecast/models.py'), 'statsforecast.models.ADIDA.__init__': ( 'src/core/models.html#adida.__init__', 'statsforecast/models.py'), - 'statsforecast.models.ADIDA.__repr__': ( 'src/core/models.html#adida.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.ADIDA.fit': ('src/core/models.html#adida.fit', 'statsforecast/models.py'), 'statsforecast.models.ADIDA.forecast': ( 'src/core/models.html#adida.forecast', 'statsforecast/models.py'), @@ -295,13 +336,9 @@ 'statsforecast.models.ARCH': ('src/core/models.html#arch', 'statsforecast/models.py'), 'statsforecast.models.ARCH.__init__': ( 'src/core/models.html#arch.__init__', 'statsforecast/models.py'), - 'statsforecast.models.ARCH.__repr__': ( 'src/core/models.html#arch.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.ARIMA': ('src/core/models.html#arima', 'statsforecast/models.py'), 'statsforecast.models.ARIMA.__init__': ( 'src/core/models.html#arima.__init__', 'statsforecast/models.py'), - 'statsforecast.models.ARIMA.__repr__': ( 'src/core/models.html#arima.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.ARIMA.fit': ('src/core/models.html#arima.fit', 'statsforecast/models.py'), 'statsforecast.models.ARIMA.forecast': ( 'src/core/models.html#arima.forecast', 'statsforecast/models.py'), @@ -314,8 +351,6 @@ 'statsforecast.models.AutoARIMA': ('src/core/models.html#autoarima', 'statsforecast/models.py'), 'statsforecast.models.AutoARIMA.__init__': ( 'src/core/models.html#autoarima.__init__', 'statsforecast/models.py'), - 'statsforecast.models.AutoARIMA.__repr__': ( 'src/core/models.html#autoarima.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.AutoARIMA.fit': ( 'src/core/models.html#autoarima.fit', 'statsforecast/models.py'), 'statsforecast.models.AutoARIMA.forecast': ( 'src/core/models.html#autoarima.forecast', @@ -329,8 +364,6 @@ 'statsforecast.models.AutoCES': ('src/core/models.html#autoces', 'statsforecast/models.py'), 'statsforecast.models.AutoCES.__init__': ( 'src/core/models.html#autoces.__init__', 'statsforecast/models.py'), - 'statsforecast.models.AutoCES.__repr__': ( 'src/core/models.html#autoces.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.AutoCES.fit': ('src/core/models.html#autoces.fit', 'statsforecast/models.py'), 'statsforecast.models.AutoCES.forecast': ( 'src/core/models.html#autoces.forecast', 'statsforecast/models.py'), @@ -343,8 +376,6 @@ 'statsforecast.models.AutoETS': ('src/core/models.html#autoets', 'statsforecast/models.py'), 'statsforecast.models.AutoETS.__init__': ( 'src/core/models.html#autoets.__init__', 'statsforecast/models.py'), - 'statsforecast.models.AutoETS.__repr__': ( 'src/core/models.html#autoets.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.AutoETS.fit': ('src/core/models.html#autoets.fit', 'statsforecast/models.py'), 'statsforecast.models.AutoETS.forecast': ( 'src/core/models.html#autoets.forecast', 'statsforecast/models.py'), @@ -354,20 +385,29 @@ 'statsforecast/models.py'), 'statsforecast.models.AutoETS.predict_in_sample': ( 'src/core/models.html#autoets.predict_in_sample', 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES': ('src/core/models.html#automfles', 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES.__init__': ( 'src/core/models.html#automfles.__init__', + 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES._fit': ( 'src/core/models.html#automfles._fit', + 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES.fit': ( 'src/core/models.html#automfles.fit', + 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES.forecast': ( 'src/core/models.html#automfles.forecast', + 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES.predict': ( 'src/core/models.html#automfles.predict', + 'statsforecast/models.py'), + 'statsforecast.models.AutoMFLES.predict_in_sample': ( 'src/core/models.html#automfles.predict_in_sample', + 'statsforecast/models.py'), 'statsforecast.models.AutoRegressive': ( 'src/core/models.html#autoregressive', 'statsforecast/models.py'), 'statsforecast.models.AutoRegressive.__init__': ( 'src/core/models.html#autoregressive.__init__', 'statsforecast/models.py'), - 'statsforecast.models.AutoRegressive.__repr__': ( 'src/core/models.html#autoregressive.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.AutoTBATS': ('src/core/models.html#autotbats', 'statsforecast/models.py'), 'statsforecast.models.AutoTBATS.__init__': ( 'src/core/models.html#autotbats.__init__', 'statsforecast/models.py'), 'statsforecast.models.AutoTheta': ('src/core/models.html#autotheta', 'statsforecast/models.py'), 'statsforecast.models.AutoTheta.__init__': ( 'src/core/models.html#autotheta.__init__', 'statsforecast/models.py'), - 'statsforecast.models.AutoTheta.__repr__': ( 'src/core/models.html#autotheta.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.AutoTheta.fit': ( 'src/core/models.html#autotheta.fit', 'statsforecast/models.py'), 'statsforecast.models.AutoTheta.forecast': ( 'src/core/models.html#autotheta.forecast', @@ -382,8 +422,6 @@ 'statsforecast/models.py'), 'statsforecast.models.ConstantModel.__init__': ( 'src/core/models.html#constantmodel.__init__', 'statsforecast/models.py'), - 'statsforecast.models.ConstantModel.__repr__': ( 'src/core/models.html#constantmodel.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.ConstantModel.fit': ( 'src/core/models.html#constantmodel.fit', 'statsforecast/models.py'), 'statsforecast.models.ConstantModel.forecast': ( 'src/core/models.html#constantmodel.forecast', @@ -398,8 +436,6 @@ 'statsforecast/models.py'), 'statsforecast.models.CrostonClassic.__init__': ( 'src/core/models.html#crostonclassic.__init__', 'statsforecast/models.py'), - 'statsforecast.models.CrostonClassic.__repr__': ( 'src/core/models.html#crostonclassic.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.CrostonClassic.fit': ( 'src/core/models.html#crostonclassic.fit', 'statsforecast/models.py'), 'statsforecast.models.CrostonClassic.forecast': ( 'src/core/models.html#crostonclassic.forecast', @@ -412,8 +448,6 @@ 'statsforecast/models.py'), 'statsforecast.models.CrostonOptimized.__init__': ( 'src/core/models.html#crostonoptimized.__init__', 'statsforecast/models.py'), - 'statsforecast.models.CrostonOptimized.__repr__': ( 'src/core/models.html#crostonoptimized.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.CrostonOptimized.fit': ( 'src/core/models.html#crostonoptimized.fit', 'statsforecast/models.py'), 'statsforecast.models.CrostonOptimized.forecast': ( 'src/core/models.html#crostonoptimized.forecast', @@ -425,8 +459,6 @@ 'statsforecast.models.CrostonSBA': ('src/core/models.html#crostonsba', 'statsforecast/models.py'), 'statsforecast.models.CrostonSBA.__init__': ( 'src/core/models.html#crostonsba.__init__', 'statsforecast/models.py'), - 'statsforecast.models.CrostonSBA.__repr__': ( 'src/core/models.html#crostonsba.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.CrostonSBA.fit': ( 'src/core/models.html#crostonsba.fit', 'statsforecast/models.py'), 'statsforecast.models.CrostonSBA.forecast': ( 'src/core/models.html#crostonsba.forecast', @@ -444,13 +476,10 @@ 'statsforecast/models.py'), 'statsforecast.models.ETS': ('src/core/models.html#ets', 'statsforecast/models.py'), 'statsforecast.models.ETS.__init__': ('src/core/models.html#ets.__init__', 'statsforecast/models.py'), - 'statsforecast.models.ETS.__repr__': ('src/core/models.html#ets.__repr__', 'statsforecast/models.py'), 'statsforecast.models.ETS._warn': ('src/core/models.html#ets._warn', 'statsforecast/models.py'), 'statsforecast.models.GARCH': ('src/core/models.html#garch', 'statsforecast/models.py'), 'statsforecast.models.GARCH.__init__': ( 'src/core/models.html#garch.__init__', 'statsforecast/models.py'), - 'statsforecast.models.GARCH.__repr__': ( 'src/core/models.html#garch.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.GARCH.fit': ('src/core/models.html#garch.fit', 'statsforecast/models.py'), 'statsforecast.models.GARCH.forecast': ( 'src/core/models.html#garch.forecast', 'statsforecast/models.py'), @@ -462,8 +491,6 @@ 'statsforecast/models.py'), 'statsforecast.models.HistoricAverage.__init__': ( 'src/core/models.html#historicaverage.__init__', 'statsforecast/models.py'), - 'statsforecast.models.HistoricAverage.__repr__': ( 'src/core/models.html#historicaverage.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.HistoricAverage.fit': ( 'src/core/models.html#historicaverage.fit', 'statsforecast/models.py'), 'statsforecast.models.HistoricAverage.forecast': ( 'src/core/models.html#historicaverage.forecast', @@ -475,18 +502,12 @@ 'statsforecast.models.Holt': ('src/core/models.html#holt', 'statsforecast/models.py'), 'statsforecast.models.Holt.__init__': ( 'src/core/models.html#holt.__init__', 'statsforecast/models.py'), - 'statsforecast.models.Holt.__repr__': ( 'src/core/models.html#holt.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.HoltWinters': ('src/core/models.html#holtwinters', 'statsforecast/models.py'), 'statsforecast.models.HoltWinters.__init__': ( 'src/core/models.html#holtwinters.__init__', 'statsforecast/models.py'), - 'statsforecast.models.HoltWinters.__repr__': ( 'src/core/models.html#holtwinters.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.IMAPA': ('src/core/models.html#imapa', 'statsforecast/models.py'), 'statsforecast.models.IMAPA.__init__': ( 'src/core/models.html#imapa.__init__', 'statsforecast/models.py'), - 'statsforecast.models.IMAPA.__repr__': ( 'src/core/models.html#imapa.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.IMAPA.fit': ('src/core/models.html#imapa.fit', 'statsforecast/models.py'), 'statsforecast.models.IMAPA.forecast': ( 'src/core/models.html#imapa.forecast', 'statsforecast/models.py'), @@ -494,11 +515,20 @@ 'statsforecast/models.py'), 'statsforecast.models.IMAPA.predict_in_sample': ( 'src/core/models.html#imapa.predict_in_sample', 'statsforecast/models.py'), + 'statsforecast.models.MFLES': ('src/core/models.html#mfles', 'statsforecast/models.py'), + 'statsforecast.models.MFLES.__init__': ( 'src/core/models.html#mfles.__init__', + 'statsforecast/models.py'), + 'statsforecast.models.MFLES._fit': ('src/core/models.html#mfles._fit', 'statsforecast/models.py'), + 'statsforecast.models.MFLES.fit': ('src/core/models.html#mfles.fit', 'statsforecast/models.py'), + 'statsforecast.models.MFLES.forecast': ( 'src/core/models.html#mfles.forecast', + 'statsforecast/models.py'), + 'statsforecast.models.MFLES.predict': ( 'src/core/models.html#mfles.predict', + 'statsforecast/models.py'), + 'statsforecast.models.MFLES.predict_in_sample': ( 'src/core/models.html#mfles.predict_in_sample', + 'statsforecast/models.py'), 'statsforecast.models.MSTL': ('src/core/models.html#mstl', 'statsforecast/models.py'), 'statsforecast.models.MSTL.__init__': ( 'src/core/models.html#mstl.__init__', 'statsforecast/models.py'), - 'statsforecast.models.MSTL.__repr__': ( 'src/core/models.html#mstl.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.MSTL.fit': ('src/core/models.html#mstl.fit', 'statsforecast/models.py'), 'statsforecast.models.MSTL.forecast': ( 'src/core/models.html#mstl.forecast', 'statsforecast/models.py'), @@ -512,8 +542,6 @@ 'statsforecast.models.Naive': ('src/core/models.html#naive', 'statsforecast/models.py'), 'statsforecast.models.Naive.__init__': ( 'src/core/models.html#naive.__init__', 'statsforecast/models.py'), - 'statsforecast.models.Naive.__repr__': ( 'src/core/models.html#naive.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.Naive.fit': ('src/core/models.html#naive.fit', 'statsforecast/models.py'), 'statsforecast.models.Naive.forecast': ( 'src/core/models.html#naive.forecast', 'statsforecast/models.py'), @@ -531,8 +559,6 @@ 'statsforecast/models.py'), 'statsforecast.models.RandomWalkWithDrift.__init__': ( 'src/core/models.html#randomwalkwithdrift.__init__', 'statsforecast/models.py'), - 'statsforecast.models.RandomWalkWithDrift.__repr__': ( 'src/core/models.html#randomwalkwithdrift.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.RandomWalkWithDrift.fit': ( 'src/core/models.html#randomwalkwithdrift.fit', 'statsforecast/models.py'), 'statsforecast.models.RandomWalkWithDrift.forecast': ( 'src/core/models.html#randomwalkwithdrift.forecast', @@ -545,8 +571,6 @@ 'statsforecast/models.py'), 'statsforecast.models.SeasonalExponentialSmoothing.__init__': ( 'src/core/models.html#seasonalexponentialsmoothing.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SeasonalExponentialSmoothing.__repr__': ( 'src/core/models.html#seasonalexponentialsmoothing.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SeasonalExponentialSmoothing.fit': ( 'src/core/models.html#seasonalexponentialsmoothing.fit', 'statsforecast/models.py'), 'statsforecast.models.SeasonalExponentialSmoothing.forecast': ( 'src/core/models.html#seasonalexponentialsmoothing.forecast', @@ -559,8 +583,6 @@ 'statsforecast/models.py'), 'statsforecast.models.SeasonalExponentialSmoothingOptimized.__init__': ( 'src/core/models.html#seasonalexponentialsmoothingoptimized.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SeasonalExponentialSmoothingOptimized.__repr__': ( 'src/core/models.html#seasonalexponentialsmoothingoptimized.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SeasonalExponentialSmoothingOptimized.fit': ( 'src/core/models.html#seasonalexponentialsmoothingoptimized.fit', 'statsforecast/models.py'), 'statsforecast.models.SeasonalExponentialSmoothingOptimized.forecast': ( 'src/core/models.html#seasonalexponentialsmoothingoptimized.forecast', @@ -573,8 +595,6 @@ 'statsforecast/models.py'), 'statsforecast.models.SeasonalNaive.__init__': ( 'src/core/models.html#seasonalnaive.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SeasonalNaive.__repr__': ( 'src/core/models.html#seasonalnaive.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SeasonalNaive.fit': ( 'src/core/models.html#seasonalnaive.fit', 'statsforecast/models.py'), 'statsforecast.models.SeasonalNaive.forecast': ( 'src/core/models.html#seasonalnaive.forecast', @@ -587,8 +607,6 @@ 'statsforecast/models.py'), 'statsforecast.models.SeasonalWindowAverage.__init__': ( 'src/core/models.html#seasonalwindowaverage.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SeasonalWindowAverage.__repr__': ( 'src/core/models.html#seasonalwindowaverage.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SeasonalWindowAverage.fit': ( 'src/core/models.html#seasonalwindowaverage.fit', 'statsforecast/models.py'), 'statsforecast.models.SeasonalWindowAverage.forecast': ( 'src/core/models.html#seasonalwindowaverage.forecast', @@ -601,8 +619,6 @@ 'statsforecast/models.py'), 'statsforecast.models.SimpleExponentialSmoothing.__init__': ( 'src/core/models.html#simpleexponentialsmoothing.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SimpleExponentialSmoothing.__repr__': ( 'src/core/models.html#simpleexponentialsmoothing.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SimpleExponentialSmoothing.fit': ( 'src/core/models.html#simpleexponentialsmoothing.fit', 'statsforecast/models.py'), 'statsforecast.models.SimpleExponentialSmoothing.forecast': ( 'src/core/models.html#simpleexponentialsmoothing.forecast', @@ -615,8 +631,6 @@ 'statsforecast/models.py'), 'statsforecast.models.SimpleExponentialSmoothingOptimized.__init__': ( 'src/core/models.html#simpleexponentialsmoothingoptimized.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SimpleExponentialSmoothingOptimized.__repr__': ( 'src/core/models.html#simpleexponentialsmoothingoptimized.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SimpleExponentialSmoothingOptimized.fit': ( 'src/core/models.html#simpleexponentialsmoothingoptimized.fit', 'statsforecast/models.py'), 'statsforecast.models.SimpleExponentialSmoothingOptimized.forecast': ( 'src/core/models.html#simpleexponentialsmoothingoptimized.forecast', @@ -628,8 +642,6 @@ 'statsforecast.models.SklearnModel': ('src/core/models.html#sklearnmodel', 'statsforecast/models.py'), 'statsforecast.models.SklearnModel.__init__': ( 'src/core/models.html#sklearnmodel.__init__', 'statsforecast/models.py'), - 'statsforecast.models.SklearnModel.__repr__': ( 'src/core/models.html#sklearnmodel.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.SklearnModel.fit': ( 'src/core/models.html#sklearnmodel.fit', 'statsforecast/models.py'), 'statsforecast.models.SklearnModel.forecast': ( 'src/core/models.html#sklearnmodel.forecast', @@ -643,8 +655,6 @@ 'statsforecast.models.TBATS': ('src/core/models.html#tbats', 'statsforecast/models.py'), 'statsforecast.models.TBATS.__init__': ( 'src/core/models.html#tbats.__init__', 'statsforecast/models.py'), - 'statsforecast.models.TBATS.__repr__': ( 'src/core/models.html#tbats.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.TBATS.fit': ('src/core/models.html#tbats.fit', 'statsforecast/models.py'), 'statsforecast.models.TBATS.forecast': ( 'src/core/models.html#tbats.forecast', 'statsforecast/models.py'), @@ -654,7 +664,6 @@ 'statsforecast/models.py'), 'statsforecast.models.TSB': ('src/core/models.html#tsb', 'statsforecast/models.py'), 'statsforecast.models.TSB.__init__': ('src/core/models.html#tsb.__init__', 'statsforecast/models.py'), - 'statsforecast.models.TSB.__repr__': ('src/core/models.html#tsb.__repr__', 'statsforecast/models.py'), 'statsforecast.models.TSB.fit': ('src/core/models.html#tsb.fit', 'statsforecast/models.py'), 'statsforecast.models.TSB.forecast': ('src/core/models.html#tsb.forecast', 'statsforecast/models.py'), 'statsforecast.models.TSB.predict': ('src/core/models.html#tsb.predict', 'statsforecast/models.py'), @@ -667,8 +676,6 @@ 'statsforecast/models.py'), 'statsforecast.models.WindowAverage.__init__': ( 'src/core/models.html#windowaverage.__init__', 'statsforecast/models.py'), - 'statsforecast.models.WindowAverage.__repr__': ( 'src/core/models.html#windowaverage.__repr__', - 'statsforecast/models.py'), 'statsforecast.models.WindowAverage.fit': ( 'src/core/models.html#windowaverage.fit', 'statsforecast/models.py'), 'statsforecast.models.WindowAverage.forecast': ( 'src/core/models.html#windowaverage.forecast', @@ -681,6 +688,7 @@ 'statsforecast.models.ZeroModel.__init__': ( 'src/core/models.html#zeromodel.__init__', 'statsforecast/models.py'), 'statsforecast.models._TS': ('src/core/models.html#_ts', 'statsforecast/models.py'), + 'statsforecast.models._TS.__repr__': ('src/core/models.html#_ts.__repr__', 'statsforecast/models.py'), 'statsforecast.models._TS._add_conformal_intervals': ( 'src/core/models.html#_ts._add_conformal_intervals', 'statsforecast/models.py'), 'statsforecast.models._TS._add_predict_conformal_intervals': ( 'src/core/models.html#_ts._add_predict_conformal_intervals', diff --git a/statsforecast/mfles.py b/statsforecast/mfles.py new file mode 100644 index 00000000..f8a88547 --- /dev/null +++ b/statsforecast/mfles.py @@ -0,0 +1,764 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/src/mfles.ipynb. + +# %% auto 0 +__all__ = ['MFLES'] + +# %% ../nbs/src/mfles.ipynb 3 +import itertools +import warnings + +import numpy as np +from coreforecast.exponentially_weighted import exponentially_weighted_mean +from coreforecast.rolling import rolling_mean +from numba import njit + +from .utils import _ensure_float + +# %% ../nbs/src/mfles.ipynb 4 +# utility functions +def calc_mse(y_true, y_pred): + sq_err = (y_true - y_pred) ** 2 + return np.mean(sq_err) + + +def calc_mae(y_true, y_pred): + abs_err = np.abs(y_true - y_pred) + return np.mean(abs_err) + + +def calc_mape(y_true, y_pred): + pct_err = np.abs((y_true - y_pred) / (y_pred + 1e-6)) + return np.mean(pct_err) + + +def calc_smape(y_true, y_pred): + pct_err = 2 * np.abs(y_true - y_pred) / np.abs(y_true + y_pred + 1e-6) + return np.mean(pct_err) + + +_metric2fn = { + "mse": calc_mse, + "mae": calc_mae, + "mape": calc_mape, + "smape": calc_smape, +} + + +def cross_validation( + y, X, test_size, n_splits, model_obj, metric, step_size=1, **kwargs +): + metrics = [] + metric_fn = _metric2fn[metric] + residuals = [] + if X is None: + exogenous = None + else: + exogenous = X.copy() + for split in range(n_splits): + train_y = y[: -(split * step_size + test_size)] + test_y = y[len(train_y) : len(train_y) + test_size] + if exogenous is not None: + train_X = exogenous[: -(split * step_size + test_size), :] + test_X = exogenous[len(train_y) : len(train_y) + test_size, :] + else: + train_X = None + test_X = None + model_obj.fit(train_y, X=train_X, **kwargs) + prediction = model_obj.predict(test_size, X=test_X) + metrics.append(metric_fn(test_y, prediction)) + residuals.append(test_y - prediction) + return {"metric": np.mean(metrics), "residuals": residuals} + + +def logic_check(keys_to_check, keys): + return set(keys_to_check).issubset(keys) + + +def logic_layer(param_dict): + keys = param_dict.keys() + # if param_dict['n_changepoints'] is None: + # if param_dict['decay'] != -1: + # return False + if logic_check(["seasonal_period", "max_rounds"], keys): + if param_dict["seasonal_period"] is None: + if param_dict["max_rounds"] < 4: + return False + if logic_check(["smoother", "ma"], keys): + if param_dict["smoother"]: + if param_dict["ma"] is not None: + return False + if logic_check(["seasonal_period", "seasonality_weights"], keys): + if param_dict["seasonality_weights"]: + if param_dict["seasonal_period"] is None: + return False + return True + + +def default_configs(seasonal_period, configs=None): + if configs is None: + if seasonal_period is not None: + if not isinstance(seasonal_period, list): + seasonal_period = [seasonal_period] + configs = { + "seasonality_weights": [True, False], + "smoother": [True, False], + "ma": [int(min(seasonal_period)), int(min(seasonal_period) / 2), None], + "seasonal_period": [None, seasonal_period], + } + else: + configs = { + "smoother": [True, False], + "cov_threshold": [0.5, -1], + "max_rounds": [5, 20], + "seasonal_period": [None], + } + keys = configs.keys() + combinations = itertools.product(*configs.values()) + ds = [dict(zip(keys, cc)) for cc in combinations] + ds = [i for i in ds if logic_layer(i)] + return ds + + +def cap_outliers(series, outlier_cap=3): + mean = np.mean(series) + std = np.std(series) + return series.clip(min=mean - outlier_cap * std, max=mean + outlier_cap * std) + + +def set_fourier(period): + if period < 10: + fourier = 5 + elif period < 70: + fourier = 10 + else: + fourier = 15 + return fourier + + +def calc_trend_strength(resids, deseasonalized): + return max(0, 1 - (np.var(resids) / np.var(deseasonalized))) + + +def calc_seas_strength(resids, detrended): + return max(0, 1 - (np.var(resids) / np.var(detrended))) + + +def calc_rsq(y, fitted): + try: + mean_y = np.mean(y) + ssres = np.sum((y - fitted) ** 2) + sstot = np.sum((y - mean_y) ** 2) + return 1 - (ssres / sstot) + except: + return 0 + + +def calc_cov(y, mult=1): + if mult: + # source http://medcraveonline.com/MOJPB/MOJPB-06-00200.pdf + res = np.sqrt(np.exp(np.log(10) * (np.std(y) ** 2) - 1)) + else: + res = np.std(y) + mean = np.mean(y) + if mean != 0: + res = res / mean + return res + + +def get_seasonality_weights(y, seasonal_period): + return 1 + np.arange(y.size) // seasonal_period + + +# feature engineering functions +def get_fourier_series(length, seasonal_period, fourier_order): + x = 2 * np.pi * np.arange(1, fourier_order + 1) / seasonal_period + t = np.arange(1, length + 1).reshape(-1, 1) + x = x * t + return np.hstack([np.cos(x), np.sin(x)]) + + +@njit +def get_basis(y, n_changepoints, decay=-1, gradient_strategy=0): + if n_changepoints < 1: + return np.arange(y.size, dtype=np.float64).reshape(-1, 1) + y = y.copy() + y -= y[0] + n = len(y) + if gradient_strategy: + gradients = np.abs(y[:-1] - y[1:]) + initial_point = y[0] + final_point = y[-1] + mean_y = np.mean(y) + changepoints = np.empty(shape=(len(y), n_changepoints + 1)) + array_splits = [] + for i in range(1, n_changepoints + 1): + i = n_changepoints - i + 1 + if gradient_strategy: + cps = np.argsort(-gradients) + cps = cps[cps > 0.1 * len(gradients)] + cps = cps[cps < 0.9 * len(gradients)] + split_point = cps[i - 1] + array_splits.append(y[:split_point]) + else: + split_point = len(y) // i + array_splits.append(y[:split_point]) + y = y[split_point:] + len_splits = 0 + for i in range(n_changepoints): + if gradient_strategy: + len_splits = len(array_splits[i]) + else: + len_splits += len(array_splits[i]) + moving_point = array_splits[i][-1] + left_basis = np.linspace(initial_point, moving_point, len_splits) + if decay is None: + end_point = final_point + else: + if decay == -1: + dd = moving_point**2 + if mean_y != 0: + dd /= mean_y**2 + if dd > 0.99: + dd = 0.99 + if dd < 0.001: + dd = 0.001 + end_point = moving_point - ((moving_point - final_point) * (1 - dd)) + else: + end_point = moving_point - ((moving_point - final_point) * (1 - decay)) + right_basis = np.linspace(moving_point, end_point, n - len_splits + 1) + changepoints[:, i] = np.append(left_basis, right_basis[1:]) + changepoints[:, i + 1] = np.ones(n) + return changepoints + + +def get_future_basis(basis_functions, forecast_horizon): + n_components = np.shape(basis_functions)[1] + slopes = np.gradient(basis_functions)[0][-1, :] + future_basis = np.arange(0, forecast_horizon + 1) + future_basis += len(basis_functions) + future_basis = np.transpose([future_basis] * n_components) + future_basis = future_basis * slopes + future_basis = future_basis + (basis_functions[-1, :] - future_basis[0, :]) + return future_basis[1:, :] + + +def lasso_nb(X, y, alpha, tol=0.001, maxiter=10000): + from sklearn.linear_model import Lasso + from sklearn.exceptions import ConvergenceWarning + + with warnings.catch_warnings(record=False): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + lasso = Lasso(alpha=alpha, fit_intercept=False, tol=tol, max_iter=maxiter) + lasso.fit(X, y) + return lasso.coef_ + + +# different models +@njit +def siegel_repeated_medians(x, y): + # Siegel repeated medians regression + n = y.size + slopes = np.empty_like(y) + slopes_sub = np.empty(shape=n - 1, dtype=y.dtype) + for i in range(n): + k = 0 + for j in range(n): + if i == j: + continue + xd = x[j] - x[i] + if xd == 0: + slope = 0 + else: + slope = (y[j] - y[i]) / xd + slopes_sub[k] = slope + k += 1 + slopes[i] = np.median(slopes_sub) + ints = y - slopes * x + return x * np.median(slopes) + np.median(ints) + + +def ses_ensemble(y, min_alpha=0.05, max_alpha=1.0, smooth=False, order=1): + # bad name but does either a ses ensemble or simple moving average + if smooth: + results = np.zeros_like(y) + alphas = np.arange(min_alpha, max_alpha, 0.05) + for alpha in alphas: + results += exponentially_weighted_mean(y, alpha) + results = results / len(alphas) + else: + results = rolling_mean(y, order + 1) + results[: order + 1] = y[: order + 1] + return results + + +def fast_ols(x, y): + """Simple OLS for two data sets.""" + M = x.size + x_sum = x.sum() + y_sum = y.sum() + x_sq_sum = x @ x + x_y_sum = x @ y + slope = (M * x_y_sum - x_sum * y_sum) / (M * x_sq_sum - x_sum**2) + intercept = (y_sum - slope * x_sum) / M + return slope * x + intercept + + +def median(y, seasonal_period): + if seasonal_period is None: + return np.full_like(y, np.median(y)) + full_periods, resid = divmod(len(y), seasonal_period) + period_medians = np.median( + y[: full_periods * seasonal_period].reshape(full_periods, seasonal_period), + axis=1, + ) + medians = np.repeat(period_medians, seasonal_period) + if resid: + remainder_median = np.median(y[-seasonal_period:]) + medians = np.append(medians, np.repeat(remainder_median, resid)) + return medians + + +def ols(X, y): + coefs = np.linalg.pinv(X.T.dot(X)).dot(X.T.dot(y)) + return X @ coefs + + +def wls(X, y, weights): + weighted_X_T = X.T @ np.diag(weights) + coefs = np.linalg.pinv(weighted_X_T.dot(X)).dot(weighted_X_T.dot(y)) + return X @ coefs + + +def _ols(X, y): + return np.linalg.pinv(X.T.dot(X)).dot(X.T.dot(y)) + + +class OLS: + def fit(self, X, y): + self.coefs = _ols(X, y) + + def predict(self, X): + return X @ self.coefs + + +class Zeros: + def predict(self, X): + return np.zeros(X.shape[0]) + +# %% ../nbs/src/mfles.ipynb 5 +class MFLES: + def __init__(self, verbose=1, robust=None): + self.penalty = None + self.trend = None + self.seasonality = None + self.robust = robust + self.const = None + self.aic = None + self.upper = None + self.lower = None + self.exogenous_models = None + self.verbose = verbose + self.predicted = None + + def fit( + self, + y, + seasonal_period=None, + X=None, + fourier_order=None, + ma=None, + alpha=1.0, + decay=-1, + n_changepoints=0.25, + seasonal_lr=0.9, + rs_lr=1, + exogenous_lr=1, + exogenous_estimator=OLS, + exogenous_params={}, + linear_lr=0.9, + cov_threshold=0.7, + moving_medians=False, + max_rounds=50, + min_alpha=0.05, + max_alpha=1.0, + round_penalty=0.0001, + trend_penalty=True, + multiplicative=None, + changepoints=True, + smoother=False, + seasonality_weights=False, + ): + """ + + + Parameters + ---------- + y : np.array + the time series as a numpy array. + seasonal_period : int, optional + DESCRIPTION. The default is None. + fourier_order : int, optional + How many fourier sin/cos pairs to create, the larger the number the more complex of a seasonal pattern can be fitted. A lower number leads to smoother results. This is auto-set based on seasonal_period. The default is None. + ma : int, optional + The moving average order to use, this is auto-set based on internal logic. Passing 4 would fit a 4 period moving average on the residual component. The default is None. + alpha : TYPE, optional + The alpha which is used in fitting the underlying LASSO when using piecewise functions. The default is 1.0. + decay : float, optional + Effects the slopes of the piecewise-linear basis function. The default is -1. + n_changepoints : float, optional + The number of changepoint knots to place, a default of .25 with place .25 * series length number of knots. The default is .25. + seasonal_lr : float, optional + A shrinkage parameter (0 the more smooth your fit. The default is 10. + min_alpha : float, optional + The min alpha in the SES ensemble. The default is .05. + max_alpha : float, optional + The max alpha used in the SES ensemble. The default is 1.0. + trend_penalty : boolean, optional + Whether to apply a simple penalty to the lienar trend component, very useful for dealing with the potentially dangerous piecewise trend. The default is True. + multiplicative : boolean, optional + Auto-set based on internal logic, but if given True it will simply take the log of the time series. The default is None. + changepoints : boolean, optional + Whether to fit for changepoints if all other logic allows for it, by setting False then MFLES will not ever fit a piecewise trend. The default is True. + smoother : boolean, optional + If True then a simple exponential ensemble will be used rather than auto settings. The default is False. + + Returns + ------- + None. + + """ + if cov_threshold == -1: + cov_threshold = 10000 + n = len(y) + y = _ensure_float(y) + self.exogenous_lr = exogenous_lr + if multiplicative is None: + if seasonal_period is None: + multiplicative = False + else: + multiplicative = True + if y.min() <= 0: + multiplicative = False + if multiplicative: + self.const = y.min() + y = np.log(y) + else: + self.const = None + self.std = np.std(y) + self.mean = np.mean(y) + y = y - self.mean + if self.std > 0: + y = y / self.std + if seasonal_period is not None: + if not isinstance(seasonal_period, list): + seasonal_period = [seasonal_period] + if n < 4 or np.all(y == np.mean(y)): + if self.verbose: + if n < 4: + print("series is too short (<4), defaulting to naive") + else: + print(f"input is constant with value {y[0]}, defaulting to naive") + self.trend = np.append(y[-1], y[-1]) + self.seasonality = np.zeros(len(y)) + self.trend_penalty = False + self.mean = y[-1] + self.std = 0 + self.exo_model = [Zeros()] + return np.tile(y[-1], len(y)) + og_y = y + self.og_y = og_y + y = y.copy() + if n_changepoints is None: + changepoints = False + if isinstance(n_changepoints, float) and n_changepoints < 1: + n_changepoints = int(n_changepoints * n) + self.linear_component = np.zeros(n) + self.seasonal_component = np.zeros(n) + self.ses_component = np.zeros(n) + self.median_component = np.zeros(n) + self.exogenous_component = np.zeros(n) + self.exo_model = [] + self.round_cost = [] + self.trend_penalty = trend_penalty + if moving_medians and seasonal_period is not None: + fitted = median(y, max(seasonal_period)) + else: + fitted = median(y, None) + self.median_component += fitted + self.trend = np.append(fitted.copy()[-1:], fitted.copy()[-1:]) + mse = None + equal = 0 + if ma is None: + ma_cycle = itertools.cycle([1]) + else: + if not isinstance(ma, list): + ma = [ma] + ma_cycle = itertools.cycle(ma) + if seasonal_period is not None: + seasons_cycle = itertools.cycle(list(range(len(seasonal_period)))) + self.seasonality = np.zeros(max(seasonal_period)) + fourier_series = [] + for period in seasonal_period: + if fourier_order is None: + fourier = set_fourier(period) + else: + fourier = fourier_order + fourier_series.append(get_fourier_series(n, period, fourier)) + if seasonality_weights: + cycle_weights = [] + for period in seasonal_period: + cycle_weights.append(get_seasonality_weights(y, period)) + else: + self.seasonality = None + for i in range(max_rounds): + resids = y - fitted + if mse is None: + mse = calc_mse(y, fitted) + else: + if mse <= calc_mse(y, fitted): + if equal == 6: + break + equal += 1 + else: + mse = calc_mse(y, fitted) + self.round_cost.append(mse) + if seasonal_period is not None: + seasonal_period_cycle = next(seasons_cycle) + if seasonality_weights: + seas = wls( + fourier_series[seasonal_period_cycle], + resids, + cycle_weights[seasonal_period_cycle], + ) + else: + seas = ols(fourier_series[seasonal_period_cycle], resids) + seas = seas * seasonal_lr + component_mse = calc_mse(y, fitted + seas) + if mse > component_mse: + mse = component_mse + fitted += seas + resids = y - fitted + self.seasonality += np.resize( + seas[-seasonal_period[seasonal_period_cycle] :], + len(self.seasonality), + ) + self.seasonal_component += seas + if X is not None and i > 0: + model_obj = exogenous_estimator(**exogenous_params) + model_obj.fit(X, resids) + self.exo_model.append(model_obj) + _fitted_values = model_obj.predict(X) * exogenous_lr + self.exogenous_component += _fitted_values + fitted += _fitted_values + resids = y - fitted + if ( + i % 2 + ): # if even get linear piece, allows for multiple seasonality fitting a bit more + if self.robust: + tren = siegel_repeated_medians( + x=np.arange(n, dtype=resids.dtype), y=resids + ) + else: + if i == 1 or not changepoints: + tren = fast_ols(x=np.arange(n), y=resids) + else: + cps = min(n_changepoints, int(0.1 * n)) + lbf = get_basis(y=resids, n_changepoints=cps, decay=decay) + tren = np.dot(lbf, lasso_nb(lbf, resids, alpha=alpha)) + tren = tren * linear_lr + component_mse = calc_mse(y, fitted + tren) + if mse > component_mse: + mse = component_mse + fitted += tren + self.linear_component += tren + self.trend += tren[-2:] + if i == 1: + self.penalty = calc_rsq(resids, tren) + elif i > 4 and not i % 2: + if smoother is None: + if seasonal_period is not None: + len_check = int(max(seasonal_period)) + else: + len_check = 12 + if resids[-1] > np.mean(resids[-len_check:-1]) + 3 * np.std( + resids[-len_check:-1] + ): + smoother = 0 + if resids[-1] < np.mean(resids[-len_check:-1]) - 3 * np.std( + resids[-len_check:-1] + ): + smoother = 0 + if resids[-2] > np.mean(resids[-len_check:-2]) + 3 * np.std( + resids[-len_check:-2] + ): + smoother = 0 + if resids[-2] < np.mean(resids[-len_check:-2]) - 3 * np.std( + resids[-len_check:-2] + ): + smoother = 0 + if smoother is None: + smoother = 1 + else: + resids[-2:] = cap_outliers(resids, 3)[-2:] + tren = ses_ensemble( + resids, + min_alpha=min_alpha, + max_alpha=max_alpha, + smooth=smoother * 1, + order=next(ma_cycle), + ) + tren = tren * rs_lr + component_mse = calc_mse(y, fitted + tren) + if mse > component_mse + round_penalty * mse: + mse = component_mse + fitted += tren + self.ses_component += tren + self.trend += tren[-1] + if i == 0: # get deasonalized cov for some heuristic logic + if self.robust is None: + try: + if calc_cov(resids, multiplicative) > cov_threshold: + self.robust = True + else: + self.robust = False + except: + self.robust = True + + if i == 1: + resids = cap_outliers( + resids, 5 + ) # cap extreme outliers after initial rounds + if multiplicative: + fitted = np.exp(fitted) + else: + fitted = self.mean + (fitted * self.std) + self.multiplicative = multiplicative + return fitted + + def predict(self, forecast_horizon, X=None): + last_point = self.trend[1] + slope = last_point - self.trend[0] + if self.trend_penalty and self.penalty is not None: + slope = slope * max(0, self.penalty) + self.predicted_trend = slope * np.arange(1, forecast_horizon + 1) + last_point + if self.seasonality is not None: + predicted = self.predicted_trend + np.resize( + self.seasonality, forecast_horizon + ) + else: + predicted = self.predicted_trend + if X is not None: + for model in self.exo_model: + predicted += model.predict(X) * self.exogenous_lr + if self.const is not None: + predicted = np.exp(predicted) + else: + predicted = self.mean + (predicted * self.std) + return predicted + + def optimize( + self, + y, + test_size, + n_steps, + step_size=1, + seasonal_period=None, + metric="smape", + X=None, + params=None, + ): + """ + Optimization method for MFLES + + Parameters + ---------- + y : np.array + Your time series as a numpy array. + test_size : int + length of the test set to hold out to calculate test error. + n_steps : int + number of train and test sets to create. + step_size : 1, optional + how many periods to move after each step. The default is 1. + seasonal_period : int or list, optional + the seasonal period to calculate for. The default is None. + metric : TYPE, optional + supported metrics are smape, mape, mse, mae. The default is 'smape'. + params : dict, optional + A user provided dictionary of params to try. The default is None. + + Returns + ------- + opt_param : TYPE + DESCRIPTION. + + """ + configs = default_configs(seasonal_period, params) + # the 4 here is because with less than 4 samples the model defaults to naive + max_steps = (len(y) - test_size - 4) // step_size + 1 + if max_steps < 1: + if self.verbose: + print( + "Series does not have enough samples for a single cross validation step " + f"({test_size + 4}). Choosing the first configuration." + ) + return configs[0] + if max_steps < n_steps: + n_steps = max_steps + if self.verbose: + print(f"Series length too small, setting n_steps to {n_steps}") + + self.metrics = [] + for param in configs: + cv_results = cross_validation( + y, + X, + test_size, + n_steps, + MFLES(verbose=self.verbose), + step_size=step_size, + metric=metric, + **param, + ) + self.metrics.append(cv_results["metric"]) + return configs[np.argmin(self.metrics)] + + def seasonal_decompose(self, y, **kwargs): + fitted = self.fit(y, **kwargs) + trend = self.linear_component + exogenous = self.median_component + self.exogenous_component + level = self.median_component + self.ses_component + seasonality = self.seasonal_component + if self.multiplicative: + trend = np.exp(trend) + level = np.exp(level) + exogenous = np.exp(exogenous) - np.exp(self.median_component) + if kwargs["seasonal_period"] is not None: + seasonality = np.exp(seasonality) + trend = trend * level + else: + trend = self.mean + (trend * self.std) + level = self.mean + (level * self.std) + exogenous = self.mean + (exogenous * self.std) + if kwargs["seasonal_period"] is not None: + seasonality = seasonality * self.std + trend = trend + level - self.mean + residuals = y - fitted + self.decomposition = { + "y": y, + "trend": trend, + "seasonality": seasonality, + "exogenous": exogenous, + "residuals": residuals, + } + return self.decomposition diff --git a/statsforecast/models.py b/statsforecast/models.py index 58bb72c5..41602a2a 100644 --- a/statsforecast/models.py +++ b/statsforecast/models.py @@ -6,8 +6,8 @@ 'SeasonalExponentialSmoothingOptimized', 'Holt', 'HoltWinters', 'HistoricAverage', 'Naive', 'RandomWalkWithDrift', 'SeasonalNaive', 'WindowAverage', 'SeasonalWindowAverage', 'ADIDA', 'CrostonClassic', 'CrostonOptimized', 'CrostonSBA', 'IMAPA', 'TSB', 'MSTL', 'TBATS', 'AutoTBATS', 'Theta', 'OptimizedTheta', - 'DynamicTheta', 'DynamicOptimizedTheta', 'GARCH', 'ARCH', 'SklearnModel', 'ConstantModel', 'ZeroModel', - 'NaNModel'] + 'DynamicTheta', 'DynamicOptimizedTheta', 'GARCH', 'ARCH', 'SklearnModel', 'MFLES', 'AutoMFLES', + 'ConstantModel', 'ZeroModel', 'NaNModel'] # %% ../nbs/src/core/models.ipynb 5 import warnings @@ -35,6 +35,7 @@ forecast_ets, forward_ets, ) +from .mfles import MFLES as _MFLES from .mstl import mstl from .theta import auto_theta, forecast_theta, forward_theta from .garch import garch_model, garch_forecast @@ -118,6 +119,9 @@ def new(self): b.__dict__.update(self.__dict__) return b + def __repr__(self): + return self.alias + def _conformity_scores( self, y: np.ndarray, @@ -326,9 +330,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -660,9 +661,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -889,9 +887,6 @@ def __init__( prediction_intervals=prediction_intervals, ) - def __repr__(self): - return self.alias - # %% ../nbs/src/core/models.ipynb 53 class AutoCES(_TS): """Complex Exponential Smoothing model. @@ -938,9 +933,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -1189,9 +1181,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -1447,9 +1436,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -1742,9 +1728,6 @@ def __init__( prediction_intervals=prediction_intervals, ) - def __repr__(self): - return self.alias - # %% ../nbs/src/core/models.ipynb 117 @njit(nogil=NOGIL, cache=CACHE) def _ses_fcst_mse(x: np.ndarray, alpha: float) -> Tuple[float, float, np.ndarray]: @@ -1866,9 +1849,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -2037,9 +2017,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -2231,9 +2208,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -2430,9 +2404,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -2604,9 +2575,6 @@ def __init__( season_length, model, alias=alias, prediction_intervals=prediction_intervals ) - def __repr__(self): - return self.alias - # %% ../nbs/src/core/models.ipynb 188 class HoltWinters(AutoETS): """Holt-Winters' method. @@ -2647,9 +2615,6 @@ def __init__( season_length, model, alias=alias, prediction_intervals=prediction_intervals ) - def __repr__(self): - return self.alias - # %% ../nbs/src/core/models.ipynb 203 def _historic_average( y: np.ndarray, # time series @@ -2692,9 +2657,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -2870,9 +2832,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -3104,9 +3063,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -3281,9 +3237,6 @@ def __init__( self.alias = alias self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -3488,9 +3441,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -3665,9 +3615,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -3921,9 +3868,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -4118,9 +4062,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -4325,9 +4266,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -4500,9 +4438,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -4697,9 +4632,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -4907,9 +4839,6 @@ def __init__( self.prediction_intervals = prediction_intervals self.only_conformal_intervals = True - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -5116,9 +5045,6 @@ def __init__( self.trend_forecaster.prediction_intervals = prediction_intervals self.stl_kwargs = dict() if stl_kwargs is None else stl_kwargs - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -5388,9 +5314,6 @@ def __init__( self.use_arma_errors = use_arma_errors self.alias = alias - def __repr__(self): - return self.alias - def fit(self, y: np.ndarray, X: Optional[np.ndarray] = None): """Fit TBATS model. @@ -5809,9 +5732,6 @@ def __init__( self.alias = alias + "(" + str(p) + ")" self.prediction_intervals = prediction_intervals - def __repr__(self): - return self.alias - def fit(self, y: np.ndarray, X: Optional[np.ndarray] = None): """Fit GARCH model. @@ -5990,9 +5910,6 @@ def __init__( self.alias = alias super().__init__(p, q=0, alias=alias) - def __repr__(self): - return self.alias - # %% ../nbs/src/core/models.ipynb 479 class SklearnModel(_TS): """scikit-learn model wrapper @@ -6016,12 +5933,7 @@ def __init__( ): self.model = model self.prediction_intervals = prediction_intervals - self.alias = alias - - def __repr__(self): - if self.alias is not None: - return self.alias - return self.model.__class__.__name__ + self.alias = alias if alias is not None else model.__class__.__name__ def fit( self, @@ -6204,7 +6116,484 @@ def forward( res = _add_fitted_pi(res=res, se=se, level=level) return res -# %% ../nbs/src/core/models.ipynb 490 +# %% ../nbs/src/core/models.ipynb 489 +class MFLES(_TS): + """MFLES model. + + A method to forecast time series based on Gradient Boosted Time Series Decomposition + which treats traditional decomposition as the base estimator in the boosting + process. Unlike normal gradient boosting, slight learning rates are applied at the + component level (trend/seasonality/exogenous). + + The method derives its name from some of the underlying estimators that can + enter into the boosting procedure, specifically: a simple Median, Fourier + functions for seasonality, a simple/piecewise Linear trend, and Exponential + Smoothing. + + Parameters + ---------- + season_length : int or list of int, optional (default=None) + Number of observations per unit of time. Ex: 24 Hourly data. + fourier_order : int, optional (default=None) + How many fourier sin/cos pairs to create, the larger the number the more complex of a seasonal pattern can be fitted. + A lower number leads to smoother results. + This is auto-set based on seasonal_period. + max_rounds : int (default=50) + The max number of boosting rounds. The boosting will auto-stop but depending on other parameters such as rs_lr you may want more rounds. + Generally more rounds means a smoother fit. + ma : int, optional (default=None) + The moving average order to use, this is auto-set based on internal logic. + Passing 4 would fit a 4 period moving average on the residual component. + alpha : float (default=1.0) + The alpha which is used in fitting the underlying LASSO when using piecewise functions. + decay : float (default=-1.0) + Effects the slopes of the piecewise-linear basis function. + changepoints : boolean (default=True) + Whether to fit for changepoints if all other logic allows for it. If False, MFLES will not ever fit a piecewise trend. + n_changepoints : int or float (default=0.25) + Number (if int) or proportion (if float) of changepoint knots to place. The default of 0.25 will place 0.25 * (series length) number of knots. + seasonal_lr : float (default=0.9) + A shrinkage parameter (0 < seasonal_lr <= 1) which penalizes the seasonal fit. + A value of 0.9 will flatly multiply the seasonal fit by 0.9 each boosting round, this can be used to allow more signal to the exogenous component. + trend_lr : float (default=0.9) + A shrinkage parameter (0 < trend_lr <= 1) which penalizes the linear trend fit + A value of 0.9 will flatly multiply the linear fit by 0.9 each boosting round, this can be used to allow more signal to the seasonality or exogenous components. + exogenous_lr : float (default=1.0) + The shrinkage parameter (0 < exogenous_lr <= 1) which controls how much of the exogenous signal is carried to the next round. + residuals_lr : float (default=1.0) + A shrinkage parameter (0 < residuals_lr <= 1) which penalizes the residual smoothing. + A value of 0.9 will flatly multiply the residual fit by 0.9 each boosting round, this can be used to allow more signal to the seasonality or linear components. + cov_threshold : float (default=0.7) + The deseasonalized cov is used to auto-set some logic, lowering the cov_threshold will result in simpler and less complex residual smoothing. + If you pass something like 1000 then there will be no safeguards applied. + moving_medians : bool (default=False) + The default behavior is to fit an initial median to the time series. If True, then it will fit a median per seasonal period. + min_alpha : float (default=0.05) + The minimum alpha in the SES ensemble. + max_alpha : float (default=1.0) + The maximum alpha used in the SES ensemble. + trend_penalty : bool (default=True) + Whether to apply a simple penalty to the linear trend component, very useful for dealing with the potentially dangerous piecewise trend. + multiplicative : bool, optional (default=None) + Auto-set based on internal logic. If True, it will simply take the log of the time series. + smoother : bool (default=False) + If True, then a simple exponential ensemble will be used rather than auto settings. + robust : bool, optional (default=None) + If True then MFLES will fit using more reserved methods, i.e. not using piecewise trend or moving average residual smoother. + Auto-set based on internal logic. + verbose : bool (default=False) + Print debugging information. + prediction_intervals : Optional[ConformalIntervals] + Information to compute conformal prediction intervals. + This is required for generating future prediction intervals. + alias : str (default='MFLES') + Custom name of the model. + """ + + def __init__( + self, + season_length: Optional[Union[int, List[int]]] = None, + fourier_order: Optional[int] = None, + max_rounds: int = 50, + ma: Optional[int] = None, + alpha: float = 1.0, + decay: float = -1.0, + changepoints: bool = True, + n_changepoints: Union[float, int] = 0.25, + seasonal_lr: float = 0.9, + trend_lr: float = 0.9, + exogenous_lr: float = 1.0, + residuals_lr: float = 1.0, + cov_threshold: float = 0.7, + moving_medians: bool = False, + min_alpha: float = 0.05, + max_alpha: float = 1.0, + trend_penalty: bool = True, + multiplicative: Optional[bool] = None, + smoother: bool = False, + robust: Optional[bool] = None, + verbose: bool = False, + prediction_intervals: Optional[ConformalIntervals] = None, + alias: str = "MFLES", + ): + try: + import sklearn # noqa: F401 + except ImportError: + raise ImportError("MFLES requires scikit-learn.") from None + self.season_length = season_length + self.fourier_order = fourier_order + self.max_rounds = max_rounds + self.ma = ma + self.alpha = alpha + self.decay = decay + self.changepoints = changepoints + self.n_changepoints = n_changepoints + self.seasonal_lr = seasonal_lr + self.trend_lr = trend_lr + self.exogenous_lr = exogenous_lr + self.residuals_lr = residuals_lr + self.cov_threshold = cov_threshold + self.moving_medians = moving_medians + self.min_alpha = min_alpha + self.max_alpha = max_alpha + self.trend_penalty = trend_penalty + self.multiplicative = multiplicative + self.smoother = smoother + self.robust = robust + self.verbose = verbose + self.prediction_intervals = prediction_intervals + self.alias = alias + + def _fit(self, y: np.ndarray, X: Optional[np.ndarray]) -> Dict[str, Any]: + model = _MFLES(verbose=self.verbose, robust=self.robust) + fitted = model.fit( + y=y, + X=X, + seasonal_period=self.season_length, + fourier_order=self.fourier_order, + ma=self.ma, + alpha=self.alpha, + decay=self.decay, + n_changepoints=self.n_changepoints, + seasonal_lr=self.seasonal_lr, + linear_lr=self.trend_lr, + exogenous_lr=self.exogenous_lr, + rs_lr=self.residuals_lr, + cov_threshold=self.cov_threshold, + moving_medians=self.moving_medians, + max_rounds=self.max_rounds, + min_alpha=self.min_alpha, + max_alpha=self.max_alpha, + trend_penalty=self.trend_penalty, + multiplicative=self.multiplicative, + changepoints=self.changepoints, + smoother=self.smoother, + ) + return {"model": model, "fitted": fitted} + + def fit(self, y: np.ndarray, X: Optional[np.ndarray] = None) -> "MFLES": + """Fit the model + + Parameters + ---------- + y : numpy.array + Clean time series of shape (t, ). + X : array-like, optional (default=None) + Exogenous of shape (t, n_x). + + Returns + ------- + self : MFLES + Fitted MFLES object. + """ + self.model_ = self._fit(y=y, X=X) + self._store_cs(y=y, X=X) + residuals = y - self.model_["fitted"] + self.model_["sigma"] = _calculate_sigma(residuals, y.size) + return self + + def predict( + self, + h: int, + X: Optional[np.ndarray] = None, + level: Optional[List[int]] = None, + ) -> Dict[str, Any]: + """Predict with fitted MFLES. + + Parameters + ---------- + h : int + Forecast horizon. + X : array-like, optional (default=None) + Exogenous of shape (h, n_x). + level: List[int] + Confidence levels (0-100) for prediction intervals. + + Returns + ------- + forecasts : dict + Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions. + """ + res = {"mean": self.model_["model"].predict(forecast_horizon=h, X=X)} + if level is None: + return res + level = sorted(level) + if self.prediction_intervals is not None: + res = self._add_predict_conformal_intervals(res, level) + else: + raise Exception("You must pass `prediction_intervals` to compute them.") + return res + + def predict_in_sample(self, level: Optional[List[int]] = None) -> Dict[str, Any]: + """Access fitted SklearnModel insample predictions. + + Parameters + ---------- + level : List[int] + Confidence levels (0-100) for prediction intervals. + + Returns + ------- + forecasts : dict + Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions. + """ + res = {"fitted": self.model_["fitted"]} + if level is not None: + level = sorted(level) + res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + return res + + def forecast( + self, + y: np.ndarray, + h: int, + X: Optional[np.ndarray] = None, + X_future: Optional[np.ndarray] = None, + level: Optional[List[int]] = None, + fitted: bool = False, + ) -> Dict[str, Any]: + """Memory Efficient MFLES predictions. + + This method avoids memory burden due from object storage. + It is analogous to `fit_predict` without storing information. + It assumes you know the forecast horizon in advance. + + Parameters + ---------- + y : numpy.array + Clean time series of shape (t, ). + h : int + Forecast horizon. + X : array-like + Insample exogenous of shape (t, n_x). + X_future : array-like + Exogenous of shape (h, n_x). + level : List[int] + Confidence levels (0-100) for prediction intervals. + fitted : bool + Whether or not to return insample predictions. + + Returns + ------- + forecasts : dict + Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions. + """ + model = self._fit(y=y, X=X) + res = {"mean": model["model"].predict(forecast_horizon=h, X=X_future)} + if fitted: + res["fitted"] = model["fitted"] + if level is not None: + level = sorted(level) + if self.prediction_intervals is not None: + res = self._add_conformal_intervals(fcst=res, y=y, X=X, level=level) + else: + raise Exception("You must pass `prediction_intervals` to compute them.") + if fitted: + residuals = y - res["fitted"] + sigma = _calculate_sigma(residuals, y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) + return res + +# %% ../nbs/src/core/models.ipynb 497 +class AutoMFLES(_TS): + """AutoMFLES + + Parameters + ---------- + test_size : int + Forecast horizon used during cross validation. + season_length : int or list of int, optional (default=None) + Number of observations per unit of time. Ex: 24 Hourly data. + n_windows : int (default=2) + Number of windows used for cross validation. + config : dict, optional (default=None) + Mapping from parameter name (from the init arguments of MFLES) to a list of values to try. + If `None`, will use defaults. + step_size : int, optional (default=None) + Step size between each cross validation window. If `None` will be set to test_size. + metric : str (default='smape') + Metric used to select the best model. Possible options are: 'smape', 'mape', 'mse' and 'mae'. + verbose : bool (default=False) + Print debugging information. + prediction_intervals : Optional[ConformalIntervals] + Information to compute conformal prediction intervals. + This is required for generating future prediction intervals. + alias : str (default='AutoMFLES') + Custom name of the model. + """ + + def __init__( + self, + test_size: int, + season_length: Optional[Union[int, List[int]]] = None, + n_windows: int = 2, + config: Optional[Dict[str, Any]] = None, + step_size: Optional[int] = None, + metric: str = "smape", + verbose: bool = False, + prediction_intervals: Optional[ConformalIntervals] = None, + alias: str = "AutoMFLES", + ): + try: + import sklearn # noqa: F401 + except ImportError: + raise ImportError("MFLES requires scikit-learn.") from None + self.season_length = season_length + self.n_windows = n_windows + self.test_size = test_size + self.config = config + self.step_size = step_size if step_size is not None else test_size + self.metric = metric + self.verbose = verbose + self.prediction_intervals = prediction_intervals + self.alias = alias + + def _fit(self, y: np.ndarray, X: Optional[np.ndarray] = None) -> Dict[str, Any]: + model = _MFLES(verbose=self.verbose) + optim_params = model.optimize( + y=y, + X=X, + test_size=self.test_size, + n_steps=self.n_windows, + step_size=self.step_size, + seasonal_period=self.season_length, + metric=self.metric, + params=self.config, + ) + # the seasonal_period may've been found during the optimization + seasonal_period = optim_params.pop("seasonal_period", self.season_length) + fitted = model.fit( + y=y, + X=X, + seasonal_period=seasonal_period, + **optim_params, + ) + return {"model": model, "fitted": fitted} + + def fit(self, y: np.ndarray, X: Optional[np.ndarray] = None) -> "AutoMFLES": + """Fit the model + + Parameters + ---------- + y : numpy.array + Clean time series of shape (t, ). + X : array-like, optional (default=None) + Exogenous of shape (t, n_x). + + Returns + ------- + self : AutoMFLES + Fitted AutoMFLES object. + """ + self.model_ = self._fit(y=y, X=X) + self._store_cs(y=y, X=X) + residuals = y - self.model_["fitted"] + self.model_["sigma"] = _calculate_sigma(residuals, y.size) + return self + + def predict( + self, + h: int, + X: Optional[np.ndarray] = None, + level: Optional[List[int]] = None, + ) -> Dict[str, Any]: + """Predict with fitted AutoMFLES. + + Parameters + ---------- + h : int + Forecast horizon. + X : array-like, optional (default=None) + Exogenous of shape (h, n_x). + level: List[int] + Confidence levels (0-100) for prediction intervals. + + Returns + ------- + forecasts : dict + Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions. + """ + res = {"mean": self.model_["model"].predict(forecast_horizon=h, X=X)} + if level is None: + return res + level = sorted(level) + if self.prediction_intervals is not None: + res = self._add_predict_conformal_intervals(res, level) + else: + raise Exception("You must pass `prediction_intervals` to compute them.") + return res + + def predict_in_sample(self, level: Optional[List[int]] = None) -> Dict[str, Any]: + """Access fitted AutoMFLES insample predictions. + + Parameters + ---------- + level : List[int] + Confidence levels (0-100) for prediction intervals. + + Returns + ------- + forecasts : dict + Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions. + """ + res = {"fitted": self.model_["fitted"]} + if level is not None: + level = sorted(level) + res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + return res + + def forecast( + self, + y: np.ndarray, + h: int, + X: Optional[np.ndarray] = None, + X_future: Optional[np.ndarray] = None, + level: Optional[List[int]] = None, + fitted: bool = False, + ) -> Dict[str, Any]: + """Memory Efficient AutoMFLES predictions. + + This method avoids memory burden due from object storage. + It is analogous to `fit_predict` without storing information. + It assumes you know the forecast horizon in advance. + + Parameters + ---------- + y : numpy.array + Clean time series of shape (t, ). + h : int + Forecast horizon. + X : array-like + Insample exogenous of shape (t, n_x). + X_future : array-like + Exogenous of shape (h, n_x). + level : List[int] + Confidence levels (0-100) for prediction intervals. + fitted : bool + Whether or not to return insample predictions. + + Returns + ------- + forecasts : dict + Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions. + """ + model = self._fit(y=y, X=X) + res = {"mean": model["model"].predict(forecast_horizon=h, X=X_future)} + if fitted: + res["fitted"] = model["fitted"] + if level is not None: + level = sorted(level) + if self.prediction_intervals is not None: + res = self._add_conformal_intervals(fcst=res, y=y, X=X, level=level) + else: + raise Exception("You must pass `prediction_intervals` to compute them.") + if fitted: + residuals = y - res["fitted"] + sigma = _calculate_sigma(residuals, y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) + return res + +# %% ../nbs/src/core/models.ipynb 501 class ConstantModel(_TS): def __init__(self, constant: float, alias: str = "ConstantModel"): @@ -6222,9 +6611,6 @@ def __init__(self, constant: float, alias: str = "ConstantModel"): self.constant = constant self.alias = alias - def __repr__(self): - return self.alias - def fit( self, y: np.ndarray, @@ -6390,7 +6776,7 @@ def forward( ) return res -# %% ../nbs/src/core/models.ipynb 504 +# %% ../nbs/src/core/models.ipynb 515 class ZeroModel(ConstantModel): def __init__(self, alias: str = "ZeroModel"): @@ -6405,7 +6791,7 @@ def __init__(self, alias: str = "ZeroModel"): """ super().__init__(constant=0, alias=alias) -# %% ../nbs/src/core/models.ipynb 518 +# %% ../nbs/src/core/models.ipynb 529 class NaNModel(ConstantModel): def __init__(self, alias: str = "NaNModel"):