diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index e3049bcd3b..a90057aad9 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -4,11 +4,6 @@ from unittest.mock import patch import pandas as pd -from torchmetrics import ( - MeanAbsoluteError, - MeanAbsolutePercentageError, - MetricCollection, -) from darts import TimeSeries from darts.logging import get_logger @@ -19,6 +14,11 @@ try: import torch + from torchmetrics import ( + MeanAbsoluteError, + MeanAbsolutePercentageError, + MetricCollection, + ) from darts.models.forecasting.rnn_model import RNNModel diff --git a/requirements/torch.txt b/requirements/torch.txt index be24c75a15..621887fd74 100644 --- a/requirements/torch.txt +++ b/requirements/torch.txt @@ -1,3 +1,2 @@ pytorch-lightning>=1.5.0 torch>=1.8.0 -torchmetrics>=0.9.1