Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for msgspec injection #197

Merged
merged 7 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion sanic_ext/extensions/injection/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from sanic.exceptions import ServerError

from sanic_ext.exceptions import InitError
from sanic_ext.utils.typing import is_attrs, is_optional, is_pydantic
from sanic_ext.utils.typing import (
is_attrs,
is_msgspec,
is_optional,
is_pydantic,
)

if TYPE_CHECKING:
from .registry import ConstantRegistry, InjectionRegistry
Expand Down Expand Up @@ -153,6 +158,7 @@ def _get_hints(self):
or is_dataclass(self.func)
or is_attrs(self.func)
or is_pydantic(self.func)
or is_msgspec(self.func)
):
return get_type_hints(self.func)
elif isclass(self.func):
Expand Down
50 changes: 48 additions & 2 deletions sanic_ext/extensions/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

from sanic_routing.patterns import alpha, ext, nonemptystr, parse_date, slug

from sanic_ext.utils.typing import UnionType, is_attrs, is_generic, is_pydantic
from sanic_ext.utils.typing import (
UnionType,
is_attrs,
is_generic,
is_msgspec,
is_pydantic,
)

try:
import attrs
Expand All @@ -26,6 +32,32 @@
except ImportError:
NOTHING = object()

try:
import msgspec
from msgspec.inspect import Metadata as MsgspecMetadata
from msgspec.inspect import type_info as msgspec_type_info

MsgspecMetadata: Any = MsgspecMetadata
NODEFAULT: Any = msgspec.NODEFAULT
UNSET: Any = msgspec.UNSET

class MsgspecAdapter(msgspec.Struct):
name: str
default: Any
metadata: dict

except ImportError:

def msgspec_type_info(struct):
pass

class MsgspecAdapter:
pass

MsgspecMetadata = object()
NODEFAULT = object()
UNSET = object()


class Definition:
__nullable__: Optional[List[str]] = []
Expand Down Expand Up @@ -290,7 +322,7 @@ def __init__(
def make(cls, value: Any, **kwargs):
extra: Dict[str, Any] = {}

# Extract from field metadata if pydantic, attrs, or dataclass
# Extract from field metadata if msgspec, pydantic, attrs, or dataclass
if isclass(value):
fields = ()
if is_pydantic(value):
Expand All @@ -303,6 +335,20 @@ def make(cls, value: Any, **kwargs):
fields = value.__attrs_attrs__
elif is_dataclass(value):
fields = value.__dataclass_fields__.values()
elif is_msgspec(value):
# adapt to msgspec metadata layout -- annotated type --
# to match dataclass "metadata" attribute
fields = [
MsgspecAdapter(
name=f.name,
default=MISSING
if f.default in (UNSET, NODEFAULT)
else f.default,
metadata=getattr(f.type, "extra", {}),
)
for f in msgspec_type_info(value).fields
]

if fields:
extra = {
field.name: {
Expand Down
33 changes: 30 additions & 3 deletions sanic_ext/extras/validation/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
Any,
Literal,
Mapping,
NamedTuple,
Optional,
Tuple,
Expand All @@ -12,7 +13,12 @@
get_origin,
)

from sanic_ext.utils.typing import UnionType, is_generic, is_optional
from sanic_ext.utils.typing import (
UnionType,
is_generic,
is_msgspec,
is_optional,
)

MISSING: Tuple[Any, ...] = (_HAS_DEFAULT_FACTORY,)

Expand All @@ -29,6 +35,14 @@
ATTRS = False


try:
import msgspec

MSGSPEC = True
except ImportError:
MSGSPEC = False


class Hint(NamedTuple):
hint: Any
model: bool
Expand Down Expand Up @@ -169,7 +183,15 @@ def check_data(model, data, schema, allow_multiple=False, allow_coerce=False):
except ValueError as e:
raise TypeError(e)

return model(**hydration_values)
if MSGSPEC and is_msgspec(model):
try:
return msgspec.from_builtins(
hydration_values, model, str_values=True, str_keys=True
)
except msgspec.ValidationError as e:
raise TypeError(e)
else:
return model(**hydration_values)


def _check_types(value, literal, expected):
Expand All @@ -179,7 +201,12 @@ def _check_types(value, literal, expected):
elif value != expected:
raise ValueError(f"Value '{value}' must be {expected}")
else:
if not isinstance(value, expected):
if MSGSPEC and is_msgspec(expected) and isinstance(value, Mapping):
try:
expected(**value)
except (TypeError, msgspec.ValidationError):
raise ValueError(f"Value '{value}' is not of type {expected}")
elif not isinstance(value, expected):
raise ValueError(f"Value '{value}' is not of type {expected}")


Expand Down
25 changes: 18 additions & 7 deletions sanic_ext/extras/validation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_type_hints,
)

from sanic_ext.utils.typing import is_attrs, is_generic
from sanic_ext.utils.typing import is_attrs, is_generic, is_msgspec

from .check import Hint

Expand All @@ -28,6 +28,13 @@
NOTHING = object() # type: ignore
Attribute = type("Attribute", (), {}) # type: ignore

try:
from msgspec.inspect import type_info as msgspec_type_info
except ModuleNotFoundError:

def msgspec_type_info(val):
pass


def make_schema(agg, item):
if type(item) in (bool, str, int, float):
Expand All @@ -36,12 +43,16 @@ def make_schema(agg, item):
if is_generic(item) and (args := get_args(item)):
for arg in args:
make_schema(agg, arg)
elif item.__name__ not in agg and (is_dataclass(item) or is_attrs(item)):
fields = (
item.__dataclass_fields__
if is_dataclass(item)
else {attr.name: attr for attr in item.__attrs_attrs__}
)
elif item.__name__ not in agg and (
is_dataclass(item) or is_attrs(item) or is_msgspec(item)
):
if is_dataclass(item):
fields = item.__dataclass_fields__
elif is_msgspec(item):
fields = {f.name: f.type for f in msgspec_type_info(item).fields}
else:
fields = {attr.name: attr for attr in item.__attrs_attrs__}

sig = signature(item)
hints = parse_hints(get_type_hints(item), fields)

Expand Down
11 changes: 11 additions & 0 deletions sanic_ext/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
except ImportError:
ATTRS = False

try:
from msgspec import Struct

MSGSPEC = True
except ImportError:
MSGSPEC = False


def is_generic(item):
return (
Expand All @@ -47,6 +54,10 @@ def is_attrs(model):
return ATTRS and (hasattr(model, "__attrs_attrs__"))


def is_msgspec(model):
return MSGSPEC and issubclass(model, Struct)


def flat_values(
item: typing.Union[
typing.Dict[str, typing.Any], typing.Iterable[typing.Any]
Expand Down
45 changes: 35 additions & 10 deletions tests/extensions/openapi/test_model_fields.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import sys
from dataclasses import dataclass, field
from typing import List
from uuid import UUID

import attrs
import pytest
from msgspec import Meta, Struct
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass as pydataclass

from sanic_ext import openapi

from .utils import get_spec

if sys.version_info >= (3, 9):
from typing import Annotated


@dataclass
class FooDataclass:
Expand Down Expand Up @@ -47,16 +52,36 @@ class FooPydanticDataclass:
ident: str = Field("XXXX", example="ABC123")


@pytest.mark.parametrize(
"Foo",
(
FooDataclass,
FooAttrs,
FooPydanticBaseModel,
FooPydanticDataclass,
),
)
def test_pydantic_base_model(app, Foo):
if sys.version_info >= (3, 9):

class FooStruct(Struct):
links: List[UUID]
priority: Annotated[
int,
Meta(
extra={
"openapi": {"exclusiveMinimum": 1, "exclusiveMaximum": 10}
}
),
]
ident: Annotated[
str, Meta(extra={"openapi": {"example": "ABC123"}})
] = "XXXX"


models = [
FooDataclass,
FooAttrs,
FooPydanticBaseModel,
FooPydanticDataclass,
]

if sys.version_info >= (3, 9):
models.append(FooStruct)


@pytest.mark.parametrize("Foo", models)
def test_models(app, Foo):
@app.get("/")
@openapi.definition(body={"application/json": Foo})
async def handler(_):
Expand Down
Loading