From ea22883b166000bbd246bce7b19af3a1caea8d0f Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Thu, 24 Oct 2024 13:58:04 -0700 Subject: [PATCH] `typing.Annotated` workaround for pydantic v1 (#187) * Workaround for pydantic v1, which strips `typing.Annotated` metadata * docstring_parser doesn't seem to work with Pydantic v1 + Python 3.8 * sys version check --- src/tyro/_fields.py | 6 +++- .../test_pydantic_generated.py | 27 +++++++++++++++ tests/test_pydantic.py | 33 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/tyro/_fields.py b/src/tyro/_fields.py index 275595dd..ed0b8276 100644 --- a/src/tyro/_fields.py +++ b/src/tyro/_fields.py @@ -674,6 +674,10 @@ def _field_list_from_pydantic( cls_cast = cast(pydantic_v1.BaseModel, cls) # type: ignore else: cls_cast = cls + + hints = _resolver.get_type_hints_with_backported_syntax( + cls, include_extras=True + ) for pd1_field in cls_cast.__fields__.values(): helptext = pd1_field.field_info.description if helptext is None: @@ -686,7 +690,7 @@ def _field_list_from_pydantic( field_list.append( FieldDefinition.make( name=pd1_field.name, - type_or_callable=pd1_field.outer_type_, + type_or_callable=hints[pd1_field.name], default=default, is_default_from_default_instance=is_default_from_default_instance, helptext=helptext, diff --git a/tests/test_py311_generated/test_pydantic_generated.py b/tests/test_py311_generated/test_pydantic_generated.py index c54ef87d..c0892be1 100644 --- a/tests/test_py311_generated/test_pydantic_generated.py +++ b/tests/test_py311_generated/test_pydantic_generated.py @@ -6,6 +6,7 @@ from typing import Annotated, cast import pytest +from helptext_utils import get_helptext_with_checks from pydantic import BaseModel, Field, v1 import tyro @@ -51,6 +52,32 @@ class ManyTypesA(v1.BaseModel): ) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~")) +def test_pydantic_v1_conf() -> None: + class ManyTypesA(v1.BaseModel): + i: int + """This is a docstring.""" + s: tyro.conf.Suppress[str] = "hello" + f: float = v1.Field(default_factory=lambda: 3.0) + + class ManyTypesB(ManyTypesA): + p: pathlib.Path + + # We can directly pass a dataclass to `tyro.cli()`: + assert tyro.cli( + ManyTypesB, + args=[ + "--i", + "5", + "--p", + "~", + ], + ) == ManyTypesB(i=5, s="hello", f=3.0, p=pathlib.Path("~")) + helptext = get_helptext_with_checks(ManyTypesB) + assert "--s" not in helptext + assert "--i" in helptext + assert "This is a docstring" in helptext + + def test_pydantic_helptext() -> None: class Helptext(BaseModel): """This docstring should be printed as a description.""" diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index 5a588e71..a64ca28c 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -3,9 +3,11 @@ import contextlib import io import pathlib +import sys from typing import cast import pytest +from helptext_utils import get_helptext_with_checks from pydantic import BaseModel, Field, v1 from typing_extensions import Annotated @@ -52,6 +54,37 @@ class ManyTypesA(v1.BaseModel): ) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~")) +def test_pydantic_v1_conf() -> None: + class ManyTypesA(v1.BaseModel): + i: int + """This is a docstring.""" + s: tyro.conf.Suppress[str] = "hello" + f: float = v1.Field(default_factory=lambda: 3.0) + + class ManyTypesB(ManyTypesA): + p: pathlib.Path + + # We can directly pass a dataclass to `tyro.cli()`: + assert tyro.cli( + ManyTypesB, + args=[ + "--i", + "5", + "--p", + "~", + ], + ) == ManyTypesB(i=5, s="hello", f=3.0, p=pathlib.Path("~")) + helptext = get_helptext_with_checks(ManyTypesB) + assert "--s" not in helptext + assert "--i" in helptext + + # This doesn't work when combining older versions of Python with older + # versions of pydantic. The root cause may be in the docstring_parser + # dependency. + if sys.version_info >= (3, 10): + assert "This is a docstring" in helptext + + def test_pydantic_helptext() -> None: class Helptext(BaseModel): """This docstring should be printed as a description."""