diff --git a/pyproject.toml b/pyproject.toml index 701c3633..3508560e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "tyro" authors = [ {name = "brentyi", email = "brentyi@berkeley.edu"}, ] -version = "0.5.6" +version = "0.5.7" description = "Strongly typed, zero-effort CLI interfaces" readme = "README.md" license = { text="MIT" } diff --git a/tests/test_errors.py b/tests/test_errors.py index bf80bd25..3ff99852 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -170,8 +170,8 @@ class Class: assert "Similar arguments" in error # --reward.track should appear in both the usage string and as a similar argument. - assert error.count("--reward.track") == 2 - assert error.count("--help") == 0 + assert error.count("--reward.track") == 1 + assert error.count("--help") == 1 def test_similar_arguments_subcommands() -> None: @@ -195,7 +195,7 @@ class ClassB: assert "Unrecognized argument" in error assert "Similar arguments:" in error assert error.count("--reward.track") == 1 - assert error.count("--help") == 2 + assert error.count("--help") == 3 def test_similar_arguments_subcommands_multiple() -> None: @@ -221,7 +221,7 @@ class ClassB: assert "Arguments similar to --reward.trac" in error assert error.count("--reward.track {True,False}") == 1 assert error.count("--reward.trace INT") == 1 - assert error.count("--help") == 4 + assert error.count("--help") == 5 def test_similar_arguments_subcommands_multiple_contains_match() -> None: @@ -247,7 +247,7 @@ class ClassB: assert "Similar arguments" in error assert error.count("--reward.track {True,False}") == 1 assert error.count("--reward.trace INT") == 1 - assert error.count("--help") == 4 # 2 subcommands * 2 arguments. + assert error.count("--help") == 5 # 2 subcommands * 2 arguments + usage hint. def test_similar_arguments_subcommands_multiple_contains_match_alt() -> None: @@ -272,7 +272,9 @@ class ClassB: assert "Unrecognized argument" in error assert "Similar arguments" in error assert error.count("--reward.track {True,False}") == 1 - assert error.count("--help") == 2 # Should show two possible subcommands. + assert ( + error.count("--help") == 3 + ) # Should show two possible subcommands + usage hint. def test_similar_arguments_subcommands_overflow_different() -> None: @@ -313,7 +315,7 @@ class ClassB: assert "Similar arguments" in error assert error.count("--reward.track") == 10 assert "[...]" not in error - assert error.count("--help") == 20 + assert error.count("--help") == 21 target = io.StringIO() with pytest.raises(SystemExit), contextlib.redirect_stdout(target): @@ -323,7 +325,7 @@ class ClassB: # Usage print should be clipped. error = target.getvalue() - assert "help:" in error + assert "For full helptext, run" in error def test_similar_arguments_subcommands_overflow_same() -> None: @@ -381,7 +383,7 @@ class ClassI: assert "Similar arguments" in error assert error.count("--reward.track") == 1 assert "[...]" in error - assert error.count("--help") == 4 + assert error.count("--help") == 5 def test_similar_arguments_subcommands_overflow_same_startswith_multiple() -> None: @@ -441,7 +443,7 @@ class ClassI: assert error.count("--rewar") == 1 assert "rewarde" not in error assert "[...]" in error - assert error.count("--help") == 4 + assert error.count("--help") == 5 def test_similar_flag() -> None: @@ -458,8 +460,8 @@ class Args: error = target.getvalue() - # Printed in the usage message. - assert error.count("--flag | --no-flag") == 1 + # We don't print usage text anymore. + assert error.count("--flag | --no-flag") == 0 # Printed in the similar argument list. assert error.count("--flag, --no-flag") == 1 diff --git a/tyro/_argparse_formatter.py b/tyro/_argparse_formatter.py index 808902eb..3e0a6a27 100644 --- a/tyro/_argparse_formatter.py +++ b/tyro/_argparse_formatter.py @@ -11,6 +11,9 @@ TODO: the current implementation should be robust given our test coverage, but unideal long-term. We should just maintain our own fork of argparse. """ + +from __future__ import annotations + import argparse import contextlib import dataclasses @@ -20,7 +23,7 @@ import shutil import sys from gettext import gettext as _ -from typing import Any, Generator, List, NoReturn, Optional, Tuple +from typing import Any, Dict, Generator, List, NoReturn, Optional, Set, Tuple from rich.columns import Columns from rich.console import Console, Group, RenderableType @@ -73,6 +76,109 @@ def set_accent_color(accent_color: Optional[str]) -> None: ) +def recursive_arg_search( + args: List[str], + parser_spec: ParserSpecification, + prog: str, + unrecognized_arguments: Set[str], +) -> Tuple[List[_ArgumentInfo], bool, bool]: + """Recursively search for arguments in a ParserSpecification. Used for error message + printing. + + Returns a list of arguments, whether the parser has subcommands or not, and -- if + `unrecognized_arguments` is passed in --- whether an unrecognized argument exists + under a different subparser. + + Args: + args: Arguments being parsed. Used for heuristics on subcommands. + parser_spec: Argument parser specification. + subcommands: Prog corresponding to parser_spec. + unrecognized_arguments: Used for same_exists return value. + """ + # Argument name => subcommands it came from. + arguments: List[_ArgumentInfo] = [] + has_subcommands = False + same_exists = False + + def _recursive_arg_search( + parser_spec: ParserSpecification, + prog: str, + subcommand_match_score: float, + ) -> None: + """Find all possible arguments that could have been passed in.""" + + # When tyro.conf.ConsolidateSubcommandArgs is turned on, arguments will + # only appear in the help message for "leaf" subparsers. + help_flag = ( + " (other subcommands) --help" + if parser_spec.consolidate_subcommand_args + and parser_spec.subparsers is not None + else " --help" + ) + for arg in parser_spec.args: + if arg.field.is_positional() or arg.lowered.is_fixed(): + # Skip positional arguments. + continue + + # Skip suppressed arguments. + if conf.Suppress in arg.field.markers or ( + conf.SuppressFixed in arg.field.markers + and conf.Fixed in arg.field.markers + ): + continue + + option_strings = (arg.lowered.name_or_flag,) + + # Handle actions, eg BooleanOptionalAction will map ("--flag",) to + # ("--flag", "--no-flag"). + if ( + arg.lowered.action is not None + # Actions are sometimes strings in Python 3.7, eg "append". + # We'll ignore these, but this kind of thing is a good reason + # for just forking argparse. + and callable(arg.lowered.action) + ): + option_strings = arg.lowered.action( + option_strings, dest="" # dest should not matter. + ).option_strings + + arguments.append( + _ArgumentInfo( + # Currently doesn't handle actions well, eg boolean optional + # arguments. + option_strings, + metavar=arg.lowered.metavar, + usage_hint=prog + help_flag, + help=arg.lowered.help, + subcommand_match_score=subcommand_match_score, + ) + ) + + # An unrecognized argument. + nonlocal same_exists + if not same_exists and arg.lowered.name_or_flag in unrecognized_arguments: + same_exists = True + + if parser_spec.subparsers is not None: + nonlocal has_subcommands + has_subcommands = True + for ( + subparser_name, + subparser, + ) in parser_spec.subparsers.parser_from_name.items(): + _recursive_arg_search( + subparser, + prog + " " + subparser_name, + # Leaky (!!) heuristic for if this subcommand is matched or not. + subcommand_match_score=subcommand_match_score + + (1 if subparser_name in args else -0.001), + ) + + _recursive_arg_search(parser_spec, prog, 0) + + return arguments, has_subcommands, same_exists + + # TODO: this is a prototype; for a v1.0.0 release we should revisit whether the global # state here is acceptable or not. THEME = TyroTheme() @@ -442,14 +548,6 @@ def consume_positionals(start_index): # return the updated namespace and the extra arguments return namespace, extras - def _print_usage_succinct(self, console: Console) -> None: - """Print usage, but abridged if too long.""" - usage = self.format_usage().strip() + "\n" - if len(usage) < 400: - print(usage) - else: # pragma: no cover - console.print(f"[bold]help:[/bold] {self.prog} --help\n") - @override def error(self, message: str) -> NoReturn: """Improve error messages from argparse. @@ -464,7 +562,6 @@ def error(self, message: str) -> NoReturn: """ console = Console(theme=THEME.as_rich_theme()) - self._print_usage_succinct(console) extra_info: List[RenderableType] = [] global global_unrecognized_args @@ -473,117 +570,33 @@ def error(self, message: str) -> NoReturn: ): global_unrecognized_args = message.partition(":")[2].strip().split(" ") + message_title = "Parsing error" + if len(global_unrecognized_args) > 0: + message_title = "Unrecognized arguments" message = f"Unrecognized arguments: {' '.join(global_unrecognized_args)}" - unrecognized_arguments = [ + unrecognized_arguments = set( arg for arg in global_unrecognized_args # If we pass in `--spell-chekc on`, we only want `spell-chekc` and not # `on`. if arg.startswith("--") - ] - - # Argument name => subcommands it came from. - arguments: List[_ArgumentInfo] = [] - has_subcommands = False - same_exists = False - - def _recursive_arg_search( - parser_spec: ParserSpecification, - subcommands: str, - subcommand_match_score: float, - ) -> None: - """Find all possible arguments that could have been passed in.""" - - # When tyro.conf.ConsolidateSubcommandArgs is turned on, arguments will - # only appear in the help message for "leaf" subparsers. - help_flag = ( - " (other subcommands) --help" - if parser_spec.consolidate_subcommand_args - and parser_spec.subparsers is not None - else " --help" - ) - for arg in parser_spec.args: - if arg.field.is_positional() or arg.lowered.is_fixed(): - # Skip positional arguments. - continue - - # Skip suppressed arguments. - if conf.Suppress in arg.field.markers or ( - conf.SuppressFixed in arg.field.markers - and conf.Fixed in arg.field.markers - ): - continue - - option_strings = (arg.lowered.name_or_flag,) - - # Handle actions, eg BooleanOptionalAction will map ("--flag",) to - # ("--flag", "--no-flag"). - if ( - arg.lowered.action is not None - # Actions are sometimes strings in Python 3.7, eg "append". - # We'll ignore these, but this kind of thing is a good reason - # for just forking argparse. - and callable(arg.lowered.action) - ): - option_strings = arg.lowered.action( - option_strings, dest="" # dest should not matter. - ).option_strings - - arguments.append( - _ArgumentInfo( - # Currently doesn't handle actions well, eg boolean optional - # arguments. - option_strings, - metavar=arg.lowered.metavar, - usage_hint=subcommands + help_flag, - help=arg.lowered.help, - subcommand_match_score=subcommand_match_score, - ) - ) - - # An unrecognized argument. - nonlocal same_exists - if ( - not same_exists - and arg.lowered.name_or_flag in unrecognized_arguments - ): - same_exists = True - - if parser_spec.subparsers is not None: - nonlocal has_subcommands - has_subcommands = True - for ( - subparser_name, - subparser, - ) in parser_spec.subparsers.parser_from_name.items(): - _recursive_arg_search( - subparser, - subcommands + " " + subparser_name, - subcommand_match_score=subcommand_match_score - + (1 if subparser_name in self._args else -0.001), - ) - - _recursive_arg_search( - self._parser_specification, - # Remove other subcommands. - self.prog.split(" ")[0], - 0, + ) + arguments, has_subcommands, same_exists = recursive_arg_search( + args=self._args, + parser_spec=self._parser_specification, + prog=self.prog.partition(" ")[0], + unrecognized_arguments=unrecognized_arguments, ) if has_subcommands and same_exists: - misplaced_arguments = message.partition(":")[2].strip() - message = ( - "unrecognized or misplaced arguments: " + misplaced_arguments - if " " in misplaced_arguments - else "unrecognized or misplaced argument: " + misplaced_arguments - ) + message = f"Unrecognized or misplaced arguments: {' '.join(global_unrecognized_args)}" # Show similar arguments for keyword options. for unrecognized_argument in unrecognized_arguments: # Sort arguments by similarity. scored_arguments: List[Tuple[_ArgumentInfo, float]] = [] - for argument in arguments: + for arg_info in arguments: # Compute a score for each argument. assert unrecognized_argument.startswith("--") @@ -605,14 +618,14 @@ def get_score(option_string: str) -> float: ).ratio() scored_arguments.append( - (argument, max(map(get_score, argument.option_strings))) + (arg_info, max(map(get_score, arg_info.option_strings))) ) # Add information about similar arguments. prev_arg_option_strings: Optional[Tuple[str, ...]] = None show_arguments: List[_ArgumentInfo] = [] unique_counter = 0 - for argument, score in ( + for arg_info, score in ( # Sort scores greatest to least. sorted( scored_arguments, @@ -634,13 +647,13 @@ def get_score(option_string: str) -> float: if ( score < 0.9 and unique_counter >= 3 - and prev_arg_option_strings != argument.option_strings + and prev_arg_option_strings != arg_info.option_strings ): break - unique_counter += prev_arg_option_strings != argument.option_strings + unique_counter += prev_arg_option_strings != arg_info.option_strings - show_arguments.append(argument) - prev_arg_option_strings = argument.option_strings + show_arguments.append(arg_info) + prev_arg_option_strings = arg_info.option_strings prev_arg_option_strings = None prev_argument_help: Optional[str] = None @@ -656,9 +669,9 @@ def get_score(option_string: str) -> float: ) unique_counter = 0 - for argument in show_arguments: + for arg_info in show_arguments: same_counter += 1 - if argument.option_strings != prev_arg_option_strings: + if arg_info.option_strings != prev_arg_option_strings: same_counter = 0 if unique_counter >= 10: break @@ -669,7 +682,7 @@ def get_score(option_string: str) -> float: if ( len(show_arguments) >= 8 and same_counter >= 4 - and argument.option_strings == prev_arg_option_strings + and arg_info.option_strings == prev_arg_option_strings ): if not dots_printed: extra_info.append( @@ -683,17 +696,17 @@ def get_score(option_string: str) -> float: if not ( has_subcommands - and argument.option_strings == prev_arg_option_strings + and arg_info.option_strings == prev_arg_option_strings ): extra_info.append( Padding( "[bold]" + ( - ", ".join(argument.option_strings) - if argument.metavar is None - else ", ".join(argument.option_strings) + ", ".join(arg_info.option_strings) + if arg_info.metavar is None + else ", ".join(arg_info.option_strings) + " " - + argument.metavar + + arg_info.metavar ) + "[/bold]", (0, 0, 0, 4), @@ -707,36 +720,115 @@ def get_score(option_string: str) -> float: # ) # ) - if argument.help is not None and ( + if arg_info.help is not None and ( # Only print help messages if it's not the same as the previous # one. - argument.help != prev_argument_help - or argument.option_strings != prev_arg_option_strings + arg_info.help != prev_argument_help + or arg_info.option_strings != prev_arg_option_strings ): - extra_info.append(Padding(argument.help, (0, 0, 0, 8))) + extra_info.append(Padding(arg_info.help, (0, 0, 0, 8))) # Show the subcommand that this argument is available in. if has_subcommands: extra_info.append( Padding( - f"in [green]{argument.usage_hint}[/green]", + f"in [green]{arg_info.usage_hint}[/green]", (0, 0, 0, 12), ) ) - prev_arg_option_strings = argument.option_strings - prev_argument_help = argument.help + prev_arg_option_strings = arg_info.option_strings + prev_argument_help = arg_info.help + + elif message.startswith("the following arguments are required:"): + message_title = "Required arguments" + + info_from_required_arg: Dict[str, Optional[_ArgumentInfo]] = {} + for arg in message.partition(":")[2].strip().split(", "): + info_from_required_arg[arg] = None + + arguments, has_subcommands, same_exists = recursive_arg_search( + args=self._args, + parser_spec=self._parser_specification, + prog=self.prog.partition(" ")[0], + unrecognized_arguments=set(), + ) + del same_exists + + for arg_info in arguments: + # Iterate over each option string separately. This can help us support + # aliases in the future. + for option_string in arg_info.option_strings: + # If the option string was found... + if option_string in info_from_required_arg and ( + # And it's the first time it was found... + info_from_required_arg[option_string] is None + # Or we found a better one... + or arg_info.subcommand_match_score + > info_from_required_arg[option_string].subcommand_match_score # type: ignore + ): + # Record the argument info. + info_from_required_arg[option_string] = arg_info + + # Try to print help text for required arguments. + first = True + for maybe_arg in info_from_required_arg.values(): + if maybe_arg is None: + # No argument info found. This will currently happen for + # subcommands. + continue + + if first: + extra_info.extend( + [ + Rule(style=Style(color="red")), + "Argument helptext:", + ] + ) + first = False + + extra_info.append( + Padding( + "[bold]" + + ( + ", ".join(maybe_arg.option_strings) + if maybe_arg.metavar is None + else ", ".join(maybe_arg.option_strings) + + " " + + maybe_arg.metavar + ) + + "[/bold]", + (0, 0, 0, 4), + ) + ) + if maybe_arg.help is not None: + extra_info.append(Padding(maybe_arg.help, (0, 0, 0, 8))) + if has_subcommands: + # We are explicit about where the argument helptext is being + # extracted from because the `subcommand_match_score` heuristic + # above is flawed. + # + # The stars really need to be aligned for it to fail, but this makes + # sure that if it does fail that it's obvious to the user. + extra_info.append( + Padding( + f"in [green]{maybe_arg.usage_hint}[/green]", + (0, 0, 0, 12), + ) + ) - # print(self._parser_specification) console.print( Panel( Group( f"{message[0].upper() + message[1:]}" if len(message) > 0 else "", *extra_info, + Rule(style=Style(color="red")), + f"For full helptext, run [bold]{self.prog} --help[/bold]", ), - title="[bold]Parsing error[/bold]", + title=f"[bold]{message_title}[/bold]", title_align="left", border_style=Style(color="bright_red"), + expand=False, ) ) sys.exit(2) diff --git a/tyro/_cli.py b/tyro/_cli.py index c3d3c799..921c87da 100644 --- a/tyro/_cli.py +++ b/tyro/_cli.py @@ -376,8 +376,8 @@ def _cli_impl( # Print help message when no arguments are passed in. (but arguments are # expected) - if len(args) == 0 and parser_spec.has_required_args: - args = ["--help"] + # if len(args) == 0 and parser_spec.has_required_args: + # args = ["--help"] if return_parser: _arguments.USE_RICH = True @@ -448,7 +448,6 @@ def _cli_impl( from ._argparse_formatter import THEME console = Console(theme=THEME.as_rich_theme()) - parser._print_usage_succinct(console) console.print( Panel( Group( @@ -469,6 +468,8 @@ def _cli_impl( ), pad=(0, 0, 0, 4), ), + Rule(style=Style(color="red")), + f"For full helptext, see [bold]{parser.prog} --help[/bold]", ] ), ),