From 93f0627045bd22470c208522fedf5b8b450485d9 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 18:12:49 +0900 Subject: [PATCH 1/4] Support trial name template --- siatune/tune/tuner.py | 7 +++++-- siatune/tune/utils/__init__.py | 4 ++++ siatune/tune/utils/name_tmpl.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 siatune/tune/utils/__init__.py create mode 100644 siatune/tune/utils/name_tmpl.py diff --git a/siatune/tune/tuner.py b/siatune/tune/tuner.py index 32103177..39ffd971 100644 --- a/siatune/tune/tuner.py +++ b/siatune/tune/tuner.py @@ -13,6 +13,7 @@ from siatune.codebase import build_task from siatune.tune import (build_callback, build_scheduler, build_searcher, build_space, build_stopper) +from .utils import NAME_TMPL class Tuner: @@ -106,8 +107,10 @@ def __init__( tune_config=TuneConfig( search_alg=searcher, scheduler=trial_scheduler, - trial_name_creator=lambda trial: trial.trial_id, - trial_dirname_creator=lambda trial: trial.experiment_tag, + trial_name_creator=NAME_TMPL.get( + tune_cfg.pop('trial_name_creator', 'trial_id')), + trial_dirname_creator=NAME_TMPL.get( + tune_cfg.pop('trial_dirname_creator', 'experiment_tag')), **tune_cfg), run_config=RunConfig( name=self.experiment_name, diff --git a/siatune/tune/utils/__init__.py b/siatune/tune/utils/__init__.py new file mode 100644 index 00000000..a244cbfb --- /dev/null +++ b/siatune/tune/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) SI-Analytics. All rights reserved. +from .name_tmpl import NAME_TMPL, trial_experiment_tag, trial_id + +__all__ = ['NAME_TMPL', 'trial_experiment_tag', 'trial_id'] diff --git a/siatune/tune/utils/name_tmpl.py b/siatune/tune/utils/name_tmpl.py new file mode 100644 index 00000000..551e18e2 --- /dev/null +++ b/siatune/tune/utils/name_tmpl.py @@ -0,0 +1,15 @@ +# Copyright (c) SI-Analytics. All rights reserved. +from mmengine.registry import Registry +from ray.tune.experiment import Trial + +NAME_TMPL = Registry('name template') + + +@NAME_TMPL.register_module() +def trial_id(trial: Trial) -> str: + return trial.trial_id + + +@NAME_TMPL.register_module() +def experiment_tag(trial: Trial) -> str: + return trial.experiment_tag From 41727ac08109b9528d5f095ada312204e4a00e5e Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 18:23:15 +0900 Subject: [PATCH 2/4] Add test code --- tests/test_tune/test_utils.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/test_tune/test_utils.py diff --git a/tests/test_tune/test_utils.py b/tests/test_tune/test_utils.py new file mode 100644 index 00000000..b2a416a5 --- /dev/null +++ b/tests/test_tune/test_utils.py @@ -0,0 +1,28 @@ +import inspect + +import pytest +from ray.tune.experiment import Trial + +from siatune.tune.utils import NAME_TMPL + + +@pytest.fixture +def trial(): + return Trial( + trainable_name='test', + trial_id='trial_id', + experiment_tag='experiment_tag') + + +def test_trial_id(trial): + tmpl = NAME_TMPL.get('trial_id') + assert inspect.isfunction(tmpl) + assert tmpl.__name__ == 'trial_id' + assert tmpl(trial) == trial.trial_id + + +def test_experiment_tag(trial): + tmpl = NAME_TMPL.get('trial_id') + assert inspect.isfunction(tmpl) + assert tmpl.__name__ == 'experiment_tag' + assert tmpl(trial) == trial.experiment_tag From a79ad713b66aacadde13bcecd687cbcfdb005bcc Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 18:26:25 +0900 Subject: [PATCH 3/4] Fix --- siatune/tune/utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/siatune/tune/utils/__init__.py b/siatune/tune/utils/__init__.py index a244cbfb..4519edd2 100644 --- a/siatune/tune/utils/__init__.py +++ b/siatune/tune/utils/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) SI-Analytics. All rights reserved. -from .name_tmpl import NAME_TMPL, trial_experiment_tag, trial_id +from .name_tmpl import NAME_TMPL, experiment_tag, trial_id -__all__ = ['NAME_TMPL', 'trial_experiment_tag', 'trial_id'] +__all__ = ['NAME_TMPL', 'experiment_tag', 'trial_id'] From 3e5f4973083d763f5ed035bf2144e5296b250008 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 18:39:02 +0900 Subject: [PATCH 4/4] Rename --- siatune/tune/tuner.py | 6 +++--- siatune/tune/utils/__init__.py | 4 ++-- siatune/tune/utils/{name_tmpl.py => name_creator.py} | 6 +++--- tests/test_tune/test_utils.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) rename siatune/tune/utils/{name_tmpl.py => name_creator.py} (71%) diff --git a/siatune/tune/tuner.py b/siatune/tune/tuner.py index 39ffd971..4578390f 100644 --- a/siatune/tune/tuner.py +++ b/siatune/tune/tuner.py @@ -13,7 +13,7 @@ from siatune.codebase import build_task from siatune.tune import (build_callback, build_scheduler, build_searcher, build_space, build_stopper) -from .utils import NAME_TMPL +from .utils import NAME_CREATOR class Tuner: @@ -107,9 +107,9 @@ def __init__( tune_config=TuneConfig( search_alg=searcher, scheduler=trial_scheduler, - trial_name_creator=NAME_TMPL.get( + trial_name_creator=NAME_CREATOR.get( tune_cfg.pop('trial_name_creator', 'trial_id')), - trial_dirname_creator=NAME_TMPL.get( + trial_dirname_creator=NAME_CREATOR.get( tune_cfg.pop('trial_dirname_creator', 'experiment_tag')), **tune_cfg), run_config=RunConfig( diff --git a/siatune/tune/utils/__init__.py b/siatune/tune/utils/__init__.py index 4519edd2..fbef4535 100644 --- a/siatune/tune/utils/__init__.py +++ b/siatune/tune/utils/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) SI-Analytics. All rights reserved. -from .name_tmpl import NAME_TMPL, experiment_tag, trial_id +from .name_creator import NAME_CREATOR, experiment_tag, trial_id -__all__ = ['NAME_TMPL', 'experiment_tag', 'trial_id'] +__all__ = ['NAME_CREATOR', 'experiment_tag', 'trial_id'] diff --git a/siatune/tune/utils/name_tmpl.py b/siatune/tune/utils/name_creator.py similarity index 71% rename from siatune/tune/utils/name_tmpl.py rename to siatune/tune/utils/name_creator.py index 551e18e2..dd916dce 100644 --- a/siatune/tune/utils/name_tmpl.py +++ b/siatune/tune/utils/name_creator.py @@ -2,14 +2,14 @@ from mmengine.registry import Registry from ray.tune.experiment import Trial -NAME_TMPL = Registry('name template') +NAME_CREATOR = Registry('name creator') -@NAME_TMPL.register_module() +@NAME_CREATOR.register_module() def trial_id(trial: Trial) -> str: return trial.trial_id -@NAME_TMPL.register_module() +@NAME_CREATOR.register_module() def experiment_tag(trial: Trial) -> str: return trial.experiment_tag diff --git a/tests/test_tune/test_utils.py b/tests/test_tune/test_utils.py index b2a416a5..43ff7289 100644 --- a/tests/test_tune/test_utils.py +++ b/tests/test_tune/test_utils.py @@ -1,28 +1,28 @@ import inspect +from unittest.mock import MagicMock import pytest -from ray.tune.experiment import Trial -from siatune.tune.utils import NAME_TMPL +from siatune.tune.utils import NAME_CREATOR @pytest.fixture def trial(): - return Trial( + return MagicMock( trainable_name='test', trial_id='trial_id', experiment_tag='experiment_tag') def test_trial_id(trial): - tmpl = NAME_TMPL.get('trial_id') + tmpl = NAME_CREATOR.get('trial_id') assert inspect.isfunction(tmpl) assert tmpl.__name__ == 'trial_id' assert tmpl(trial) == trial.trial_id def test_experiment_tag(trial): - tmpl = NAME_TMPL.get('trial_id') + tmpl = NAME_CREATOR.get('experiment_tag') assert inspect.isfunction(tmpl) assert tmpl.__name__ == 'experiment_tag' assert tmpl(trial) == trial.experiment_tag