diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index 81dc4f0845..75d3890923 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -12,9 +12,9 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook, DistSamplerSeedHook, EMAHook, EvalHook, Fp16OptimizerHook, Hook, IterTimerHook, LoggerHook, LrUpdaterHook, - MlflowLoggerHook, OptimizerHook, PaviLoggerHook, - SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, - WandbLoggerHook) + MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook, + PaviLoggerHook, SyncBuffersHook, TensorboardLoggerHook, + TextLoggerHook, WandbLoggerHook) from .iter_based_runner import IterBasedRunner, IterLoader from .log_buffer import LogBuffer from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS, @@ -28,15 +28,16 @@ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', - 'WandbLoggerHook', 'MlflowLoggerHook', '_load_checkpoint', - 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', - 'Priority', 'get_priority', 'get_host_info', 'get_time_str', - 'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only', - 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor', - 'build_optimizer', 'build_optimizer_constructor', 'IterLoader', - 'set_random_seed', 'auto_fp16', 'force_fp32', 'wrap_fp16_model', - 'Fp16OptimizerHook', 'SyncBuffersHook', 'EMAHook', 'build_runner', - 'RUNNERS', 'allreduce_grads', 'allreduce_params', 'LossScaler', - 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', - 'EvalHook', 'DistEvalHook', 'Sequential', 'ModuleList' + 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook', + '_load_checkpoint', 'load_state_dict', 'load_checkpoint', 'weights_to_cpu', + 'save_checkpoint', 'Priority', 'get_priority', 'get_host_info', + 'get_time_str', 'obj_from_dict', 'init_dist', 'get_dist_info', + 'master_only', 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', + 'DefaultOptimizerConstructor', 'build_optimizer', + 'build_optimizer_constructor', 'IterLoader', 'set_random_seed', + 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook', + 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads', + 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', + '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', + 'ModuleList' ] diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py index caa4df6b8f..648b9fa42d 100644 --- a/mmcv/runner/hooks/__init__.py +++ b/mmcv/runner/hooks/__init__.py @@ -5,8 +5,9 @@ from .evaluation import DistEvalHook, EvalHook from .hook import HOOKS, Hook from .iter_timer import IterTimerHook -from .logger import (LoggerHook, MlflowLoggerHook, PaviLoggerHook, - TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) +from .logger import (LoggerHook, MlflowLoggerHook, NeptuneLoggerHook, + PaviLoggerHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) from .lr_updater import LrUpdaterHook from .memory import EmptyCacheHook from .momentum_updater import MomentumUpdaterHook @@ -20,6 +21,6 @@ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook', - 'WandbLoggerHook', 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', - 'EvalHook', 'DistEvalHook', 'ProfilerHook' + 'NeptuneLoggerHook', 'WandbLoggerHook', 'MomentumUpdaterHook', + 'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook' ] diff --git a/mmcv/runner/hooks/logger/__init__.py b/mmcv/runner/hooks/logger/__init__.py index 8fe4d81492..28c11933e0 100644 --- a/mmcv/runner/hooks/logger/__init__.py +++ b/mmcv/runner/hooks/logger/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Open-MMLab. All rights reserved. from .base import LoggerHook from .mlflow import MlflowLoggerHook +from .neptune import NeptuneLoggerHook from .pavi import PaviLoggerHook from .tensorboard import TensorboardLoggerHook from .text import TextLoggerHook @@ -8,5 +9,6 @@ __all__ = [ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', - 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook' + 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', + 'NeptuneLoggerHook' ] diff --git a/mmcv/runner/hooks/logger/neptune.py b/mmcv/runner/hooks/logger/neptune.py new file mode 100644 index 0000000000..2e695863b1 --- /dev/null +++ b/mmcv/runner/hooks/logger/neptune.py @@ -0,0 +1,82 @@ +# Copyright (c) Open-MMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class NeptuneLoggerHook(LoggerHook): + """Class to log metrics to NeptuneAI. + + It requires `neptune-client` to be installed. + + Args: + init_kwargs (dict): a dict contains the initialization keys as below: + - project (str): Name of a project in a form of + namespace/project_name. If None, the value of + NEPTUNE_PROJECT environment variable will be taken. + - api_token (str): User’s API token. + If None, the value of NEPTUNE_API_TOKEN environment + variable will be taken. Note: It is strongly recommended + to use NEPTUNE_API_TOKEN environment variable rather than + placing your API token in plain text in your source code. + - name (str, optional, default is 'Untitled'): Editable name of + the run. Name is displayed in the run's Details and in + Runs table as a column. + Check https://docs.neptune.ai/api-reference/neptune#init for + more init arguments. + interval (int): Logging interval (every k iterations). + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. + reset_flag (bool): Whether to clear the output buffer after logging + by_epoch (bool): Whether EpochBasedRunner is used. + + .. _NeptuneAI: + https://docs.neptune.ai/you-should-know/logging-metadata + """ + + def __init__(self, + init_kwargs=None, + interval=10, + ignore_last=True, + reset_flag=True, + with_step=True, + by_epoch=True): + + super(NeptuneLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_neptune() + self.init_kwargs = init_kwargs + self.with_step = with_step + + def import_neptune(self): + try: + import neptune.new as neptune + except ImportError: + raise ImportError( + 'Please run "pip install neptune-client" to install neptune') + self.neptune = neptune + self.run = None + + @master_only + def before_run(self, runner): + if self.init_kwargs: + self.run = self.neptune.init(**self.init_kwargs) + else: + self.run = self.neptune.init() + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + for tag_name, tag_value in tags.items(): + if self.with_step: + self.run[tag_name].log( + tag_value, step=self.get_iter(runner)) + else: + tags['global_step'] = self.get_iter(runner) + self.run[tag_name].log(tags) + + @master_only + def after_run(self, runner): + self.run.stop() diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 13a0514feb..bc692db136 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -19,8 +19,8 @@ from torch.utils.data import DataLoader from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, - MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook, - build_runner) + MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook, + WandbLoggerHook, build_runner) from mmcv.runner.hooks.hook import HOOKS, Hook from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, @@ -915,6 +915,22 @@ def test_wandb_hook(): hook.wandb.join.assert_called_with() +def test_neptune_hook(): + sys.modules['neptune'] = MagicMock() + sys.modules['neptune.new'] = MagicMock() + runner = _build_demo_runner() + hook = NeptuneLoggerHook() + loader = DataLoader(torch.ones((5, 2))) + + runner.register_hook(hook) + runner.run([loader, loader], [('train', 1), ('val', 1)]) + shutil.rmtree(runner.work_dir) + + hook.neptune.init.assert_called_with() + hook.run['momentum'].log.assert_called_with(0.95, step=6) + hook.run.stop.assert_called_with() + + def _build_demo_runner_without_hook(runner_type='EpochBasedRunner', max_epochs=1, max_iters=None,