From 3f3e6f2fe90fb00355382ccbb9d53ed3b7bae5bd Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 30 Aug 2023 13:09:44 -0700 Subject: [PATCH] Handle flags / custom actions in unrecognized argument errors --- tests/test_errors.py | 21 ++++++++ tyro/_argparse_formatter.py | 102 ++++++++++++++++++++++-------------- tyro/_parsers.py | 3 ++ 3 files changed, 88 insertions(+), 38 deletions(-) diff --git a/tests/test_errors.py b/tests/test_errors.py index 35ea355d..8feb632f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -432,3 +432,24 @@ class ClassI: assert "rewarde" not in error assert "[...]" in error assert error.count("--help") == 4 + + +def test_similar_flag() -> None: + @dataclasses.dataclass + class Args: + flag: bool = False + + target = io.StringIO() + with pytest.raises(SystemExit), contextlib.redirect_stdout(target): + tyro.cli( + Args, + args="--lag".split(" "), + ) + + error = target.getvalue() + + # Printed in the usage message. + assert error.count("--flag | --no-flag") == 1 + + # Printed in the similar argument list. + assert error.count("--flag, --no-flag") == 1 diff --git a/tyro/_argparse_formatter.py b/tyro/_argparse_formatter.py index 42f72e66..b9d39a19 100644 --- a/tyro/_argparse_formatter.py +++ b/tyro/_argparse_formatter.py @@ -19,6 +19,7 @@ import re as _re import shutil import sys +from gettext import gettext as _ from typing import Any, Generator, List, NoReturn, Optional, Tuple from rich.columns import Columns @@ -136,7 +137,7 @@ def str_from_rich( @dataclasses.dataclass(frozen=True) class _ArgumentInfo: - flag: str + option_strings: Tuple[str, ...] metavar: Optional[str] usage_hint: str help: Optional[str] @@ -237,8 +238,8 @@ def take_action(action, argument_strings, option_string=None): for conflict_action in action_conflicts.get(action, []): if conflict_action in seen_non_default_actions: msg = _("not allowed with argument %s") - action_name = _get_action_name(conflict_action) - raise ArgumentError(action, msg % action_name) + action_name = argparse._get_action_name(conflict_action) + raise argparse.ArgumentError(action, msg % action_name) # take the action if we didn't receive a SUPPRESS value # (e.g. from a default) @@ -293,7 +294,7 @@ def consume_optional(start_index): explicit_arg = new_explicit_arg else: msg = _("ignored explicit argument %r") - raise ArgumentError(action, msg % explicit_arg) + raise argparse.ArgumentError(action, msg % explicit_arg) # if the action expect exactly one argument, we've # successfully matched the option; exit the loop @@ -307,7 +308,7 @@ def consume_optional(start_index): # explicit argument else: msg = _("ignored explicit argument %r") - raise ArgumentError(action, msg % explicit_arg) + raise argparse.ArgumentError(action, msg % explicit_arg) # if there is no explicit argument, try to match the # optional's string arguments with the following strings @@ -417,7 +418,7 @@ def consume_positionals(start_index): if required_actions: self.error( - argparse._("the following arguments are required: %s") + _("the following arguments are required: %s") % ", ".join(required_actions) ) @@ -436,7 +437,7 @@ def consume_positionals(start_index): if action.help is not argparse.SUPPRESS ] msg = _("one of the arguments %s is required") - self.error(msg % " ".join(names)) + self.error(msg % " ".join(names)) # type: ignore # return the updated namespace and the extra arguments return namespace, extras @@ -504,11 +505,20 @@ def _recursive_arg_search( ): continue + option_strings = (arg.lowered.name_or_flag,) + + # Handle actions, eg BooleanOptionalAction will map ("--flag",) to + # ("--flag", "--no-flag"). + if arg.lowered.action is not None: + option_strings = arg.lowered.action( + option_strings, dest="" # dest should not matter. + ).option_strings + arguments.append( _ArgumentInfo( # Currently doesn't handle actions well, eg boolean optional # arguments. - arg.lowered.name_or_flag, + option_strings, metavar=arg.lowered.metavar, usage_hint=subcommands + help_flag, help=arg.lowered.help, @@ -560,26 +570,31 @@ def _recursive_arg_search( for argument in arguments: # Compute a score for each argument. assert unrecognized_argument.startswith("--") - if argument.flag.endswith( - unrecognized_argument[2:] - ) or argument.flag.startswith(unrecognized_argument[2:]): - score = 0.9 - elif len(unrecognized_argument) >= 4 and all( - map( - lambda part: part in argument.flag, - unrecognized_argument[2:].split("."), - ) - ): - score = 0.9 - else: - score = difflib.SequenceMatcher( - a=unrecognized_argument, b=argument.flag - ).ratio() - scored_arguments.append((argument, score)) + + def get_score(option_string: str) -> float: + if option_string.endswith( + unrecognized_argument[2:] + ) or option_string.startswith(unrecognized_argument[2:]): + return 0.9 + elif len(unrecognized_argument) >= 4 and all( + map( + lambda part: part in option_string, + unrecognized_argument[2:].split("."), + ) + ): + return 0.9 + else: + return difflib.SequenceMatcher( + a=unrecognized_argument, b=option_string + ).ratio() + + scored_arguments.append( + (argument, max(map(get_score, argument.option_strings))) + ) # Add information about similar arguments. - prev_argument_flag: Optional[str] = None - show_arguments = [] + prev_arg_option_strings: Optional[Tuple[str, ...]] = None + show_arguments: List[_ArgumentInfo] = [] unique_counter = 0 for argument, score in ( # Sort scores greatest to least. @@ -592,7 +607,7 @@ def _recursive_arg_search( # subcommands. -arg_score[0].subcommand_match_score, # Cluster by flag name, usage hint, help message. - arg_score[0].flag, + arg_score[0].option_strings[0], arg_score[0].usage_hint, arg_score[0].help, ), @@ -603,15 +618,15 @@ def _recursive_arg_search( if ( score < 0.9 and unique_counter >= 3 - and prev_argument_flag != argument.flag + and prev_arg_option_strings != argument.option_strings ): break - unique_counter += prev_argument_flag != argument.flag + unique_counter += prev_arg_option_strings != argument.option_strings show_arguments.append(argument) - prev_argument_flag = argument.flag + prev_arg_option_strings = argument.option_strings - prev_argument_flag = None + prev_arg_option_strings = None prev_argument_help: Optional[str] = None same_counter = 0 dots_printed = False @@ -627,7 +642,7 @@ def _recursive_arg_search( unique_counter = 0 for argument in show_arguments: same_counter += 1 - if argument.flag != prev_argument_flag: + if argument.option_strings != prev_arg_option_strings: same_counter = 0 if unique_counter >= 10: break @@ -638,7 +653,7 @@ def _recursive_arg_search( if ( len(show_arguments) >= 8 and same_counter >= 4 - and argument.flag == prev_argument_flag + and argument.option_strings == prev_arg_option_strings ): if not dots_printed: extra_info.append( @@ -650,10 +665,21 @@ def _recursive_arg_search( dots_printed = True continue - if not (has_subcommands and argument.flag == prev_argument_flag): + if not ( + has_subcommands + and argument.option_strings == prev_arg_option_strings + ): extra_info.append( Padding( - f"[bold]{argument.flag if argument.metavar is None else argument.flag + ' ' + argument.metavar}[/bold]", + "[bold]" + + ( + ", ".join(argument.option_strings) + if argument.metavar is None + else ", ".join(argument.option_strings) + + " " + + argument.metavar + ) + + "[/bold]", (0, 0, 0, 4), ) ) @@ -669,7 +695,7 @@ def _recursive_arg_search( # Only print help messages if it's not the same as the previous # one. argument.help != prev_argument_help - or argument.flag != prev_argument_flag + or argument.option_strings != prev_arg_option_strings ): extra_info.append(Padding(argument.help, (0, 0, 0, 8))) @@ -682,7 +708,7 @@ def _recursive_arg_search( ) ) - prev_argument_flag = argument.flag + prev_arg_option_strings = argument.option_strings prev_argument_help = argument.help # print(self._parser_specification) @@ -1065,7 +1091,7 @@ def _format_actions_usage(self, actions, groups): # pragma: no cover group_action_count - suppressed_actions_count ) - if not group.required: + if not group.required: # type: ignore if start in inserts: inserts[start] += " [" else: diff --git a/tyro/_parsers.py b/tyro/_parsers.py index 7ee9778e..a8cb131b 100644 --- a/tyro/_parsers.py +++ b/tyro/_parsers.py @@ -540,11 +540,14 @@ def apply( help=helptext, allow_abbrev=False, ) + + # Attributes used for error message generation. assert isinstance(subparser, _argparse_formatter.TyroArgumentParser) assert isinstance(parent_parser, _argparse_formatter.TyroArgumentParser) subparser._parsing_known_args = parent_parser._parsing_known_args subparser._parser_specification = parent_parser._parser_specification subparser._args = parent_parser._args + subparser_tree_leaves.extend(subparser_def.apply(subparser)) return tuple(subparser_tree_leaves)