From c7264fcfb6cbb95f9cf76ce5382929b3ed14a4be Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Wed, 1 Feb 2023 14:51:17 +0800 Subject: [PATCH] Clean the commit history --- mmengine/logging/message_hub.py | 2 +- tests/test_logging/test_message_hub.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 79fc131ae4..d173565d84 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -324,7 +324,7 @@ def _get_valid_value( Returns: float or int: python built-in type value. """ - if isinstance(value, np.ndarray): + if isinstance(value, (np.ndarray, np.number)): assert value.size == 1 value = value.item() elif isinstance(value, (int, float)): diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index b6061f82e8..4dffdec06c 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -34,14 +34,18 @@ def test_init(self): def test_update_scalar(self): message_hub = MessageHub.get_instance('mmengine') - # test create target `HistoryBuffer` by name + # Update scalar with int. message_hub.update_scalar('name', 1) log_buffer = message_hub.log_scalars['name'] assert (log_buffer._log_history == np.array([1])).all() - # test update target `HistoryBuffer` by name - message_hub.update_scalar('name', 1) + + # Update scalar with np.ndarray. + message_hub.update_scalar('name', np.array(1)) assert (log_buffer._log_history == np.array([1, 1])).all() - # unmatched string will raise a key error + + # Update scalar with np.int + message_hub.update_scalar('name', np.int32(1)) + assert (log_buffer._log_history == np.array([1, 1, 1])).all() def test_update_info(self): message_hub = MessageHub.get_instance('mmengine')