diff --git a/pontos/enum.py b/pontos/enum.py new file mode 100644 index 00000000..b3144e2f --- /dev/null +++ b/pontos/enum.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2024 Greenbone AG +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from argparse import ArgumentTypeError +from enum import Enum +from typing import Callable, Type, TypeVar, Union + + +class StrEnum(str, Enum): + # Should be replaced by enum.StrEnum when we require Python >= 3.11 + """ + An Enum that provides str like behavior + """ + + def __str__(self) -> str: + return self.value + + +def enum_choice(enum: Type[Enum]) -> list[str]: + """ + Return a sequence of choices for argparse from an enum + """ + return [str(e) for e in enum] + + +def to_choices(enum: Type[Enum]) -> str: + """ + Convert an enum to a comma separated string of choices. For example useful + in help messages for argparse. + """ + return ", ".join([str(t) for t in enum]) + + +T = TypeVar("T", bound=Enum) + + +def enum_type(enum: Type[T]) -> Callable[[Union[str, T]], T]: + """ + Create a argparse type function for converting the string input into an Enum + """ + + def convert(value: Union[str, T]) -> T: + if isinstance(value, str): + try: + return enum(value) + except ValueError: + raise ArgumentTypeError( + f"invalid value {value}. Expected one of {to_choices(enum)}." + ) from None + return value + + return convert diff --git a/pontos/models/__init__.py b/pontos/models/__init__.py index d956ed26..e3a839eb 100644 --- a/pontos/models/__init__.py +++ b/pontos/models/__init__.py @@ -5,12 +5,12 @@ from dataclasses import dataclass from datetime import date, datetime, timezone -from enum import Enum from inspect import isclass from typing import Any, Dict, Type, Union, get_args, get_origin, get_type_hints from dateutil import parser as dateparser +from pontos.enum import StrEnum from pontos.errors import PontosError __all__ = ( @@ -27,16 +27,6 @@ class ModelError(PontosError): """ -class StrEnum(str, Enum): - # Should be replaced by enum.StrEnum when we require Python >= 3.11 - """ - An Enum that provides str like behavior - """ - - def __str__(self) -> str: - return self.value - - def dotted_attributes(obj: Any, data: Dict[str, Any]) -> Any: """ Set dotted attributes on an object diff --git a/tests/test_enum.py b/tests/test_enum.py new file mode 100644 index 00000000..0d7c82a2 --- /dev/null +++ b/tests/test_enum.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2024 Greenbone AG +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import unittest +from argparse import ArgumentTypeError + +from pontos.enum import StrEnum, enum_type + + +class EnumTypeTestCase(unittest.TestCase): + def test_enum_type(self): + class FooEnum(StrEnum): + ALL = "all" + NONE = "none" + + func = enum_type(FooEnum) + + self.assertEqual(func("all"), FooEnum.ALL) + self.assertEqual(func("none"), FooEnum.NONE) + + self.assertEqual(func(FooEnum.ALL), FooEnum.ALL) + self.assertEqual(func(FooEnum.NONE), FooEnum.NONE) + + def test_enum_type_error(self): + class FooEnum(StrEnum): + ALL = "all" + NONE = "none" + + func = enum_type(FooEnum) + + with self.assertRaisesRegex( + ArgumentTypeError, + r"invalid value foo. Expected one of all, none", + ): + func("foo")