From 5ce547c290c81965801dacd69a9aa763d7a29587 Mon Sep 17 00:00:00 2001 From: David McKeone Date: Tue, 23 May 2023 13:56:02 -0600 Subject: [PATCH 1/6] Add support for msgspec injection --- sanic_ext/extensions/injection/constructor.py | 3 +- sanic_ext/extensions/openapi/types.py | 39 +- sanic_ext/extras/validation/check.py | 25 +- sanic_ext/extras/validation/schema.py | 22 +- sanic_ext/utils/typing.py | 12 +- tests/extensions/openapi/test_model_fields.py | 10 +- tests/extra/test_validation_msgspec.py | 392 ++++++++++++++++++ 7 files changed, 488 insertions(+), 15 deletions(-) create mode 100644 tests/extra/test_validation_msgspec.py diff --git a/sanic_ext/extensions/injection/constructor.py b/sanic_ext/extensions/injection/constructor.py index c3f17b9..7271316 100644 --- a/sanic_ext/extensions/injection/constructor.py +++ b/sanic_ext/extensions/injection/constructor.py @@ -20,7 +20,7 @@ from sanic.exceptions import ServerError from sanic_ext.exceptions import InitError -from sanic_ext.utils.typing import is_attrs, is_optional, is_pydantic +from sanic_ext.utils.typing import is_attrs, is_optional, is_pydantic, is_msgspec if TYPE_CHECKING: from .registry import ConstantRegistry, InjectionRegistry @@ -153,6 +153,7 @@ def _get_hints(self): or is_dataclass(self.func) or is_attrs(self.func) or is_pydantic(self.func) + or is_msgspec(self.func) ): return get_type_hints(self.func) elif isclass(self.func): diff --git a/sanic_ext/extensions/openapi/types.py b/sanic_ext/extensions/openapi/types.py index bc369bf..03b4af6 100644 --- a/sanic_ext/extensions/openapi/types.py +++ b/sanic_ext/extensions/openapi/types.py @@ -17,7 +17,7 @@ from sanic_routing.patterns import alpha, ext, nonemptystr, parse_date, slug -from sanic_ext.utils.typing import is_attrs, is_generic, is_pydantic +from sanic_ext.utils.typing import is_attrs, is_generic, is_pydantic, is_msgspec try: import attrs @@ -26,6 +26,30 @@ except ImportError: NOTHING = object() +try: + import msgspec + from msgspec.inspect import type_info as msgspec_type_info, Metadata as MsgspecMetadata + + MsgspecMetadata: Any = MsgspecMetadata + NODEFAULT: Any = msgspec.NODEFAULT + UNSET: Any = msgspec.UNSET + + class MsgspecAdapter(msgspec.Struct): + name: str + default: Any + metadata: dict + +except ImportError: + def msgspec_type_info(struct): + pass + + class MsgspecAdapter: + pass + + MsgspecMetadata = object() + NODEFAULT = object() + UNSET = object() + class Definition: __nullable__: Optional[List[str]] = [] @@ -290,7 +314,7 @@ def __init__( def make(cls, value: Any, **kwargs): extra: Dict[str, Any] = {} - # Extract from field metadata if pydantic, attrs, or dataclass + # Extract from field metadata if msgspec, pydantic, attrs, or dataclass if isclass(value): fields = () if is_pydantic(value): @@ -303,6 +327,17 @@ def make(cls, value: Any, **kwargs): fields = value.__attrs_attrs__ elif is_dataclass(value): fields = value.__dataclass_fields__.values() + elif is_msgspec(value): + # adapt to msgspec metadata layout -- annotated type -- to match dataclass "metadata" attribute + fields = [ + MsgspecAdapter( + name=f.name, + default=MISSING if f.default in (UNSET, NODEFAULT) else f.default, + metadata=getattr(f.type, 'extra', {}) + ) + for f in msgspec_type_info(value).fields + ] + if fields: extra = { field.name: { diff --git a/sanic_ext/extras/validation/check.py b/sanic_ext/extras/validation/check.py index 5a6a166..6162176 100644 --- a/sanic_ext/extras/validation/check.py +++ b/sanic_ext/extras/validation/check.py @@ -4,6 +4,7 @@ from typing import ( Any, Literal, + Mapping, NamedTuple, Optional, Tuple, @@ -12,7 +13,7 @@ get_origin, ) -from sanic_ext.utils.typing import UnionType, is_generic, is_optional +from sanic_ext.utils.typing import UnionType, is_generic, is_optional, is_msgspec MISSING: Tuple[Any, ...] = (_HAS_DEFAULT_FACTORY,) @@ -29,6 +30,13 @@ ATTRS = False +try: + import msgspec + MSGSPEC = True +except ImportError: + MSGSPEC = False + + class Hint(NamedTuple): hint: Any model: bool @@ -169,7 +177,13 @@ def check_data(model, data, schema, allow_multiple=False, allow_coerce=False): except ValueError as e: raise TypeError(e) - return model(**hydration_values) + if MSGSPEC and is_msgspec(model): + try: + return msgspec.from_builtins(hydration_values, model) + except msgspec.ValidationError as e: + raise TypeError(e) + else: + return model(**hydration_values) def _check_types(value, literal, expected): @@ -179,7 +193,12 @@ def _check_types(value, literal, expected): elif value != expected: raise ValueError(f"Value '{value}' must be {expected}") else: - if not isinstance(value, expected): + if MSGSPEC and is_msgspec(expected) and isinstance(value, Mapping): + try: + expected(**value) + except (TypeError, msgspec.ValidationError): + raise ValueError(f"Value '{value}' is not of type {expected}") + elif not isinstance(value, expected): raise ValueError(f"Value '{value}' is not of type {expected}") diff --git a/sanic_ext/extras/validation/schema.py b/sanic_ext/extras/validation/schema.py index 1d66c25..ea4d4e3 100644 --- a/sanic_ext/extras/validation/schema.py +++ b/sanic_ext/extras/validation/schema.py @@ -13,7 +13,7 @@ get_type_hints, ) -from sanic_ext.utils.typing import is_attrs, is_generic +from sanic_ext.utils.typing import is_attrs, is_generic, is_msgspec from .check import Hint @@ -28,6 +28,12 @@ NOTHING = object() # type: ignore Attribute = type("Attribute", (), {}) # type: ignore +try: + from msgspec.inspect import type_info as msgspec_type_info +except ModuleNotFoundError: + def msgspec_type_info(val): + pass + def make_schema(agg, item): if type(item) in (bool, str, int, float): @@ -36,12 +42,14 @@ def make_schema(agg, item): if is_generic(item) and (args := get_args(item)): for arg in args: make_schema(agg, arg) - elif item.__name__ not in agg and (is_dataclass(item) or is_attrs(item)): - fields = ( - item.__dataclass_fields__ - if is_dataclass(item) - else {attr.name: attr for attr in item.__attrs_attrs__} - ) + elif item.__name__ not in agg and (is_dataclass(item) or is_attrs(item) or is_msgspec(item)): + if is_dataclass(item): + fields = item.__dataclass_fields__ + elif is_msgspec(item): + fields = {f.name: f.type for f in msgspec_type_info(item).fields} + else: + fields = {attr.name: attr for attr in item.__attrs_attrs__} + sig = signature(item) hints = parse_hints(get_type_hints(item), fields) diff --git a/sanic_ext/utils/typing.py b/sanic_ext/utils/typing.py index ec5d9f3..6b525d4 100644 --- a/sanic_ext/utils/typing.py +++ b/sanic_ext/utils/typing.py @@ -21,6 +21,13 @@ except ImportError: ATTRS = False +try: + from msgspec import Struct + + MSGSPEC = True +except ImportError: + MSGSPEC = False + def is_generic(item): return ( @@ -42,11 +49,14 @@ def is_pydantic(model): issubclass(model, BaseModel) or hasattr(model, "__pydantic_model__") ) - def is_attrs(model): return ATTRS and (hasattr(model, "__attrs_attrs__")) +def is_msgspec(model): + return MSGSPEC and issubclass(model, Struct) + + def flat_values( item: typing.Union[ typing.Dict[str, typing.Any], typing.Iterable[typing.Any] diff --git a/tests/extensions/openapi/test_model_fields.py b/tests/extensions/openapi/test_model_fields.py index 7246bce..5c521e6 100644 --- a/tests/extensions/openapi/test_model_fields.py +++ b/tests/extensions/openapi/test_model_fields.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field -from typing import List +from typing import List, Annotated from uuid import UUID import attrs import pytest +from msgspec import Struct, Meta from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass as pydataclass @@ -47,6 +48,12 @@ class FooPydanticDataclass: ident: str = Field("XXXX", example="ABC123") +class FooStruct(Struct): + links: List[UUID] + priority: Annotated[int, Meta(extra={"openapi": {"exclusiveMinimum": 1, "exclusiveMaximum": 10}})] + ident: Annotated[str, Meta(extra={"openapi": {"example": "ABC123"}})] = "XXXX" + + @pytest.mark.parametrize( "Foo", ( @@ -54,6 +61,7 @@ class FooPydanticDataclass: FooAttrs, FooPydanticBaseModel, FooPydanticDataclass, + FooStruct, ), ) def test_pydantic_base_model(app, Foo): diff --git a/tests/extra/test_validation_msgspec.py b/tests/extra/test_validation_msgspec.py new file mode 100644 index 0000000..387175c --- /dev/null +++ b/tests/extra/test_validation_msgspec.py @@ -0,0 +1,392 @@ +import sys + +from msgspec import Struct +from typing import List, Optional + +import pytest +from sanic import json +from sanic.views import HTTPMethodView + +from sanic_ext import validate +from sanic_ext.extras.validation.check import check_data +from sanic_ext.extras.validation.schema import make_schema, parse_hint + +from . import __models__ as models + +SNOOPY_DATA = {"name": "Snoopy", "alter_ego": ["Flying Ace", "Joe Cool"]} + + +def test_schema(): + class Pet(Struct): + name: str + + class Person(Struct): + name: str + age: int + pets: Optional[List[Pet]] + + schema = make_schema({}, Person) + + assert "Person" in schema + assert schema["Person"]["hints"]["name"] == parse_hint(str) + assert schema["Person"]["hints"]["age"] == parse_hint(int) + assert schema["Person"]["hints"]["pets"] == parse_hint(Optional[List[Pet]]) + + assert "Pet" in schema + assert schema["Pet"]["hints"]["name"] == parse_hint(str) + + +def test_should_hydrate(): + class Pet(Struct): + name: str + + class Person(Struct): + name: str + age: int + pets: List[Pet] + + data = {"name": "Charlie Brown", "age": 8, "pets": [{"name": "Snoopy"}]} + + schema = make_schema({}, Person) + cb = check_data(Person, data, schema) + + assert cb.name == "Charlie Brown" + assert cb.age == 8 + assert cb.pets[0].name == "Snoopy" + + +@pytest.mark.parametrize( + "data", + ( + {"name": "Charlie Brown", "age": 8, "pets": {"name": "Snoopy"}}, + {"name": "Charlie Brown", "age": 8, "pets": [{"name": 123}]}, + {"name": "Charlie Brown", "age": 8, "pets": [123]}, + {"name": "Charlie Brown", "age": 8, "pets": 123}, + {"name": "Charlie Brown", "age": "8", "pets": {"name": "Snoopy"}}, + {"name": True, "age": 8, "pets": {"name": "Snoopy"}}, + ), +) +def test_should_not_hydrate(data): + class Pet(Struct): + name: str + + class Person(Struct): + name: str + age: int + pets: List[Pet] + + schema = make_schema({}, Person) + with pytest.raises(TypeError): + check_data(Person, data, schema) + + +@pytest.mark.parametrize( + "model,okay,data", + ( + (models.ModelStr, True, {"foo": "bar"}), + (models.ModelStr, False, {"foo": 1}), + (models.ModelStr, False, {"foo": True}), + (models.ModelStr, False, {"foo": ["bar"]}), + (models.ModelStr, False, {"bar": "bar"}), + (models.ModelStr, False, {"foo": None}), + (models.ModelStr, False, 123), + (models.ModelInt, True, {"foo": 1}), + (models.ModelInt, True, {"foo": True}), + (models.ModelInt, False, {"foo": "1"}), + (models.ModelInt, False, {"foo": 1.1}), + (models.ModelInt, False, {"foo": None}), + (models.ModelFloat, True, {"foo": 1.1}), + (models.ModelFloat, False, {"foo": 1}), + (models.ModelFloat, False, {"foo": "1.1"}), + (models.ModelFloat, False, {"foo": None}), + (models.ModelBool, True, {"foo": True}), + (models.ModelBool, True, {"foo": False}), + (models.ModelBool, False, {"foo": 1}), + (models.ModelBool, False, {"foo": 0}), + (models.ModelBool, False, {"foo": 2}), + (models.ModelBool, False, {"foo": "True"}), + (models.ModelBool, False, {"foo": None}), + (models.ModelOptionalStr, True, {"foo": "bar"}), + (models.ModelOptionalStr, True, {"foo": None}), + (models.ModelOptionalStr, False, {"foo": 0}), + (models.ModelUnion, True, {"foo": 1}), + (models.ModelUnion, True, {"foo": 1.1}), + (models.ModelUnion, False, {"foo": "1.1"}), + (models.ModelUnion, False, {"foo": None}), + (models.ModelUnionModels, True, {"foo": {"foo": 1}}), + (models.ModelUnionModels, True, {"foo": {"foo": 1.1}}), + (models.ModelUnionModels, False, {"foo": {"foo": "1.1"}}), + (models.ModelUnionModels, False, {"foo": 1}), + (models.ModelUnionModels, False, {"foo": 1.1}), + (models.ModelUnionModels, False, {"foo": None}), + (models.ModelUnionStrInt, True, {"foo": "1"}), + (models.ModelUnionStrInt, True, {"foo": "1q"}), + (models.ModelUnionStrInt, True, {"foo": 1}), + (models.ModelUnionStrInt, False, {"foo": 1.1}), + (models.ModelUnionStrInt, False, {"foo": None}), + (models.ModelUnionIntStr, True, {"foo": "1"}), + (models.ModelUnionIntStr, True, {"foo": "1q"}), + (models.ModelUnionIntStr, True, {"foo": 1}), + (models.ModelUnionIntStr, False, {"foo": 1.1}), + (models.ModelUnionIntStr, False, {"foo": None}), + (models.ModelOptionalUnionStrInt, True, {"foo": "1"}), + (models.ModelOptionalUnionStrInt, True, {"foo": "1q"}), + (models.ModelOptionalUnionStrInt, True, {"foo": 1}), + (models.ModelOptionalUnionStrInt, False, {"foo": 1.1}), + (models.ModelOptionalUnionStrInt, True, {"foo": None}), + (models.ModelOptionalUnionIntStr, True, {"foo": "1"}), + (models.ModelOptionalUnionIntStr, True, {"foo": "1q"}), + (models.ModelOptionalUnionIntStr, True, {"foo": 1}), + (models.ModelOptionalUnionIntStr, False, {"foo": 1.1}), + (models.ModelOptionalUnionIntStr, True, {"foo": None}), + (models.ModelListStr, True, {"foo": ["bar"]}), + (models.ModelListStr, True, {"foo": ["one", "two"]}), + (models.ModelListStr, False, {"foo": "bar"}), + (models.ModelListStr, False, {"foo": ["one", 2]}), + (models.ModelListStr, False, {"foo": ["one", None]}), + (models.ModelListStr, False, {"foo": None}), + (models.ModelListModel, True, {"foo": [{"foo": "bar"}]}), + ( + models.ModelListModel, + True, + {"foo": [{"foo": "one"}, {"foo": "two"}]}, + ), + (models.ModelListModel, False, {"foo": {"foo": "bar"}}), + (models.ModelListModel, False, {"foo": [{"foo": "bar"}, 2]}), + (models.ModelListModel, False, {"foo": [{"foo": "bar"}, None]}), + (models.ModelListModel, False, {"foo": None}), + (models.ModelOptionalList, True, {"foo": None}), + (models.ModelOptionalList, True, {"foo": ["bar"]}), + (models.ModelOptionalList, False, {"foo": [1]}), + (models.ModelOptionalList, False, {"foo": [None]}), + (models.ModelListUnion, True, {"foo": [1]}), + (models.ModelListUnion, True, {"foo": [1.1]}), + (models.ModelListUnion, True, {"foo": [1, 1.1]}), + (models.ModelListUnion, False, {"foo": [1, 1.1, "one"]}), + (models.ModelListUnion, False, {"foo": [1, 1.1, None]}), + (models.ModelListUnion, False, {"foo": 1}), + (models.ModelListUnion, False, {"foo": 1.1}), + (models.ModelListUnion, False, {"foo": None}), + (models.ModelOptionalListUnion, True, {"foo": [1]}), + (models.ModelOptionalListUnion, True, {"foo": [1.1]}), + (models.ModelOptionalListUnion, True, {"foo": [1, 1.1]}), + (models.ModelOptionalListUnion, True, {"foo": None}), + (models.ModelOptionalListUnion, False, {"foo": [1, 1.1, "one"]}), + (models.ModelOptionalListUnion, False, {"foo": [1, 1.1, None]}), + (models.ModelOptionalListUnion, False, {"foo": 1}), + (models.ModelOptionalListUnion, False, {"foo": 1.1}), + (models.ModelModel, True, {"foo": {"foo": "one"}}), + (models.ModelModel, False, {"foo": {"foo": 1}}), + (models.ModelModel, False, {"foo": {"foo": None}}), + (models.ModelModel, False, {"foo": "one"}), + (models.ModelModel, False, {"foo": None}), + (models.ModelOptionalModel, True, {"foo": {"foo": "one"}}), + (models.ModelOptionalModel, True, {"foo": None}), + (models.ModelOptionalModel, False, {"foo": {"foo": 1}}), + (models.ModelOptionalModel, False, {"foo": {"foo": None}}), + (models.ModelOptionalModel, False, {"foo": "one"}), + (models.ModelDictStr, True, {"foo": {"foo": "one"}}), + (models.ModelDictStr, False, {"foo": {"foo": 1}}), + (models.ModelDictStr, False, {"foo": {"foo": None}}), + (models.ModelDictStr, False, {"foo": "one"}), + (models.ModelDictStr, False, {"foo": None}), + (models.ModelDictModel, True, {"foo": {"foo": {"foo": "one"}}}), + (models.ModelDictModel, False, {"foo": {"foo": {"foo": 1}}}), + (models.ModelDictModel, False, {"foo": {"foo": 1}}), + (models.ModelDictModel, False, {"foo": {"foo": None}}), + (models.ModelDictModel, False, {"foo": "one"}), + (models.ModelDictModel, False, {"foo": None}), + (models.ModelOptionalDict, True, {"foo": {"foo": "one"}}), + (models.ModelOptionalDict, True, {"foo": None}), + (models.ModelOptionalDict, False, {"foo": {"foo": 1}}), + (models.ModelOptionalDict, False, {"foo": {"foo": None}}), + (models.ModelOptionalDict, False, {"foo": "one"}), + (models.ModelDictUnion, True, {"foo": {"foo": 1}}), + (models.ModelDictUnion, True, {"foo": {"foo": 1.1}}), + (models.ModelDictUnion, False, {"foo": {"foo": "one"}}), + (models.ModelDictUnion, False, {"foo": {"foo": None}}), + (models.ModelDictUnion, False, {"foo": "one"}), + (models.ModelDictUnion, False, {"foo": 1}), + (models.ModelDictUnion, False, {"foo": 1.1}), + (models.ModelDictUnion, False, {"foo": None}), + (models.ModelOptionalDictUnion, True, {"foo": {"foo": 1}}), + (models.ModelOptionalDictUnion, True, {"foo": {"foo": 1.1}}), + (models.ModelOptionalDictUnion, True, {"foo": None}), + (models.ModelOptionalDictUnion, False, {"foo": {"foo": "one"}}), + (models.ModelOptionalDictUnion, False, {"foo": {"foo": None}}), + (models.ModelOptionalDictUnion, False, {"foo": "one"}), + (models.ModelOptionalDictUnion, False, {"foo": 1}), + (models.ModelOptionalDictUnion, False, {"foo": 1.1}), + (models.ModelSingleLiteral, True, {"foo": True}), + (models.ModelSingleLiteral, False, {"foo": False}), + (models.ModelSingleLiteral, False, {"foo": "True"}), + (models.ModelSingleLiteral, False, {"foo": None}), + (models.ModelOptionalSingleLiteral, True, {"foo": True}), + (models.ModelOptionalSingleLiteral, True, {"foo": None}), + (models.ModelOptionalSingleLiteral, False, {"foo": False}), + (models.ModelOptionalSingleLiteral, False, {"foo": "True"}), + (models.ModelOptionalMultipleLiteral, True, {"foo": True}), + (models.ModelOptionalMultipleLiteral, True, {"foo": 1}), + (models.ModelOptionalMultipleLiteral, True, {"foo": "y"}), + (models.ModelOptionalMultipleLiteral, True, {"foo": "Y"}), + (models.ModelOptionalMultipleLiteral, True, {"foo": None}), + (models.ModelOptionalMultipleLiteral, False, {"foo": "n"}), + (models.ModelOptionalMultipleLiteral, False, {"foo": False}), + (models.ModelListStrWithDefaultFactory, True, {}), + (models.ModelListStrWithDefaultFactory, True, {"foo": ["bar"]}), + (models.ModelListStrWithDefaultFactory, True, {"foo": []}), + (models.ModelListStrWithDefaultFactory, False, {"foo": [1]}), + (models.ModelListStrWithDefaultFactory, False, {"foo": None}), + ), +) +def test_modeling(model, okay, data): + schema = make_schema({}, model) + + if okay: + check_data(model, data, schema) + else: + with pytest.raises(TypeError): + check_data(model, data, schema) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="UnionType added in 3.10" +) +def test_modeling_union_type_ModelUnionTypeStrNone(): + schema = make_schema({}, models.ModelUnionTypeStrNone) + + check_data(models.ModelUnionTypeStrNone, {"foo": "bar"}, schema) + check_data(models.ModelUnionTypeStrNone, {"foo": None}, schema) + with pytest.raises(TypeError): + check_data(models.ModelUnionTypeStrNone, {"foo": 0}, schema) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="UnionType added in 3.10" +) +def test_modeling_union_type_ModelUnionTypeStrIntNone(): + schema = make_schema({}, models.ModelUnionTypeStrIntNone) + + check_data(models.ModelUnionTypeStrIntNone, {"foo": "1"}, schema) + check_data(models.ModelUnionTypeStrIntNone, {"foo": "bar"}, schema) + check_data(models.ModelUnionTypeStrIntNone, {"foo": None}, schema) + check_data(models.ModelUnionTypeStrIntNone, {"foo": 1}, schema) + check_data(models.ModelUnionTypeStrIntNone, {"foo": 0}, schema) + with pytest.raises(TypeError): + check_data(models.ModelUnionTypeStrIntNone, {"foo": 1.1}, schema) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="UnionType added in 3.10" +) +def test_modeling_union_type_ModelUnionTypeStrInt(): + schema = make_schema({}, models.ModelUnionTypeStrInt) + + check_data(models.ModelUnionTypeStrInt, {"foo": "1"}, schema) + check_data(models.ModelUnionTypeStrInt, {"foo": "bar"}, schema) + check_data(models.ModelUnionTypeStrInt, {"foo": 1}, schema) + check_data(models.ModelUnionTypeStrInt, {"foo": 0}, schema) + with pytest.raises(TypeError): + check_data(models.ModelUnionTypeStrInt, {"foo": None}, schema) + with pytest.raises(TypeError): + check_data(models.ModelUnionTypeStrInt, {"foo": 1.1}, schema) + + +def test_validate_json(app): + class Pet(Struct): + name: str + alter_ego: List[str] + + @app.post("/function") + @validate(json=Pet) + async def handler(_, body: Pet): + return json( + { + "is_pet": isinstance(body, Pet), + "pet": {"name": body.name, "alter_ego": body.alter_ego}, + } + ) + + class MethodView(HTTPMethodView, attach=app, uri="/method"): + decorators = [validate(json=Pet)] + + async def post(self, _, body: Pet): + return json( + { + "is_pet": isinstance(body, Pet), + "pet": {"name": body.name, "alter_ego": body.alter_ego}, + } + ) + + _, response = app.test_client.post("/function", json=SNOOPY_DATA) + assert response.status == 200 + assert response.json["is_pet"] + assert response.json["pet"] == SNOOPY_DATA + + _, response = app.test_client.post("/method", json=SNOOPY_DATA) + assert response.status == 200 + assert response.json["is_pet"] + assert response.json["pet"] == SNOOPY_DATA + + +def test_validate_form(app): + class Pet(Struct): + name: str + alter_ego: List[str] + + @app.post("/function") + @validate(form=Pet) + async def handler(_, body: Pet): + return json( + { + "is_pet": isinstance(body, Pet), + "pet": {"name": body.name, "alter_ego": body.alter_ego}, + } + ) + + class MethodView(HTTPMethodView, attach=app, uri="/method"): + decorators = [validate(form=Pet)] + + async def post(self, _, body: Pet): + return json( + { + "is_pet": isinstance(body, Pet), + "pet": {"name": body.name, "alter_ego": body.alter_ego}, + } + ) + + _, response = app.test_client.post("/function", data=SNOOPY_DATA) + assert response.status == 200 + assert response.json["is_pet"] + assert response.json["pet"] == SNOOPY_DATA + + _, response = app.test_client.post("/method", data=SNOOPY_DATA) + assert response.status == 200 + assert response.json["is_pet"] + assert response.json["pet"] == SNOOPY_DATA + + +def test_validate_query(app): + class Search(Struct): + q: str + + @app.get("/function") + @validate(query=Search) + async def handler(_, query: Search): + return json({"q": query.q, "is_search": isinstance(query, Search)}) + + class MethodView(HTTPMethodView, attach=app, uri="/method"): + decorators = [validate(query=Search)] + + async def get(self, _, query: Search): + return json({"q": query.q, "is_search": isinstance(query, Search)}) + + _, response = app.test_client.get("/function", params={"q": "Snoopy"}) + assert response.status == 200 + assert response.json["is_search"] + assert response.json["q"] == "Snoopy" + + _, response = app.test_client.get("/method", params={"q": "Snoopy"}) + assert response.status == 200 + assert response.json["is_search"] + assert response.json["q"] == "Snoopy" From c7872236200f449a53a79093f3ad56a6934b9ee8 Mon Sep 17 00:00:00 2001 From: David McKeone Date: Sat, 3 Jun 2023 15:51:13 -0400 Subject: [PATCH 2/6] Add msgspec to test build environment --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 5828480..48545b1 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ deps = git+https://github.com/sanic-org/sanic.git#egg=sanic pydantic attrs + msgspec ; sanic21.6: sanic==21.6 ; sanic21.6: sanic_testing From 7fcd91772188fa3409171277d4a8994d4272c786 Mon Sep 17 00:00:00 2001 From: David McKeone Date: Sat, 3 Jun 2023 15:53:06 -0400 Subject: [PATCH 3/6] Use recommended arguments from msgspec author when validating data using check_data() --- sanic_ext/extras/validation/check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sanic_ext/extras/validation/check.py b/sanic_ext/extras/validation/check.py index 6162176..7d14158 100644 --- a/sanic_ext/extras/validation/check.py +++ b/sanic_ext/extras/validation/check.py @@ -179,7 +179,7 @@ def check_data(model, data, schema, allow_multiple=False, allow_coerce=False): if MSGSPEC and is_msgspec(model): try: - return msgspec.from_builtins(hydration_values, model) + return msgspec.from_builtins(hydration_values, model, str_values=True, str_keys=True) except msgspec.ValidationError as e: raise TypeError(e) else: From 82313262928d740b5eccfc6e4169266bcd68e90b Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 9 Jul 2023 23:48:55 +0300 Subject: [PATCH 4/6] Shorten line length --- sanic_ext/extensions/openapi/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sanic_ext/extensions/openapi/types.py b/sanic_ext/extensions/openapi/types.py index 2111059..5250c83 100644 --- a/sanic_ext/extensions/openapi/types.py +++ b/sanic_ext/extensions/openapi/types.py @@ -336,7 +336,8 @@ def make(cls, value: Any, **kwargs): elif is_dataclass(value): fields = value.__dataclass_fields__.values() elif is_msgspec(value): - # adapt to msgspec metadata layout -- annotated type -- to match dataclass "metadata" attribute + # adapt to msgspec metadata layout -- annotated type -- + # to match dataclass "metadata" attribute fields = [ MsgspecAdapter( name=f.name, From f60a43f07c8c7f84ca48359e2a16ce3537d7405c Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 9 Jul 2023 23:58:45 +0300 Subject: [PATCH 5/6] Annotated only on 3.9 --- tests/extensions/openapi/test_model_fields.py | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/tests/extensions/openapi/test_model_fields.py b/tests/extensions/openapi/test_model_fields.py index 40d58de..1f57396 100644 --- a/tests/extensions/openapi/test_model_fields.py +++ b/tests/extensions/openapi/test_model_fields.py @@ -1,5 +1,6 @@ +import sys from dataclasses import dataclass, field -from typing import Annotated, List +from typing import List from uuid import UUID import attrs @@ -7,11 +8,13 @@ from msgspec import Meta, Struct from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass as pydataclass - from sanic_ext import openapi from .utils import get_spec +if sys.version_info >= (3, 9): + from typing import Annotated + @dataclass class FooDataclass: @@ -48,30 +51,36 @@ class FooPydanticDataclass: ident: str = Field("XXXX", example="ABC123") -class FooStruct(Struct): - links: List[UUID] - priority: Annotated[ - int, - Meta( - extra={"openapi": {"exclusiveMinimum": 1, "exclusiveMaximum": 10}} - ), - ] - ident: Annotated[ - str, Meta(extra={"openapi": {"example": "ABC123"}}) - ] = "XXXX" - - -@pytest.mark.parametrize( - "Foo", - ( - FooDataclass, - FooAttrs, - FooPydanticBaseModel, - FooPydanticDataclass, - FooStruct, - ), -) -def test_pydantic_base_model(app, Foo): +if sys.version_info >= (3, 9): + + class FooStruct(Struct): + links: List[UUID] + priority: Annotated[ + int, + Meta( + extra={ + "openapi": {"exclusiveMinimum": 1, "exclusiveMaximum": 10} + } + ), + ] + ident: Annotated[ + str, Meta(extra={"openapi": {"example": "ABC123"}}) + ] = "XXXX" + + +models = [ + FooDataclass, + FooAttrs, + FooPydanticBaseModel, + FooPydanticDataclass, +] + +if sys.version_info >= (3, 9): + models.append(FooStruct) + + +@pytest.mark.parametrize("Foo", models) +def test_models(app, Foo): @app.get("/") @openapi.definition(body={"application/json": Foo}) async def handler(_): From e94f6b97071f779dd27ac198b1d2056b19447410 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 10 Jul 2023 00:03:28 +0300 Subject: [PATCH 6/6] Make pretty --- tests/extensions/openapi/test_model_fields.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/extensions/openapi/test_model_fields.py b/tests/extensions/openapi/test_model_fields.py index 1f57396..b7692d9 100644 --- a/tests/extensions/openapi/test_model_fields.py +++ b/tests/extensions/openapi/test_model_fields.py @@ -8,6 +8,7 @@ from msgspec import Meta, Struct from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass as pydataclass + from sanic_ext import openapi from .utils import get_spec