Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Failed to remove the previous best checkpoints #1086

Merged
merged 5 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,9 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
runner.message_hub.update_info(best_score_key, best_score)

if best_ckpt_path and \
self.file_client.isfile(best_ckpt_path) and \
self.file_backend.isfile(best_ckpt_path) and \
is_main_process():
self.file_client.remove(best_ckpt_path)
self.file_backend.remove(best_ckpt_path)
runner.logger.info(
f'The previous best checkpoint {best_ckpt_path} '
'is removed')
Expand All @@ -490,13 +490,13 @@ def _save_best_checkpoint(self, runner, metrics) -> None:
# Replace illegal characters for filename with `_`
best_ckpt_name = best_ckpt_name.replace('/', '_')
if len(self.key_indicators) == 1:
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501
self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(runtime_best_ckpt_key,
self.best_ckpt_path)
else:
self.best_ckpt_path_dict[
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501
key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name)
runner.message_hub.update_info(
runtime_best_ckpt_key,
Expand Down
6 changes: 5 additions & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,7 +2191,11 @@ def save_checkpoint(
checkpoint['param_schedulers'].append(state_dict)

self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath)
save_checkpoint(
checkpoint,
filepath,
file_client_args=file_client_args,
backend_args=backend_args)

@master_only
def dump_config(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions mmengine/testing/runner_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import shutil
import tempfile
import time
from unittest import TestCase
Expand Down Expand Up @@ -184,3 +185,12 @@ def setup_dist_env(self):
os.environ['RANK'] = self.dist_cfg['RANK']
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']

def clear_work_dir(self):
logging.shutdown()
for filename in os.listdir(self.temp_dir.name):
filepath = os.path.join(self.temp_dir.name, filename)
if os.path.isfile(filepath):
os.remove(filepath)
else:
shutil.rmtree(filepath)
50 changes: 50 additions & 0 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import os.path as osp
import re
import sys
from unittest.mock import MagicMock, patch

import torch
from parameterized import parameterized
Expand Down Expand Up @@ -312,6 +314,54 @@ def test_after_val_epoch(self):
self.assertFalse(
osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')))

# There should only one best checkpoint be reserved
# dist backend
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
self.clear_work_dir()
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=by_epoch, save_best='acc')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
all_files = os.listdir(runner.work_dir)
best_ckpts = [
file for file in all_files if file.startswith('best')
]
self.assertTrue(len(best_ckpts) == 1)

# petrel backend
# TODO use real petrel oss bucket to test
petrel_client = MagicMock()
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
isfile = MagicMock(return_value=True)
self.clear_work_dir()
with patch.dict(sys.modules, {'petrel_client': petrel_client}), \
patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \
patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \
patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
metrics = dict(acc=0.5)
petrel_client.client.Client = MagicMock(
return_value=petrel_client)
checkpoint_hook = CheckpointHook(
interval=2,
by_epoch=by_epoch,
save_best='acc',
backend_args=dict(backend='petrel'))
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
put_mock.assert_called_once()
metrics['acc'] += 0.1
runner.train_loop._epoch += 1
runner.train_loop._iter += 1
checkpoint_hook.after_val_epoch(runner, metrics)
isfile.assert_called_once()
remove_mock.assert_called_once()

def test_after_train_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
Expand Down
7 changes: 2 additions & 5 deletions tests/test_runner/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
load_from_local, load_from_pavi,
save_checkpoint)

sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()


@MODEL_WRAPPERS.register_module()
class DDPWrapper:
Expand Down Expand Up @@ -150,9 +147,8 @@ def test_get_state_dict():
wrapped_model.module.conv.module.bias)


@patch.dict(sys.modules, {'pavi': MagicMock()})
def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel()
import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
Expand Down Expand Up @@ -296,6 +292,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)


@patch.dict(sys.modules, {'petrel_client': MagicMock()})
def test_checkpoint_loader():
filenames = [
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
Expand Down