Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow enum dict keys #215

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,18 @@ def _decode_generic(type_, value, infer_missing):
if value is None:
res = value
elif _issubclass_safe(type_, Enum):
# Convert to an Enum using the type as a constructor.
# Assumes a direct match is found.
res = type_(value)
# we got the enum value
for enum_member in type_:
# We rely on the user that the enum values are unique
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this is necessarily true -- enum names are not necessarily one-to-one with values. There is a decorator @unique that enforces it, but the standard library does allow for it

# We need to check for the string value
if str(enum_member.value) == value:
res = enum_member
break
else:
# Convert to an Enum using the type as a constructor.
# Assumes a direct match is found.
# Enums can overwrite missing, so we can try this as a last resort
res = type_(value)
# FIXME this is a hack to fix a deeper underlying issue. A refactor is due.
elif _is_collection(type_):
if _is_mapping(type_):
Expand Down Expand Up @@ -326,5 +335,7 @@ def _asdict(obj, encode_json=False):
elif isinstance(obj, Collection) and not isinstance(obj, str) \
and not isinstance(obj, bytes):
return list(_asdict(v, encode_json=encode_json) for v in obj)
elif isinstance(obj, Enum):
return _asdict(obj.value, encode_json=encode_json)
else:
return copy.deepcopy(obj)
51 changes: 50 additions & 1 deletion tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ class EnumContainer:
dict_enum_value={"key1str": MyEnum.STR1, "key1float": MyEnum.FLOAT1})


@dataclass_json
@dataclass(frozen=True)
class DataWithEnumKeys:
name: str
enum_dict: Dict[MyEnum, str]


keys_json = '{"name": "name1", "enum_dict": {"str1": "str_test", "1": "int_test"}}'
d_keys = DataWithEnumKeys(name="name1", enum_dict={MyEnum.STR1: "str_test",
MyEnum.INT1: "int_test"})


class TestEncoder:
def test_data_with_enum(self):
assert d1.to_json() == d1_json, f'Actual: {d1.to_json()}, Expected: {d1_json}'
Expand All @@ -82,6 +94,9 @@ def test_data_with_enum_default_value(self):
def test_collection_with_enum(self):
assert container.to_json() == container_json

def test_enum_dict_keys(self):
assert d_keys.to_json() == keys_json


class TestDecoder:
def test_data_with_enum(self):
Expand Down Expand Up @@ -114,6 +129,11 @@ def test_collection_with_enum(self):
assert container == container_from_json
assert container_from_json.to_json() == container_json

def test_enum_dict_keys(self):
dict_keys_from_json = DataWithEnumKeys.from_json(keys_json)
assert dict_keys_from_json == d_keys
assert dict_keys_from_json.to_json() == keys_json


class TestValidator:
@pytest.mark.parametrize('enum_value, is_valid', [
Expand All @@ -140,7 +160,21 @@ def test_data_with_str_enum(self, enum_value, is_valid):
data = '{"my_str_enum": "' + str(enum_value) + '"}'
schema = DataWithStrEnum.schema()
res = schema.validate(json.loads(data))
assert not res == is_valid
no_errors = not res
assert no_errors == is_valid

@pytest.mark.parametrize("enum_value, is_valid", [
("str1", True),
# This may be counter intuitive, but json only allows string keys
("1", False), (1, False),
("FOO", False)
])
def test_data_with_enum_keys(self, enum_value, is_valid):
data = '{"name": "name1", "enum_dict": {"' + str(enum_value) + '": "bar"}}'
schema = DataWithEnumKeys.schema()
res = schema.validate(json.loads(data))
no_errors = not res
assert no_errors == is_valid


class TestLoader:
Expand Down Expand Up @@ -170,3 +204,18 @@ def test_data_with_str_enum_exception(self):
schema = DataWithStrEnum.schema()
with pytest.raises(ValidationError):
schema.loads('{"my_str_enum": "str2"}')

def test_data_with_enum_keys_works_with_str_values(self):
schema = DataWithEnumKeys.schema()
data = '{"name": "name1", "enum_dict": {"str1": "bar"}}'
loaded = schema.loads(data)
assert loaded == DataWithEnumKeys(name="name1",
enum_dict={MyEnum.STR1: "bar"})

@pytest.mark.parametrize("enum_value", [MyEnum.INT1.value,
MyEnum.FLOAT1.value])
def test_data_with_enum_keys_requires_str_values(self, enum_value):
schema = DataWithEnumKeys.schema()
data = '{"name": "name1", "enum_dict": {"' + str(enum_value) + '": "bar"}}'
with pytest.raises(ValidationError):
schema.loads(data)