Skip to content

Commit

Permalink
Add tyro.conf.ConsolidateSubcommandArgs, tyro.conf.configure()
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 20, 2022
1 parent 62f49a1 commit fd13815
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 7 deletions.
119 changes: 119 additions & 0 deletions tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,122 @@ def main(x: int, y: tyro.conf.Positional[int]) -> int:

assert tyro.cli(main, args="5 --x 3".split(" ")) == 8
assert tyro.cli(main, args="--x 3 5".split(" ")) == 8


def test_omit_subcommand_prefix_and_consolidate_subcommand_args():
@dataclasses.dataclass
class DefaultInstanceHTTPServer:
y: int = 0
flag: bool = True

@dataclasses.dataclass
class DefaultInstanceSMTPServer:
z: int = 0

@dataclasses.dataclass
class DefaultInstanceSubparser:
x: int
# bc: Union[DefaultInstanceHTTPServer, DefaultInstanceSMTPServer]
bc: tyro.conf.OmitSubcommandPrefixes[
Union[DefaultInstanceHTTPServer, DefaultInstanceSMTPServer]
]

assert (
tyro.cli(
tyro.conf.ConsolidateSubcommandArgs[DefaultInstanceSubparser],
args=[
"bc:default-instance-http-server",
"--x",
"1",
"--y",
"5",
"--no-flag",
],
)
== tyro.cli(
tyro.conf.ConsolidateSubcommandArgs[DefaultInstanceSubparser],
args=[
"bc:default-instance-http-server",
"--x",
"1",
"--y",
"5",
],
default=DefaultInstanceSubparser(
x=1, bc=DefaultInstanceHTTPServer(y=3, flag=False)
),
)
== DefaultInstanceSubparser(x=1, bc=DefaultInstanceHTTPServer(y=5, flag=False))
)
assert (
tyro.cli(
tyro.conf.ConsolidateSubcommandArgs[DefaultInstanceSubparser],
args=[
"bc:default-instance-http-server",
"--x",
"1",
"--y",
"8",
],
)
== tyro.cli(
tyro.conf.ConsolidateSubcommandArgs[DefaultInstanceSubparser],
args=[
"bc:default-instance-http-server",
"--x",
"1",
"--y",
"8",
],
default=DefaultInstanceSubparser(x=1, bc=DefaultInstanceHTTPServer(y=7)),
)
== DefaultInstanceSubparser(x=1, bc=DefaultInstanceHTTPServer(y=8))
)


def test_omit_subcommand_prefix_and_consolidate_subcommand_args_in_function():
@dataclasses.dataclass
class DefaultInstanceHTTPServer:
y: int = 0
flag: bool = True

@dataclasses.dataclass
class DefaultInstanceSMTPServer:
z: int = 0

@dataclasses.dataclass
class DefaultInstanceSubparser:
x: int
# bc: Union[DefaultInstanceHTTPServer, DefaultInstanceSMTPServer]
bc: Union[DefaultInstanceHTTPServer, DefaultInstanceSMTPServer]

@tyro.conf.configure(
tyro.conf.OmitSubcommandPrefixes,
tyro.conf.ConsolidateSubcommandArgs,
)
def func(parent: DefaultInstanceSubparser) -> DefaultInstanceSubparser:
return parent

assert tyro.cli(
func,
args=[
"parent.bc:default-instance-http-server",
"--parent.x",
"1",
# --y and --no-flag are in a subcommand with prefix omission.
"--y",
"5",
"--no-flag",
],
) == DefaultInstanceSubparser(x=1, bc=DefaultInstanceHTTPServer(y=5, flag=False))
assert tyro.cli(
func,
args=[
"parent.bc:default-instance-http-server",
"--parent.x",
"1",
# --y is in a subcommand with prefix omission.
"--y",
"8",
],
) == DefaultInstanceSubparser(x=1, bc=DefaultInstanceHTTPServer(y=8))
1 change: 1 addition & 0 deletions tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def fix_arg(arg: str) -> str:
parent_classes=set(), # Used for recursive calls.
default_instance=default_instance_internal, # Overrides for default values.
prefix="", # Used for recursive calls.
subcommand_prefix="", # Used for recursive calls.
)

# Generate parser!
Expand Down
50 changes: 44 additions & 6 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@

import argparse
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union, cast
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)

from typing_extensions import Annotated, get_args, get_origin

Expand Down Expand Up @@ -33,6 +45,7 @@ class ParserSpecification:
subparsers: Optional[SubparsersSpecification]
prefix: str
has_required_args: bool
consolidate_subcommand_args: bool

@staticmethod
def from_callable_or_type(
Expand All @@ -47,6 +60,12 @@ def from_callable_or_type(
) -> ParserSpecification:
"""Create a parser definition from a callable or type."""

# Consolidate subcommand types.
consolidate_subcommand_args = (
_markers.ConsolidateSubcommandArgs
in _resolver.unwrap_annotated(f, _markers.Marker)[1]
)

# Resolve generic types.
f, type_from_typevar = _resolver.resolve_generic_types(f)
f = _resolver.narrow_type(f, default_instance)
Expand Down Expand Up @@ -197,18 +216,32 @@ def from_callable_or_type(
subparsers=subparsers,
prefix=prefix,
has_required_args=has_required_args,
consolidate_subcommand_args=consolidate_subcommand_args,
)

def apply(self, parser: argparse.ArgumentParser) -> None:
def apply(
self, parser: argparse.ArgumentParser
) -> Tuple[argparse.ArgumentParser, ...]:
"""Create defined arguments and subparsers."""

# Generate helptext.
parser.description = self.description
self.apply_args(parser)

# Create subparser tree.
if self.subparsers is not None:
self.subparsers.apply(parser)
leaves = self.subparsers.apply(parser)
else:
leaves = (parser,)

# Depending on whether we want to consolidate subcommand args, we can either
# apply arguments to the intermediate parser or only on the leaves.
if self.consolidate_subcommand_args:
for leaf in leaves:
self.apply_args(leaf)
else:
self.apply_args(parser)

return leaves

def apply_args(self, parser: argparse.ArgumentParser) -> None:
"""Create defined arguments and subparsers."""
Expand Down Expand Up @@ -460,7 +493,9 @@ def from_field(
can_be_none=options != options_no_none,
)

def apply(self, parent_parser: argparse.ArgumentParser) -> None:
def apply(
self, parent_parser: argparse.ArgumentParser
) -> Tuple[argparse.ArgumentParser, ...]:
title = "subcommands"
metavar = (
"{"
Expand Down Expand Up @@ -494,6 +529,7 @@ def apply(self, parent_parser: argparse.ArgumentParser) -> None:
help="",
)

subparser_tree_leaves: List[argparse.ArgumentParser] = []
for name, subparser_def in self.parser_from_name.items():
helptext = subparser_def.description.replace("%", "%%")
if len(helptext) > 0:
Expand All @@ -505,7 +541,9 @@ def apply(self, parent_parser: argparse.ArgumentParser) -> None:
formatter_class=_argparse_formatter.TyroArgparseHelpFormatter,
help=helptext,
)
subparser_def.apply(subparser)
subparser_tree_leaves.extend(subparser_def.apply(subparser))

return tuple(subparser_tree_leaves)


def add_subparsers_to_leaves(
Expand Down
4 changes: 4 additions & 0 deletions tyro/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,26 @@
from ._confstruct import arg, subcommand
from ._markers import (
AvoidSubcommands,
ConsolidateSubcommandArgs,
Fixed,
FlagConversionOff,
OmitSubcommandPrefixes,
Positional,
Suppress,
SuppressFixed,
configure,
)

__all__ = [
"arg",
"subcommand",
"AvoidSubcommands",
"ConsolidateSubcommandArgs",
"Fixed",
"FlagConversionOff",
"OmitSubcommandPrefixes",
"Positional",
"Suppress",
"SuppressFixed",
"configure",
]
53 changes: 52 additions & 1 deletion tyro/conf/_markers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Callable, Type, TypeVar

from typing_extensions import Annotated

Expand Down Expand Up @@ -51,6 +51,28 @@
Can be used directly on union types, `AvoidSubcommands[Union[...]]`, or recursively
applied to nested types."""

ConsolidateSubcommandArgs = Annotated[T, None]
"""Consolidate arguments applied to subcommands. Makes CLI less sensitive to argument
ordering, at the cost of support for optional subcommands.
By default, `tyro` will generate a traditional CLI interface where args are applied to
the directly preceding subcommand. When we have two subcommands `s1` and `s2`:
```
python x.py {--root options} s1 {--s1 options} s2 {--s2 options}
```
This can be frustrating because the resulting CLI is sensitive the exact positioning and
ordering of options.
To consolidate subcommands, we push arguments to the end, after all subcommands:
```
python x.py s1 s2 {--root, s1, and s2 options}
```
This is more robust to reordering of options, ensuring that any new options can simply
be placed at the end of the command>
"""

OmitSubcommandPrefixes = Annotated[T, None]
"""Make flags used for keyword arguments in subcommands shorter by omitting prefixes.
Expand All @@ -62,16 +84,45 @@
If subcommand prefixes are omitted, we would instead simply have `--arg`.
"""

CallableType = TypeVar("CallableType", bound=Callable)

# Dynamically generate marker singletons.
# These can be used one of two ways:
# - Marker[T]
# - Annotated[T, Marker]


class Marker(_singleton.Singleton):
def __getitem__(self, key):
return Annotated.__class_getitem__((key, self)) # type: ignore


def configure(*markers: Marker) -> Callable[[CallableType], CallableType]:
"""Decorator for configuring functions.
Configuration markers are implemented via `typing.Annotated` and straightforward to
apply to types, for example:
```python
field: tyro.conf.FlagConversionOff[bool]
```
This decorator makes markers applicable to general functions as well:
```python
# Recursively apply FlagConversionOff to all field in `main()`.
@tyro.conf.configure_function(tyro.conf.FlagConversionOff)
def main(field: bool) -> None:
...
```
"""

def _inner(callable: CallableType) -> CallableType:
return Annotated.__class_getitem__((callable,) + tuple(markers)) # type: ignore

return _inner


if not TYPE_CHECKING:

def _make_marker(description: str) -> Marker:
Expand Down

0 comments on commit fd13815

Please sign in to comment.