Skip to content

Commit

Permalink
Minor helptext generation improvements (#184)
Browse files Browse the repository at this point in the history
* Adjust helptext behavior

* Tests
  • Loading branch information
brentyi authored Oct 23, 2024
1 parent 1a9448a commit b358dfa
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 16 deletions.
36 changes: 24 additions & 12 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def from_callable_or_type(
)

# Helptext for this field; used as description for grouping arguments.
class_field_name = _strings.make_field_name([field.intern_name])
class_field_name = _strings.make_field_name(
[intern_prefix, field.intern_name]
)
if field.helptext is not None:
helptext_from_intern_prefixed_field_name[class_field_name] = (
field.helptext
Expand Down Expand Up @@ -235,10 +237,20 @@ def apply_args(
"""Create defined arguments and subparsers."""

# Make argument groups.
def format_group_name(prefix: str) -> str:
return (prefix + " options").strip()
def format_group_name(group_name: str) -> str:
return (group_name + " options").strip()

def group_name_from_arg(arg: _arguments.ArgumentDefinition) -> str:
prefix = arg.lowered.name_or_flag
if prefix.startswith("--"):
prefix = prefix[2:]
if "." in prefix:
prefix = prefix.rpartition(".")[0]
else:
prefix = ""
return prefix

group_from_prefix: Dict[str, argparse._ArgumentGroup] = {
group_from_group_name: Dict[str, argparse._ArgumentGroup] = {
"": parser._action_groups[1],
**{
cast(str, group.title).partition(" ")[0]: group
Expand All @@ -251,9 +263,10 @@ def format_group_name(prefix: str) -> str:
# Add each argument group. Note that groups with only suppressed arguments won't
# be added.
for arg in self.args:
group_name = group_name_from_arg(arg)
if (
arg.lowered.help is not argparse.SUPPRESS
and arg.extern_prefix not in group_from_prefix
and group_name not in group_from_group_name
):
description = (
parent.helptext_from_intern_prefixed_field_name.get(
Expand All @@ -262,24 +275,23 @@ def format_group_name(prefix: str) -> str:
if parent is not None
else None
)
group_from_prefix[arg.extern_prefix] = parser.add_argument_group(
format_group_name(arg.extern_prefix),
group_from_group_name[group_name] = parser.add_argument_group(
format_group_name(group_name),
description=description,
)

# Add each argument.
for arg in self.args:
# Add each argument.
if arg.field.is_positional():
arg.add_argument(positional_group)
continue

if arg.extern_prefix in group_from_prefix:
arg.add_argument(group_from_prefix[arg.extern_prefix])
if group_name in group_from_group_name:
arg.add_argument(group_from_group_name[group_name])
else:
# Suppressed argument: still need to add them, but they won't show up in
# the helptext so it doesn't matter which group.
assert arg.lowered.help is argparse.SUPPRESS
arg.add_argument(group_from_prefix[""])
arg.add_argument(group_from_group_name[""])

for child in self.child_from_prefix.values():
child.apply_args(parser, parent=self)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,8 +1497,8 @@ class DatasetConfig:
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
instantiate_dataclasses((OptimizerConfig, DatasetConfig), args=["--help"])
helptext = target.getvalue()
assert "OptimizerConfig options" in helptext
assert "DatasetConfig options" in helptext
assert "OptimizerConfig options" not in helptext
assert "DatasetConfig options" not in helptext


def test_counter_action() -> None:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ class B:
class Nested:
x: int
b: B
"""Helptext for b"""

assert tyro.cli(Nested, args=["--x", "1", "--b.y", "3"]) == Nested(x=1, b=B(y=3))
with pytest.raises(SystemExit):
tyro.cli(Nested, args=["--x", "1"])

def main(x: Nested):
return x

assert "Helptext for b" in get_helptext_with_checks(main)


def test_nested_annotated() -> None:
@dataclasses.dataclass
Expand Down
12 changes: 12 additions & 0 deletions tests/test_py311_generated/ok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass
from typing import Literal

import tyro


@dataclass(frozen=True)
class Container[T]:
a: T


tyro.cli(Container[Container[bool] | Container[Literal["1", "2"]]])
4 changes: 2 additions & 2 deletions tests/test_py311_generated/test_conf_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,8 +1492,8 @@ class DatasetConfig:
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
instantiate_dataclasses((OptimizerConfig, DatasetConfig), args=["--help"])
helptext = target.getvalue()
assert "OptimizerConfig options" in helptext
assert "DatasetConfig options" in helptext
assert "OptimizerConfig options" not in helptext
assert "DatasetConfig options" not in helptext


def test_counter_action() -> None:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_py311_generated/test_nested_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@ class B:
class Nested:
x: int
b: B
"""Helptext for b"""

assert tyro.cli(Nested, args=["--x", "1", "--b.y", "3"]) == Nested(x=1, b=B(y=3))
with pytest.raises(SystemExit):
tyro.cli(Nested, args=["--x", "1"])

def main(x: Nested):
return x

assert "Helptext for b" in get_helptext_with_checks(main)


def test_nested_annotated() -> None:
@dataclasses.dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Adapted from: https://github.com/brentyi/tyro/issues/183"""

from typing import Annotated, NamedTuple, Set

from helptext_utils import get_helptext_with_checks
from pydantic import BaseModel, Field

import tyro


class MyRange(NamedTuple):
low: int
high: int

def __str__(self):
return f"<{self.low}, {self.high}>"

@staticmethod
def tyro_constructor(
range_str: Annotated[
str,
tyro.conf.arg(name=""),
],
):
import re

m = re.match("([0-9]+)(-([0-9]+))*", range_str)
low = m[1] # type: ignore
high = low if not m[3] else m[3] # type: ignore

return MyRange(int(low), int(high))

@staticmethod
def tyro_constructor_set(
range_str_set: Annotated[
Set[str],
tyro.conf.arg(name=""),
],
):
return {MyRange.tyro_constructor(r) for r in range_str_set}


class MySpec(BaseModel):
some_set: Set[int] = Field(
default={1, 2, 3},
description="Some set of integers",
title="Some set",
)

some_string: str = Field(
description="Some string without a default value.", title="SomeSTR"
)

here_comes_the_trouble: Annotated[
Set[MyRange],
tyro.conf.arg(constructor=MyRange.tyro_constructor_set),
] = Field(
default={MyRange(0, 1024)},
description="I would like this one in the same group as others",
title="Please help",
)


def add_spec(spec: MySpec) -> MySpec:
return spec


def test_functionality() -> None:
assert tyro.cli(
add_spec, args=["--spec.some-set", "1", "2", "3", "--spec.some-string", "hello"]
) == MySpec(
some_set={1, 2, 3},
some_string="hello",
here_comes_the_trouble={MyRange(0, 1024)},
)
assert tyro.cli(
add_spec,
args=[
"--spec.some-set",
"1",
"2",
"3",
"--spec.some-string",
"hello",
"--spec.here-comes-the-trouble",
"0-512",
],
) == MySpec(
some_set={1, 2, 3},
some_string="hello",
here_comes_the_trouble={MyRange(0, 512)},
)


def test_helptext() -> None:
helptext = get_helptext_with_checks(add_spec)
assert "spec options" in helptext
assert "spec.here-comes-the-trouble-options" not in helptext
99 changes: 99 additions & 0 deletions tests/test_pydantic_helptext_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Adapted from: https://github.com/brentyi/tyro/issues/183"""

from typing import NamedTuple, Set

from helptext_utils import get_helptext_with_checks
from pydantic import BaseModel, Field
from typing_extensions import Annotated

import tyro


class MyRange(NamedTuple):
low: int
high: int

def __str__(self):
return f"<{self.low}, {self.high}>"

@staticmethod
def tyro_constructor(
range_str: Annotated[
str,
tyro.conf.arg(name=""),
],
):
import re

m = re.match("([0-9]+)(-([0-9]+))*", range_str)
low = m[1] # type: ignore
high = low if not m[3] else m[3] # type: ignore

return MyRange(int(low), int(high))

@staticmethod
def tyro_constructor_set(
range_str_set: Annotated[
Set[str],
tyro.conf.arg(name=""),
],
):
return {MyRange.tyro_constructor(r) for r in range_str_set}


class MySpec(BaseModel):
some_set: Set[int] = Field(
default={1, 2, 3},
description="Some set of integers",
title="Some set",
)

some_string: str = Field(
description="Some string without a default value.", title="SomeSTR"
)

here_comes_the_trouble: Annotated[
Set[MyRange],
tyro.conf.arg(constructor=MyRange.tyro_constructor_set),
] = Field(
default={MyRange(0, 1024)},
description="I would like this one in the same group as others",
title="Please help",
)


def add_spec(spec: MySpec) -> MySpec:
return spec


def test_functionality() -> None:
assert tyro.cli(
add_spec, args=["--spec.some-set", "1", "2", "3", "--spec.some-string", "hello"]
) == MySpec(
some_set={1, 2, 3},
some_string="hello",
here_comes_the_trouble={MyRange(0, 1024)},
)
assert tyro.cli(
add_spec,
args=[
"--spec.some-set",
"1",
"2",
"3",
"--spec.some-string",
"hello",
"--spec.here-comes-the-trouble",
"0-512",
],
) == MySpec(
some_set={1, 2, 3},
some_string="hello",
here_comes_the_trouble={MyRange(0, 512)},
)


def test_helptext() -> None:
helptext = get_helptext_with_checks(add_spec)
assert "spec options" in helptext
assert "spec.here-comes-the-trouble-options" not in helptext

0 comments on commit b358dfa

Please sign in to comment.