diff --git a/dataclasses_json/core.py b/dataclasses_json/core.py index 57289695..ad7b1d73 100644 --- a/dataclasses_json/core.py +++ b/dataclasses_json/core.py @@ -234,9 +234,18 @@ def _decode_generic(type_, value, infer_missing): if value is None: res = value elif _issubclass_safe(type_, Enum): - # Convert to an Enum using the type as a constructor. - # Assumes a direct match is found. - res = type_(value) + # we got the enum value + for enum_member in type_: + # We rely on the user that the enum values are unique + # We need to check for the string value + if str(enum_member.value) == value: + res = enum_member + break + else: + # Convert to an Enum using the type as a constructor. + # Assumes a direct match is found. + # Enums can overwrite missing, so we can try this as a last resort + res = type_(value) # FIXME this is a hack to fix a deeper underlying issue. A refactor is due. elif _is_collection(type_): if _is_mapping(type_): @@ -326,5 +335,7 @@ def _asdict(obj, encode_json=False): elif isinstance(obj, Collection) and not isinstance(obj, str) \ and not isinstance(obj, bytes): return list(_asdict(v, encode_json=encode_json) for v in obj) + elif isinstance(obj, Enum): + return _asdict(obj.value, encode_json=encode_json) else: return copy.deepcopy(obj) diff --git a/tests/test_enum.py b/tests/test_enum.py index 7c15082d..93881439 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -65,6 +65,18 @@ class EnumContainer: dict_enum_value={"key1str": MyEnum.STR1, "key1float": MyEnum.FLOAT1}) +@dataclass_json +@dataclass(frozen=True) +class DataWithEnumKeys: + name: str + enum_dict: Dict[MyEnum, str] + + +keys_json = '{"name": "name1", "enum_dict": {"str1": "str_test", "1": "int_test"}}' +d_keys = DataWithEnumKeys(name="name1", enum_dict={MyEnum.STR1: "str_test", + MyEnum.INT1: "int_test"}) + + class TestEncoder: def test_data_with_enum(self): assert d1.to_json() == d1_json, f'Actual: {d1.to_json()}, Expected: {d1_json}' @@ -82,6 +94,9 @@ def test_data_with_enum_default_value(self): def test_collection_with_enum(self): assert container.to_json() == container_json + def test_enum_dict_keys(self): + assert d_keys.to_json() == keys_json + class TestDecoder: def test_data_with_enum(self): @@ -114,6 +129,11 @@ def test_collection_with_enum(self): assert container == container_from_json assert container_from_json.to_json() == container_json + def test_enum_dict_keys(self): + dict_keys_from_json = DataWithEnumKeys.from_json(keys_json) + assert dict_keys_from_json == d_keys + assert dict_keys_from_json.to_json() == keys_json + class TestValidator: @pytest.mark.parametrize('enum_value, is_valid', [ @@ -140,7 +160,21 @@ def test_data_with_str_enum(self, enum_value, is_valid): data = '{"my_str_enum": "' + str(enum_value) + '"}' schema = DataWithStrEnum.schema() res = schema.validate(json.loads(data)) - assert not res == is_valid + no_errors = not res + assert no_errors == is_valid + + @pytest.mark.parametrize("enum_value, is_valid", [ + ("str1", True), + # This may be counter intuitive, but json only allows string keys + ("1", False), (1, False), + ("FOO", False) + ]) + def test_data_with_enum_keys(self, enum_value, is_valid): + data = '{"name": "name1", "enum_dict": {"' + str(enum_value) + '": "bar"}}' + schema = DataWithEnumKeys.schema() + res = schema.validate(json.loads(data)) + no_errors = not res + assert no_errors == is_valid class TestLoader: @@ -170,3 +204,18 @@ def test_data_with_str_enum_exception(self): schema = DataWithStrEnum.schema() with pytest.raises(ValidationError): schema.loads('{"my_str_enum": "str2"}') + + def test_data_with_enum_keys_works_with_str_values(self): + schema = DataWithEnumKeys.schema() + data = '{"name": "name1", "enum_dict": {"str1": "bar"}}' + loaded = schema.loads(data) + assert loaded == DataWithEnumKeys(name="name1", + enum_dict={MyEnum.STR1: "bar"}) + + @pytest.mark.parametrize("enum_value", [MyEnum.INT1.value, + MyEnum.FLOAT1.value]) + def test_data_with_enum_keys_requires_str_values(self, enum_value): + schema = DataWithEnumKeys.schema() + data = '{"name": "name1", "enum_dict": {"' + str(enum_value) + '": "bar"}}' + with pytest.raises(ValidationError): + schema.loads(data)