From 84af1434cf2bf5aa41639aec5261ce55d33cc451 Mon Sep 17 00:00:00 2001 From: mata Date: Sun, 15 Oct 2023 21:56:02 +0200 Subject: [PATCH 1/2] Add support for InitVar --- dataclasses_json/core.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/dataclasses_json/core.py b/dataclasses_json/core.py index e1cfdcbb..dd5a25e7 100644 --- a/dataclasses_json/core.py +++ b/dataclasses_json/core.py @@ -1,4 +1,5 @@ import copy +import inspect import json import sys import warnings @@ -145,7 +146,7 @@ def _decode_dataclass(cls, kvs, infer_missing): return kvs overrides = _user_overrides_or_exts(cls) kvs = {} if kvs is None and infer_missing else kvs - field_names = [field.name for field in fields(cls)] + field_names = set(cls.__dataclass_fields__.keys()) decode_names = _decode_letter_case_overrides(field_names, overrides) kvs = {decode_names.get(k, k): v for k, v in kvs.items()} missing_fields = {field for field in fields(cls) if field.name not in kvs} @@ -163,18 +164,20 @@ def _decode_dataclass(cls, kvs, infer_missing): init_kwargs = {} types = get_type_hints(cls) - for field in fields(cls): + constructor_args = set(inspect.signature(cls).parameters.keys()) + + for field_name in field_names: # The field should be skipped from being added # to init_kwargs as it's not intended as a constructor argument. - if not field.init: + if field_name not in constructor_args: continue - field_value = kvs[field.name] - field_type = types[field.name] + field_value = kvs[field_name] + field_type = types[field_name] if field_value is None: if not _is_optional(field_type): warning = ( - f"value of non-optional type {field.name} detected " + f"value of non-optional type {field_name} detected " f"when decoding {cls.__name__}" ) if infer_missing: @@ -188,7 +191,7 @@ def _decode_dataclass(cls, kvs, infer_missing): warnings.warn( f"'NoneType' object {warning}.", RuntimeWarning ) - init_kwargs[field.name] = field_value + init_kwargs[field_name] = field_value continue while True: @@ -197,13 +200,13 @@ def _decode_dataclass(cls, kvs, infer_missing): field_type = field_type.__supertype__ - if (field.name in overrides - and overrides[field.name].decoder is not None): + if (field_name in overrides + and overrides[field_name].decoder is not None): # FIXME hack if field_type is type(field_value): - init_kwargs[field.name] = field_value + init_kwargs[field_name] = field_value else: - init_kwargs[field.name] = overrides[field.name].decoder( + init_kwargs[field_name] = overrides[field_name].decoder( field_value) elif is_dataclass(field_type): # FIXME this is a band-aid to deal with the value already being @@ -215,13 +218,13 @@ def _decode_dataclass(cls, kvs, infer_missing): else: value = _decode_dataclass(field_type, field_value, infer_missing) - init_kwargs[field.name] = value + init_kwargs[field_name] = value elif _is_supported_generic(field_type) and field_type != str: - init_kwargs[field.name] = _decode_generic(field_type, + init_kwargs[field_name] = _decode_generic(field_type, field_value, infer_missing) else: - init_kwargs[field.name] = _support_extended_types(field_type, + init_kwargs[field_name] = _support_extended_types(field_type, field_value) return cls(**init_kwargs) From e9d80c60c9d0c5875cd620b7bd3e748c4fc1d12d Mon Sep 17 00:00:00 2001 From: mata Date: Sun, 15 Oct 2023 21:56:44 +0200 Subject: [PATCH 2/2] Add tests --- tests/test_init_var.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_init_var.py diff --git a/tests/test_init_var.py b/tests/test_init_var.py new file mode 100644 index 00000000..a2c989b6 --- /dev/null +++ b/tests/test_init_var.py @@ -0,0 +1,26 @@ +from dataclasses import InitVar, dataclass +from typing import Optional + +import pytest + +from dataclasses_json import DataClassJsonMixin + + +@dataclass +class A(DataClassJsonMixin): + a_init: InitVar[int] + _a: Optional[int] = None + + def __post_init__(self, a_init: int): + self._a = a_init + + +class TestEncoder: + def test_init_var(self): + assert A(a_init=1).to_dict() == {'_a': 1} + + +class TestDecoder: + def test_init_var(self): + result = A.from_dict({'a_init': 1}) + assert result._a == 1