From 75e2fc5fbcff68d8e9ff65d75de078b021e7498a Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 26 Sep 2023 10:50:01 -0700 Subject: [PATCH 1/4] Add `use_underscores` argument --- tests/helptext_utils.py | 10 +++-- tests/test_dcargs.py | 31 +++++++++++++ tests/test_helptext.py | 98 +++++++++++++++++++++++++++++++++++++++++ tyro/_arguments.py | 8 ++-- tyro/_cli.py | 70 ++++++++++++++++++----------- tyro/_strings.py | 60 ++++++++++++++++++------- 6 files changed, 227 insertions(+), 50 deletions(-) diff --git a/tests/helptext_utils.py b/tests/helptext_utils.py index fa8e8e84..58553e1f 100644 --- a/tests/helptext_utils.py +++ b/tests/helptext_utils.py @@ -11,13 +11,15 @@ import tyro._strings -def get_helptext(f: Callable, args: List[str] = ["--help"]) -> str: +def get_helptext( + f: Callable, args: List[str] = ["--help"], use_underscores: bool = False +) -> str: target = io.StringIO() with pytest.raises(SystemExit), contextlib.redirect_stdout(target): - tyro.cli(f, args=args) + tyro.cli(f, args=args, use_underscores=use_underscores) # Check tyro.extras.get_parser(). - parser = tyro.extras.get_parser(f) + parser = tyro.extras.get_parser(f, use_underscores=use_underscores) assert isinstance(parser, argparse.ArgumentParser) # Returned parser should have formatting information stripped. External tools rarely @@ -44,7 +46,7 @@ def get_helptext(f: Callable, args: List[str] = ["--help"]) -> str: target2 = io.StringIO() with pytest.raises(SystemExit), contextlib.redirect_stdout(target2): tyro._arguments.USE_RICH = False - tyro.cli(f, args=args) + tyro.cli(f, args=args, use_underscores=use_underscores) tyro._arguments.USE_RICH = True if target2.getvalue() != tyro._strings.strip_ansi_sequences(target.getvalue()): diff --git a/tests/test_dcargs.py b/tests/test_dcargs.py index 3d21d512..61e5f389 100644 --- a/tests/test_dcargs.py +++ b/tests/test_dcargs.py @@ -688,3 +688,34 @@ class A: assert tyro.cli(A, args="--x".split(" ")).x == () assert tyro.cli(A, args="--y".split(" ")).y == [] + + +def test_unknown_args_with_consistent_duplicates_use_underscores() -> None: + @dataclasses.dataclass + class A: + a_b: List[int] = dataclasses.field(default_factory=list) + c_d: List[int] = dataclasses.field(default_factory=list) + + # Tests logic for consistent duplicate arguments when performing argument fixing. + # i.e., we can fix arguments if the separator is consistent (all _'s or all -'s). + a, unknown_args = tyro.cli( + A, + args=[ + "--a-b", + "5", + "--a-b", + "7", + "--c_d", + "5", + "--c_d", + "7", + "--e-f", + "--e-f", + "--g_h", + "--g_h", + ], + return_unknown_args=True, + use_underscores=True, + ) + assert a == A(a_b=[7], c_d=[7]) + assert unknown_args == ["--e-f", "--e-f", "--g_h", "--g_h"] diff --git a/tests/test_helptext.py b/tests/test_helptext.py index 1e8055c8..3ba07210 100644 --- a/tests/test_helptext.py +++ b/tests/test_helptext.py @@ -587,3 +587,101 @@ def main(child: Child) -> None: helptext = get_helptext(main) assert "--child.x | --child.no-x" in helptext + + +def test_multiple_subparsers_helptext_hyphens() -> None: + @dataclasses.dataclass + class SubcommandOne: + """2% milk.""" # % symbol is prone to bugs in argparse. + + arg_x: int = 0 + arg_flag: bool = False + + @dataclasses.dataclass + class SubcommandTwo: + arg_y: int = 1 + + @dataclasses.dataclass + class SubcommandThree: + arg_z: int = 2 + + @dataclasses.dataclass + class MultipleSubparsers: + # Field a description. + a: Union[SubcommandOne, SubcommandTwo, SubcommandThree] + # Field b description. + b: Union[SubcommandOne, SubcommandTwo, SubcommandThree] + # Field c description. + c: Union[SubcommandOne, SubcommandTwo, SubcommandThree] = dataclasses.field( + default_factory=SubcommandThree + ) + + helptext = get_helptext(MultipleSubparsers) + + assert "2% milk." in helptext + assert "Field a description." in helptext + assert "Field b description." not in helptext + assert "Field c description." not in helptext + + helptext = get_helptext( + MultipleSubparsers, args=["a:subcommand-one", "b:subcommand-one", "--help"] + ) + + assert "2% milk." in helptext + assert "Field a description." not in helptext + assert "Field b description." not in helptext + assert "Field c description." in helptext + assert "(default: c:subcommand-three)" in helptext + assert "--b.arg-x" in helptext + assert "--b.no-arg-flag" in helptext + assert "--b.arg-flag" in helptext + + +def test_multiple_subparsers_helptext_underscores() -> None: + @dataclasses.dataclass + class SubcommandOne: + """2% milk.""" # % symbol is prone to bugs in argparse. + + arg_x: int = 0 + arg_flag: bool = False + + @dataclasses.dataclass + class SubcommandTwo: + arg_y: int = 1 + + @dataclasses.dataclass + class SubcommandThree: + arg_z: int = 2 + + @dataclasses.dataclass + class MultipleSubparsers: + # Field a description. + a: Union[SubcommandOne, SubcommandTwo, SubcommandThree] + # Field b description. + b: Union[SubcommandOne, SubcommandTwo, SubcommandThree] + # Field c description. + c: Union[SubcommandOne, SubcommandTwo, SubcommandThree] = dataclasses.field( + default_factory=SubcommandThree + ) + + helptext = get_helptext(MultipleSubparsers, use_underscores=True) + + assert "2% milk." in helptext + assert "Field a description." in helptext + assert "Field b description." not in helptext + assert "Field c description." not in helptext + + helptext = get_helptext( + MultipleSubparsers, + args=["a:subcommand_one", "b:subcommand_one", "--help"], + use_underscores=True, + ) + + assert "2% milk." in helptext + assert "Field a description." not in helptext + assert "Field b description." not in helptext + assert "Field c description." in helptext + assert "(default: c:subcommand_three)" in helptext + assert "--b.arg_x" in helptext + assert "--b.no_arg_flag" in helptext + assert "--b.arg_flag" in helptext diff --git a/tyro/_arguments.py b/tyro/_arguments.py index 55f3c384..8854edb5 100644 --- a/tyro/_arguments.py +++ b/tyro/_arguments.py @@ -67,11 +67,13 @@ def __init__( if option_string.startswith("--"): if "." not in option_string: - option_string = "--no-" + option_string[2:] + option_string = ( + "--no" + _strings.get_delimeter() + option_string[2:] + ) else: - # Loose heuristic for where to add the no- prefix. + # Loose heuristic for where to add the no-/no_ prefix. left, _, right = option_string.rpartition(".") - option_string = left + ".no-" + right + option_string = left + ".no" + _strings.get_delimeter() + right self._no_strings.add(option_string) _option_strings.append(option_string) diff --git a/tyro/_cli.py b/tyro/_cli.py index 921c87da..7fe1fd24 100644 --- a/tyro/_cli.py +++ b/tyro/_cli.py @@ -50,6 +50,7 @@ def cli( args: Optional[Sequence[str]] = None, default: Optional[OutT] = None, return_unknown_args: Literal[False] = False, + use_underscores: bool = False, ) -> OutT: ... @@ -63,6 +64,7 @@ def cli( args: Optional[Sequence[str]] = None, default: Optional[OutT] = None, return_unknown_args: Literal[True], + use_underscores: bool = False, ) -> Tuple[OutT, List[str]]: ... @@ -79,6 +81,7 @@ def cli( # of the callable itself. default: None = None, return_unknown_args: Literal[False] = False, + use_underscores: bool = False, ) -> OutT: ... @@ -95,6 +98,7 @@ def cli( # of the callable itself. default: None = None, return_unknown_args: Literal[True], + use_underscores: bool = False, ) -> Tuple[OutT, List[str]]: ... @@ -107,6 +111,7 @@ def cli( args: Optional[Sequence[str]] = None, default: Optional[OutT] = None, return_unknown_args: bool = False, + use_underscores: bool = False, **deprecated_kwargs, ) -> Union[OutT, Tuple[OutT, List[str]]]: """Call or instantiate `f`, with inputs populated from an automatically generated @@ -163,6 +168,10 @@ def cli( return_unknown_args: If True, return a tuple of the output of `f` and a list of unknown arguments. Mirrors the unknown arguments returned from `argparse.ArgumentParser.parse_known_args()`. + use_underscores: If True, use underscores as a word delimeter instead of hyphens. + This primarily impacts helptext; underscores and hyphens are treated equivalently + when parsing happens. We default helptext to hyphens to follow the GNU style guide. + https://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html Returns: The output of `f(...)` or an instance `f`. If `f` is a class, the two are @@ -174,16 +183,18 @@ def cli( # memory address conflicts. _unsafe_cache.clear_cache() - output = _cli_impl( - f, - prog=prog, - description=description, - args=args, - default=default, - return_parser=False, - return_unknown_args=return_unknown_args, - **deprecated_kwargs, - ) + with _strings.delimeter_context("_" if use_underscores else "-"): + output = _cli_impl( + f, + prog=prog, + description=description, + args=args, + default=default, + return_parser=False, + return_unknown_args=return_unknown_args, + use_underscores=use_underscores, + **deprecated_kwargs, + ) # Prevent unnecessary memory usage. _unsafe_cache.clear_cache() @@ -201,6 +212,7 @@ def get_parser( prog: Optional[str] = None, description: Optional[str] = None, default: Optional[OutT] = None, + use_underscores: bool = False, ) -> argparse.ArgumentParser: ... @@ -212,6 +224,7 @@ def get_parser( prog: Optional[str] = None, description: Optional[str] = None, default: Optional[OutT] = None, + use_underscores: bool = False, ) -> argparse.ArgumentParser: ... @@ -224,24 +237,27 @@ def get_parser( prog: Optional[str] = None, description: Optional[str] = None, default: Optional[OutT] = None, + use_underscores: bool = False, ) -> argparse.ArgumentParser: """Get the `argparse.ArgumentParser` object generated under-the-hood by `tyro.cli()`. Useful for tools like `sphinx-argparse`, `argcomplete`, etc. For tab completion, we recommend using `tyro.cli()`'s built-in `--tyro-write-completion` flag.""" - return cast( - argparse.ArgumentParser, - _cli_impl( - f, - prog=prog, - description=description, - args=None, - default=default, - return_parser=True, - return_unknown_args=False, - ), - ) + with _strings.delimeter_context("_" if use_underscores else "-"): + return cast( + argparse.ArgumentParser, + _cli_impl( + f, + prog=prog, + description=description, + args=None, + default=default, + return_parser=True, + return_unknown_args=False, + use_underscores=use_underscores, + ), + ) def _cli_impl( @@ -302,19 +318,21 @@ def _cli_impl( args = list(sys.argv[1:]) if args is None else list(args) # Fix arguments. This will modify all option-style arguments replacing - # underscores with dashes. This is to support the common convention of using - # underscores in variable names, but dashes in command line arguments. + # underscores with hyphens, or vice versa if use_underscores=True. # If two options are ambiguous, e.g., --a_b and --a-b, raise a runtime error. modified_args: Dict[str, str] = {} for index, arg in enumerate(args): if not arg.startswith("--"): continue + delimeter = _strings.get_delimeter() + to_swap_delimeter = "-" if delimeter == "_" else "_" + if "=" in arg: arg, _, val = arg.partition("=") - fixed = arg.replace("_", "-") + "=" + val + fixed = "--" + arg[2:].replace(to_swap_delimeter, delimeter) + "=" + val else: - fixed = arg.replace("_", "-") + fixed = "--" + arg[2:].replace(to_swap_delimeter, delimeter) if ( return_unknown_args and fixed in modified_args diff --git a/tyro/_strings.py b/tyro/_strings.py index e6b9527a..c53fec92 100644 --- a/tyro/_strings.py +++ b/tyro/_strings.py @@ -1,21 +1,39 @@ """Utilities and constants for working with strings.""" +import contextlib import functools import re import textwrap from typing import Iterable, List, Sequence, Tuple, Type -from typing_extensions import get_args, get_origin +from typing_extensions import Literal, get_args, get_origin from . import _resolver dummy_field_name = "__tyro_dummy_field__" +DELIMETER: Literal["-", "_"] = "-" def _strip_dummy_field_names(parts: Iterable[str]) -> Iterable[str]: return filter(lambda name: len(name) > 0 and name != dummy_field_name, parts) +@contextlib.contextmanager +def delimeter_context(delimeter: Literal["-", "_"]): + """Context for setting the delimeter. Determines if `field_a` is populated as + `--field-a` or `--field_a`. Not thread-safe.""" + global DELIMETER + delimeter_restore = DELIMETER + DELIMETER = delimeter + yield + DELIMETER = delimeter_restore + + +def get_delimeter() -> Literal["-", "_"]: + """Get delimeter used to separate words.""" + return DELIMETER + + def make_field_name(parts: Sequence[str]) -> str: """Join parts of a field name together. Used for nesting. @@ -28,13 +46,16 @@ def make_field_name(parts: Sequence[str]) -> str: out.append(".") # Replace all underscores with hyphens, except ones at the start of a string. - num_underscore_prefix = 0 - for i in range(len(p)): - if p[i] == "_": - num_underscore_prefix += 1 - else: - break - p = "_" * num_underscore_prefix + p[num_underscore_prefix:].replace("_", "-") + if get_delimeter() == "-": + num_underscore_prefix = 0 + for i in range(len(p)): + if p[i] == "_": + num_underscore_prefix += 1 + else: + break + p = "_" * num_underscore_prefix + ( + p[num_underscore_prefix:].replace("_", "-") + ) out.append(p) return "".join(out) @@ -52,13 +73,13 @@ def dedent(text: str) -> str: return f"{first_line.strip()}\n{textwrap.dedent(rest)}" -_camel_separator_pattern = functools.lru_cache(maxsize=1)( - lambda: re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") -) - - def hyphen_separated_from_camel_case(name: str) -> str: - return _camel_separator_pattern().sub(r"-\1", name).lower() + out = ( + re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") + .sub(get_delimeter() + r"\1", name) + .lower() + ) + return out def _subparser_name_from_type(cls: Type) -> Tuple[str, bool]: @@ -85,7 +106,7 @@ def get_name(cls: Type) -> str: if orig is not None and hasattr(orig, "__name__"): parts = [orig.__name__] # type: ignore parts.extend(map(get_name, get_args(cls))) - return "-".join(parts) + return get_delimeter().join(parts) elif hasattr(cls, "__name__"): return hyphen_separated_from_camel_case(cls.__name__) else: @@ -97,7 +118,7 @@ def get_name(cls: Type) -> str: return get_name(cls), prefix_name # type: ignore return ( - "-".join( + get_delimeter().join( map( lambda x: _subparser_name_from_type(x)[0], [cls] + list(type_from_typevar.values()), @@ -113,7 +134,12 @@ def subparser_name_from_type(prefix: str, cls: Type) -> str: ) if len(prefix) == 0 or not use_prefix: return suffix - return f"{prefix}:{suffix}".replace("_", "-") + + if get_delimeter() == "-": + return f"{prefix}:{suffix}".replace("_", "-") + else: + assert get_delimeter() == "_" + return f"{prefix}:{suffix}" @functools.lru_cache(maxsize=None) From e1c9ee4d865ec9ccf2587beb80812d0ecd8dcee3 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 26 Sep 2023 11:42:38 -0700 Subject: [PATCH 2/4] Dictionary tests --- tests/test_dict_namedtuple.py | 84 +++++++++++++++++++++++++++++++++++ tyro/_arguments.py | 1 + tyro/_strings.py | 2 + 3 files changed, 87 insertions(+) diff --git a/tests/test_dict_namedtuple.py b/tests/test_dict_namedtuple.py index 6aa61bd1..bcf33e73 100644 --- a/tests/test_dict_namedtuple.py +++ b/tests/test_dict_namedtuple.py @@ -311,6 +311,36 @@ class HelptextNamedTuple(NamedTuple): def test_nested_dict() -> None: + loaded_config = { + "batch_size": 32, + "optimizer": { + "learning_rate": 1e-4, + "epsilon": 1e-8, + "scheduler": {"schedule_type": "constant"}, + }, + } + backup_config = copy.deepcopy(loaded_config) + overrided_config = tyro.cli( + dict, + default=loaded_config, + args=[ + "--batch-size", + "16", + "--optimizer.scheduler.schedule_type", + "exponential", + ], + ) + + # Overridden config should be different from loaded config. + assert overrided_config != loaded_config + assert overrided_config["batch_size"] == 16 + assert overrided_config["optimizer"]["scheduler"]["schedule_type"] == "exponential" + + # Original loaded config should not be mutated. + assert loaded_config == backup_config + + +def test_nested_dict_use_underscores() -> None: loaded_config = { "batch_size": 32, "optimizer": { @@ -329,6 +359,7 @@ def test_nested_dict() -> None: "--optimizer.scheduler.schedule-type", "exponential", ], + use_underscores=True, ) # Overridden config should be different from loaded config. @@ -372,6 +403,59 @@ def test_nested_dict_hyphen() -> None: assert loaded_config == backup_config +def test_nested_dict_hyphen_use_underscores() -> None: + # We do a lot of underscore <=> conversion in the code; this is just to make sure it + # doesn't break anything! + loaded_config = { + "batch-size": 32, + "optimizer": { + "learning-rate": 1e-4, + "epsilon": 1e-8, + "scheduler": {"schedule-type": "constant"}, + }, + } + backup_config = copy.deepcopy(loaded_config) + overrided_config = tyro.cli( + dict, + default=loaded_config, + args=[ + "--batch-size", + "16", + "--optimizer.scheduler.schedule-type", + "exponential", + ], + use_underscores=True, + ) + + # Overridden config should be different from loaded config. + assert overrided_config != loaded_config + assert overrided_config["batch-size"] == 16 + assert overrided_config["optimizer"]["scheduler"]["schedule-type"] == "exponential" + + # Original loaded config should not be mutated. + assert loaded_config == backup_config + + overrided_config = tyro.cli( + dict, + default=loaded_config, + args=[ + "--batch_size", + "16", + "--optimizer.scheduler.schedule_type", + "exponential", + ], + use_underscores=True, + ) + + # Overridden config should be different from loaded config. + assert overrided_config != loaded_config + assert overrided_config["batch-size"] == 16 + assert overrided_config["optimizer"]["scheduler"]["schedule-type"] == "exponential" + + # Original loaded config should not be mutated. + assert loaded_config == backup_config + + def test_nested_dict_annotations() -> None: loaded_config = { "optimizer": { diff --git a/tyro/_arguments.py b/tyro/_arguments.py index 8854edb5..67f7f767 100644 --- a/tyro/_arguments.py +++ b/tyro/_arguments.py @@ -463,6 +463,7 @@ def _rule_set_name_or_flag_and_dest( and _markers.OmitArgPrefixes not in arg.field.markers else [arg.field.name] ) + print(name_or_flag) # Prefix keyword arguments with --. if not arg.field.is_positional(): diff --git a/tyro/_strings.py b/tyro/_strings.py index c53fec92..1caa03a0 100644 --- a/tyro/_strings.py +++ b/tyro/_strings.py @@ -56,6 +56,8 @@ def make_field_name(parts: Sequence[str]) -> str: p = "_" * num_underscore_prefix + ( p[num_underscore_prefix:].replace("_", "-") ) + else: + p = p.replace("-", "_") out.append(p) return "".join(out) From 7f24f768813736544c94c97353dacf385607050c Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 26 Sep 2023 11:47:00 -0700 Subject: [PATCH 3/4] Remove debug print --- tyro/_arguments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tyro/_arguments.py b/tyro/_arguments.py index 67f7f767..8854edb5 100644 --- a/tyro/_arguments.py +++ b/tyro/_arguments.py @@ -463,7 +463,6 @@ def _rule_set_name_or_flag_and_dest( and _markers.OmitArgPrefixes not in arg.field.markers else [arg.field.name] ) - print(name_or_flag) # Prefix keyword arguments with --. if not arg.field.is_positional(): From 24d75ab143d3de214cba03ee6b7f8df34012971b Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 26 Sep 2023 11:52:24 -0700 Subject: [PATCH 4/4] Format --- tests/test_helptext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_helptext.py b/tests/test_helptext.py index 3ba07210..47d682ed 100644 --- a/tests/test_helptext.py +++ b/tests/test_helptext.py @@ -670,7 +670,7 @@ class MultipleSubparsers: assert "Field a description." in helptext assert "Field b description." not in helptext assert "Field c description." not in helptext - + helptext = get_helptext( MultipleSubparsers, args=["a:subcommand_one", "b:subcommand_one", "--help"],