Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve Union deserialization when "__type" field specifier is not present #478

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions dataclasses_json/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,17 @@ def _decode_generic(type_, value, infer_missing):
type_options = _get_type_args(type_)
res = value # assume already decoded
if type(value) is dict and dict not in type_options:
# FIXME if all types in the union are dataclasses this
george-zubrienko marked this conversation as resolved.
Show resolved Hide resolved
# will just pick the first option -
# maybe find the best fitting class in that case instead?
for type_option in type_options:
if is_dataclass(type_option):
res = _decode_dataclass(type_option, value, infer_missing)
break
try:
res = _decode_dataclass(type_option, value, infer_missing)
break
except (KeyError, ValueError):
george-zubrienko marked this conversation as resolved.
Show resolved Hide resolved
continue
if res == value:
warnings.warn(
f"Failed to encode {value} Union dataclasses."
f"Expected Union to include a dataclass and it didn't."
f"Failed to decode {value} Union dataclasses."
f"Expected Union to include a matching dataclass and it didn't."
)
return res

Expand Down
27 changes: 18 additions & 9 deletions dataclasses_json/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,25 @@ def _deserialize(self, value, attr, data, **kwargs):
if is_dataclass(type_) and type_.__name__ == dc_name:
del tmp_value['__type']
return schema_._deserialize(tmp_value, attr, data, **kwargs)
for type_, schema_ in self.desc.items():
if isinstance(tmp_value, _get_type_origin(type_)):
return schema_._deserialize(tmp_value, attr, data, **kwargs)
else:
elif isinstance(tmp_value, dict):
warnings.warn(
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") '
f'is not in the list of possible types of typing.Union '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value cannot be deserialized properly.')
return super()._deserialize(tmp_value, attr, data, **kwargs)
f'Attempting to deserialize "dict" (value: "{tmp_value}) '
f'that does not have a "__type" type specifier field into'
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).'
f'Deserialization may fail, or deserialization to wrong type may occur.'
)
return super()._deserialize(tmp_value, attr, data, **kwargs)
else:
for type_, schema_ in self.desc.items():
if isinstance(tmp_value, _get_type_origin(type_)):
return schema_._deserialize(tmp_value, attr, data, **kwargs)
else:
warnings.warn(
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") '
f'is not in the list of possible types of typing.Union '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value cannot be deserialized properly.')
return super()._deserialize(tmp_value, attr, data, **kwargs)


class _TupleVarLen(fields.List):
Expand Down
42 changes: 42 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,21 @@ class Aux2:
f1: str


@dataclass_json
@dataclass
class Aux3:
f2: str

@dataclass_json
@dataclass
class C4:
f1: Union[Aux1, Aux2]

@dataclass_json
@dataclass
class C12:
f1: Union[Aux2, Aux3]


@dataclass_json
@dataclass
Expand Down Expand Up @@ -198,3 +208,35 @@ def test_deserialize_with_error(cls, data):
s = cls.schema()
with pytest.raises(ValidationError):
assert s.load(data)

def test_deserialize_without_discriminator():
# determine based on type
json = '{"f1": {"f1": 1}}'
s = C4.schema()
obj = s.loads(json)
assert obj.f1 is not None
assert type(obj.f1) == Aux1

json = '{"f1": {"f1": "str1"}}'
s = C4.schema()
obj = s.loads(json)
assert obj.f1 is not None
assert type(obj.f1) == Aux2

# determine based on field name
json = '{"f1": {"f1": "str1"}}'
s = C12.schema()
obj = s.loads(json)
assert obj.f1 is not None
assert type(obj.f1) == Aux2
json = '{"f1": {"f2": "str1"}}'
s = C12.schema()
obj = s.loads(json)
assert obj.f1 is not None
assert type(obj.f1) == Aux3

# if no matching types, type should remain dict
json = '{"f1": {"f3": "str2"}}'
s = C12.schema()
obj = s.loads(json)
assert type(obj.f1) == dict