diff --git a/darts/models/components/feed_forward.py b/darts/models/components/feed_forward.py index 810125ec23..65248d72ec 100644 --- a/darts/models/components/feed_forward.py +++ b/darts/models/components/feed_forward.py @@ -55,6 +55,8 @@ import torch from torch import nn as nn +from darts.utils.torch import MonteCarloDropout + class FeedForward(nn.Module): """ @@ -78,7 +80,8 @@ def __init__( """ * `d_model` is the number of features in a token embedding * `d_ff` is the number of features in the hidden layer of the FFN - * `dropout` is dropout probability for the hidden layer + * `dropout` is dropout probability for the hidden layer, + compatible with Monte Carlo dropout at inference time * `is_gated` specifies whether the hidden layer is gated * `bias1` specified whether the first fully connected layer should have a learnable bias * `bias2` specified whether the second fully connected layer should have a learnable bias @@ -90,7 +93,7 @@ def __init__( # Layer one parameterized by weight $W_1$ and bias $b_1$ self.layer2 = nn.Linear(d_ff, d_model, bias=bias2) # Hidden layer dropout - self.dropout = nn.Dropout(dropout) + self.dropout = MonteCarloDropout(dropout) # Activation function $f$ self.activation = activation # Whether there is a gate diff --git a/darts/models/forecasting/nbeats.py b/darts/models/forecasting/nbeats.py index 516823cd7d..ace92aec02 100644 --- a/darts/models/forecasting/nbeats.py +++ b/darts/models/forecasting/nbeats.py @@ -13,6 +13,7 @@ from darts.logging import get_logger, raise_if_not, raise_log from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel +from darts.utils.torch import MonteCarloDropout logger = get_logger(__name__) @@ -168,7 +169,7 @@ def __init__( ) if self.dropout > 0: - self.linear_layer_stack_list.append(nn.Dropout(p=self.dropout)) + self.linear_layer_stack_list.append(MonteCarloDropout(p=self.dropout)) self.fc_stack = nn.ModuleList(self.linear_layer_stack_list) @@ -586,7 +587,9 @@ def __init__( The degree of the polynomial used as waveform generator in trend stacks. Only used if `generic_architecture` is set to `False`. dropout - The dropout probability to be used in the fully connected layers (default=0.0). + The dropout probability to be used in fully connected layers. This is compatible with Monte Carlo dropout + at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at + prediction time). activation The activation function of encoder/decoder intermediate layer (default='ReLU'). Supported activations: ['ReLU','RReLU', 'PReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid'] diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index bb0f9322ca..a1e158a43c 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -13,6 +13,7 @@ from darts.logging import get_logger, raise_if_not from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel +from darts.utils.torch import MonteCarloDropout logger = get_logger(__name__) @@ -153,7 +154,7 @@ def __init__( layers.append(nn.BatchNorm1d(num_features=self.layer_widths[i + 1])) if self.dropout > 0: - layers.append(nn.Dropout(p=self.dropout)) + layers.append(MonteCarloDropout(p=self.dropout)) self.layers = nn.Sequential(*layers) @@ -520,7 +521,9 @@ def __init__( downsampling factors before interpolation, for each block in each stack. If left to ``None``, some default values will be used based on ``output_chunk_length``. dropout - Fraction of neurons affected by Dropout (default=0.1). + The dropout probability to be used in fully connected layers. This is compatible with Monte Carlo dropout + at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at + prediction time). activation The activation function of encoder/decoder intermediate layer (default='ReLU'). Supported activations: ['ReLU','RReLU', 'PReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid'] diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index a03ea78339..9b0a587ec0 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -15,6 +15,7 @@ from darts.timeseries import TimeSeries from darts.utils.likelihood_models import Likelihood from darts.utils.timeseries_generation import _build_forecast_series +from darts.utils.torch import MonteCarloDropout logger = get_logger(__name__) @@ -342,8 +343,22 @@ def _sample_tiling(input_data_tuple, batch_sample_size): tiled_input_data.append(None) return tuple(tiled_input_data) + def _get_mc_dropout_modules(self) -> set: + def recurse_children(children, acc): + for module in children: + if isinstance(module, MonteCarloDropout): + acc.add(module) + acc = recurse_children(module.children(), acc) + return acc + + return recurse_children(self.children(), set()) + + def set_mc_dropout(self, active: bool): + for module in self._get_mc_dropout_modules(): + module.mc_dropout_enabled = active + def _is_probabilistic(self) -> bool: - return self.likelihood is not None + return self.likelihood is not None or len(self._get_mc_dropout_modules()) > 0 def _produce_predict_output(self, x: Tuple): if self.likelihood: diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index a2276b871b..834d6e7509 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -15,6 +15,7 @@ from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel from darts.timeseries import TimeSeries from darts.utils.data import PastCovariatesShiftedDataset +from darts.utils.torch import MonteCarloDropout logger = get_logger(__name__) @@ -191,7 +192,7 @@ def __init__( self.target_size = target_size self.nr_params = nr_params self.dilation_base = dilation_base - self.dropout = nn.Dropout(p=dropout) + self.dropout = MonteCarloDropout(p=dropout) # If num_layers is not passed, compute number of layers needed for full history coverage if num_layers is None and dilation_base > 1: @@ -288,7 +289,9 @@ def __init__( num_layers The number of convolutional layers. dropout - The dropout rate for every convolutional layer. + The dropout rate for every convolutional layer. This is compatible with Monte Carlo dropout + at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at + prediction time). **kwargs Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and Darts' :class:`TorchForecastingModel`. diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index b7adc7b7e6..caa24845e6 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -617,7 +617,9 @@ def __init__( or the TFT original FeedForward Network. ["GatedResidualNetwork"] dropout : float - Fraction of neurons affected by Dropout. + Fraction of neurons affected by dropout. This is compatible with Monte Carlo dropout + at inference time for model uncertainty estimation (enabled with ``mc_dropout=True`` at + prediction time). hidden_continuous_size : int Default for hidden size for processing continuous variables add_relative_index : bool diff --git a/darts/models/forecasting/tft_submodels.py b/darts/models/forecasting/tft_submodels.py index 16945fe1d2..167c7007d0 100644 --- a/darts/models/forecasting/tft_submodels.py +++ b/darts/models/forecasting/tft_submodels.py @@ -27,6 +27,7 @@ import torch.nn.functional as F from darts.logging import get_logger +from darts.utils.torch import MonteCarloDropout logger = get_logger(__name__) @@ -188,7 +189,7 @@ def __init__(self, input_size: int, hidden_size: int = None, dropout: float = No super().__init__() if dropout is not None: - self.dropout = nn.Dropout(dropout) + self.dropout = MonteCarloDropout(dropout) else: self.dropout = dropout self.hidden_size = hidden_size or input_size @@ -500,7 +501,7 @@ class _ScaledDotProductAttention(nn.Module): def __init__(self, dropout: float = None, scale: bool = True): super().__init__() if dropout is not None: - self.dropout = nn.Dropout(p=dropout) + self.dropout = MonteCarloDropout(p=dropout) else: self.dropout = dropout self.softmax = nn.Softmax(dim=2) @@ -530,7 +531,7 @@ def __init__(self, n_head: int, d_model: int, dropout: float = 0.0): self.n_head = n_head self.d_model = d_model self.d_k = self.d_q = self.d_v = d_model // n_head - self.dropout = nn.Dropout(p=dropout) + self.dropout = MonteCarloDropout(p=dropout) self.v_layer = nn.Linear(self.d_model, self.d_v) self.q_layers = nn.ModuleList( diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index a002a4c2d0..c81ee5d21a 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -953,6 +953,7 @@ def predict( roll_size: Optional[int] = None, num_samples: int = 1, num_loader_workers: int = 0, + mc_dropout: bool = False, ) -> Union[TimeSeries, Sequence[TimeSeries]]: """Predict the ``n`` time step following the end of the training series, or of the specified ``series``. @@ -1015,6 +1016,9 @@ def predict( for the inference/prediction dataset loaders (if any). A larger number of workers can sometimes increase performance, but can also incur extra overheads and increase memory usage, as more batches are loaded in parallel. + mc_dropout + Optionally, enable monte carlo dropout for predictions using neural network based models. + This allows bayesian approximation by specifying an implicit prior over learned models. Returns ------- @@ -1077,6 +1081,8 @@ def predict( n_jobs=n_jobs, roll_size=roll_size, num_samples=num_samples, + num_loader_workers=num_loader_workers, + mc_dropout=mc_dropout, ) return predictions[0] if called_with_single_series else predictions @@ -1093,6 +1099,7 @@ def predict_from_dataset( roll_size: Optional[int] = None, num_samples: int = 1, num_loader_workers: int = 0, + mc_dropout: bool = False, ) -> Sequence[TimeSeries]: """ @@ -1136,6 +1143,9 @@ def predict_from_dataset( for the inference/prediction dataset loaders (if any). A larger number of workers can sometimes increase performance, but can also incur extra overheads and increase memory usage, as more batches are loaded in parallel. + mc_dropout + Optionally, enable monte carlo dropout for predictions using neural network based models. + This allows bayesian approximation by specifying an implicit prior over learned models. Returns ------- @@ -1184,6 +1194,9 @@ def predict_from_dataset( collate_fn=self._batch_collate_fn, ) + # Set mc_dropout rate + self.model.set_mc_dropout(mc_dropout) + # setup trainer. will only be re-instantiated if both `trainer` and `self.trainer` are `None` trainer = trainer if trainer is not None else self.trainer self._setup_trainer(trainer=trainer, verbose=verbose, epochs=self.n_epochs) @@ -1428,7 +1441,7 @@ def _is_probabilistic(self) -> bool: return ( self.model._is_probabilistic() if self.model_created - else self.likelihood is not None + else True # all torch models can be probabilistic (via Dropout) ) diff --git a/darts/tests/models/forecasting/test_TCN.py b/darts/tests/models/forecasting/test_TCN.py index b880e07aea..e68836fab2 100644 --- a/darts/tests/models/forecasting/test_TCN.py +++ b/darts/tests/models/forecasting/test_TCN.py @@ -103,6 +103,10 @@ def test_coverage(self): ) model.model.eval() + + # also disable MC Dropout: + model.model.set_mc_dropout(False) + input_tensor = torch.zeros( [1, input_chunk_length, 1], dtype=torch.float64 ) @@ -146,6 +150,10 @@ def test_coverage(self): ) model_2.model.eval() + + # also disable MC Dropout: + model_2.model.set_mc_dropout(False) + input_tensor = torch.zeros( [1, input_chunk_length, 1], dtype=torch.float64 ) diff --git a/darts/utils/torch.py b/darts/utils/torch.py index f275bdd1cc..552f285384 100644 --- a/darts/utils/torch.py +++ b/darts/utils/torch.py @@ -7,8 +7,11 @@ from inspect import signature from typing import Any, Callable, TypeVar +import torch.nn as nn +import torch.nn.functional as F from numpy.random import randint from sklearn.utils import check_random_state +from torch import Tensor from torch.random import fork_rng, manual_seed from darts.logging import get_logger, raise_if_not @@ -20,6 +23,45 @@ MAX_NUMPY_SEED_VALUE = (1 << 31) - 1 +class MonteCarloDropout(nn.Dropout): + """ + Defines Monte Carlo dropout Module as defined + in the paper https://arxiv.org/pdf/1506.02142.pdf. + In summary, This technique uses the regular dropout + which can be interpreted as a Bayesian approximation of + a well-known probabilistic model: the Gaussian process. + We can treat the many different networks + (with different neurons dropped out) as Monte Carlo samples + from the space of all available models. This provides mathematical + grounds to reason about the model’s uncertainty and, as it turns out, + often improves its performance. + """ + + # We need to init it to False as some models may start by + # a validation round, in which case MC dropout is disabled. + mc_dropout_enabled: bool = False + + def train(self, mode: bool = True): + # NOTE: we could use the line below if self.mc_dropout_rate represented + # a rate to be applied at inference time, and self.applied_rate the + # actual rate to be used in self.forward(). However, the original paper + # considers the same rate for training and inference; we also stick to this. + + # self.applied_rate = self.p if mode else self.mc_dropout_rate + + if mode: # in train mode, keep dropout as is + self.mc_dropout_enabled = True + # in eval mode, bank on the mc_dropout_enabled flag + # mc_dropout_enabled is set equal to "mc_dropout" param given to predict() + + def forward(self, input: Tensor) -> Tensor: + # NOTE: we could use the following line in case a different rate + # is used for inference: + # return F.dropout(input, self.applied_rate, True, self.inplace) + + return F.dropout(input, self.p, self.mc_dropout_enabled, self.inplace) + + def _is_method(func: Callable[..., Any]) -> bool: """Check if the specified function is a method. diff --git a/docs/userguide/forecasting_overview.md b/docs/userguide/forecasting_overview.md index 654fa40e77..5501d58ccc 100644 --- a/docs/userguide/forecasting_overview.md +++ b/docs/userguide/forecasting_overview.md @@ -131,10 +131,11 @@ pred.plot(label='forecast') ![Exponential Smoothing](./images/probabilistic/example_ets.png) ### Probabilistic neural networks -All neural networks (torch-based models) in Darts have a rich support to fit different kinds of probability distributions. -When creating the model, it is possible to provide one of the *likelihood models* available in [darts.utils.likelihood_models](https://unit8co.github.io/darts/generated_api/darts.utils.likelihood_models.html), which determine the distribution that will be fit by the model. +All neural networks (torch-based models) in Darts have a rich support to estimate different kinds of probability distributions. +When creating the model, it is possible to provide one of the *likelihood models* available in [darts.utils.likelihood_models](https://unit8co.github.io/darts/generated_api/darts.utils.likelihood_models.html), which determine the distribution that will be estimated by the model. In such cases, the model will output the parameters of the distribution, and it will be trained by minimising the negative log-likelihood of the training samples. Most of the likelihood models also support prior values for the distribution's parameters, in which case the training loss is regularized by a Kullback-Leibler divergence term pushing the resulting distribution in the direction of the distribution specified by the prior parameters. +The strength of this regularization term can also be specified when creating the likelihood model object. For example, the code below trains a TCNModel to fit a Laplace distribution. So the neural network outputs 2 parameters (location and scale) of the Laplace distribution. We also specify a prior value of 0.1 on the scale parameter. @@ -166,7 +167,7 @@ pred.plot(label='forecast') ![TCN Laplace regression](./images/probabilistic/example_tcn_laplace.png) -It is also possible to perform quantile regression (using arbitrary quantiles) with neural networks, by using [darts.utils.likelihood_models.QuantileRegression](https://unit8co.github.io/darts/generated_api/darts.utils.likelihood_models.html#darts.utils.likelihood_models.QuantileRegression), in which case the network will be trained with the pinball loss. +It is also possible to perform quantile regression (using arbitrary quantiles) with neural networks, by using [darts.utils.likelihood_models.QuantileRegression](https://unit8co.github.io/darts/generated_api/darts.utils.likelihood_models.html#darts.utils.likelihood_models.QuantileRegression), in which case the network will be trained with the pinball loss. This produces an empirical non-parametric distrution, and it can often be a good option in practice, when one is not sure of the "real" distribution, or when fitting parametric likelihoods give poor results. For example, the code snippet below is almost exactly the same as the preceding snippet; the only difference is that it now uses a `QuantileRegression` likelihood, which means that the neural network will be trained with a pinball loss, and its number of outputs will be dynamically configured to match the number of quantiles. ```python @@ -196,6 +197,40 @@ pred.plot(label='forecast') ![TCN quantile regression](./images/probabilistic/example_tcn_quantile.png) +### Capturing model uncertainty using Monte Carlo Dropout +In Darts, dropout can also be used as an additional way to capture model uncertainty, following the approach described in [1]. This is sometimes referred to as *epistemic uncertainty*, and can be seen as a way to marginalize over a family of models represented by all the different dropout activation functions. + +This feature is readily available for all deep learning models integrating some dropout (except RNN models - we refer to the dropout API reference documentations for a mention of models supporting this). It only requires to specify `mc_dropout=True` at prediction time. For example, the code below trains a TCN model (using the default MSE loss) with a dropout rate of 10%, and then produces a probabilistic forecasts using Monte Carlo Dropout: + +```python +from darts.datasets import AirPassengersDataset +from darts import TimeSeries +from darts.models import TCNModel +from darts.dataprocessing.transformers import Scaler +from darts.utils.likelihood_models import QuantileRegression + +series = AirPassengersDataset().load() +train, val = series[:-36], series[-36:] + +scaler = Scaler() +train = scaler.fit_transform(train) +val = scaler.transform(val) +series = scaler.transform(series) + +model = TCNModel(input_chunk_length=30, + output_chunk_length=12, + dropout=0.1) +model.fit(train, epochs=400) +pred = model.predict(n=36, mc_dropout=True, num_samples=500) + +series.plot() +pred.plot(label='forecast') +``` + +![TCN quantile regression](./images/probabilistic/example_mc_dropout.png) + +Monte Carlo Dropout can be combined with other likelihood estimation in Darts, which can be interpreted as a way to capture both epistemic and aleatoric uncertainty. + ### Probabilistic regression models Some regression models can also be configured to produce probabilistic forecasts too. At the time of writing, [LinearRegressionModel](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.linear_regression_model.html) and [LightGBMModel](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.gradient_boosted_model.html) support a `likelihood` argument. When set to `"poisson"` the model will fit a Poisson distribution, and when set to `"quantile"` the model will use the pinball loss to perform quantile regression (the quantiles themselves can be specified using the `quantiles` argument). @@ -219,4 +254,7 @@ series.plot() pred.plot(label='forecast') ``` -![quantile linear regression](./images/probabilistic/example_linreg_quantile.png) \ No newline at end of file +![quantile linear regression](./images/probabilistic/example_linreg_quantile.png) + + +[1] Yarin Gal, Zoubin Ghahramani, ["Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning"](https://arxiv.org/abs/1506.02142) \ No newline at end of file diff --git a/docs/userguide/images/probabilistic/example_mc_dropout.png b/docs/userguide/images/probabilistic/example_mc_dropout.png new file mode 100644 index 0000000000..caecf6acfd Binary files /dev/null and b/docs/userguide/images/probabilistic/example_mc_dropout.png differ