diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 442bef2fc4..94ca1a7955 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -11,6 +11,35 @@ from mmengine.utils.manager import _accquire_lock, _release_lock +class FilterDuplicateWarning(logging.Filter): + """Filter the repeated warning message. + + Args: + name (str): name of the filter. + """ + + def __init__(self, name: str = 'mmengine'): + super().__init__(name) + self.seen: set = set() + + def filter(self, record: LogRecord) -> bool: + """Filter the repeated warning message. + + Args: + record (LogRecord): The log record. + + Returns: + bool: Whether to output the log record. + """ + if record.levelno != logging.WARNING: + return True + + if record.msg not in self.seen: + self.seen.add(record.msg) + return True + return False + + class MMFormatter(logging.Formatter): """Colorful format for MMLogger. If the log level is error, the logger will additionally output the location of the code. @@ -164,6 +193,7 @@ def __init__(self, # Only rank0 `StreamHandler` will log messages below error level. stream_handler.setLevel(log_level) if rank == 0 else \ stream_handler.setLevel(logging.ERROR) + stream_handler.addFilter(FilterDuplicateWarning(logger_name)) self.handlers.append(stream_handler) if log_file is not None: @@ -191,6 +221,7 @@ def __init__(self, file_handler.setFormatter( MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) file_handler.setLevel(log_level) + file_handler.addFilter(FilterDuplicateWarning(logger_name)) self.handlers.append(file_handler) @classmethod diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index e0734702cc..4644f3be30 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -184,3 +184,17 @@ def test_set_level(self, capsys): logger.warning('hello') out, _ = capsys.readouterr() assert 'WARNING' in out + + def test_filter(self, capsys): + logger = MMLogger.get_instance('test_filter') + logger.warning('hello') + out, _ = capsys.readouterr() + assert 'WARNING' in out + # Filter repeated warning. + logger.warning('hello') + out, _ = capsys.readouterr() + assert not out + # Pass new warning + logger.warning('hello1') + out, _ = capsys.readouterr() + assert 'WARNING' in out