Skip to content

Commit

Permalink
Change: Move StrEnum to pontos.enum and add functions for argparse
Browse files Browse the repository at this point in the history
Create a dedicated enum module and add functions for using enum with
argparse.
  • Loading branch information
bjoernricks committed Feb 5, 2024
1 parent 2ed58cb commit 9498433
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 11 deletions.
53 changes: 53 additions & 0 deletions pontos/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: 2024 Greenbone AG
#
# SPDX-License-Identifier: GPL-3.0-or-later

from argparse import ArgumentTypeError
from enum import Enum
from typing import Callable, Type, TypeVar, Union


class StrEnum(str, Enum):
# Should be replaced by enum.StrEnum when we require Python >= 3.11
"""
An Enum that provides str like behavior
"""

def __str__(self) -> str:
return self.value


def enum_choice(enum: Type[Enum]) -> list[str]:
"""
Return a sequence of choices for argparse from an enum
"""
return [str(e) for e in enum]


def to_choices(enum: Type[Enum]) -> str:
"""
Convert an enum to a comma separated string of choices. For example useful
in help messages for argparse.
"""
return ", ".join([str(t) for t in enum])


T = TypeVar("T", bound=Enum)


def enum_type(enum: Type[T]) -> Callable[[Union[str, T]], T]:
"""
Create a argparse type function for converting the string input into an Enum
"""

def convert(value: Union[str, T]) -> T:
if isinstance(value, str):
try:
return enum(value)
except ValueError:
raise ArgumentTypeError(
f"invalid value {value}. Expected one of {to_choices(enum)}."
) from None
return value

return convert
12 changes: 1 addition & 11 deletions pontos/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from dataclasses import dataclass
from datetime import date, datetime, timezone
from enum import Enum
from inspect import isclass
from typing import Any, Dict, Type, Union, get_args, get_origin, get_type_hints

from dateutil import parser as dateparser

from pontos.enum import StrEnum
from pontos.errors import PontosError

__all__ = (
Expand All @@ -27,16 +27,6 @@ class ModelError(PontosError):
"""


class StrEnum(str, Enum):
# Should be replaced by enum.StrEnum when we require Python >= 3.11
"""
An Enum that provides str like behavior
"""

def __str__(self) -> str:
return self.value


def dotted_attributes(obj: Any, data: Dict[str, Any]) -> Any:
"""
Set dotted attributes on an object
Expand Down
37 changes: 37 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: 2024 Greenbone AG
#
# SPDX-License-Identifier: GPL-3.0-or-later


import unittest
from argparse import ArgumentTypeError

from pontos.enum import StrEnum, enum_type


class EnumTypeTestCase(unittest.TestCase):
def test_enum_type(self):
class FooEnum(StrEnum):
ALL = "all"
NONE = "none"

func = enum_type(FooEnum)

self.assertEqual(func("all"), FooEnum.ALL)
self.assertEqual(func("none"), FooEnum.NONE)

self.assertEqual(func(FooEnum.ALL), FooEnum.ALL)
self.assertEqual(func(FooEnum.NONE), FooEnum.NONE)

def test_enum_type_error(self):
class FooEnum(StrEnum):
ALL = "all"
NONE = "none"

func = enum_type(FooEnum)

with self.assertRaisesRegex(
ArgumentTypeError,
r"invalid value foo. Expected one of all, none",
):
func("foo")

0 comments on commit 9498433

Please sign in to comment.