Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 20, 2022
1 parent 94b4901 commit 62f49a1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 29 deletions.
8 changes: 4 additions & 4 deletions tyro/_argparse_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


@dataclasses.dataclass
class DcargsTheme:
class TyroTheme:
border: Style = Style()
description: Style = Style()
invocation: Style = Style()
Expand Down Expand Up @@ -65,7 +65,7 @@ def set_accent_color(accent_color: Optional[str]) -> None:

# TODO: this is a prototype; for a v1.0.0 release we should revisit whether the global
# state here is acceptable or not.
THEME = DcargsTheme()
THEME = TyroTheme()
set_accent_color(None)


Expand Down Expand Up @@ -128,7 +128,7 @@ def str_from_rich(
return out.get().rstrip("\n")


class DcargsArgparseHelpFormatter(argparse.RawDescriptionHelpFormatter):
class TyroArgparseHelpFormatter(argparse.RawDescriptionHelpFormatter):
def __init__(self, prog: str):
indent_increment = 4
width = shutil.get_terminal_size().columns - 2
Expand Down Expand Up @@ -383,7 +383,7 @@ def _tyro_format_nonroot(self):
item_content = func(*args)
if (
getattr(func, "__func__", None)
is DcargsArgparseHelpFormatter._format_action
is TyroArgparseHelpFormatter._format_action
):
(action,) = args
assert isinstance(action, argparse.Action)
Expand Down
6 changes: 2 additions & 4 deletions tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def fix_arg(arg: str) -> str:
with _argparse_formatter.ansi_context():
parser = argparse.ArgumentParser(
prog=prog,
formatter_class=_argparse_formatter.DcargsArgparseHelpFormatter,
formatter_class=_argparse_formatter.TyroArgparseHelpFormatter,
)
parser_definition.apply(parser)

Expand Down Expand Up @@ -327,9 +327,7 @@ def _cache_subparsers(parser_definition: _parsers.ParserSpecification) -> None:
subparsers = parser_definition.subparsers
if subparsers is None:
return
subparser_def_from_prefixed_field_name[
subparsers.prefix if subparsers.prefix != _strings.dummy_field_name else ""
] = subparsers
subparser_def_from_prefixed_field_name[subparsers.prefix] = subparsers
for p in subparsers.parser_from_name.values():
_cache_subparsers(p)

Expand Down
44 changes: 23 additions & 21 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def from_callable_or_type(
subparsers = subparsers_attempt
continue
else:
subparsers = subparsers.add_subparsers_to_leaves(
subparsers_attempt
subparsers = add_subparsers_to_leaves(
subparsers, subparsers_attempt
)
continue

Expand Down Expand Up @@ -150,8 +150,8 @@ def from_callable_or_type(
subparsers = (
nested_parser.subparsers
if subparsers is None
else subparsers.add_subparsers_to_leaves(
nested_parser.subparsers
else add_subparsers_to_leaves(
subparsers, nested_parser.subparsers
)
)

Expand Down Expand Up @@ -490,7 +490,7 @@ def apply(self, parent_parser: argparse.ArgumentParser) -> None:
if self.can_be_none:
subparser = argparse_subparsers.add_parser(
name=_strings.subparser_name_from_type(self.prefix, None),
formatter_class=_argparse_formatter.DcargsArgparseHelpFormatter,
formatter_class=_argparse_formatter.TyroArgparseHelpFormatter,
help="",
)

Expand All @@ -502,24 +502,26 @@ def apply(self, parent_parser: argparse.ArgumentParser) -> None:

subparser = argparse_subparsers.add_parser(
name,
formatter_class=_argparse_formatter.DcargsArgparseHelpFormatter,
formatter_class=_argparse_formatter.TyroArgparseHelpFormatter,
help=helptext,
)
subparser_def.apply(subparser)

def add_subparsers_to_leaves(
self, subparsers: SubparsersSpecification
) -> SubparsersSpecification:
new_parsers_from_name = {}
for name, parser in self.parser_from_name.items():
new_parsers_from_name[name] = dataclasses.replace(
parser,
subparsers=subparsers
if parser.subparsers is None
else parser.subparsers.add_subparsers_to_leaves(subparsers),
)
return dataclasses.replace(
self,
parser_from_name=new_parsers_from_name,
required=self.required or subparsers.required,

def add_subparsers_to_leaves(
root: Optional[SubparsersSpecification], leaf: SubparsersSpecification
) -> SubparsersSpecification:
if root is None:
return leaf

new_parsers_from_name = {}
for name, parser in root.parser_from_name.items():
new_parsers_from_name[name] = dataclasses.replace(
parser,
subparsers=add_subparsers_to_leaves(parser.subparsers, leaf),
)
return dataclasses.replace(
root,
parser_from_name=new_parsers_from_name,
required=root.required or leaf.required,
)

0 comments on commit 62f49a1

Please sign in to comment.