From 29d7d202a48c62804e444012300e1c87e0eee4d8 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 8 Aug 2024 16:31:34 -0700 Subject: [PATCH] Update autogenerated Python>=3.11 tests --- tests/test_py311_generated/_generate.py | 1 - .../test_collections_generated.py | 122 ++++++++++++++++++ .../test_conf_generated.py | 38 ++++++ .../test_dcargs_generated.py | 8 ++ ...st_generics_and_serialization_generated.py | 25 ++-- 5 files changed, 182 insertions(+), 12 deletions(-) diff --git a/tests/test_py311_generated/_generate.py b/tests/test_py311_generated/_generate.py index 77c93968..6ff72a43 100644 --- a/tests/test_py311_generated/_generate.py +++ b/tests/test_py311_generated/_generate.py @@ -49,7 +49,6 @@ def generate_from_path(test_path: pathlib.Path) -> None: ) out_path.write_text(content) - subprocess.run(["isort", "--profile=black", str(out_path)], check=True) subprocess.run(["ruff", "format", str(out_path)], check=True) subprocess.run(["ruff", "check", "--fix", str(out_path)], check=True) diff --git a/tests/test_py311_generated/test_collections_generated.py b/tests/test_py311_generated/test_collections_generated.py index ff03ac8d..5b6e3d81 100644 --- a/tests/test_py311_generated/test_collections_generated.py +++ b/tests/test_py311_generated/test_collections_generated.py @@ -1,8 +1,10 @@ import collections +import collections.abc import contextlib import dataclasses import enum import io +import sys from typing import ( Any, Deque, @@ -154,6 +156,63 @@ class A: tyro.cli(A, args=[]) +def test_sequences_narrow() -> None: + @dataclasses.dataclass + class A: + x: Sequence = dataclasses.field(default_factory=lambda: [0]) + + assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3]) + assert tyro.cli(A, args=[]) == A(x=[0]) + assert tyro.cli(A, args=["--x"]) == A(x=[]) + + +def test_sequences_narrow_any() -> None: + @dataclasses.dataclass + class A: + x: Sequence[Any] = dataclasses.field(default_factory=lambda: [0]) + + assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3]) + assert tyro.cli(A, args=[]) == A(x=[0]) + assert tyro.cli(A, args=["--x"]) == A(x=[]) + + +if sys.version_info >= (3, 9): + + def test_abc_sequences() -> None: + @dataclasses.dataclass + class A: + x: collections.abc.Sequence[int] + + assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3]) + assert tyro.cli(A, args=["--x"]) == A(x=[]) + with pytest.raises(SystemExit): + tyro.cli(A, args=[]) + + +def test_abc_sequences_narrow() -> None: + @dataclasses.dataclass + class A: + x: collections.abc.Sequence = dataclasses.field(default_factory=lambda: [0]) + + assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3]) + assert tyro.cli(A, args=[]) == A(x=[0]) + assert tyro.cli(A, args=["--x"]) == A(x=[]) + + +if sys.version_info >= (3, 9): + + def test_abc_sequences_narrow_any() -> None: + @dataclasses.dataclass + class A: + x: collections.abc.Sequence[Any] = dataclasses.field( + default_factory=lambda: [0] + ) + + assert tyro.cli(A, args=["--x", "1", "2", "3"]) == A(x=[1, 2, 3]) + assert tyro.cli(A, args=[]) == A(x=[0]) + assert tyro.cli(A, args=["--x"]) == A(x=[]) + + def test_lists() -> None: @dataclasses.dataclass class A: @@ -446,6 +505,27 @@ def main(x: list = [0, 1, 2, "hello"]) -> Any: assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", 5] +def test_list_narrowing_any() -> None: + def main(x: List[Any] = [0, 1, 2, "hello"]) -> Any: + return x + + assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", 5] + + +def test_list_narrowing_empty() -> None: + def main(x: list = []) -> Any: + return x + + assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", "5"] + + +def test_list_narrowing_empty_any() -> None: + def main(x: List[Any] = []) -> Any: + return x + + assert tyro.cli(main, args="--x hi there 5".split(" ")) == ["hi", "there", "5"] + + def test_set_narrowing() -> None: def main(x: set = {0, 1, 2, "hello"}) -> Any: return x @@ -453,6 +533,27 @@ def main(x: set = {0, 1, 2, "hello"}) -> Any: assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", 5} +def test_set_narrowing_any() -> None: + def main(x: Set[Any] = {0, 1, 2, "hello"}) -> Any: + return x + + assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", 5} + + +def test_set_narrowing_empty() -> None: + def main(x: set = set()) -> Any: + return x + + assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", "5"} + + +def test_set_narrowing_any_empty() -> None: + def main(x: Set[Any] = set()) -> Any: + return x + + assert tyro.cli(main, args="--x hi there 5".split(" ")) == {"hi", "there", "5"} + + def test_tuple_narrowing() -> None: def main(x: tuple = (0, 1, 2, "hello")) -> Any: return x @@ -460,6 +561,27 @@ def main(x: tuple = (0, 1, 2, "hello")) -> Any: assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == (0, 1, 2, "3") +def test_tuple_narrowing_any() -> None: + def main(x: Tuple[Any, ...] = (0, 1, 2, "hello")) -> Any: + return x + + assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == (0, 1, 2, "3") + + +def test_tuple_narrowing_empty() -> None: + def main(x: tuple = ()) -> Any: + return x + + assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == ("0", "1", "2", "3") + + +def test_tuple_narrowing_empty_any() -> None: + def main(x: Tuple[Any, ...] = ()) -> Any: + return x + + assert tyro.cli(main, args="--x 0 1 2 3".split(" ")) == ("0", "1", "2", "3") + + def test_tuple_narrowing_empty_default() -> None: def main(x: tuple = ()) -> Any: return x diff --git a/tests/test_py311_generated/test_conf_generated.py b/tests/test_py311_generated/test_conf_generated.py index dd26d8bd..7f688f52 100644 --- a/tests/test_py311_generated/test_conf_generated.py +++ b/tests/test_py311_generated/test_conf_generated.py @@ -1207,6 +1207,44 @@ def commit(branch: str) -> int: ) +def test_custom_constructor_10() -> None: + def commit(branch: str) -> int: + """Commit""" + print(f"commit branch={branch}") + return 3 + + def inner(x: Annotated[Any, tyro.conf.arg(constructor=commit)]) -> None: + return x + + def inner_no_prefix( + x: Annotated[Any, tyro.conf.arg(constructor=commit, prefix_name=False)], + ) -> None: + return x + + def outer(x: Annotated[Any, tyro.conf.arg(constructor=inner)]) -> None: + return x + + def outer_no_prefix( + x: Annotated[Any, tyro.conf.arg(constructor=inner_no_prefix)], + ) -> None: + return x + + assert ( + tyro.cli( + outer, + args="--x.x.branch 5".split(" "), + ) + == 3 + ) + assert ( + tyro.cli( + outer_no_prefix, + args="--x.branch 5".split(" "), + ) + == 3 + ) + + def test_alias() -> None: """Arguments with aliases.""" diff --git a/tests/test_py311_generated/test_dcargs_generated.py b/tests/test_py311_generated/test_dcargs_generated.py index 25b1b7c1..fadcf18e 100644 --- a/tests/test_py311_generated/test_dcargs_generated.py +++ b/tests/test_py311_generated/test_dcargs_generated.py @@ -514,6 +514,14 @@ def main(device: torch.device) -> torch.device: assert tyro.cli(main, args=["--device", "cpu"]) == torch.device("cpu") +def test_supports_inference_mode_decorator() -> None: + @torch.inference_mode() + def main(x: int, device: str) -> Tuple[int, str]: + return x, device + + assert tyro.cli(main, args="--x 3 --device cuda".split(" ")) == (3, "cuda") + + def test_torch_device_2() -> None: assert tyro.cli(torch.device, args=["cpu"]) == torch.device("cpu") diff --git a/tests/test_py311_generated/test_generics_and_serialization_generated.py b/tests/test_py311_generated/test_generics_and_serialization_generated.py index 93eea527..d832a3f3 100644 --- a/tests/test_py311_generated/test_generics_and_serialization_generated.py +++ b/tests/test_py311_generated/test_generics_and_serialization_generated.py @@ -436,20 +436,23 @@ class Wrapper: assert wrapper1 == tyro.extras.from_yaml(Wrapper, tyro.extras.to_yaml(wrapper1)) -def test_superclass() -> None: - # https://github.com/brentyi/tyro/issues/7 +@dataclasses.dataclass +class TypeA: + data: int - @dataclasses.dataclass - class TypeA: - data: int - @dataclasses.dataclass - class TypeASubclass(TypeA): - pass +@dataclasses.dataclass +class TypeASubclass(TypeA): + pass - @dataclasses.dataclass - class Wrapper: - subclass: TypeA + +@dataclasses.dataclass +class Wrapper: + subclass: TypeA + + +def test_superclass() -> None: + # https://github.com/brentyi/tyro/issues/7 wrapper1 = Wrapper(TypeASubclass(3)) # Create Wrapper object. assert wrapper1 == tyro.extras.from_yaml(Wrapper, tyro.extras.to_yaml(wrapper1))