Skip to content

Commit

Permalink
Merge pull request #57 from pavalos6401/add_argument_group
Browse files Browse the repository at this point in the history
Add argument group
  • Loading branch information
mivade committed Aug 5, 2023
2 parents 37baa76 + 3a1d35c commit c41a7c6
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 1 deletion.
33 changes: 33 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ Configuring a field with the Optional generic type:
>>> print(parser.parse_args(["--name", "John", "--id", "1234"]))
Options(name='John', id=1234)
Creating argument groups by group title:

.. code-block:: pycon
>>> from dataclasses import dataclass, field
>>> from argparse_dataclass import ArgumentParser
>>> @dataclass
... class Options:
... foo: str = field(metadata=dict(group="string group"))
... bar: str = field(metadata=dict(group=dict(title="dict group", description="using a dict")))
... baz: str = field(metadata=dict(group=("sequence group", "using a sequence")))
...
>>> parser = ArgumentParser(Options)
>>> parser.print_help()
usage: [-h] --foo FOO --bar BAR --baz BAZ
options:
-h, --help show this help message and exit
string group:
--foo FOO
dict group:
using a dict
--bar BAR
sequence group:
using a sequence
--baz BAZ
Contributors
------------

Expand Down
30 changes: 29 additions & 1 deletion argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,11 @@ def _add_dataclass_options(
"For Union types other than 'Optional', a custom 'type' must be specified using "
"'metadata'."
)
parser.add_argument(*args, **kwargs)

if "group" in field.metadata:
_handle_argument_group(parser, field, args, kwargs)
else:
parser.add_argument(*args, **kwargs)


def _get_kwargs(namespace: argparse.Namespace) -> Dict[str, Any]:
Expand Down Expand Up @@ -450,6 +454,30 @@ def _handle_bool_type(field: Field, args: list, kwargs: dict):
kwargs["required"] = True


def _handle_argument_group(
parser: argparse.ArgumentParser, field: Field, args: list, kwargs: dict
) -> None:
"""Handles adding the argument to an argument group."""
groups = {x.title: x for x in parser._action_groups}
group = field.metadata.get("group")
if isinstance(group, str):
title = group
description = None
elif isinstance(group, dict):
title = group.get("title")
description = group.get("description")
elif isinstance(group, Sequence):
len_ = len(group)
title = group[0] if len_ > 0 else None
description = group[1] if len_ > 1 else None
else:
raise TypeError("'group' must be a group title, dictionary, or sequence")
group = groups.get(title)
if title is None or group is None:
group = parser.add_argument_group(title, description)
group.add_argument(*args, **kwargs)


class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
"""Command line argument parser that derives its options from a dataclass.
Expand Down
181 changes: 181 additions & 0 deletions tests/test_argumentgroups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import argparse
from dataclasses import dataclass, field

import unittest

from argparse_dataclass import ArgumentParser


class ArgumentParserGroupsTests(unittest.TestCase):
def test_basic_str(self):
parser = argparse.ArgumentParser()
group = parser.add_argument_group("title")
group.add_argument("--x", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": "title"})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_basic_dict(self):
title = "title"

parser = argparse.ArgumentParser()
group = parser.add_argument_group(title)
group.add_argument("--x", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": {"title": title}})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_basic_dict_description(self):
title = "title"
description = "description"

parser = argparse.ArgumentParser()
group = parser.add_argument_group(title, description)
group.add_argument("--x", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(
metadata={"group": {"title": title, "description": description}}
)

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_basic_sequence(self):
parser = argparse.ArgumentParser()
group = parser.add_argument_group("group")
group.add_argument("--x", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": ("group")})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_basic_sequence_description(self):
parser = argparse.ArgumentParser()
group = parser.add_argument_group("group", "description")
group.add_argument("--x", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": ("group", "description")})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_basic_empty(self):
parser = argparse.ArgumentParser()
group = parser.add_argument_group()
group.add_argument("--x", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": ()})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_multiple_arguments(self):
title = "title"

parser = argparse.ArgumentParser()
group = parser.add_argument_group(title)
group.add_argument("--x", required=True, type=int)
group.add_argument("--y", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": title})
y: int = field(metadata={"group": title})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_multiple_groups_empty(self):
parser = argparse.ArgumentParser()
group_a = parser.add_argument_group()
group_a.add_argument("--x", required=True, type=int)
group_b = parser.add_argument_group()
group_b.add_argument("--y", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
x: int = field(metadata={"group": ()})
y: int = field(metadata={"group": ()})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

def test_argument_groups(self):
title_a = "Group A"
title_b = "Group B"
descr_b = "Description B"
title_c = "Group C"
descr_c = "Description C"

parser = argparse.ArgumentParser()
group_a = parser.add_argument_group(title_a)
group_a.add_argument("--arga1", required=True, type=int)
group_a.add_argument("--arga2", required=True, type=int)
group_b = parser.add_argument_group(title_b, descr_b)
group_b.add_argument("--argb", required=True, type=int)
group_c = parser.add_argument_group(title_c, descr_c)
group_c.add_argument("--argc", required=True, type=int)
group_d = parser.add_argument_group()
group_d.add_argument("--argd", required=True, type=int)
group_e = parser.add_argument_group()
group_e.add_argument("--arge", required=True, type=int)
group_a.add_argument("--arga3", required=True, type=int)
expected = parser.format_help()

@dataclass
class Opt:
arga1: int = field(metadata={"group": title_a})
arga2: int = field(metadata={"group": title_a})
argb: int = field(
metadata={"group": {"title": title_b, "description": descr_b}}
)
argc: int = field(metadata={"group": (title_c, descr_c)})
argd: int = field(metadata={"group": ()})
arge: int = field(metadata={"group": ()})
arga3: int = field(metadata={"group": title_a})

parser = ArgumentParser(Opt)
out = parser.format_help()

self.assertEqual(expected, out)

0 comments on commit c41a7c6

Please sign in to comment.