Skip to content

Commit

Permalink
[Fix] Fix error format of log message (open-mmlab#508)
Browse files Browse the repository at this point in the history
* Fix error format of log message

* Fix unit test

* remove unnecessary comment
  • Loading branch information
HAOCHENYE authored and C1rN09 committed Nov 1, 2022
1 parent 84528ea commit fb37c2c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
42 changes: 21 additions & 21 deletions mmengine/hooks/iter_timer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class IterTimerHook(Hook):

def __init__(self):
self.time_sec_tot = 0
self.time_sec_test_val = 0
self.start_iter = 0

def before_train(self, runner) -> None:
Expand All @@ -41,6 +42,9 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
"""
self.t = time.time()

def _after_epoch(self, runner, mode: str = 'train') -> None:
self.time_sec_test_val = 0

def _before_iter(self,
runner,
batch_idx: int,
Expand Down Expand Up @@ -82,26 +86,22 @@ def _after_iter(self,
message_hub = runner.message_hub
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
self.t = time.time()
window_size = runner.log_processor.window_size
# Calculate eta every `window_size` iterations. Since test and val
# loop will not update runner.iter, use `every_n_innter_iters`to check
# the interval.
if self.every_n_inner_iters(batch_idx, window_size):
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
window_size)
if mode == 'train':
self.time_sec_tot += iter_time * window_size
# Calculate average iterative time.
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
iter_time = message_hub.get_scalar(f'{mode}/time')
if mode == 'train':
self.time_sec_tot += iter_time.current()
# Calculate average iterative time.
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
else:
if mode == 'val':
cur_dataloader = runner.val_dataloader
else:
if mode == 'val':
cur_dataloader = runner.val_dataloader
else:
cur_dataloader = runner.test_dataloader
cur_dataloader = runner.test_dataloader

eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)
self.time_sec_test_val += iter_time.current()
time_sec_avg = self.time_sec_test_val / (batch_idx + 1)
eta_sec = time_sec_avg * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)
32 changes: 22 additions & 10 deletions tests/test_hooks/test_iter_timer_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,30 @@ def test_after_iter(self):
runner.iter = 0
runner.test_dataloader = [0] * 20
runner.val_dataloader = [0] * 20
self.hook._before_epoch(runner)
self.hook.before_run(runner)
self.hook._after_iter(runner, batch_idx=1)
runner.message_hub.update_scalar.assert_called()
runner.message_hub.get_log.assert_not_called()
runner.message_hub.update_info.assert_not_called()
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
runner.iter = 9

self.hook.before_run(runner)
self.hook._before_epoch(runner)
# eta = (100 - 10) / 1
self.hook._after_iter(runner, batch_idx=89)
for _ in range(10):
self.hook._after_iter(runner, 1)
runner.iter += 1
assert runner.message_hub.get_info('eta') == 90
self.hook._after_iter(runner, batch_idx=9, mode='val')

for i in range(10):
self.hook._after_iter(runner, batch_idx=i, mode='val')
assert runner.message_hub.get_info('eta') == 10
self.hook._after_iter(runner, batch_idx=19, mode='test')

for i in range(11, 20):
self.hook._after_iter(runner, batch_idx=i, mode='val')
assert runner.message_hub.get_info('eta') == 0

self.hook.after_val_epoch(runner)

for i in range(10):
self.hook._after_iter(runner, batch_idx=i, mode='test')
assert runner.message_hub.get_info('eta') == 10

for i in range(11, 20):
self.hook._after_iter(runner, batch_idx=i, mode='test')
assert runner.message_hub.get_info('eta') == 0

0 comments on commit fb37c2c

Please sign in to comment.