diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 2b14d75e11..2832f86e17 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -22,13 +22,19 @@ class BaseDataPreprocessor(nn.Module): forward method to implement custom data pre-processing, such as batch-resize, MixUp, or CutMix. + Args: + non_blocking (bool): Whether block current process + when transferring data to device. + New in version 0.3.0. + Note: Data dictionary returned by dataloader must be a dict and at least contain the ``inputs`` key. """ - def __init__(self): + def __init__(self, non_blocking: Optional[bool] = False): super().__init__() + self._non_blocking = non_blocking self._device = torch.device('cpu') def cast_data(self, data: CastData) -> CastData: @@ -48,9 +54,9 @@ def cast_data(self, data: CastData) -> CastData: elif isinstance(data, Sequence): return [self.cast_data(sample) for sample in data] elif isinstance(data, torch.Tensor): - return data.to(self.device) + return data.to(self.device, non_blocking=self._non_blocking) elif isinstance(data, BaseDataElement): - return data.to(self.device) + return data.to(self.device, non_blocking=self._non_blocking) else: return data @@ -150,6 +156,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor): Defaults to False. rgb_to_bgr (bool): whether to convert image from RGB to RGB. Defaults to False. + non_blocking (bool): Whether block current process + when transferring data to device. + New in version v0.3.0. Note: if images do not need to be normalized, `std` and `mean` should be @@ -163,8 +172,9 @@ def __init__(self, pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False): - super().__init__() + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__(non_blocking) assert not (bgr_to_rgb and rgb_to_bgr), ( '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') assert (mean is None) == (std is None), ( diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 15ba57d390..a1bc25d41c 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -14,6 +14,11 @@ class TestBaseDataPreprocessor(TestCase): def test_init(self): base_data_preprocessor = BaseDataPreprocessor() self.assertEqual(base_data_preprocessor._device.type, 'cpu') + self.assertEqual(base_data_preprocessor._non_blocking, False) + + base_data_preprocessor = BaseDataPreprocessor(True) + self.assertEqual(base_data_preprocessor._device.type, 'cpu') + self.assertEqual(base_data_preprocessor._non_blocking, True) def test_forward(self): # Test cpu forward with list of data samples.