Skip to content

Commit

Permalink
Merge b824621 into 7a074fa
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Mar 10, 2023
2 parents 7a074fa + b824621 commit 580b3be
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
31 changes: 31 additions & 0 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@
from mmengine.utils.manager import _accquire_lock, _release_lock


class FilterDuplicateWarning(logging.Filter):
"""Filter the repeated warning message.
Args:
name (str): name of the filter.
"""

def __init__(self, name: str = 'mmengine'):
super().__init__(name)
self.seen: set = set()

def filter(self, record: LogRecord) -> bool:
"""Filter the repeated warning message.
Args:
record (LogRecord): The log record.
Returns:
bool: Whether to output the log record.
"""
if record.levelno != logging.WARNING:
return True

if record.msg not in self.seen:
self.seen.add(record.msg)
return True
return False


class MMFormatter(logging.Formatter):
"""Colorful format for MMLogger. If the log level is error, the logger will
additionally output the location of the code.
Expand Down Expand Up @@ -164,6 +193,7 @@ def __init__(self,
# Only rank0 `StreamHandler` will log messages below error level.
stream_handler.setLevel(log_level) if rank == 0 else \
stream_handler.setLevel(logging.ERROR)
stream_handler.addFilter(FilterDuplicateWarning(logger_name))
self.handlers.append(stream_handler)

if log_file is not None:
Expand Down Expand Up @@ -191,6 +221,7 @@ def __init__(self,
file_handler.setFormatter(
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel(log_level)
file_handler.addFilter(FilterDuplicateWarning(logger_name))
self.handlers.append(file_handler)

@classmethod
Expand Down
14 changes: 14 additions & 0 deletions tests/test_logging/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,17 @@ def test_set_level(self, capsys):
logger.warning('hello')
out, _ = capsys.readouterr()
assert 'WARNING' in out

def test_filter(self, capsys):
logger = MMLogger.get_instance('test_filter')
logger.warning('hello')
out, _ = capsys.readouterr()
assert 'WARNING' in out
# Filter repeated warning.
logger.warning('hello')
out, _ = capsys.readouterr()
assert not out
# Pass new warning
logger.warning('hello1')
out, _ = capsys.readouterr()
assert 'WARNING' in out

0 comments on commit 580b3be

Please sign in to comment.