From 8ccd20b683d12cc72de5382d5c7dfd694413c11f Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 11 Mar 2023 22:19:28 -0800 Subject: [PATCH] Support container edge case in subcommands --- tests/test_nested.py | 25 +++++++++++++++++++++++++ tyro/_calling.py | 6 +++++- tyro/_cli.py | 2 +- tyro/_docstrings.py | 12 +----------- tyro/_parsers.py | 2 -- tyro/_strings.py | 17 +++++++++++++++-- tyro/_subcommand_matching.py | 2 +- 7 files changed, 48 insertions(+), 18 deletions(-) diff --git a/tests/test_nested.py b/tests/test_nested.py index 66cd6dac..e8c90a57 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -726,6 +726,31 @@ def main(x: Tuple[Tuple[Color], Location, float]): ) == ((Color(255, 0, 0),), Location(5.0, 0.0, 2.0), 4.0) +def test_tuple_nesting_union() -> None: + @dataclasses.dataclass(frozen=True) + class Color: + r: int + g: int + b: int + + @dataclasses.dataclass(frozen=True) + class Location: + x: float + y: float + z: float + + def main(x: Union[Tuple[Tuple[Color], Location, float], Tuple[Color, Color]]): + return x + + assert tyro.cli( + main, + args=( + "x:tuple-tuple-color-location-float --x.0.0.r 255 --x.0.0.g 0 --x.0.0.b 0" + " --x.1.x 5.0 --x.1.y 0.0 --x.1.z 2.0 --x.2 4.0".split(" ") + ), + ) == ((Color(255, 0, 0),), Location(5.0, 0.0, 2.0), 4.0) + + def test_generic_subparsers() -> None: T = TypeVar("T") diff --git a/tyro/_calling.py b/tyro/_calling.py index ec6b38ac..e38dc386 100644 --- a/tyro/_calling.py +++ b/tyro/_calling.py @@ -155,7 +155,11 @@ def get_value_from_arg(prefixed_field_name: str) -> Any: value, consumed_keywords_child = call_from_args( chosen_f, subparser_def.parser_from_name[subparser_name], - field.default if type(field.default) is chosen_f else None, + ( + field.default + if type(field.default) is chosen_f + else _fields.MISSING_NONPROP + ), value_from_prefixed_field_name, field_name_prefix=prefixed_field_name, ) diff --git a/tyro/_cli.py b/tyro/_cli.py index 683a3612..591c3a78 100644 --- a/tyro/_cli.py +++ b/tyro/_cli.py @@ -326,7 +326,7 @@ def _cli_impl( and modified_args[fixed] != arg ): raise RuntimeError( - f"Ambiguous arguments: " + modified_args[fixed] + " and " + arg + "Ambiguous arguments: " + modified_args[fixed] + " and " + arg ) modified_args[fixed] = arg args[index] = fixed diff --git a/tyro/_docstrings.py b/tyro/_docstrings.py index ac172c7b..265562ce 100644 --- a/tyro/_docstrings.py +++ b/tyro/_docstrings.py @@ -7,17 +7,7 @@ import io import itertools import tokenize -from typing import ( - Callable, - Dict, - Generic, - Hashable, - List, - Optional, - Type, - TypeVar, - cast, -) +from typing import Callable, Dict, Generic, Hashable, List, Optional, Type, TypeVar import docstring_parser from typing_extensions import get_origin, is_typeddict diff --git a/tyro/_parsers.py b/tyro/_parsers.py index f492c821..b56f4d93 100644 --- a/tyro/_parsers.py +++ b/tyro/_parsers.py @@ -30,7 +30,6 @@ _resolver, _strings, _subcommand_matching, - _unsafe_cache, ) from ._typing import TypeForm from .conf import _confstruct, _markers @@ -355,7 +354,6 @@ def from_field( option, found_subcommand_configs = _resolver.unwrap_annotated( option, _confstruct._SubcommandConfiguration ) - default_hash = None if len(found_subcommand_configs) != 0: # Explicitly annotated default. assert len(found_subcommand_configs) == 1, ( diff --git a/tyro/_strings.py b/tyro/_strings.py index 8b31fc50..7ca994e7 100644 --- a/tyro/_strings.py +++ b/tyro/_strings.py @@ -5,6 +5,8 @@ import textwrap from typing import Iterable, List, Sequence, Tuple, Type, Union +from typing_extensions import get_args, get_origin + from . import _resolver dummy_field_name = "__tyro_dummy_field__" @@ -78,9 +80,20 @@ def _subparser_name_from_type(cls: Type) -> Tuple[str, bool]: return found_name, prefix_name # Subparser name from class name. + def get_name(cls: Type) -> str: + if hasattr(cls, "__name__"): + return hyphen_separated_from_camel_case(cls.__name__) + elif hasattr(get_origin(cls), "__name__"): + parts = [get_origin(cls).__name__] # type: ignore + parts.extend(map(get_name, get_args(cls))) + return "-".join(parts) + else: + raise AssertionError( + f"Tried to interpret {cls} as a subcommand, but could not infer name" + ) + if len(type_from_typevar) == 0: - assert hasattr(cls, "__name__") - return hyphen_separated_from_camel_case(cls.__name__), prefix_name # type: ignore + return get_name(cls), prefix_name # type: ignore return ( "-".join( diff --git a/tyro/_subcommand_matching.py b/tyro/_subcommand_matching.py index c7bf6bfc..b7e6c6b3 100644 --- a/tyro/_subcommand_matching.py +++ b/tyro/_subcommand_matching.py @@ -5,7 +5,7 @@ from typing_extensions import get_args, get_origin -from . import _fields, _instantiators, _resolver, _typing, _unsafe_cache +from . import _fields, _instantiators, _resolver, _typing from .conf import _confstruct