Skip to content

Commit

Permalink
Handle flags / custom actions in unrecognized argument errors
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 30, 2023
1 parent 2dbc399 commit 3f3e6f2
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 38 deletions.
21 changes: 21 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,24 @@ class ClassI:
assert "rewarde" not in error
assert "[...]" in error
assert error.count("--help") == 4


def test_similar_flag() -> None:
@dataclasses.dataclass
class Args:
flag: bool = False

target = io.StringIO()
with pytest.raises(SystemExit), contextlib.redirect_stdout(target):
tyro.cli(
Args,
args="--lag".split(" "),
)

error = target.getvalue()

# Printed in the usage message.
assert error.count("--flag | --no-flag") == 1

# Printed in the similar argument list.
assert error.count("--flag, --no-flag") == 1
102 changes: 64 additions & 38 deletions tyro/_argparse_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re as _re
import shutil
import sys
from gettext import gettext as _
from typing import Any, Generator, List, NoReturn, Optional, Tuple

from rich.columns import Columns
Expand Down Expand Up @@ -136,7 +137,7 @@ def str_from_rich(

@dataclasses.dataclass(frozen=True)
class _ArgumentInfo:
flag: str
option_strings: Tuple[str, ...]
metavar: Optional[str]
usage_hint: str
help: Optional[str]
Expand Down Expand Up @@ -237,8 +238,8 @@ def take_action(action, argument_strings, option_string=None):
for conflict_action in action_conflicts.get(action, []):
if conflict_action in seen_non_default_actions:
msg = _("not allowed with argument %s")
action_name = _get_action_name(conflict_action)
raise ArgumentError(action, msg % action_name)
action_name = argparse._get_action_name(conflict_action)
raise argparse.ArgumentError(action, msg % action_name)

# take the action if we didn't receive a SUPPRESS value
# (e.g. from a default)
Expand Down Expand Up @@ -293,7 +294,7 @@ def consume_optional(start_index):
explicit_arg = new_explicit_arg
else:
msg = _("ignored explicit argument %r")
raise ArgumentError(action, msg % explicit_arg)
raise argparse.ArgumentError(action, msg % explicit_arg)

# if the action expect exactly one argument, we've
# successfully matched the option; exit the loop
Expand All @@ -307,7 +308,7 @@ def consume_optional(start_index):
# explicit argument
else:
msg = _("ignored explicit argument %r")
raise ArgumentError(action, msg % explicit_arg)
raise argparse.ArgumentError(action, msg % explicit_arg)

# if there is no explicit argument, try to match the
# optional's string arguments with the following strings
Expand Down Expand Up @@ -417,7 +418,7 @@ def consume_positionals(start_index):

if required_actions:
self.error(
argparse._("the following arguments are required: %s")
_("the following arguments are required: %s")
% ", ".join(required_actions)
)

Expand All @@ -436,7 +437,7 @@ def consume_positionals(start_index):
if action.help is not argparse.SUPPRESS
]
msg = _("one of the arguments %s is required")
self.error(msg % " ".join(names))
self.error(msg % " ".join(names)) # type: ignore

# return the updated namespace and the extra arguments
return namespace, extras
Expand Down Expand Up @@ -504,11 +505,20 @@ def _recursive_arg_search(
):
continue

option_strings = (arg.lowered.name_or_flag,)

# Handle actions, eg BooleanOptionalAction will map ("--flag",) to
# ("--flag", "--no-flag").
if arg.lowered.action is not None:
option_strings = arg.lowered.action(
option_strings, dest="" # dest should not matter.
).option_strings

arguments.append(
_ArgumentInfo(
# Currently doesn't handle actions well, eg boolean optional
# arguments.
arg.lowered.name_or_flag,
option_strings,
metavar=arg.lowered.metavar,
usage_hint=subcommands + help_flag,
help=arg.lowered.help,
Expand Down Expand Up @@ -560,26 +570,31 @@ def _recursive_arg_search(
for argument in arguments:
# Compute a score for each argument.
assert unrecognized_argument.startswith("--")
if argument.flag.endswith(
unrecognized_argument[2:]
) or argument.flag.startswith(unrecognized_argument[2:]):
score = 0.9
elif len(unrecognized_argument) >= 4 and all(
map(
lambda part: part in argument.flag,
unrecognized_argument[2:].split("."),
)
):
score = 0.9
else:
score = difflib.SequenceMatcher(
a=unrecognized_argument, b=argument.flag
).ratio()
scored_arguments.append((argument, score))

def get_score(option_string: str) -> float:
if option_string.endswith(
unrecognized_argument[2:]
) or option_string.startswith(unrecognized_argument[2:]):
return 0.9
elif len(unrecognized_argument) >= 4 and all(
map(
lambda part: part in option_string,
unrecognized_argument[2:].split("."),
)
):
return 0.9
else:
return difflib.SequenceMatcher(
a=unrecognized_argument, b=option_string
).ratio()

scored_arguments.append(
(argument, max(map(get_score, argument.option_strings)))
)

# Add information about similar arguments.
prev_argument_flag: Optional[str] = None
show_arguments = []
prev_arg_option_strings: Optional[Tuple[str, ...]] = None
show_arguments: List[_ArgumentInfo] = []
unique_counter = 0
for argument, score in (
# Sort scores greatest to least.
Expand All @@ -592,7 +607,7 @@ def _recursive_arg_search(
# subcommands.
-arg_score[0].subcommand_match_score,
# Cluster by flag name, usage hint, help message.
arg_score[0].flag,
arg_score[0].option_strings[0],
arg_score[0].usage_hint,
arg_score[0].help,
),
Expand All @@ -603,15 +618,15 @@ def _recursive_arg_search(
if (
score < 0.9
and unique_counter >= 3
and prev_argument_flag != argument.flag
and prev_arg_option_strings != argument.option_strings
):
break
unique_counter += prev_argument_flag != argument.flag
unique_counter += prev_arg_option_strings != argument.option_strings

show_arguments.append(argument)
prev_argument_flag = argument.flag
prev_arg_option_strings = argument.option_strings

prev_argument_flag = None
prev_arg_option_strings = None
prev_argument_help: Optional[str] = None
same_counter = 0
dots_printed = False
Expand All @@ -627,7 +642,7 @@ def _recursive_arg_search(
unique_counter = 0
for argument in show_arguments:
same_counter += 1
if argument.flag != prev_argument_flag:
if argument.option_strings != prev_arg_option_strings:
same_counter = 0
if unique_counter >= 10:
break
Expand All @@ -638,7 +653,7 @@ def _recursive_arg_search(
if (
len(show_arguments) >= 8
and same_counter >= 4
and argument.flag == prev_argument_flag
and argument.option_strings == prev_arg_option_strings
):
if not dots_printed:
extra_info.append(
Expand All @@ -650,10 +665,21 @@ def _recursive_arg_search(
dots_printed = True
continue

if not (has_subcommands and argument.flag == prev_argument_flag):
if not (
has_subcommands
and argument.option_strings == prev_arg_option_strings
):
extra_info.append(
Padding(
f"[bold]{argument.flag if argument.metavar is None else argument.flag + ' ' + argument.metavar}[/bold]",
"[bold]"
+ (
", ".join(argument.option_strings)
if argument.metavar is None
else ", ".join(argument.option_strings)
+ " "
+ argument.metavar
)
+ "[/bold]",
(0, 0, 0, 4),
)
)
Expand All @@ -669,7 +695,7 @@ def _recursive_arg_search(
# Only print help messages if it's not the same as the previous
# one.
argument.help != prev_argument_help
or argument.flag != prev_argument_flag
or argument.option_strings != prev_arg_option_strings
):
extra_info.append(Padding(argument.help, (0, 0, 0, 8)))

Expand All @@ -682,7 +708,7 @@ def _recursive_arg_search(
)
)

prev_argument_flag = argument.flag
prev_arg_option_strings = argument.option_strings
prev_argument_help = argument.help

# print(self._parser_specification)
Expand Down Expand Up @@ -1065,7 +1091,7 @@ def _format_actions_usage(self, actions, groups): # pragma: no cover
group_action_count - suppressed_actions_count
)

if not group.required:
if not group.required: # type: ignore
if start in inserts:
inserts[start] += " ["
else:
Expand Down
3 changes: 3 additions & 0 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,14 @@ def apply(
help=helptext,
allow_abbrev=False,
)

# Attributes used for error message generation.
assert isinstance(subparser, _argparse_formatter.TyroArgumentParser)
assert isinstance(parent_parser, _argparse_formatter.TyroArgumentParser)
subparser._parsing_known_args = parent_parser._parsing_known_args
subparser._parser_specification = parent_parser._parser_specification
subparser._args = parent_parser._args

subparser_tree_leaves.extend(subparser_def.apply(subparser))

return tuple(subparser_tree_leaves)
Expand Down

0 comments on commit 3f3e6f2

Please sign in to comment.