diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index ef6bf013..22802113 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -145,7 +145,9 @@ def from_callable_or_type( ) # Helptext for this field; used as description for grouping arguments. - class_field_name = _strings.make_field_name([field.intern_name]) + class_field_name = _strings.make_field_name( + [intern_prefix, field.intern_name] + ) if field.helptext is not None: helptext_from_intern_prefixed_field_name[class_field_name] = ( field.helptext @@ -235,10 +237,20 @@ def apply_args( """Create defined arguments and subparsers.""" # Make argument groups. - def format_group_name(prefix: str) -> str: - return (prefix + " options").strip() + def format_group_name(group_name: str) -> str: + return (group_name + " options").strip() + + def group_name_from_arg(arg: _arguments.ArgumentDefinition) -> str: + prefix = arg.lowered.name_or_flag + if prefix.startswith("--"): + prefix = prefix[2:] + if "." in prefix: + prefix = prefix.rpartition(".")[0] + else: + prefix = "" + return prefix - group_from_prefix: Dict[str, argparse._ArgumentGroup] = { + group_from_group_name: Dict[str, argparse._ArgumentGroup] = { "": parser._action_groups[1], **{ cast(str, group.title).partition(" ")[0]: group @@ -251,9 +263,10 @@ def format_group_name(prefix: str) -> str: # Add each argument group. Note that groups with only suppressed arguments won't # be added. for arg in self.args: + group_name = group_name_from_arg(arg) if ( arg.lowered.help is not argparse.SUPPRESS - and arg.extern_prefix not in group_from_prefix + and group_name not in group_from_group_name ): description = ( parent.helptext_from_intern_prefixed_field_name.get( @@ -262,24 +275,23 @@ def format_group_name(prefix: str) -> str: if parent is not None else None ) - group_from_prefix[arg.extern_prefix] = parser.add_argument_group( - format_group_name(arg.extern_prefix), + group_from_group_name[group_name] = parser.add_argument_group( + format_group_name(group_name), description=description, ) - # Add each argument. - for arg in self.args: + # Add each argument. if arg.field.is_positional(): arg.add_argument(positional_group) continue - if arg.extern_prefix in group_from_prefix: - arg.add_argument(group_from_prefix[arg.extern_prefix]) + if group_name in group_from_group_name: + arg.add_argument(group_from_group_name[group_name]) else: # Suppressed argument: still need to add them, but they won't show up in # the helptext so it doesn't matter which group. assert arg.lowered.help is argparse.SUPPRESS - arg.add_argument(group_from_prefix[""]) + arg.add_argument(group_from_group_name[""]) for child in self.child_from_prefix.values(): child.apply_args(parser, parent=self) diff --git a/tests/test_conf.py b/tests/test_conf.py index 1750e8db..54dd9aa3 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -1497,8 +1497,8 @@ class DatasetConfig: with pytest.raises(SystemExit), contextlib.redirect_stdout(target): instantiate_dataclasses((OptimizerConfig, DatasetConfig), args=["--help"]) helptext = target.getvalue() - assert "OptimizerConfig options" in helptext - assert "DatasetConfig options" in helptext + assert "OptimizerConfig options" not in helptext + assert "DatasetConfig options" not in helptext def test_counter_action() -> None: diff --git a/tests/test_nested.py b/tests/test_nested.py index f9a57c09..98ec5cb6 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -18,11 +18,17 @@ class B: class Nested: x: int b: B + """Helptext for b""" assert tyro.cli(Nested, args=["--x", "1", "--b.y", "3"]) == Nested(x=1, b=B(y=3)) with pytest.raises(SystemExit): tyro.cli(Nested, args=["--x", "1"]) + def main(x: Nested): + return x + + assert "Helptext for b" in get_helptext_with_checks(main) + def test_nested_annotated() -> None: @dataclasses.dataclass diff --git a/tests/test_py311_generated/ok.py b/tests/test_py311_generated/ok.py new file mode 100644 index 00000000..73e5b33b --- /dev/null +++ b/tests/test_py311_generated/ok.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Literal + +import tyro + + +@dataclass(frozen=True) +class Container[T]: + a: T + + +tyro.cli(Container[Container[bool] | Container[Literal["1", "2"]]]) diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index 04301ed7..13997ecb 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1492,8 +1492,8 @@ class DatasetConfig: with pytest.raises(SystemExit), contextlib.redirect_stdout(target): instantiate_dataclasses((OptimizerConfig, DatasetConfig), args=["--help"]) helptext = target.getvalue() - assert "OptimizerConfig options" in helptext - assert "DatasetConfig options" in helptext + assert "OptimizerConfig options" not in helptext + assert "DatasetConfig options" not in helptext def test_counter_action() -> None: diff --git a/tests/test_py311_generated/test_nested_generated.py b/tests/test_py311_generated/test_nested_generated.py index 4ae8a4e7..5a611b42 100644 --- a/tests/test_py311_generated/test_nested_generated.py +++ b/tests/test_py311_generated/test_nested_generated.py @@ -28,11 +28,17 @@ class B: class Nested: x: int b: B + """Helptext for b""" assert tyro.cli(Nested, args=["--x", "1", "--b.y", "3"]) == Nested(x=1, b=B(y=3)) with pytest.raises(SystemExit): tyro.cli(Nested, args=["--x", "1"]) + def main(x: Nested): + return x + + assert "Helptext for b" in get_helptext_with_checks(main) + def test_nested_annotated() -> None: @dataclasses.dataclass diff --git a/tests/test_py311_generated/test_pydantic_helptext_advanced_generated.py b/tests/test_py311_generated/test_pydantic_helptext_advanced_generated.py new file mode 100644 index 00000000..a7da8acc --- /dev/null +++ b/tests/test_py311_generated/test_pydantic_helptext_advanced_generated.py @@ -0,0 +1,98 @@ +"""Adapted from: https://github.com/brentyi/tyro/issues/183""" + +from typing import Annotated, NamedTuple, Set + +from helptext_utils import get_helptext_with_checks +from pydantic import BaseModel, Field + +import tyro + + +class MyRange(NamedTuple): + low: int + high: int + + def __str__(self): + return f"<{self.low}, {self.high}>" + + @staticmethod + def tyro_constructor( + range_str: Annotated[ + str, + tyro.conf.arg(name=""), + ], + ): + import re + + m = re.match("([0-9]+)(-([0-9]+))*", range_str) + low = m[1] # type: ignore + high = low if not m[3] else m[3] # type: ignore + + return MyRange(int(low), int(high)) + + @staticmethod + def tyro_constructor_set( + range_str_set: Annotated[ + Set[str], + tyro.conf.arg(name=""), + ], + ): + return {MyRange.tyro_constructor(r) for r in range_str_set} + + +class MySpec(BaseModel): + some_set: Set[int] = Field( + default={1, 2, 3}, + description="Some set of integers", + title="Some set", + ) + + some_string: str = Field( + description="Some string without a default value.", title="SomeSTR" + ) + + here_comes_the_trouble: Annotated[ + Set[MyRange], + tyro.conf.arg(constructor=MyRange.tyro_constructor_set), + ] = Field( + default={MyRange(0, 1024)}, + description="I would like this one in the same group as others", + title="Please help", + ) + + +def add_spec(spec: MySpec) -> MySpec: + return spec + + +def test_functionality() -> None: + assert tyro.cli( + add_spec, args=["--spec.some-set", "1", "2", "3", "--spec.some-string", "hello"] + ) == MySpec( + some_set={1, 2, 3}, + some_string="hello", + here_comes_the_trouble={MyRange(0, 1024)}, + ) + assert tyro.cli( + add_spec, + args=[ + "--spec.some-set", + "1", + "2", + "3", + "--spec.some-string", + "hello", + "--spec.here-comes-the-trouble", + "0-512", + ], + ) == MySpec( + some_set={1, 2, 3}, + some_string="hello", + here_comes_the_trouble={MyRange(0, 512)}, + ) + + +def test_helptext() -> None: + helptext = get_helptext_with_checks(add_spec) + assert "spec options" in helptext + assert "spec.here-comes-the-trouble-options" not in helptext diff --git a/tests/test_pydantic_helptext_advanced.py b/tests/test_pydantic_helptext_advanced.py new file mode 100644 index 00000000..245a2150 --- /dev/null +++ b/tests/test_pydantic_helptext_advanced.py @@ -0,0 +1,99 @@ +"""Adapted from: https://github.com/brentyi/tyro/issues/183""" + +from typing import NamedTuple, Set + +from helptext_utils import get_helptext_with_checks +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +import tyro + + +class MyRange(NamedTuple): + low: int + high: int + + def __str__(self): + return f"<{self.low}, {self.high}>" + + @staticmethod + def tyro_constructor( + range_str: Annotated[ + str, + tyro.conf.arg(name=""), + ], + ): + import re + + m = re.match("([0-9]+)(-([0-9]+))*", range_str) + low = m[1] # type: ignore + high = low if not m[3] else m[3] # type: ignore + + return MyRange(int(low), int(high)) + + @staticmethod + def tyro_constructor_set( + range_str_set: Annotated[ + Set[str], + tyro.conf.arg(name=""), + ], + ): + return {MyRange.tyro_constructor(r) for r in range_str_set} + + +class MySpec(BaseModel): + some_set: Set[int] = Field( + default={1, 2, 3}, + description="Some set of integers", + title="Some set", + ) + + some_string: str = Field( + description="Some string without a default value.", title="SomeSTR" + ) + + here_comes_the_trouble: Annotated[ + Set[MyRange], + tyro.conf.arg(constructor=MyRange.tyro_constructor_set), + ] = Field( + default={MyRange(0, 1024)}, + description="I would like this one in the same group as others", + title="Please help", + ) + + +def add_spec(spec: MySpec) -> MySpec: + return spec + + +def test_functionality() -> None: + assert tyro.cli( + add_spec, args=["--spec.some-set", "1", "2", "3", "--spec.some-string", "hello"] + ) == MySpec( + some_set={1, 2, 3}, + some_string="hello", + here_comes_the_trouble={MyRange(0, 1024)}, + ) + assert tyro.cli( + add_spec, + args=[ + "--spec.some-set", + "1", + "2", + "3", + "--spec.some-string", + "hello", + "--spec.here-comes-the-trouble", + "0-512", + ], + ) == MySpec( + some_set={1, 2, 3}, + some_string="hello", + here_comes_the_trouble={MyRange(0, 512)}, + ) + + +def test_helptext() -> None: + helptext = get_helptext_with_checks(add_spec) + assert "spec options" in helptext + assert "spec.here-comes-the-trouble-options" not in helptext