Skip to content

Commit

Permalink
Test annotated struct narrowing
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 12, 2024
1 parent 5b99019 commit 16dd501
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ def with_new_type_stripped(
self, new_type_stripped: TypeForm[Any] | Callable
) -> FieldDefinition:
if get_origin(self.type) is Annotated:
new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])]
new_type = Annotated[(new_type_stripped, *get_args(self.type)[1:])] # type: ignore
else:
new_type = new_type_stripped
new_type = new_type_stripped # type: ignore
return dataclasses.replace(
self,
type=new_type,
type=new_type, # type: ignore
type_stripped=new_type_stripped,
)

Expand Down
19 changes: 17 additions & 2 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, Generic, Mapping, NewType, Optional, Tuple, TypeVar, Union

import pytest
import tyro
from frozendict import frozendict # type: ignore
from helptext_utils import get_helptext_with_checks
from typing_extensions import Annotated, Final, Literal

import tyro
from helptext_utils import get_helptext_with_checks


def test_nested() -> None:

Check failure on line 12 in tests/test_nested.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/test_nested.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -1231,3 +1231,18 @@ class Args:

assert tyro.cli(Args, args=[]) == Args(A("hello"))
assert "default: inner:alt" in get_helptext_with_checks(Args)


def test_annotated_narrow() -> None:
@dataclasses.dataclass
class A: ...

@dataclasses.dataclass
class B(A):
x: int

def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any:
return x

assert tyro.cli(main, args=[]) == B(x=3)
assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5)
15 changes: 15 additions & 0 deletions tests/test_py311_generated/test_nested_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,3 +1241,18 @@ class Args:

assert tyro.cli(Args, args=[]) == Args(A("hello"))
assert "default: inner:alt" in get_helptext_with_checks(Args)


def test_annotated_narrow() -> None:
@dataclasses.dataclass
class A: ...

@dataclasses.dataclass
class B(A):
x: int

def main(x: Annotated[A, tyro.conf.OmitArgPrefixes] = B(x=3)) -> Any:
return x

assert tyro.cli(main, args=[]) == B(x=3)
assert tyro.cli(main, args="--x 5".split(" ")) == B(x=5)

0 comments on commit 16dd501

Please sign in to comment.