From 3ab3949856f010ca2230c84827d6338ba8d409f0 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Tue, 17 Jan 2023 20:42:40 +0900 Subject: [PATCH 1/9] Add support for dict input in Choice --- siatune/tune/spaces/choice.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index 329a8178..ce594cf7 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -14,9 +14,9 @@ class Choice(BaseSpace): """Sample a categorical value. Args: - categories (Sequence): The categories. - alias (Sequence, optional): A alias to be expressed. - Defaults to None. + categories (Sequence | dict): The categories. If categories is dict, + keys of dict will override the alias. + alias (Sequence, optional): A alias to be expressed. Defaults to None. """ sample: Callable = tune.choice @@ -24,6 +24,10 @@ class Choice(BaseSpace): def __init__(self, categories: Sequence, alias: Optional[Sequence] = None) -> None: + if isinstance(categories, dict): + categories = categories.values() + alias = categories.keys() + if alias is not None: assert isinstance(alias, Sequence) assert len(categories) == len(alias) From cbcc39e13522e0f185dfc22ea570d22f78d2e5f7 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 01:14:09 +0900 Subject: [PATCH 2/9] Fix typo --- siatune/tune/spaces/choice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index ce594cf7..1b35b919 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -15,7 +15,7 @@ class Choice(BaseSpace): Args: categories (Sequence | dict): The categories. If categories is dict, - keys of dict will override the alias. + keys of dict will overwrite the alias. alias (Sequence, optional): A alias to be expressed. Defaults to None. """ From c4cece720005c592aa2ca573be8cc3c6bd9e394d Mon Sep 17 00:00:00 2001 From: Junhwa Song Date: Wed, 18 Jan 2023 10:07:13 +0900 Subject: [PATCH 3/9] Update siatune/tune/spaces/choice.py Co-authored-by: YH <100389977+yhna940@users.noreply.github.com> Signed-off-by: Junhwa Song --- siatune/tune/spaces/choice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index 1b35b919..b06f3c0a 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -22,7 +22,7 @@ class Choice(BaseSpace): sample: Callable = tune.choice def __init__(self, - categories: Sequence, + categories: Union[Sequence, Dict] , alias: Optional[Sequence] = None) -> None: if isinstance(categories, dict): categories = categories.values() From c73d7c9e4ea506f335e58a15d653953c169b543c Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 10:54:40 +0900 Subject: [PATCH 4/9] Apply lint --- siatune/tune/spaces/choice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index b06f3c0a..5261a762 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -1,5 +1,5 @@ # Copyright (c) SI-Analytics. All rights reserved. -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Union import ray.tune as tune from ray.tune.search.sample import Domain @@ -22,7 +22,7 @@ class Choice(BaseSpace): sample: Callable = tune.choice def __init__(self, - categories: Union[Sequence, Dict] , + categories: Union[Sequence, dict], alias: Optional[Sequence] = None) -> None: if isinstance(categories, dict): categories = categories.values() From b37017fac609ff6434bd6046ddf5c7780b728827 Mon Sep 17 00:00:00 2001 From: Junhwa Song Date: Wed, 18 Jan 2023 10:58:55 +0900 Subject: [PATCH 5/9] Update siatune/tune/spaces/choice.py Co-authored-by: Hakjin Lee Signed-off-by: Junhwa Song --- siatune/tune/spaces/choice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index 5261a762..824ab6ea 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -16,7 +16,7 @@ class Choice(BaseSpace): Args: categories (Sequence | dict): The categories. If categories is dict, keys of dict will overwrite the alias. - alias (Sequence, optional): A alias to be expressed. Defaults to None. + alias (Sequence, optional): An alias to be expressed. Defaults to None. """ sample: Callable = tune.choice From e3b11db5f46d17551a6e28af948c31e21588e8fe Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 11:03:58 +0900 Subject: [PATCH 6/9] Update docs --- siatune/tune/spaces/choice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index 824ab6ea..bc421c18 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -14,9 +14,9 @@ class Choice(BaseSpace): """Sample a categorical value. Args: - categories (Sequence | dict): The categories. If categories is dict, - keys of dict will overwrite the alias. - alias (Sequence, optional): An alias to be expressed. Defaults to None. + categories (Sequence | dict): The categorical search space to choose + one. If categories is dict, keys of dict will overwrite the alias. + alias (Sequence, optional): A alias to be expressed. Defaults to None. """ sample: Callable = tune.choice From 457456e695e66180dea524f719882fca6af90528 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 11:05:18 +0900 Subject: [PATCH 7/9] Add test code --- tests/test_tune/test_spaces.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_tune/test_spaces.py b/tests/test_tune/test_spaces.py index 901f0bf4..50acd806 100644 --- a/tests/test_tune/test_spaces.py +++ b/tests/test_tune/test_spaces.py @@ -33,6 +33,7 @@ def is_immutable(config): with pytest.raises(AssertionError): choice = Choice(categories=[True, False], alias=['TF']) choice = Choice(categories=[True, False], alias=['T', 'F']) + choice = Choice(categories=dict(T=True, F=False)) tune.run(is_immutable, config=dict(test=choice.space)) From 83e69652c309a9a3d5bf8a3110635df7918310f1 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 11:12:21 +0900 Subject: [PATCH 8/9] Fix --- siatune/tune/spaces/choice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index bc421c18..ee3750d7 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -25,8 +25,8 @@ def __init__(self, categories: Union[Sequence, dict], alias: Optional[Sequence] = None) -> None: if isinstance(categories, dict): - categories = categories.values() alias = categories.keys() + categories = categories.values() if alias is not None: assert isinstance(alias, Sequence) From b232aad194f0bad74982551babeb4577367a9594 Mon Sep 17 00:00:00 2001 From: KKIEEK Date: Wed, 18 Jan 2023 11:22:24 +0900 Subject: [PATCH 9/9] Fix --- siatune/tune/spaces/choice.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/siatune/tune/spaces/choice.py b/siatune/tune/spaces/choice.py index ee3750d7..15e69174 100644 --- a/siatune/tune/spaces/choice.py +++ b/siatune/tune/spaces/choice.py @@ -25,8 +25,7 @@ def __init__(self, categories: Union[Sequence, dict], alias: Optional[Sequence] = None) -> None: if isinstance(categories, dict): - alias = categories.keys() - categories = categories.values() + alias, categories = zip(*[(k, v) for k, v in categories.items()]) if alias is not None: assert isinstance(alias, Sequence)