diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 275595dd..ed0b8276 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -674,6 +674,10 @@ def _field_list_from_pydantic( cls_cast = cast(pydantic_v1.BaseModel, cls) # type: ignore else: cls_cast = cls + + hints = _resolver.get_type_hints_with_backported_syntax( + cls, include_extras=True + ) for pd1_field in cls_cast.__fields__.values(): helptext = pd1_field.field_info.description if helptext is None: @@ -686,7 +690,7 @@ def _field_list_from_pydantic( field_list.append( FieldDefinition.make( name=pd1_field.name, - type_or_callable=pd1_field.outer_type_, + type_or_callable=hints[pd1_field.name], default=default, is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, diff --git a/tests/test_py311_generated/test_pydantic_generated.py b/tests/test_py311_generated/test_pydantic_generated.py index c54ef87d..c0892be1 100644 --- a/tests/test_py311_generated/test_pydantic_generated.py +++ b/tests/test_py311_generated/test_pydantic_generated.py @@ -6,6 +6,7 @@ from typing import Annotated, cast import pytest +from helptext_utils import get_helptext_with_checks from pydantic import BaseModel, Field, v1 import tyro @@ -51,6 +52,32 @@ class ManyTypesA(v1.BaseModel): ) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~")) +def test_pydantic_v1_conf() -> None: + class ManyTypesA(v1.BaseModel): + i: int + """This is a docstring.""" + s: tyro.conf.Suppress[str] = "hello" + f: float = v1.Field(default_factory=lambda: 3.0) + + class ManyTypesB(ManyTypesA): + p: pathlib.Path + + # We can directly pass a dataclass to `tyro.cli()`: + assert tyro.cli( + ManyTypesB, + args=[ + "--i", + "5", + "--p", + "~", + ], + ) == ManyTypesB(i=5, s="hello", f=3.0, p=pathlib.Path("~")) + helptext = get_helptext_with_checks(ManyTypesB) + assert "--s" not in helptext + assert "--i" in helptext + assert "This is a docstring" in helptext + + def test_pydantic_helptext() -> None: class Helptext(BaseModel): """This docstring should be printed as a description.""" diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index 5a588e71..fd8d75b0 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -6,6 +6,7 @@ from typing import cast import pytest +from helptext_utils import get_helptext_with_checks from pydantic import BaseModel, Field, v1 from typing_extensions import Annotated @@ -52,6 +53,32 @@ class ManyTypesA(v1.BaseModel): ) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~")) +def test_pydantic_v1_conf() -> None: + class ManyTypesA(v1.BaseModel): + i: int + """This is a docstring.""" + s: tyro.conf.Suppress[str] = "hello" + f: float = v1.Field(default_factory=lambda: 3.0) + + class ManyTypesB(ManyTypesA): + p: pathlib.Path + + # We can directly pass a dataclass to `tyro.cli()`: + assert tyro.cli( + ManyTypesB, + args=[ + "--i", + "5", + "--p", + "~", + ], + ) == ManyTypesB(i=5, s="hello", f=3.0, p=pathlib.Path("~")) + helptext = get_helptext_with_checks(ManyTypesB) + assert "--s" not in helptext + assert "--i" in helptext + assert "This is a docstring" in helptext + + def test_pydantic_helptext() -> None: class Helptext(BaseModel): """This docstring should be printed as a description."""