-
Notifications
You must be signed in to change notification settings - Fork 341
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] enhance runner test case (#631)
* Add runner test cast * Fix unit test * fix unit test * pop None if key does not exist * Fix is_model_wrapper and force register class in test_runner * [Fix] Fix is_model_wrapper * destroy group after ut * register module in testcase * fix as comment * minor refine Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
- Loading branch information
Showing
5 changed files
with
301 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import copy | ||
import logging | ||
import os | ||
import tempfile | ||
import time | ||
from unittest import TestCase | ||
from uuid import uuid4 | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed import destroy_process_group | ||
from torch.utils.data import Dataset | ||
|
||
import mmengine.hooks # noqa F401 | ||
import mmengine.optim # noqa F401 | ||
from mmengine.config import Config | ||
from mmengine.dist import is_distributed | ||
from mmengine.evaluator import BaseMetric | ||
from mmengine.logging import MessageHub, MMLogger | ||
from mmengine.model import BaseModel | ||
from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope | ||
from mmengine.runner import Runner | ||
from mmengine.visualization import Visualizer | ||
|
||
|
||
class ToyModel(BaseModel): | ||
|
||
def __init__(self, data_preprocessor=None): | ||
super().__init__(data_preprocessor=data_preprocessor) | ||
self.linear1 = nn.Linear(2, 2) | ||
self.linear2 = nn.Linear(2, 1) | ||
|
||
def forward(self, inputs, data_samples, mode='tensor'): | ||
if isinstance(inputs, list): | ||
inputs = torch.stack(inputs) | ||
if isinstance(data_samples, list): | ||
data_sample = torch.stack(data_samples) | ||
outputs = self.linear1(inputs) | ||
outputs = self.linear2(outputs) | ||
|
||
if mode == 'tensor': | ||
return outputs | ||
elif mode == 'loss': | ||
loss = (data_sample - outputs).sum() | ||
outputs = dict(loss=loss) | ||
return outputs | ||
elif mode == 'predict': | ||
return outputs | ||
|
||
|
||
class ToyDataset(Dataset): | ||
METAINFO = dict() # type: ignore | ||
data = torch.randn(12, 2) | ||
label = torch.ones(12) | ||
|
||
@property | ||
def metainfo(self): | ||
return self.METAINFO | ||
|
||
def __len__(self): | ||
return self.data.size(0) | ||
|
||
def __getitem__(self, index): | ||
return dict(inputs=self.data[index], data_samples=self.label[index]) | ||
|
||
|
||
class ToyMetric(BaseMetric): | ||
|
||
def __init__(self, collect_device='cpu', dummy_metrics=None): | ||
super().__init__(collect_device=collect_device) | ||
self.dummy_metrics = dummy_metrics | ||
|
||
def process(self, data_batch, predictions): | ||
result = {'acc': 1} | ||
self.results.append(result) | ||
|
||
def compute_metrics(self, results): | ||
return dict(acc=1) | ||
|
||
|
||
class RunnerTestCase(TestCase): | ||
"""A test case to build runner easily. | ||
`RunnerTestCase` will do the following things: | ||
1. Registers a toy model, a toy metric, and a toy dataset, which can be | ||
used to run the `Runner` successfully. | ||
2. Provides epoch based and iteration based cfg to build runner. | ||
3. Provides `build_runner` method to build runner easily. | ||
4. Clean the global variable used by the runner. | ||
""" | ||
dist_cfg = dict( | ||
MASTER_ADDR='127.0.0.1', | ||
MASTER_PORT=29600, | ||
RANK='0', | ||
WORLD_SIZE='1', | ||
LOCAL_RANK='0') | ||
|
||
def setUp(self) -> None: | ||
self.temp_dir = tempfile.TemporaryDirectory() | ||
# Prevent from registering module with the same name by other unit | ||
# test. These registries will be cleared in `tearDown` | ||
MODELS.register_module(module=ToyModel, force=True) | ||
METRICS.register_module(module=ToyMetric, force=True) | ||
DATASETS.register_module(module=ToyDataset, force=True) | ||
epoch_based_cfg = dict( | ||
work_dir=self.temp_dir.name, | ||
model=dict(type='ToyModel'), | ||
train_dataloader=dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_size=3, | ||
num_workers=0), | ||
val_dataloader=dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
batch_size=3, | ||
num_workers=0), | ||
val_evaluator=[dict(type='ToyMetric')], | ||
test_dataloader=dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
batch_size=3, | ||
num_workers=0), | ||
test_evaluator=[dict(type='ToyMetric')], | ||
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), | ||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), | ||
val_cfg=dict(), | ||
test_cfg=dict(), | ||
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), | ||
custom_hooks=[], | ||
env_cfg=dict(dist_cfg=dict(backend='nccl')), | ||
experiment_name='test1') | ||
self.epoch_based_cfg = Config(epoch_based_cfg) | ||
|
||
# prepare iter based cfg. | ||
self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg) | ||
self.iter_based_cfg.train_dataloader = dict( | ||
dataset=dict(type='ToyDataset'), | ||
sampler=dict(type='InfiniteSampler', shuffle=True), | ||
batch_size=3, | ||
num_workers=0) | ||
self.iter_based_cfg.log_processor = dict(by_epoch=False) | ||
|
||
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) | ||
self.iter_based_cfg.default_hooks = dict( | ||
logger=dict(type='LoggerHook', interval=1), | ||
checkpoint=dict( | ||
type='CheckpointHook', interval=12, by_epoch=False)) | ||
|
||
def tearDown(self): | ||
# `FileHandler` should be closed in Windows, otherwise we cannot | ||
# delete the temporary directory | ||
logging.shutdown() | ||
MMLogger._instance_dict.clear() | ||
Visualizer._instance_dict.clear() | ||
DefaultScope._instance_dict.clear() | ||
MessageHub._instance_dict.clear() | ||
MODELS.module_dict.pop('ToyModel', None) | ||
METRICS.module_dict.pop('ToyMetric', None) | ||
DATASETS.module_dict.pop('ToyDataset', None) | ||
self.temp_dir.cleanup() | ||
if is_distributed(): | ||
destroy_process_group() | ||
|
||
def build_runner(self, cfg: Config): | ||
cfg.experiment_name = self.experiment_name | ||
runner = Runner.from_cfg(cfg) | ||
return runner | ||
|
||
@property | ||
def experiment_name(self): | ||
# Since runners could be built too fast to have a unique experiment | ||
# name(timestamp is the same), here we use uuid to make sure each | ||
# runner has the unique experiment name. | ||
return f'{self._testMethodName}_{time.time()} + ' \ | ||
f'{uuid4()}' | ||
|
||
def setup_dist_env(self): | ||
self.dist_cfg['MASTER_PORT'] += 1 | ||
os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT']) | ||
os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR'] | ||
os.environ['RANK'] = self.dist_cfg['RANK'] | ||
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] | ||
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.