Skip to content

Commit

Permalink
fix: handle recursive collections (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong authored Sep 21, 2024
1 parent 135d7fe commit 6440faa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
5 changes: 4 additions & 1 deletion polyfactory/value_generators/complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def handle_collection_type(
container_type = INSTANTIABLE_TYPE_MAPPING[container_type] # type: ignore[assignment]

container = container_type()
if not field_meta.children:
if field_meta.children is None or any(
child_meta.annotation in factory._get_build_context(build_context)["seen_models"]
for child_meta in field_meta.children
):
return container

if issubclass(container_type, MutableMapping) or is_typeddict(container_type):
Expand Down
18 changes: 14 additions & 4 deletions polyfactory/value_generators/constrained_collections.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

from enum import EnumMeta
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, TypeVar

from polyfactory.exceptions import ParameterException
from polyfactory.field_meta import FieldMeta

if TYPE_CHECKING:
from polyfactory.factories.base import BaseFactory, BuildContext
from polyfactory.field_meta import FieldMeta

T = TypeVar("T", list, set, frozenset)


def handle_constrained_collection(
def handle_constrained_collection( # noqa: C901
collection_type: Callable[..., T],
factory: type[BaseFactory[Any]],
field_meta: FieldMeta,
Expand All @@ -37,6 +37,10 @@ def handle_constrained_collection(
:returns: A collection value.
"""
build_context = factory._get_build_context(build_context)
if field_meta.annotation in build_context["seen_models"]:
return collection_type()

min_items = abs(min_items if min_items is not None else (max_items or 0))
max_items = abs(max_items if max_items is not None else min_items + 1)

Expand Down Expand Up @@ -99,6 +103,12 @@ def handle_constrained_mapping(
:returns: A mapping instance.
"""
build_context = factory._get_build_context(build_context)
if field_meta.children is None or any(
child_meta.annotation in build_context["seen_models"] for child_meta in field_meta.children
):
return {}

min_items = abs(min_items if min_items is not None else (max_items or 0))
max_items = abs(max_items if max_items is not None else min_items + 1)

Expand All @@ -110,7 +120,7 @@ def handle_constrained_mapping(

collection: dict[Any, Any] = {}

children = cast(List[FieldMeta], field_meta.children)
children = field_meta.children
key_field_meta = children[0]
value_field_meta = children[1]
while len(collection) < length:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_recursive_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pytest

Expand Down Expand Up @@ -47,6 +47,8 @@ class PydanticNode(BaseModel):
optional_union_child: Union[PydanticNode, None] # noqa: UP007
optional_child: Optional[PydanticNode] # noqa: UP007
child: PydanticNode = Field(default=_Sentinel) # type: ignore[assignment]
recursive_key: Dict[PydanticNode, Any] # noqa: UP006
recursive_value: Dict[str, PydanticNode] # noqa: UP006


@pytest.mark.parametrize("factory_use_construct", (True, False))
Expand All @@ -59,6 +61,8 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
assert result.optional_union_child is None
assert result.optional_child is None
assert result.list_child == []
assert result.recursive_key == {}
assert result.recursive_value == {}


@dataclass
Expand Down

0 comments on commit 6440faa

Please sign in to comment.