Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Python 3.13 #200

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
Expand All @@ -45,14 +46,14 @@ dev = [
"pytest-cov>=3.0.0",
"omegaconf>=2.2.2",
"attrs>=21.4.0",
"torch>=1.10.0",
"torch>=1.10.0;python_version<='3.12'",
"pyright>=1.1.349,!=1.1.379",
"ruff>=0.1.13",
"mypy>=1.4.1",
"numpy>=1.20.0",
# As of 7/27/2023, flax install fails for Python 3.7 without pinning to an
# old version. But doing so breaks other Python versions.
"flax>=0.6.9;python_version>='3.8'",
"flax>=0.6.9;python_version>='3.8' and python_version<='3.12'",
"pydantic>=2.5.2",
"coverage[toml]>=6.5.0",
"eval_type_backport>=0.1.3",
Expand Down
10 changes: 4 additions & 6 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,12 @@ def with_new_type_stripped(
self, new_type_stripped: TypeForm[Any] | Callable
) -> FieldDefinition:
if get_origin(self.type) is Annotated:
new_type = Annotated.__class_getitem__( # type: ignore
(new_type_stripped, *get_args(self.type)[1:])
)
new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])] # type: ignore
else:
new_type = new_type_stripped
new_type = new_type_stripped # type: ignore
return dataclasses.replace(
self,
type=new_type,
type=new_type, # type: ignore
type_stripped=new_type_stripped,
)

Expand Down Expand Up @@ -400,7 +398,7 @@ def _field_list_from_function(
# param.annotation doesn't resolve forward references.
typ=typ
if default_instance in MISSING_SINGLETONS
else Annotated.__class_getitem__((typ, _markers._OPTIONAL_GROUP)), # type: ignore
else Annotated[(typ, _markers._OPTIONAL_GROUP)], # type: ignore
default=default,
is_default_from_default_instance=False,
helptext=helptext,
Expand Down
10 changes: 4 additions & 6 deletions src/tyro/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,12 @@ def from_field(
len(found_subcommand_configs) > 0
and found_subcommand_configs[0].constructor_factory is not None
):
options[i] = Annotated.__class_getitem__( # type: ignore
options[i] = Annotated[ # type: ignore
(
found_subcommand_configs[0].constructor_factory(),
*_resolver.unwrap_annotated(option, "all")[1],
)
)
]

# Exit if we don't contain any nested types.
if not any(
Expand Down Expand Up @@ -541,13 +541,11 @@ def from_field(
if len(annotations) == 0:
option = option_origin
else:
option = Annotated.__class_getitem__( # type: ignore
(option_origin,) + annotations
)
option = Annotated[(option_origin,) + annotations] # type: ignore

with _fields.FieldDefinition.marker_context(tuple(field.markers)):
subparser = ParserSpecification.from_callable_or_type(
option,
option, # type: ignore
markers=field.markers,
description=subcommand_config.description,
parent_classes=parent_classes,
Expand Down
20 changes: 8 additions & 12 deletions src/tyro/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def resolve_newtype_and_aliases(
) -> TypeOrCallableOrNone:
# Handle type aliases, eg via the `type` statement in Python 3.12.
if isinstance(typ, TypeAliasType):
return Annotated.__class_getitem__(
return Annotated[
(
cast(Any, resolve_newtype_and_aliases(typ.__value__)),
TyroTypeAliasBreadCrumb(typ.__name__),
)
)
]

# We'll unwrap NewType annotations here; this is needed before issubclass
# checks!
Expand All @@ -156,7 +156,7 @@ def resolve_newtype_and_aliases(
typ = resolve_newtype_and_aliases(getattr(typ, "__supertype__"))

if return_name is not None:
typ = Annotated.__class_getitem__((typ, TyroTypeAliasBreadCrumb(return_name))) # type: ignore
typ = Annotated[(typ, TyroTypeAliasBreadCrumb(return_name))] # type: ignore

return cast(TypeOrCallableOrNone, typ)

Expand Down Expand Up @@ -192,9 +192,7 @@ def narrow_subtypes(

if superclass is Any or issubclass(potential_subclass, superclass): # type: ignore
if get_origin(typ) is Annotated:
return Annotated.__class_getitem__( # type: ignore
(potential_subclass,) + get_args(typ)[1:]
)
return Annotated[(potential_subclass,) + get_args(typ)[1:]] # type: ignore
typ = cast(TypeOrCallable, potential_subclass)
except TypeError:
# TODO: document where this TypeError can be raised, and reduce the amount of
Expand All @@ -221,9 +219,7 @@ def swap_type_using_confstruct(typ: TypeOrCallable) -> TypeOrCallable:
)
and anno.constructor_factory is not None
):
return Annotated.__class_getitem__( # type: ignore
(anno.constructor_factory(),) + annotations
)
return Annotated[(anno.constructor_factory(),) + annotations] # type: ignore
return typ


Expand Down Expand Up @@ -589,7 +585,7 @@ def resolve_generic_types(
return typ, type_from_typevar
else:
return (
Annotated.__class_getitem__((typ, *annotations)), # type: ignore
Annotated[(typ, *annotations)], # type: ignore
type_from_typevar,
)

Expand Down Expand Up @@ -625,12 +621,12 @@ def get_type_hints_with_backported_syntax(
non_none = args[1] if args[0] is NoneType else args[0]
if get_origin(non_none) is Annotated:
annotated_args = get_args(non_none)
out[k] = Annotated.__class_getitem__( # type: ignore
out[k] = Annotated[ # type: ignore
(
Union.__getitem__((annotated_args[0], None)), # type: ignore
*annotated_args[1:],
)
)
]
return out

except TypeError as e: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion src/tyro/conf/_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Args:

class _Marker(_singleton.Singleton):
def __getitem__(self, key):
return Annotated.__class_getitem__((key, self)) # type: ignore
return Annotated[(key, self)] # type: ignore


Marker = Any
Expand Down
4 changes: 2 additions & 2 deletions src/tyro/constructors/_struct_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,9 @@ def pydantic_rule(info: StructTypeInfo) -> StructConstructorSpec | None:
StructFieldSpec(
name=name,
type=(
Annotated.__class_getitem__( # type: ignore
Annotated[ # type: ignore
(pd2_field.annotation,) + tuple(pd2_field.metadata)
)
]
if len(pd2_field.metadata) > 0
else pd2_field.annotation
),
Expand Down
4 changes: 2 additions & 2 deletions src/tyro/extras/_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def subcommand_type_from_defaults(
assert len(defaults) >= 2, "At least two subcommands are required."
return Union.__getitem__( # type: ignore
tuple(
Annotated.__class_getitem__( # type: ignore
Annotated[ # type: ignore
(
type(v),
tyro.conf.subcommand(
Expand All @@ -144,7 +144,7 @@ def subcommand_type_from_defaults(
prefix_name=prefix_names,
),
)
)
]
for k, v in defaults.items()
)
)
14 changes: 7 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
collect_ignore_glob: List[str] = []

if not sys.version_info >= (3, 8):
collect_ignore_glob.append("*_min_py38.py")
collect_ignore_glob.append("*_min_py38_generated.py")
collect_ignore_glob.append("*min_py38*.py")

if not sys.version_info >= (3, 9):
collect_ignore_glob.append("*_min_py39.py")
collect_ignore_glob.append("*_min_py39_generated.py")
collect_ignore_glob.append("*min_py39*.py")

if not sys.version_info >= (3, 10):
collect_ignore_glob.append("*_min_py310.py")
collect_ignore_glob.append("*min_py310*.py")
collect_ignore_glob.append("*_min_py310_generated.py")

if not sys.version_info >= (3, 12):
collect_ignore_glob.append("*_min_py312.py")
collect_ignore_glob.append("*_min_py312_generated.py")
collect_ignore_glob.append("*min_py312*.py")

if not sys.version_info >= (3, 11):
collect_ignore_glob.append("test_py311_generated/*.py")

if sys.version_info >= (3, 13):
collect_ignore_glob.append("*_exclude_py313*.py")
20 changes: 0 additions & 20 deletions tests/test_dcargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)

import pytest
import torch
from typing_extensions import Annotated, Final, Literal, TypeAlias

import tyro
Expand Down Expand Up @@ -608,25 +607,6 @@ def test_missing_singleton() -> None:
assert tyro.MISSING is copy.deepcopy(tyro.MISSING)


def test_torch_device() -> None:
def main(device: torch.device) -> torch.device:
return device

assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu")


def test_supports_inference_mode_decorator() -> None:
@torch.inference_mode()
def main(x: int, device: str) -> Tuple[int, str]:
return x, device

assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda")


def test_torch_device_2() -> None:
assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu")


def test_just_int() -> None:
assert tyro.cli(int, args=["123"]) == 123

Expand Down
File renamed without changes.
22 changes: 0 additions & 22 deletions tests/test_helptext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import json
import os
import pathlib
from collections.abc import Callable
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union, cast

from helptext_utils import get_helptext_with_checks
from torch import nn
from typing_extensions import Annotated, Literal, NotRequired, TypedDict

import tyro
Expand Down Expand Up @@ -662,26 +660,6 @@ class Something(
assert "But this text should!" in helptext


def test_unparsable() -> None:
class Struct:
a: int = 5
b: str = "7"

def main(x: Any = Struct()):
pass

helptext = get_helptext_with_checks(main)
assert "--x {fixed}" not in helptext

def main2(x: Callable = nn.ReLU):
pass

helptext = get_helptext_with_checks(main2)
assert "--x {fixed}" in helptext
assert "(fixed to:" in helptext
assert "torch" in helptext


def test_pathlike() -> None:
def main(x: os.PathLike) -> None:
pass
Expand Down
28 changes: 28 additions & 0 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,3 +1231,31 @@ class Args:

assert tyro.cli(Args, args=[]) == Args(A("hello"))
assert "default: inner:alt" in get_helptext_with_checks(Args)


def test_annotated_narrow_0() -> None:
@dataclasses.dataclass
class A: ...

@dataclasses.dataclass
class B(A):
x: int

def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any:
return x

assert tyro.cli(main, args=[]) == B(x=3)
assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5)


def test_annotated_narrow_1() -> None:
@dataclasses.dataclass
class A: ...

@dataclasses.dataclass
class B(A):
x: int

from tyro._resolver import narrow_subtypes

assert narrow_subtypes(Annotated[A, False], B(3)) == Annotated[B, False] # type: ignore
30 changes: 10 additions & 20 deletions tests/test_py311_generated/test_dcargs_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
List,
Literal,
Optional,
Text,
Tuple,
TypeAlias,
TypeVar,
)

import pytest
import torch

import tyro

Expand Down Expand Up @@ -575,6 +575,15 @@ def main(x: AnyStr) -> AnyStr:
assert tyro.cli(main, args=["--x", "hello„"]) == "hello„"


def test_text() -> None:
# `Text` is an alias for `str` in Python 3.
def main(x: Text) -> Text:
return x

assert tyro.cli(main, args=["--x", "hello"]) == "hello"
assert tyro.cli(main, args=["--x", "hello„"]) == "hello„"


def test_fixed() -> None:
def main(x: Callable[[int], int] = lambda x: x * 2) -> Callable[[int], int]:
return x
Expand All @@ -600,25 +609,6 @@ def test_missing_singleton() -> None:
assert tyro.MISSING is copy.deepcopy(tyro.MISSING)


def test_torch_device() -> None:
def main(device: torch.device) -> torch.device:
return device

assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu")


def test_supports_inference_mode_decorator() -> None:
@torch.inference_mode()
def main(x: int, device: str) -> Tuple[int, str]:
return x, device

assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda")


def test_torch_device_2() -> None:
assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu")


def test_just_int() -> None:
assert tyro.cli(int, args=["123"]) == 123

Expand Down
Loading
Loading