From 4003a588ba8f9aaa5c8f78fdb101753aed573d5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Veith=20R=C3=B6thlingsh=C3=B6fer?= Date: Tue, 14 Jan 2020 10:51:58 +0100 Subject: [PATCH 1/2] Fix nested optional type when loading from schema, including test cases --- dataclasses_json/core.py | 5 ++++- dataclasses_json/mm.py | 4 ++++ dataclasses_json/utils.py | 6 ++++-- tests/entities.py | 21 ++++++++++++++++++++- tests/test_schema.py | 17 +++++++++++++++-- 5 files changed, 47 insertions(+), 6 deletions(-) diff --git a/dataclasses_json/core.py b/dataclasses_json/core.py index b0a58173..6f603c69 100644 --- a/dataclasses_json/core.py +++ b/dataclasses_json/core.py @@ -235,7 +235,10 @@ def _decode_generic(type_, value, infer_missing): except TypeError: res = type_(xs) else: # Optional or Union - if _is_optional(type_) and len(type_.__args__) == 2: # Optional + if not hasattr(type_, "__args__"): + # Any, just accept + res = value + elif _is_optional(type_) and len(type_.__args__) == 2: # Optional type_arg = type_.__args__[0] if is_dataclass(type_arg) or is_dataclass(value): res = _decode_dataclass(type_arg, value, infer_missing) diff --git a/dataclasses_json/mm.py b/dataclasses_json/mm.py index 815b2455..54fa999c 100644 --- a/dataclasses_json/mm.py +++ b/dataclasses_json/mm.py @@ -119,6 +119,7 @@ def _deserialize(self, value, attr, data, **kwargs): typing.Dict: fields.Dict, typing.Tuple: fields.Tuple, typing.Callable: fields.Function, + typing.Any: fields.Raw, dict: fields.Dict, list: fields.List, str: fields.Str, @@ -249,6 +250,9 @@ def inner(type_, options): args = [inner(a, {}) for a in getattr(type_, '__args__', []) if a is not type(None)] + if _is_optional(type_): + options["allow_none"] = True + if origin in TYPES: return TYPES[origin](*args, **options) diff --git a/dataclasses_json/utils.py b/dataclasses_json/utils.py index c8f93818..cdf63cc7 100644 --- a/dataclasses_json/utils.py +++ b/dataclasses_json/utils.py @@ -1,7 +1,7 @@ import inspect import sys from datetime import datetime, timezone -from typing import Collection, Mapping, Optional, TypeVar +from typing import Collection, Mapping, Optional, TypeVar, Any def _get_type_cons(type_): @@ -88,7 +88,9 @@ def _is_new_type(type_): def _is_optional(type_): - return _issubclass_safe(type_, Optional) or _hasargs(type_, type(None)) + return (_issubclass_safe(type_, Optional) or + _hasargs(type_, type(None)) or + type_ is Any) def _is_mapping(type_): diff --git a/tests/entities.py b/tests/entities.py index 866fe638..449b17b7 100644 --- a/tests/entities.py +++ b/tests/entities.py @@ -10,7 +10,8 @@ Set, Tuple, TypeVar, - Union) + Union, + Any) from uuid import UUID from marshmallow import fields @@ -246,3 +247,21 @@ class DataClassWithOptionalDecimal: @dataclass class DataClassWithOptionalUuid: a: Optional[UUID] + + +@dataclass_json +@dataclass +class DataClassWithNestedAny: + a: Dict[str, Any] + + +@dataclass_json +@dataclass +class DataClassWithNestedOptionalAny: + a: Dict[str, Optional[Any]] + + +@dataclass_json +@dataclass +class DataClassWithNestedOptional: + a: Dict[str, Optional[int]] diff --git a/tests/test_schema.py b/tests/test_schema.py index e5652429..2545821d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,9 +1,10 @@ -from .entities import DataClassDefaultListStr, DataClassDefaultOptionalList, DataClassList, DataClassOptional +from .entities import (DataClassDefaultListStr, DataClassDefaultOptionalList, DataClassList, DataClassOptional, + DataClassWithNestedOptional, DataClassWithNestedOptionalAny, DataClassWithNestedAny) from .test_letter_case import CamelCasePerson, KebabCasePerson, SnakeCasePerson, FieldNamePerson - test_do_list = """[{}, {"children": [{"name": "a"}, {"name": "b"}]}]""" test_list = '[{"children": [{"name": "a"}, {"name": "b"}]}]' +nested_optional_data = '{"a": {"test": null}}' class TestSchema: @@ -27,3 +28,15 @@ def test_letter_case(self): for cls in (CamelCasePerson, KebabCasePerson, SnakeCasePerson, FieldNamePerson): p = cls('Alice') assert p.to_dict() == cls.schema().dump(p) + + def test_nested_optional(self): + DataClassWithNestedOptional.schema().loads(nested_optional_data) + assert True + + def test_nested_optional_any(self): + DataClassWithNestedOptionalAny.schema().loads(nested_optional_data) + assert True + + def test_nested_any_accepts_optional(self): + DataClassWithNestedAny.schema().loads(nested_optional_data) + assert True From 0f1041c5b1a9d9a439227cbf3f118336549c6023 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Veith=20R=C3=B6thlingsh=C3=B6fer?= Date: Thu, 30 Apr 2020 13:02:43 +0200 Subject: [PATCH 2/2] Allow string enums as dict keys --- dataclasses_json/core.py | 17 +++++++++++--- tests/test_enum.py | 51 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/dataclasses_json/core.py b/dataclasses_json/core.py index 57289695..ad7b1d73 100644 --- a/dataclasses_json/core.py +++ b/dataclasses_json/core.py @@ -234,9 +234,18 @@ def _decode_generic(type_, value, infer_missing): if value is None: res = value elif _issubclass_safe(type_, Enum): - # Convert to an Enum using the type as a constructor. - # Assumes a direct match is found. - res = type_(value) + # we got the enum value + for enum_member in type_: + # We rely on the user that the enum values are unique + # We need to check for the string value + if str(enum_member.value) == value: + res = enum_member + break + else: + # Convert to an Enum using the type as a constructor. + # Assumes a direct match is found. + # Enums can overwrite missing, so we can try this as a last resort + res = type_(value) # FIXME this is a hack to fix a deeper underlying issue. A refactor is due. elif _is_collection(type_): if _is_mapping(type_): @@ -326,5 +335,7 @@ def _asdict(obj, encode_json=False): elif isinstance(obj, Collection) and not isinstance(obj, str) \ and not isinstance(obj, bytes): return list(_asdict(v, encode_json=encode_json) for v in obj) + elif isinstance(obj, Enum): + return _asdict(obj.value, encode_json=encode_json) else: return copy.deepcopy(obj) diff --git a/tests/test_enum.py b/tests/test_enum.py index 7c15082d..93881439 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -65,6 +65,18 @@ class EnumContainer: dict_enum_value={"key1str": MyEnum.STR1, "key1float": MyEnum.FLOAT1}) +@dataclass_json +@dataclass(frozen=True) +class DataWithEnumKeys: + name: str + enum_dict: Dict[MyEnum, str] + + +keys_json = '{"name": "name1", "enum_dict": {"str1": "str_test", "1": "int_test"}}' +d_keys = DataWithEnumKeys(name="name1", enum_dict={MyEnum.STR1: "str_test", + MyEnum.INT1: "int_test"}) + + class TestEncoder: def test_data_with_enum(self): assert d1.to_json() == d1_json, f'Actual: {d1.to_json()}, Expected: {d1_json}' @@ -82,6 +94,9 @@ def test_data_with_enum_default_value(self): def test_collection_with_enum(self): assert container.to_json() == container_json + def test_enum_dict_keys(self): + assert d_keys.to_json() == keys_json + class TestDecoder: def test_data_with_enum(self): @@ -114,6 +129,11 @@ def test_collection_with_enum(self): assert container == container_from_json assert container_from_json.to_json() == container_json + def test_enum_dict_keys(self): + dict_keys_from_json = DataWithEnumKeys.from_json(keys_json) + assert dict_keys_from_json == d_keys + assert dict_keys_from_json.to_json() == keys_json + class TestValidator: @pytest.mark.parametrize('enum_value, is_valid', [ @@ -140,7 +160,21 @@ def test_data_with_str_enum(self, enum_value, is_valid): data = '{"my_str_enum": "' + str(enum_value) + '"}' schema = DataWithStrEnum.schema() res = schema.validate(json.loads(data)) - assert not res == is_valid + no_errors = not res + assert no_errors == is_valid + + @pytest.mark.parametrize("enum_value, is_valid", [ + ("str1", True), + # This may be counter intuitive, but json only allows string keys + ("1", False), (1, False), + ("FOO", False) + ]) + def test_data_with_enum_keys(self, enum_value, is_valid): + data = '{"name": "name1", "enum_dict": {"' + str(enum_value) + '": "bar"}}' + schema = DataWithEnumKeys.schema() + res = schema.validate(json.loads(data)) + no_errors = not res + assert no_errors == is_valid class TestLoader: @@ -170,3 +204,18 @@ def test_data_with_str_enum_exception(self): schema = DataWithStrEnum.schema() with pytest.raises(ValidationError): schema.loads('{"my_str_enum": "str2"}') + + def test_data_with_enum_keys_works_with_str_values(self): + schema = DataWithEnumKeys.schema() + data = '{"name": "name1", "enum_dict": {"str1": "bar"}}' + loaded = schema.loads(data) + assert loaded == DataWithEnumKeys(name="name1", + enum_dict={MyEnum.STR1: "bar"}) + + @pytest.mark.parametrize("enum_value", [MyEnum.INT1.value, + MyEnum.FLOAT1.value]) + def test_data_with_enum_keys_requires_str_values(self, enum_value): + schema = DataWithEnumKeys.schema() + data = '{"name": "name1", "enum_dict": {"' + str(enum_value) + '": "bar"}}' + with pytest.raises(ValidationError): + schema.loads(data)