Skip to content

Commit

Permalink
Support container edge case in subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Mar 12, 2023
1 parent 620073c commit 8ccd20b
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 18 deletions.
25 changes: 25 additions & 0 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,31 @@ def main(x: Tuple[Tuple[Color], Location, float]):
) == ((Color(255, 0, 0),), Location(5.0, 0.0, 2.0), 4.0)


def test_tuple_nesting_union() -> None:
@dataclasses.dataclass(frozen=True)
class Color:
r: int
g: int
b: int

@dataclasses.dataclass(frozen=True)
class Location:
x: float
y: float
z: float

def main(x: Union[Tuple[Tuple[Color], Location, float], Tuple[Color, Color]]):
return x

assert tyro.cli(
main,
args=(
"x:tuple-tuple-color-location-float --x.0.0.r 255 --x.0.0.g 0 --x.0.0.b 0"
" --x.1.x 5.0 --x.1.y 0.0 --x.1.z 2.0 --x.2 4.0".split(" ")
),
) == ((Color(255, 0, 0),), Location(5.0, 0.0, 2.0), 4.0)


def test_generic_subparsers() -> None:
T = TypeVar("T")

Expand Down
6 changes: 5 additions & 1 deletion tyro/_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ def get_value_from_arg(prefixed_field_name: str) -> Any:
value, consumed_keywords_child = call_from_args(
chosen_f,
subparser_def.parser_from_name[subparser_name],
field.default if type(field.default) is chosen_f else None,
(
field.default
if type(field.default) is chosen_f
else _fields.MISSING_NONPROP
),
value_from_prefixed_field_name,
field_name_prefix=prefixed_field_name,
)
Expand Down
2 changes: 1 addition & 1 deletion tyro/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _cli_impl(
and modified_args[fixed] != arg
):
raise RuntimeError(
f"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
"Ambiguous arguments: " + modified_args[fixed] + " and " + arg
)
modified_args[fixed] = arg
args[index] = fixed
Expand Down
12 changes: 1 addition & 11 deletions tyro/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,7 @@
import io
import itertools
import tokenize
from typing import (
Callable,
Dict,
Generic,
Hashable,
List,
Optional,
Type,
TypeVar,
cast,
)
from typing import Callable, Dict, Generic, Hashable, List, Optional, Type, TypeVar

import docstring_parser
from typing_extensions import get_origin, is_typeddict
Expand Down
2 changes: 0 additions & 2 deletions tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
_resolver,
_strings,
_subcommand_matching,
_unsafe_cache,
)
from ._typing import TypeForm
from .conf import _confstruct, _markers
Expand Down Expand Up @@ -355,7 +354,6 @@ def from_field(
option, found_subcommand_configs = _resolver.unwrap_annotated(
option, _confstruct._SubcommandConfiguration
)
default_hash = None
if len(found_subcommand_configs) != 0:
# Explicitly annotated default.
assert len(found_subcommand_configs) == 1, (
Expand Down
17 changes: 15 additions & 2 deletions tyro/_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import textwrap
from typing import Iterable, List, Sequence, Tuple, Type, Union

from typing_extensions import get_args, get_origin

from . import _resolver

dummy_field_name = "__tyro_dummy_field__"
Expand Down Expand Up @@ -78,9 +80,20 @@ def _subparser_name_from_type(cls: Type) -> Tuple[str, bool]:
return found_name, prefix_name

# Subparser name from class name.
def get_name(cls: Type) -> str:
if hasattr(cls, "__name__"):
return hyphen_separated_from_camel_case(cls.__name__)
elif hasattr(get_origin(cls), "__name__"):
parts = [get_origin(cls).__name__] # type: ignore
parts.extend(map(get_name, get_args(cls)))
return "-".join(parts)
else:
raise AssertionError(
f"Tried to interpret {cls} as a subcommand, but could not infer name"
)

if len(type_from_typevar) == 0:
assert hasattr(cls, "__name__")
return hyphen_separated_from_camel_case(cls.__name__), prefix_name # type: ignore
return get_name(cls), prefix_name # type: ignore

return (
"-".join(
Expand Down
2 changes: 1 addition & 1 deletion tyro/_subcommand_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing_extensions import get_args, get_origin

from . import _fields, _instantiators, _resolver, _typing, _unsafe_cache
from . import _fields, _instantiators, _resolver, _typing
from .conf import _confstruct


Expand Down

0 comments on commit 8ccd20b

Please sign in to comment.