From 6440faa24abfc34c2d10e31797224f5aec2d43c0 Mon Sep 17 00:00:00 2001 From: Andrew Truong <40660973+adhtruong@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:53:48 +0100 Subject: [PATCH] fix: handle recursive collections (#587) --- polyfactory/value_generators/complex_types.py | 5 ++++- .../constrained_collections.py | 18 ++++++++++++++---- tests/test_recursive_models.py | 6 +++++- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/polyfactory/value_generators/complex_types.py b/polyfactory/value_generators/complex_types.py index beef1a39..d8280261 100644 --- a/polyfactory/value_generators/complex_types.py +++ b/polyfactory/value_generators/complex_types.py @@ -34,7 +34,10 @@ def handle_collection_type( container_type = INSTANTIABLE_TYPE_MAPPING[container_type] # type: ignore[assignment] container = container_type() - if not field_meta.children: + if field_meta.children is None or any( + child_meta.annotation in factory._get_build_context(build_context)["seen_models"] + for child_meta in field_meta.children + ): return container if issubclass(container_type, MutableMapping) or is_typeddict(container_type): diff --git a/polyfactory/value_generators/constrained_collections.py b/polyfactory/value_generators/constrained_collections.py index 0cbaba10..580b7d17 100644 --- a/polyfactory/value_generators/constrained_collections.py +++ b/polyfactory/value_generators/constrained_collections.py @@ -1,18 +1,18 @@ from __future__ import annotations from enum import EnumMeta -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, TypeVar from polyfactory.exceptions import ParameterException -from polyfactory.field_meta import FieldMeta if TYPE_CHECKING: from polyfactory.factories.base import BaseFactory, BuildContext + from polyfactory.field_meta import FieldMeta T = TypeVar("T", list, set, frozenset) -def handle_constrained_collection( +def handle_constrained_collection( # noqa: C901 collection_type: Callable[..., T], factory: type[BaseFactory[Any]], field_meta: FieldMeta, @@ -37,6 +37,10 @@ def handle_constrained_collection( :returns: A collection value. """ + build_context = factory._get_build_context(build_context) + if field_meta.annotation in build_context["seen_models"]: + return collection_type() + min_items = abs(min_items if min_items is not None else (max_items or 0)) max_items = abs(max_items if max_items is not None else min_items + 1) @@ -99,6 +103,12 @@ def handle_constrained_mapping( :returns: A mapping instance. """ + build_context = factory._get_build_context(build_context) + if field_meta.children is None or any( + child_meta.annotation in build_context["seen_models"] for child_meta in field_meta.children + ): + return {} + min_items = abs(min_items if min_items is not None else (max_items or 0)) max_items = abs(max_items if max_items is not None else min_items + 1) @@ -110,7 +120,7 @@ def handle_constrained_mapping( collection: dict[Any, Any] = {} - children = cast(List[FieldMeta], field_meta.children) + children = field_meta.children key_field_meta = children[0] value_field_meta = children[1] while len(collection) < length: diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index ac4c1c30..0b2a2f9b 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import pytest @@ -47,6 +47,8 @@ class PydanticNode(BaseModel): optional_union_child: Union[PydanticNode, None] # noqa: UP007 optional_child: Optional[PydanticNode] # noqa: UP007 child: PydanticNode = Field(default=_Sentinel) # type: ignore[assignment] + recursive_key: Dict[PydanticNode, Any] # noqa: UP006 + recursive_value: Dict[str, PydanticNode] # noqa: UP006 @pytest.mark.parametrize("factory_use_construct", (True, False)) @@ -59,6 +61,8 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None: assert result.optional_union_child is None assert result.optional_child is None assert result.list_child == [] + assert result.recursive_key == {} + assert result.recursive_value == {} @dataclass