From ba1fecfc3d49a08f5f98ac02121ebf508a200d44 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Sun, 31 Mar 2024 19:41:36 +0200 Subject: [PATCH] higher order hints for @extend_schema_field (case 2) #1174 #1212 --- drf_spectacular/openapi.py | 2 ++ drf_spectacular/plumbing.py | 3 +-- drf_spectacular/utils.py | 2 +- tests/test_regressions.py | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 13998ff2..d954914b 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -628,6 +628,8 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False): schema = build_basic_type(override) if schema is None: return None + elif is_higher_order_type_hint(override): + schema = resolve_type_hint(override) elif isinstance(override, dict): schema = override else: diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index af5b5def..2897ad21 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -21,7 +21,6 @@ else: from typing_extensions import TypeGuard # noqa: F401 - import inflection import uritemplate from django.apps import apps @@ -1366,7 +1365,7 @@ def resolve_type_hint(hint): elif origin is collections.abc.Iterable: return build_array_type(resolve_type_hint(args[0])) else: - raise UnableToProceedError() + raise UnableToProceedError(hint) def whitelisted(obj: object, classes: Optional[List[Type[object]]], exact=False) -> bool: diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index 8689aec6..c0c33f76 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -571,7 +571,7 @@ def get_external_docs(self): def extend_schema_field( - field: Union[_SerializerType, _FieldType, OpenApiTypes, _SchemaType], + field: Union[_SerializerType, _FieldType, OpenApiTypes, _SchemaType, _KnownPythonTypes], component_name: Optional[str] = None ) -> Callable[[F], F]: """ diff --git a/tests/test_regressions.py b/tests/test_regressions.py index a0629298..2cfeecaf 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -3348,3 +3348,40 @@ def favorite(self, request, pk=None): schema = generate_schema('m', XViewSet) assert list(schema['paths'].keys()) == ['/m/', '/m/{id}/', '/m/{id}/favorite/'] + + +def test_extend_schema_field_with_types(no_warnings): + @extend_schema_field(int) + class CustomField(serializers.CharField): + pass # pragma: no cover + + @extend_schema_field(typing.List[int]) # this is the new case + class CustomField2(serializers.CharField): + pass # pragma: no cover + + class XSerializer(serializers.Serializer): + foo = serializers.SerializerMethodField() + bar = serializers.SerializerMethodField() + baz = CustomField() + qux = CustomField2() + + @extend_schema_field(int) + def get_foo(self, field, extra_param): + return 'foo' # pragma: no cover + + @extend_schema_field(typing.List[int]) + def get_bar(self, field, extra_param): + return 1 # pragma: no cover + + @extend_schema(request=XSerializer, responses=XSerializer) + @api_view(['POST']) + def view_func(request, format=None): + pass # pragma: no cover + + schema = generate_schema('/x/', view_function=view_func) + assert schema['components']['schemas']['X']['properties'] == { + 'foo': {'readOnly': True, 'type': 'integer'}, + 'bar': {'items': {'type': 'integer'}, 'readOnly': True, 'type': 'array'}, + 'baz': {'type': 'integer'}, + 'qux': {'items': {'type': 'integer'}, 'type': 'array'} + }