From f8ac5057a55e266fcc4a28a3af2246c8d5a5bfb2 Mon Sep 17 00:00:00 2001 From: brentyi Date: Mon, 11 Nov 2024 15:34:39 -0800 Subject: [PATCH 1/7] Support Python 3.13 --- pyproject.toml | 4 +- src/tyro/_fields.py | 6 +-- src/tyro/_parsers.py | 10 ++-- src/tyro/_resolver.py | 20 ++++---- src/tyro/conf/_markers.py | 2 +- src/tyro/constructors/_struct_spec.py | 4 +- src/tyro/extras/_base_configs.py | 4 +- tests/conftest.py | 14 +++--- ...test_base_configs_nested_exclude_py313.py} | 0 tests/test_dcargs.py | 23 +--------- ...py => test_flax_min_py38_exclude_py313.py} | 0 tests/test_helptext.py | 25 +--------- ...configs_nested_exclude_py313_generated.py} | 0 .../test_dcargs_generated.py | 30 ++++-------- ..._flax_min_py38_exclude_py313_generated.py} | 0 .../test_helptext_generated.py | 22 --------- .../test_torch_exclude_py313_generated.py | 46 +++++++++++++++++++ tests/test_torch_exclude_py313.py | 46 +++++++++++++++++++ 18 files changed, 133 insertions(+), 123 deletions(-) rename tests/{test_base_configs_nested.py => test_base_configs_nested_exclude_py313.py} (100%) rename tests/{test_flax_min_py38.py => test_flax_min_py38_exclude_py313.py} (100%) rename tests/test_py311_generated/{test_base_configs_nested_generated.py => test_base_configs_nested_exclude_py313_generated.py} (100%) rename tests/test_py311_generated/{test_flax_min_py38_generated.py => test_flax_min_py38_exclude_py313_generated.py} (100%) create mode 100644 tests/test_py311_generated/test_torch_exclude_py313_generated.py create mode 100644 tests/test_torch_exclude_py313.py diff --git a/pyproject.toml b/pyproject.toml index 5ba0d3ca..e9ba46cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,14 +45,14 @@ dev = [ "pytest-cov>=3.0.0", "omegaconf>=2.2.2", "attrs>=21.4.0", - "torch>=1.10.0", + "torch>=1.10.0;python_version<='3.12'", "pyright>=1.1.349,!=1.1.379", "ruff>=0.1.13", "mypy>=1.4.1", "numpy>=1.20.0", # As of 7/27/2023, flax install fails for Python 3.7 without pinning to an # old version. But doing so breaks other Python versions. - "flax>=0.6.9;python_version>='3.8'", + "flax>=0.6.9;python_version>='3.8' and python_version<='3.12'", "pydantic>=2.5.2", "coverage[toml]>=6.5.0", "eval_type_backport>=0.1.3", diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 4e786c5f..415020bf 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -186,9 +186,7 @@ def with_new_type_stripped( self, new_type_stripped: TypeForm[Any] | Callable ) -> FieldDefinition: if get_origin(self.type) is Annotated: - new_type = Annotated.__class_getitem__( # type: ignore - (new_type_stripped, *get_args(self.type)[1:]) - ) + new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])] else: new_type = new_type_stripped return dataclasses.replace( @@ -400,7 +398,7 @@ def _field_list_from_function( # param.annotation doesn't resolve forward references. typ=typ if default_instance in MISSING_SINGLETONS - else Annotated.__class_getitem__((typ, _markers._OPTIONAL_GROUP)), # type: ignore + else Annotated[(typ, _markers._OPTIONAL_GROUP)], # type: ignore default=default, is_default_from_default_instance=False, helptext=helptext, diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index edfceb72..7a84ec01 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -434,12 +434,12 @@ def from_field( len(found_subcommand_configs) > 0 and found_subcommand_configs[0].constructor_factory is not None ): - options[i] = Annotated.__class_getitem__( # type: ignore + options[i] = Annotated[ # type: ignore ( found_subcommand_configs[0].constructor_factory(), *_resolver.unwrap_annotated(option, "all")[1], ) - ) + ] # Exit if we don't contain any nested types. if not any( @@ -541,13 +541,11 @@ def from_field( if len(annotations) == 0: option = option_origin else: - option = Annotated.__class_getitem__( # type: ignore - (option_origin,) + annotations - ) + option = Annotated[(option_origin,) + annotations] with _fields.FieldDefinition.marker_context(tuple(field.markers)): subparser = ParserSpecification.from_callable_or_type( - option, + option, # type: ignore markers=field.markers, description=subcommand_config.description, parent_classes=parent_classes, diff --git a/src/tyro/_resolver.py b/src/tyro/_resolver.py index 2473447f..f045a0cf 100644 --- a/src/tyro/_resolver.py +++ b/src/tyro/_resolver.py @@ -137,12 +137,12 @@ def resolve_newtype_and_aliases( ) -> TypeOrCallableOrNone: # Handle type aliases, eg via the `type` statement in Python 3.12. if isinstance(typ, TypeAliasType): - return Annotated.__class_getitem__( + return Annotated[ ( cast(Any, resolve_newtype_and_aliases(typ.__value__)), TyroTypeAliasBreadCrumb(typ.__name__), ) - ) + ] # We'll unwrap NewType annotations here; this is needed before issubclass # checks! @@ -156,7 +156,7 @@ def resolve_newtype_and_aliases( typ = resolve_newtype_and_aliases(getattr(typ, "__supertype__")) if return_name is not None: - typ = Annotated.__class_getitem__((typ, TyroTypeAliasBreadCrumb(return_name))) # type: ignore + typ = Annotated[(typ, TyroTypeAliasBreadCrumb(return_name))] # type: ignore return cast(TypeOrCallableOrNone, typ) @@ -192,9 +192,7 @@ def narrow_subtypes( if superclass is Any or issubclass(potential_subclass, superclass): # type: ignore if get_origin(typ) is Annotated: - return Annotated.__class_getitem__( # type: ignore - (potential_subclass,) + get_args(typ)[1:] - ) + return Annotated[(potential_subclass,) + get_args(typ)[1:]] # type: ignore typ = cast(TypeOrCallable, potential_subclass) except TypeError: # TODO: document where this TypeError can be raised, and reduce the amount of @@ -221,9 +219,7 @@ def swap_type_using_confstruct(typ: TypeOrCallable) -> TypeOrCallable: ) and anno.constructor_factory is not None ): - return Annotated.__class_getitem__( # type: ignore - (anno.constructor_factory(),) + annotations - ) + return Annotated[(anno.constructor_factory(),) + annotations] # type: ignore return typ @@ -589,7 +585,7 @@ def resolve_generic_types( return typ, type_from_typevar else: return ( - Annotated.__class_getitem__((typ, *annotations)), # type: ignore + Annotated[(typ, *annotations)], # type: ignore type_from_typevar, ) @@ -625,12 +621,12 @@ def get_type_hints_with_backported_syntax( non_none = args[1] if args[0] is NoneType else args[0] if get_origin(non_none) is Annotated: annotated_args = get_args(non_none) - out[k] = Annotated.__class_getitem__( # type: ignore + out[k] = Annotated[ # type: ignore ( Union.__getitem__((annotated_args[0], None)), # type: ignore *annotated_args[1:], ) - ) + ] return out except TypeError as e: # pragma: no cover diff --git a/src/tyro/conf/_markers.py b/src/tyro/conf/_markers.py index 4e6bf1b0..227da7d4 100644 --- a/src/tyro/conf/_markers.py +++ b/src/tyro/conf/_markers.py @@ -170,7 +170,7 @@ class Args: class _Marker(_singleton.Singleton): def __getitem__(self, key): - return Annotated.__class_getitem__((key, self)) # type: ignore + return Annotated[(key, self)] # type: ignore Marker = Any diff --git a/src/tyro/constructors/_struct_spec.py b/src/tyro/constructors/_struct_spec.py index 1da5502d..d826aef2 100644 --- a/src/tyro/constructors/_struct_spec.py +++ b/src/tyro/constructors/_struct_spec.py @@ -531,9 +531,9 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None: StructFieldSpec( name=name, type=( - Annotated.__class_getitem__( # type: ignore + Annotated[ # type: ignore (pd2_field.annotation,) + tuple(pd2_field.metadata) - ) + ] if len(pd2_field.metadata) > 0 else pd2_field.annotation ), diff --git a/src/tyro/extras/_base_configs.py b/src/tyro/extras/_base_configs.py index 44f90efa..51b8bbe8 100644 --- a/src/tyro/extras/_base_configs.py +++ b/src/tyro/extras/_base_configs.py @@ -134,7 +134,7 @@ def subcommand_type_from_defaults( assert len(defaults) >= 2, "At least two subcommands are required." return Union.__getitem__( # type: ignore tuple( - Annotated.__class_getitem__( # type: ignore + Annotated[ # type: ignore ( type(v), tyro.conf.subcommand( @@ -144,7 +144,7 @@ def subcommand_type_from_defaults( prefix_name=prefix_names, ), ) - ) + ] for k, v in defaults.items() ) ) diff --git a/tests/conftest.py b/tests/conftest.py index 23d01aae..b4605fec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,20 +4,20 @@ collect_ignore_glob: List[str] = [] if not sys.version_info >= (3, 8): - collect_ignore_glob.append("*_min_py38.py") - collect_ignore_glob.append("*_min_py38_generated.py") + collect_ignore_glob.append("*min_py38*.py") if not sys.version_info >= (3, 9): - collect_ignore_glob.append("*_min_py39.py") - collect_ignore_glob.append("*_min_py39_generated.py") + collect_ignore_glob.append("*min_py39*.py") if not sys.version_info >= (3, 10): - collect_ignore_glob.append("*_min_py310.py") + collect_ignore_glob.append("*min_py310*.py") collect_ignore_glob.append("*_min_py310_generated.py") if not sys.version_info >= (3, 12): - collect_ignore_glob.append("*_min_py312.py") - collect_ignore_glob.append("*_min_py312_generated.py") + collect_ignore_glob.append("*min_py312*.py") if not sys.version_info >= (3, 11): collect_ignore_glob.append("test_py311_generated/*.py") + +if sys.version_info >= (3, 13): + collect_ignore_glob.append("*_exclude_py313*.py") diff --git a/tests/test_base_configs_nested.py b/tests/test_base_configs_nested_exclude_py313.py similarity index 100% rename from tests/test_base_configs_nested.py rename to tests/test_base_configs_nested_exclude_py313.py diff --git a/tests/test_dcargs.py b/tests/test_dcargs.py index 317ad1b7..1c71c726 100644 --- a/tests/test_dcargs.py +++ b/tests/test_dcargs.py @@ -20,10 +20,8 @@ ) import pytest -import torch -from typing_extensions import Annotated, Final, Literal, TypeAlias - import tyro +from typing_extensions import Annotated, Final, Literal, TypeAlias def test_no_args() -> None: @@ -608,25 +606,6 @@ def test_missing_singleton() -> None: assert tyro.MISSING is copy.deepcopy(tyro.MISSING) -def test_torch_device() -> None: - def main(device: torch.device) -> torch.device: - return device - - assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu") - - -def test_supports_inference_mode_decorator() -> None: - @torch.inference_mode() - def main(x: int, device: str) -> Tuple[int, str]: - return x, device - - assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda") - - -def test_torch_device_2() -> None: - assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu") - - def test_just_int() -> None: assert tyro.cli(int, args=["123"]) == 123 diff --git a/tests/test_flax_min_py38.py b/tests/test_flax_min_py38_exclude_py313.py similarity index 100% rename from tests/test_flax_min_py38.py rename to tests/test_flax_min_py38_exclude_py313.py diff --git a/tests/test_helptext.py b/tests/test_helptext.py index 2301285c..4844e751 100644 --- a/tests/test_helptext.py +++ b/tests/test_helptext.py @@ -6,11 +6,10 @@ from collections.abc import Callable from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast -from helptext_utils import get_helptext_with_checks -from torch import nn +import tyro from typing_extensions import Annotated, Literal, NotRequired, TypedDict -import tyro +from helptext_utils import get_helptext_with_checks def test_helptext() -> None: @@ -662,26 +661,6 @@ class Something( assert "But this text should!" in helptext -def test_unparsable() -> None: - class Struct: - a: int = 5 - b: str = "7" - - def main(x: Any = Struct()): - pass - - helptext = get_helptext_with_checks(main) - assert "--x {fixed}" not in helptext - - def main2(x: Callable = nn.ReLU): - pass - - helptext = get_helptext_with_checks(main2) - assert "--x {fixed}" in helptext - assert "(fixed to:" in helptext - assert "torch" in helptext - - def test_pathlike() -> None: def main(x: os.PathLike) -> None: pass diff --git a/tests/test_py311_generated/test_base_configs_nested_generated.py b/tests/test_py311_generated/test_base_configs_nested_exclude_py313_generated.py similarity index 100% rename from tests/test_py311_generated/test_base_configs_nested_generated.py rename to tests/test_py311_generated/test_base_configs_nested_exclude_py313_generated.py diff --git a/tests/test_py311_generated/test_dcargs_generated.py b/tests/test_py311_generated/test_dcargs_generated.py index 7f45093b..909b3a03 100644 --- a/tests/test_py311_generated/test_dcargs_generated.py +++ b/tests/test_py311_generated/test_dcargs_generated.py @@ -16,13 +16,13 @@ List, Literal, Optional, + Text, Tuple, TypeAlias, TypeVar, ) import pytest -import torch import tyro @@ -575,6 +575,15 @@ def main(x: AnyStr) -> AnyStr: assert tyro.cli(main, args=["--x", "hello„"]) == "hello„" +def test_text() -> None: + # `Text` is an alias for `str` in Python 3. + def main(x: Text) -> Text: + return x + + assert tyro.cli(main, args=["--x", "hello"]) == "hello" + assert tyro.cli(main, args=["--x", "hello„"]) == "hello„" + + def test_fixed() -> None: def main(x: Callable[[int], int] = lambda x: x * 2) -> Callable[[int], int]: return x @@ -600,25 +609,6 @@ def test_missing_singleton() -> None: assert tyro.MISSING is copy.deepcopy(tyro.MISSING) -def test_torch_device() -> None: - def main(device: torch.device) -> torch.device: - return device - - assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu") - - -def test_supports_inference_mode_decorator() -> None: - @torch.inference_mode() - def main(x: int, device: str) -> Tuple[int, str]: - return x, device - - assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda") - - -def test_torch_device_2() -> None: - assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu") - - def test_just_int() -> None: assert tyro.cli(int, args=["123"]) == 123 diff --git a/tests/test_py311_generated/test_flax_min_py38_generated.py b/tests/test_py311_generated/test_flax_min_py38_exclude_py313_generated.py similarity index 100% rename from tests/test_py311_generated/test_flax_min_py38_generated.py rename to tests/test_py311_generated/test_flax_min_py38_exclude_py313_generated.py diff --git a/tests/test_py311_generated/test_helptext_generated.py b/tests/test_py311_generated/test_helptext_generated.py index 3d0f8fb8..92eac7b5 100644 --- a/tests/test_py311_generated/test_helptext_generated.py +++ b/tests/test_py311_generated/test_helptext_generated.py @@ -3,7 +3,6 @@ import json import os import pathlib -from collections.abc import Callable from typing import ( Annotated, Any, @@ -20,7 +19,6 @@ ) from helptext_utils import get_helptext_with_checks -from torch import nn import tyro @@ -663,26 +661,6 @@ class Something( assert "But this text should!" in helptext -def test_unparsable() -> None: - class Struct: - a: int = 5 - b: str = "7" - - def main(x: Any = Struct()): - pass - - helptext = get_helptext_with_checks(main) - assert "--x {fixed}" not in helptext - - def main2(x: Callable = nn.ReLU): - pass - - helptext = get_helptext_with_checks(main2) - assert "--x {fixed}" in helptext - assert "(fixed to:" in helptext - assert "torch" in helptext - - def test_pathlike() -> None: def main(x: os.PathLike) -> None: pass diff --git a/tests/test_py311_generated/test_torch_exclude_py313_generated.py b/tests/test_py311_generated/test_torch_exclude_py313_generated.py new file mode 100644 index 00000000..8284a163 --- /dev/null +++ b/tests/test_py311_generated/test_torch_exclude_py313_generated.py @@ -0,0 +1,46 @@ +from typing import Any, Callable, Tuple + +import torch +from helptext_utils import get_helptext_with_checks +from torch import nn + +import tyro + + +def test_torch_device() -> None: + def main(device: torch.device) -> torch.device: + return device + + assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu") + + +def test_supports_inference_mode_decorator() -> None: + @torch.inference_mode() + def main(x: int, device: str) -> Tuple[int, str]: + return x, device + + assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda") + + +def test_torch_device_2() -> None: + assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu") + + +def test_unparsable() -> None: + class Struct: + a: int = 5 + b: str = "7" + + def main(x: Any = Struct()): + pass + + helptext = get_helptext_with_checks(main) + assert "--x {fixed}" not in helptext + + def main2(x: Callable = nn.ReLU): + pass + + helptext = get_helptext_with_checks(main2) + assert "--x {fixed}" in helptext + assert "(fixed to:" in helptext + assert "torch" in helptext diff --git a/tests/test_torch_exclude_py313.py b/tests/test_torch_exclude_py313.py new file mode 100644 index 00000000..bad381b9 --- /dev/null +++ b/tests/test_torch_exclude_py313.py @@ -0,0 +1,46 @@ +from typing import Any, Callable, Tuple + +import torch +import tyro +from torch import nn + +from helptext_utils import get_helptext_with_checks + + +def test_torch_device() -> None: + def main(device: torch.device) -> torch.device: + return device + + assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu") + + +def test_supports_inference_mode_decorator() -> None: + @torch.inference_mode() + def main(x: int, device: str) -> Tuple[int, str]: + return x, device + + assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda") + + +def test_torch_device_2() -> None: + assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu") + + +def test_unparsable() -> None: + class Struct: + a: int = 5 + b: str = "7" + + def main(x: Any = Struct()): + pass + + helptext = get_helptext_with_checks(main) + assert "--x {fixed}" not in helptext + + def main2(x: Callable = nn.ReLU): + pass + + helptext = get_helptext_with_checks(main2) + assert "--x {fixed}" in helptext + assert "(fixed to:" in helptext + assert "torch" in helptext From 5b9901928ba1c29467178e576c144394b5dd7c31 Mon Sep 17 00:00:00 2001 From: brentyi Date: Mon, 11 Nov 2024 15:34:55 -0800 Subject: [PATCH 2/7] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e9ba46cd..fffa4ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent" ] From 16dd501516abdb46e4b1aca795f0529bddc59b95 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 12 Nov 2024 15:42:26 -0800 Subject: [PATCH 3/7] Test annotated struct narrowing --- src/tyro/_fields.py | 6 +++--- tests/test_nested.py | 19 +++++++++++++++++-- .../test_nested_generated.py | 15 +++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 415020bf..03e7d1a9 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -186,12 +186,12 @@ def with_new_type_stripped( self, new_type_stripped: TypeForm[Any] | Callable ) -> FieldDefinition: if get_origin(self.type) is Annotated: - new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])] + new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])] # type: ignore else: - new_type = new_type_stripped + new_type = new_type_stripped # type: ignore return dataclasses.replace( self, - type=new_type, + type=new_type, # type: ignore type_stripped=new_type_stripped, ) diff --git a/tests/test_nested.py b/tests/test_nested.py index 98ec5cb6..d8815bc9 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,11 +2,11 @@ from typing import Any, Generic, Mapping, NewType, Optional, Tuple, TypeVar, Union import pytest +import tyro from frozendict import frozendict # type: ignore -from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated, Final, Literal -import tyro +from helptext_utils import get_helptext_with_checks def test_nested() -> None: @@ -1231,3 +1231,18 @@ class Args: assert tyro.cli(Args, args=[]) == Args(A("hello")) assert "default: inner:alt" in get_helptext_with_checks(Args) + + +def test_annotated_narrow() -> None: + @dataclasses.dataclass + class A: ... + + @dataclasses.dataclass + class B(A): + x: int + + def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any: + return x + + assert tyro.cli(main, args=[]) == B(x=3) + assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5) diff --git a/tests/test_py311_generated/test_nested_generated.py b/tests/test_py311_generated/test_nested_generated.py index 5a611b42..fafdeb45 100644 --- a/tests/test_py311_generated/test_nested_generated.py +++ b/tests/test_py311_generated/test_nested_generated.py @@ -1241,3 +1241,18 @@ class Args: assert tyro.cli(Args, args=[]) == Args(A("hello")) assert "default: inner:alt" in get_helptext_with_checks(Args) + + +def test_annotated_narrow() -> None: + @dataclasses.dataclass + class A: ... + + @dataclasses.dataclass + class B(A): + x: int + + def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any: + return x + + assert tyro.cli(main, args=[]) == B(x=3) + assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5) From 1fd4fc4e401ca126798cf614f1d9f24b1c5fdb59 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 12 Nov 2024 15:43:02 -0800 Subject: [PATCH 4/7] ruff --- tests/test_dcargs.py | 3 ++- tests/test_helptext.py | 5 ++--- tests/test_nested.py | 4 ++-- tests/test_torch_exclude_py313.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_dcargs.py b/tests/test_dcargs.py index 1c71c726..b82543d1 100644 --- a/tests/test_dcargs.py +++ b/tests/test_dcargs.py @@ -20,9 +20,10 @@ ) import pytest -import tyro from typing_extensions import Annotated, Final, Literal, TypeAlias +import tyro + def test_no_args() -> None: def main() -> int: diff --git a/tests/test_helptext.py b/tests/test_helptext.py index 4844e751..768da2f6 100644 --- a/tests/test_helptext.py +++ b/tests/test_helptext.py @@ -3,13 +3,12 @@ import json import os import pathlib -from collections.abc import Callable from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast -import tyro +from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated, Literal, NotRequired, TypedDict -from helptext_utils import get_helptext_with_checks +import tyro def test_helptext() -> None: diff --git a/tests/test_nested.py b/tests/test_nested.py index d8815bc9..e328370a 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,11 +2,11 @@ from typing import Any, Generic, Mapping, NewType, Optional, Tuple, TypeVar, Union import pytest -import tyro from frozendict import frozendict # type: ignore +from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated, Final, Literal -from helptext_utils import get_helptext_with_checks +import tyro def test_nested() -> None: diff --git a/tests/test_torch_exclude_py313.py b/tests/test_torch_exclude_py313.py index bad381b9..8284a163 100644 --- a/tests/test_torch_exclude_py313.py +++ b/tests/test_torch_exclude_py313.py @@ -1,10 +1,10 @@ from typing import Any, Callable, Tuple import torch -import tyro +from helptext_utils import get_helptext_with_checks from torch import nn -from helptext_utils import get_helptext_with_checks +import tyro def test_torch_device() -> None: From 91d487006c4117b5cb9f668e8caf923b01983c94 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 12 Nov 2024 15:48:46 -0800 Subject: [PATCH 5/7] mypy --- src/tyro/_parsers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tyro/_parsers.py b/src/tyro/_parsers.py index 7a84ec01..41a9d84b 100644 --- a/src/tyro/_parsers.py +++ b/src/tyro/_parsers.py @@ -541,7 +541,7 @@ def from_field( if len(annotations) == 0: option = option_origin else: - option = Annotated[(option_origin,) + annotations] + option = Annotated[(option_origin,) + annotations] # type: ignore with _fields.FieldDefinition.marker_context(tuple(field.markers)): subparser = ParserSpecification.from_callable_or_type( From f40f86d77bbad4c9ed5f0b28880166c8aacb7ad1 Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 12 Nov 2024 20:27:08 -0800 Subject: [PATCH 6/7] Add narrow_subtype test --- tests/test_nested.py | 19 ++++++++++++++++--- .../test_nested_generated.py | 15 ++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/test_nested.py b/tests/test_nested.py index e328370a..be90dfa8 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,11 +2,11 @@ from typing import Any, Generic, Mapping, NewType, Optional, Tuple, TypeVar, Union import pytest +import tyro from frozendict import frozendict # type: ignore -from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated, Final, Literal -import tyro +from helptext_utils import get_helptext_with_checks def test_nested() -> None: @@ -1233,7 +1233,7 @@ class Args: assert "default: inner:alt" in get_helptext_with_checks(Args) -def test_annotated_narrow() -> None: +def test_annotated_narrow_0() -> None: @dataclasses.dataclass class A: ... @@ -1246,3 +1246,16 @@ def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any: assert tyro.cli(main, args=[]) == B(x=3) assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5) + + +def test_annotated_narrow_1() -> None: + @dataclasses.dataclass + class A: ... + + @dataclasses.dataclass + class B(A): + x: int + + from tyro._resolver import narrow_subtypes + + assert narrow_subtypes(Annotated[A, False], B(3)) == Annotated[B, False] # type: ignore diff --git a/tests/test_py311_generated/test_nested_generated.py b/tests/test_py311_generated/test_nested_generated.py index fafdeb45..97324544 100644 --- a/tests/test_py311_generated/test_nested_generated.py +++ b/tests/test_py311_generated/test_nested_generated.py @@ -1243,7 +1243,7 @@ class Args: assert "default: inner:alt" in get_helptext_with_checks(Args) -def test_annotated_narrow() -> None: +def test_annotated_narrow_0() -> None: @dataclasses.dataclass class A: ... @@ -1256,3 +1256,16 @@ def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any: assert tyro.cli(main, args=[]) == B(x=3) assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5) + + +def test_annotated_narrow_1() -> None: + @dataclasses.dataclass + class A: ... + + @dataclasses.dataclass + class B(A): + x: int + + from tyro._resolver import narrow_subtypes + + assert narrow_subtypes(Annotated[A, False], B(3)) == Annotated[B, False] # type: ignore From 1cc176856f078aad6bcd6da12f14f6a4cc7474bf Mon Sep 17 00:00:00 2001 From: brentyi Date: Tue, 12 Nov 2024 20:38:25 -0800 Subject: [PATCH 7/7] ruff --- tests/test_nested.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_nested.py b/tests/test_nested.py index be90dfa8..9bc1748b 100644 --- a/tests/test_nested.py +++ b/tests/test_nested.py @@ -2,11 +2,11 @@ from typing import Any, Generic, Mapping, NewType, Optional, Tuple, TypeVar, Union import pytest -import tyro from frozendict import frozendict # type: ignore +from helptext_utils import get_helptext_with_checks from typing_extensions import Annotated, Final, Literal -from helptext_utils import get_helptext_with_checks +import tyro def test_nested() -> None: