Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support abstract collections #532

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import sys
import warnings
from collections import defaultdict, namedtuple
from collections.abc import (Collection as ABCCollection, Mapping as ABCMapping, MutableMapping, MutableSequence,
MutableSet, Sequence, Set)
from dataclasses import (MISSING,
fields,
is_dataclass # type: ignore
)
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from types import MappingProxyType
from typing import (Any, Collection, Mapping, Union, get_type_hints,
Tuple, TypeVar, Type)
from uuid import UUID
Expand All @@ -31,6 +34,15 @@

confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude']
FieldOverride = namedtuple('FieldOverride', confs) # type: ignore
collections_abc_type_to_implementation_type = MappingProxyType({
ABCCollection: tuple,
ABCMapping: dict,
MutableMapping: dict,
MutableSequence: list,
MutableSet: set,
Sequence: tuple,
Set: frozenset,
})


class _ExtendedEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -302,14 +314,8 @@ def _decode_generic(type_, value, infer_missing):
else:
xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing)

# get the constructor if using corresponding generic type in `typing`
# otherwise fallback on constructing using type_ itself
materialize_type = type_
try:
materialize_type = _get_type_cons(type_)
except (TypeError, AttributeError):
pass
res = materialize_type(xs)
collection_type = _resolve_collection_type_to_decode_to(type_)
res = collection_type(xs)
elif _is_generic_dataclass(type_):
origin = _get_type_origin(type_)
res = _decode_dataclass(origin, value, infer_missing)
Expand Down Expand Up @@ -402,6 +408,18 @@ def handle_pep0673(pre_0673_hint: str) -> Union[Type, str]:
return list(_decode_type(type_args, x, infer_missing) for x in xs)


def _resolve_collection_type_to_decode_to(type_):
# get the constructor if using corresponding generic type in `typing`
# otherwise fallback on constructing using type_ itself
try:
collection_type = _get_type_cons(type_)
except (TypeError, AttributeError):
collection_type = type_

# map abstract collection to concrete implementation
return collections_abc_type_to_implementation_type.get(collection_type, collection_type)


def _asdict(obj, encode_json=False):
"""
A re-implementation of `asdict` (based on the original in the `dataclasses`
Expand Down
46 changes: 45 additions & 1 deletion tests/entities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from collections import deque
from collections.abc import Mapping, MutableMapping, MutableSequence, MutableSet, Sequence, Set as ABCSet
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
Expand All @@ -18,8 +19,9 @@
from uuid import UUID
if sys.version_info >= (3, 9):
from collections import Counter
from collections.abc import Mapping, MutableMapping, MutableSequence, MutableSet, Sequence, Set as ABCSet
else:
from typing import Counter
from typing import Counter, Mapping, MutableMapping, MutableSequence, MutableSet, Sequence, Set as ABCSet

from marshmallow import fields

Expand Down Expand Up @@ -388,3 +390,45 @@ class DataClassWithCounter:
class DataClassWithSelf(DataClassJsonMixin):
id: str
ref: Optional['DataClassWithSelf']


@dataclass_json
@dataclass
class DataClassWithCollection(DataClassJsonMixin):
c: Collection[int]


@dataclass_json
@dataclass
class DataClassWithMapping(DataClassJsonMixin):
c: Mapping[str, int]


@dataclass_json
@dataclass
class DataClassWithMutableMapping(DataClassJsonMixin):
c: MutableMapping[str, int]


@dataclass_json
@dataclass
class DataClassWithMutableSequence(DataClassJsonMixin):
c: MutableSequence[int]


@dataclass_json
@dataclass
class DataClassWithMutableSet(DataClassJsonMixin):
c: MutableSet[int]


@dataclass_json
@dataclass
class DataClassWithSequence(DataClassJsonMixin):
c: Sequence[int]


@dataclass_json
@dataclass
class DataClassWithAbstractSet(DataClassJsonMixin):
c: ABCSet[int]
22 changes: 21 additions & 1 deletion tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import Counter, deque

import pytest

from tests.entities import (DataClassIntImmutableDefault,
DataClassMutableDefaultDict,
DataClassMutableDefaultList, DataClassWithDeque,
Expand All @@ -20,7 +22,10 @@
DataClassWithDequeCollections,
DataClassWithTuple, DataClassWithTupleUnbound,
DataClassWithUnionIntNone, MyCollection,
DataClassWithCounter)
DataClassWithCounter, DataClassWithCollection,
DataClassWithMapping, DataClassWithMutableMapping,
DataClassWithMutableSet, DataClassWithMutableSequence,
DataClassWithSequence, DataClassWithAbstractSet)


class TestEncoder:
Expand Down Expand Up @@ -244,3 +249,18 @@ def test_mutable_default_dict(self):
def test_counter(self):
assert DataClassWithCounter.from_json('{"c": {"f": 1, "o": 2}}') == \
DataClassWithCounter(c=Counter('foo'))

@pytest.mark.parametrize(
"json_string, expected_instance",
[
pytest.param('{"c": [1, 2]}', DataClassWithCollection((1, 2)), id="collection"),
pytest.param('{"c": [1, 2]}', DataClassWithSequence((1, 2)), id="sequence"),
pytest.param('{"c": [1, 2]}', DataClassWithMutableSequence([1, 2]), id="mutable-sequence"),
pytest.param('{"c": [1, 2]}', DataClassWithAbstractSet({1, 2}), id="set"),
pytest.param('{"c": [1, 2]}', DataClassWithMutableSet({1, 2}), id="mutable-set"),
pytest.param('{"c": {"1": 1, "2": 2}}', DataClassWithMapping({"1": 1, "2": 2}), id="mapping"),
pytest.param('{"c": {"1": 1, "2": 2}}', DataClassWithMutableMapping({"1": 1, "2": 2}), id="mutable-mapping"),
]
)
def test_abstract_collections(self, json_string, expected_instance):
assert type(expected_instance).from_json(json_string) == expected_instance
Loading