From 3a10fc088409a8790390c854a12759cd972fdc63 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 12 Oct 2022 19:17:55 +0800 Subject: [PATCH 1/6] [Fix] Fix cound not handle string data --- .../model/base_model/data_preprocessor.py | 11 +++--- .../test_base_model/test_data_preprocessor.py | 36 +++++++++++++------ 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 2b14d75e11..b814dacaf3 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -42,17 +42,20 @@ def cast_data(self, data: CastData) -> CastData: """ if isinstance(data, Mapping): return {key: self.cast_data(data[key]) for key in data} + elif isinstance(data, (str, bytes)): + return data # type: ignore elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, Sequence): return [self.cast_data(sample) for sample in data] - elif isinstance(data, torch.Tensor): - return data.to(self.device) - elif isinstance(data, BaseDataElement): + elif isinstance(data, (torch.Tensor, BaseDataElement)): return data.to(self.device) else: - return data + raise TypeError( + '`BaseDataPreprocessor.cast_data`: batch must contain ' + 'tensors, numpy arrays, numbers, dicts or lists, but ' + f'found {type(data)}') def forward(self, data: dict, training: bool = False) -> Union[dict, list]: """Preprocesses the data into the model input format. diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 15ba57d390..0ed37c39d4 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -23,8 +23,8 @@ def test_forward(self): label1 = torch.randn(1) label2 = torch.randn(1) + # Test with dict of batch inputs and batch data samples data = dict(inputs=[input1, input2], data_sample=[label1, label2]) - output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output['data_sample'] self.assertTrue(torch.is_floating_point(batch_inputs[0])) @@ -36,40 +36,54 @@ def test_forward(self): assert_allclose(label2, batch_labels[1]) # Test with tuple of batch inputs and batch data samples - data = dict( - inputs=torch.stack([input1, input2]), data_sample=[label1, label2]) - output = base_data_preprocessor(data)['inputs'] + data = (torch.stack([input1, input2]), (label1, label2)) + batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertTrue(torch.is_floating_point(batch_inputs)) + self.assertEqual(batch_inputs[0].shape, (1, 3, 5)) + self.assertEqual(batch_inputs[1].shape, (1, 3, 5)) self.assertTrue(torch.is_floating_point(batch_inputs[0])) # Test cuda forward if torch.cuda.is_available(): # Test with list of data samples. + data = dict(inputs=[input1, input2], data_sample=[label1, label2]) base_data_preprocessor = base_data_preprocessor.cuda() output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output[ 'data_sample'] - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].device.type, 'cuda') + # Fallback to test with cpu. base_data_preprocessor = base_data_preprocessor.cpu() output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output[ 'data_sample'] - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.device.type, 'cpu') + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].device.type, 'cpu') + # Test `base_data_preprocessor` can be moved to cuda again. base_data_preprocessor = base_data_preprocessor.to('cuda:0') output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output[ 'data_sample'] - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].device.type, 'cuda') # device of `base_data_preprocessor` is cuda, output should be # cuda tensor. - self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertEqual(batch_inputs[0].device.type, 'cuda') self.assertEqual(batch_labels[0].device.type, 'cuda') + # Test forward with string value + data = dict(string='abc') + base_data_preprocessor(data) + + with self.assertRaisesRegex(TypeError, + '`BaseDataPreprocessor.cast_data`:'): + data = dict(string=object()) + base_data_preprocessor(data) + class TestImgDataPreprocessor(TestBaseDataPreprocessor): From 762f92b36559a82204a49368aabaa0b817984115 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 12 Oct 2022 19:25:35 +0800 Subject: [PATCH 2/6] Minor refine --- mmengine/model/base_model/data_preprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index b814dacaf3..0c1d7218a0 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -42,7 +42,7 @@ def cast_data(self, data: CastData) -> CastData: """ if isinstance(data, Mapping): return {key: self.cast_data(data[key]) for key in data} - elif isinstance(data, (str, bytes)): + elif isinstance(data, (str, bytes)) or data is None: return data # type: ignore elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple From e31d8c691d4d0e85d228812683917b1d3b056880 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 12 Oct 2022 19:32:49 +0800 Subject: [PATCH 3/6] Refine type hint Refine type hint --- mmengine/model/base_model/data_preprocessor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 0c1d7218a0..98abf3e1f1 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -11,7 +11,8 @@ from mmengine.utils import is_list_of from ..utils import stack_batch -CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list] +CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, + None] @MODELS.register_module() @@ -43,7 +44,7 @@ def cast_data(self, data: CastData) -> CastData: if isinstance(data, Mapping): return {key: self.cast_data(data[key]) for key in data} elif isinstance(data, (str, bytes)) or data is None: - return data # type: ignore + return data elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable From 625071b2c305b5613d2cfbff10109f3336b503cf Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Mon, 17 Oct 2022 17:46:01 +0800 Subject: [PATCH 4/6] fix as comment --- mmengine/model/base_model/data_preprocessor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 98abf3e1f1..9deaf3bf5f 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -49,12 +49,12 @@ def cast_data(self, data: CastData) -> CastData: # namedtuple return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, Sequence): - return [self.cast_data(sample) for sample in data] + return type(data)([self.cast_data(sample) for sample in data]) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, (torch.Tensor, BaseDataElement)): return data.to(self.device) else: raise TypeError( - '`BaseDataPreprocessor.cast_data`: batch must contain ' + '`BaseDataPreprocessor.cast_data`: batch data must contain ' 'tensors, numpy arrays, numbers, dicts or lists, but ' f'found {type(data)}') From 258b9f707a7826f2bcbdb0c83c0b7846e557cf1d Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Thu, 27 Oct 2022 15:09:55 +0800 Subject: [PATCH 5/6] Minor refine --- mmengine/model/base_model/data_preprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 3a2c104548..8b6a7f8e36 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -53,7 +53,7 @@ def cast_data(self, data: CastData) -> CastData: return data elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple - return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable + return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, Sequence): return type(data)([self.cast_data(sample) for sample in data]) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, (torch.Tensor, BaseDataElement)): From b4f1ca97e14fec40a128e6b979c2b03127fbc910 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Fri, 28 Oct 2022 01:28:16 +0800 Subject: [PATCH 6/6] Update mmengine/model/base_model/data_preprocessor.py --- mmengine/model/base_model/data_preprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 8b6a7f8e36..d7d27ec907 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -55,7 +55,7 @@ def cast_data(self, data: CastData) -> CastData: # namedtuple return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, Sequence): - return type(data)([self.cast_data(sample) for sample in data]) # type: ignore # noqa: E501 # yapf:disable + return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, (torch.Tensor, BaseDataElement)): return data.to(self.device, non_blocking=self._non_blocking) else: