Skip to content

Commit

Permalink
Workaround for pydantic v1, which strips typing.Annotated metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 24, 2024
1 parent 1a9448a commit 96cdc01
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/tyro/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,10 @@ def _field_list_from_pydantic(
cls_cast = cast(pydantic_v1.BaseModel, cls) # type: ignore
else:
cls_cast = cls

hints = _resolver.get_type_hints_with_backported_syntax(
cls, include_extras=True
)
for pd1_field in cls_cast.__fields__.values():
helptext = pd1_field.field_info.description
if helptext is None:
Expand All @@ -686,7 +690,7 @@ def _field_list_from_pydantic(
field_list.append(
FieldDefinition.make(
name=pd1_field.name,
type_or_callable=pd1_field.outer_type_,
type_or_callable=hints[pd1_field.name],
default=default,
is_default_from_default_instance=is_default_from_default_instance,
helptext=helptext,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_py311_generated/test_pydantic_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Annotated, cast

import pytest
from helptext_utils import get_helptext_with_checks
from pydantic import BaseModel, Field, v1

import tyro
Expand Down Expand Up @@ -51,6 +52,32 @@ class ManyTypesA(v1.BaseModel):
) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~"))


def test_pydantic_v1_conf() -> None:
class ManyTypesA(v1.BaseModel):
i: int
"""This is a docstring."""
s: tyro.conf.Suppress[str] = "hello"
f: float = v1.Field(default_factory=lambda: 3.0)

class ManyTypesB(ManyTypesA):
p: pathlib.Path

# We can directly pass a dataclass to `tyro.cli()`:
assert tyro.cli(
ManyTypesB,
args=[
"--i",
"5",
"--p",
"~",
],
) == ManyTypesB(i=5, s="hello", f=3.0, p=pathlib.Path("~"))
helptext = get_helptext_with_checks(ManyTypesB)
assert "--s" not in helptext
assert "--i" in helptext
assert "This is a docstring" in helptext


def test_pydantic_helptext() -> None:
class Helptext(BaseModel):
"""This docstring should be printed as a description."""
Expand Down
27 changes: 27 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import cast

import pytest
from helptext_utils import get_helptext_with_checks
from pydantic import BaseModel, Field, v1
from typing_extensions import Annotated

Expand Down Expand Up @@ -52,6 +53,32 @@ class ManyTypesA(v1.BaseModel):
) == ManyTypesA(i=5, s="hello", f=3.0, p=pathlib.Path("~"))


def test_pydantic_v1_conf() -> None:
class ManyTypesA(v1.BaseModel):
i: int
"""This is a docstring."""
s: tyro.conf.Suppress[str] = "hello"
f: float = v1.Field(default_factory=lambda: 3.0)

class ManyTypesB(ManyTypesA):
p: pathlib.Path

# We can directly pass a dataclass to `tyro.cli()`:
assert tyro.cli(
ManyTypesB,
args=[
"--i",
"5",
"--p",
"~",
],
) == ManyTypesB(i=5, s="hello", f=3.0, p=pathlib.Path("~"))
helptext = get_helptext_with_checks(ManyTypesB)
assert "--s" not in helptext
assert "--i" in helptext
assert "This is a docstring" in helptext


def test_pydantic_helptext() -> None:
class Helptext(BaseModel):
"""This docstring should be printed as a description."""
Expand Down

0 comments on commit 96cdc01

Please sign in to comment.