From 41ce56601137f5127b4cd13460d57f67e4a069a9 Mon Sep 17 00:00:00 2001 From: yukinarit Date: Sun, 26 May 2024 22:25:05 +0900 Subject: [PATCH] Cache code templates Before 4572556 function calls (4301342 primitive calls) in 1.405 seconds After 1080875 function calls (1052233 primitive calls) in 0.383 seconds --- profile_codegen.py | 119 +++++++++++++++++++++++++++++ serde/__init__.py | 3 +- serde/compat.py | 2 +- serde/core.py | 2 +- serde/de.py | 185 +++++++++++++++++++++++---------------------- serde/se.py | 157 +++++++++++++++++++------------------- 6 files changed, 295 insertions(+), 173 deletions(-) create mode 100644 profile_codegen.py diff --git a/profile_codegen.py b/profile_codegen.py new file mode 100644 index 00000000..f4c69dfa --- /dev/null +++ b/profile_codegen.py @@ -0,0 +1,119 @@ +from serde import serde, serialize, deserialize +from serde.json import from_json, to_json +from dataclasses import dataclass +from beartype import beartype +import cProfile +import pstats + + +@beartype +@dataclass +class FewFields: + a: int + b: float + c: str + d: bool + + +@beartype +@dataclass +class ManyFields: + a1: int + a2: int + a3: int + a4: int + a5: int + a6: int + a7: int + a8: int + a9: int + a10: int + a11: int + a12: int + a13: int + a14: int + a15: int + a16: int + a17: int + a18: int + a19: int + a20: int + b1: int + b2: int + b3: int + b4: int + b5: int + b6: int + b7: int + b8: int + b9: int + b10: int + b11: int + b12: int + b13: int + b14: int + b15: int + b16: int + b17: int + b18: int + b19: int + b20: int + c1: int + c2: int + c3: int + c4: int + c5: int + c6: int + c7: int + c8: int + c9: int + c10: int + c11: int + c12: int + c13: int + c14: int + c15: int + c16: int + c17: int + c18: int + c19: int + c20: int + d1: int + d2: int + d3: int + d4: int + d5: int + d6: int + d7: int + d8: int + d9: int + d10: int + d11: int + d12: int + d13: int + d14: int + d15: int + d16: int + d17: int + d18: int + d19: int + d20: int + + +def profile_few_fields() -> None: + for n in range(100): + serde(FewFields) + + +def profile_many_fields() -> None: + for n in range(100): + serde(ManyFields) + + +cProfile.run("profile_few_fields()", filename="profile_results.prof") +stats = pstats.Stats("profile_results.prof") +stats.sort_stats("tottime").print_stats(20) + +cProfile.run("profile_many_fields()", filename="profile_results.prof") +stats = pstats.Stats("profile_results.prof") +stats.sort_stats("tottime").print_stats(20) diff --git a/serde/__init__.py b/serde/__init__.py index b879cbc0..02dfa00f 100644 --- a/serde/__init__.py +++ b/serde/__init__.py @@ -129,6 +129,7 @@ def serde( @overload def serde( + _cls: Any = None, rename_all: Optional[str] = None, reuse_instances_default: bool = True, convert_sets_default: bool = False, @@ -142,7 +143,7 @@ def serde( ) -> Callable[[type[T]], type[T]]: ... -@dataclass_transform(field_specifiers=(field,)) # type: ignore +@dataclass_transform(field_specifiers=(field,)) def serde( _cls: Any = None, rename_all: Optional[str] = None, diff --git a/serde/compat.py b/serde/compat.py index 0248327f..b841d6dd 100644 --- a/serde/compat.py +++ b/serde/compat.py @@ -60,7 +60,7 @@ def get_np_args(tp: Any) -> tuple[Any, ...]: """ List of datetime types """ -@dataclasses.dataclass +@dataclasses.dataclass(unsafe_hash=True) class _WithTagging(Generic[T]): """ Intermediate data structure for (de)serializaing Union without dataclass. diff --git a/serde/core.py b/serde/core.py index 71f47e05..0c15d50c 100644 --- a/serde/core.py +++ b/serde/core.py @@ -754,7 +754,7 @@ def literal_func_name(literal_args: Sequence[Any]) -> str: ) -@dataclass +@dataclass(unsafe_hash=True) class Tagging: """ Controls how union is (de)serialized. This is the same concept as in diff --git a/serde/de.py b/serde/de.py index 3b205925..56a46883 100644 --- a/serde/de.py +++ b/serde/de.py @@ -162,7 +162,7 @@ def _make_deserialize( """ Create a deserializable class programatically. """ - C = dataclasses.make_dataclass(cls_name, fields, *args, **kwargs) + C: type[Any] = dataclasses.make_dataclass(cls_name, fields, *args, **kwargs) C = deserialize( C, rename_all=rename_all, @@ -1063,13 +1063,10 @@ def renderable(f: DeField[Any]) -> bool: return f.init -def render_from_iter( - cls: type[Any], - legacy_class_deserializer: Optional[DeserializeFunc] = None, - type_check: TypeCheck = strict, - class_deserializer: Optional[ClassDeserializer] = None, -) -> str: - template = """ +jinja2_env = jinja2.Environment( + loader=jinja2.DictLoader( + { + "iter": """ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None, variable_type_args=None, reuse_instances=None): if reuse_instances is None: @@ -1078,7 +1075,7 @@ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=Non maybe_generic_type_vars = maybe_generic_type_vars or {{cls_type_vars}} {% for f in fields %} - __{{f.name}} = {{f|arg(loop.index-1)|rvalue}} + __{{f.name}} = {{rvalue(arg(f,loop.index-1))}} {% endfor %} try: @@ -1091,8 +1088,82 @@ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=Non raise SerdeError(e) except Exception as e: raise UserError(e) - """ +""", + "dict": """ +def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None, + variable_type_args=None, reuse_instances=None): + if reuse_instances is None: + reuse_instances = {{serde_scope.reuse_instances_default}} + + maybe_generic_type_vars = maybe_generic_type_vars or {{cls_type_vars}} + + {% for f in fields %} + __{{f.name}} = {{rvalue(arg(f,loop.index-1))}} + {% endfor %} + try: + return cls( + {% for f in fields %} + {% if f.kw_only %} + {{f.name}}=__{{f.name}}, + {% else %} + __{{f.name}}, + {% endif %} + {% endfor %} + ) + except BeartypeCallHintParamViolation as e: + raise SerdeError(e) + except Exception as e: + raise UserError(e) +""", + "union": """ +def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None, + variable_type_args=None, reuse_instances = {{serde_scope.reuse_instances_default}}): + errors = [] + {% for t in union_args %} + try: + # create fake dict so we can reuse the normal render function + {% if tagging.is_external() and is_taggable(t) %} + ensure("{{typename(t)}}" in data , "'{{typename(t)}}' key is not present") + fake_dict = {"fake_key": data["{{typename(t)}}"]} + + {% elif tagging.is_internal() and is_taggable(t) %} + ensure("{{tagging.tag}}" in data , "'{{tagging.tag}}' key is not present") + ensure("{{typename(t)}}" == data["{{tagging.tag}}"], "tag '{{typename(t)}}' isn't found") + fake_dict = {"fake_key": data} + + {% elif tagging.is_adjacent() and is_taggable(t) %} + ensure("{{tagging.tag}}" in data , "'{{tagging.tag}}' key is not present") + ensure("{{tagging.content}}" in data , "'{{tagging.content}}' key is not present") + ensure("{{typename(t)}}" == data["{{tagging.tag}}"], "tag '{{typename(t)}}' isn't found") + fake_dict = {"fake_key": data["{{tagging.content}}"]} + + {% else %} + fake_dict = {"fake_key": data} + {% endif %} + + {% if is_primitive(t) or is_none(t) %} + if not isinstance(fake_dict["fake_key"], {{typename(t)}}): + raise Exception("Not a type of {{typename(t)}}") + {% endif %} + return {{rvalue(arg(t))}} + except Exception as e: + errors.append(f' Failed to deserialize into {{typename(t)}}: {e}') + {% endfor %} + raise SerdeError("Can not deserialize " + repr(data) + " of type " + \ + typename(type(data)) + " into {{union_name}}.\\nReasons:\\n" + "\\n".join(errors)) +""", + } + ) +) + + +def render_from_iter( + cls: type[Any], + legacy_class_deserializer: Optional[DeserializeFunc] = None, + type_check: TypeCheck = strict, + class_deserializer: Optional[ClassDeserializer] = None, +) -> str: renderer = Renderer( FROM_ITER, cls=cls, @@ -1100,15 +1171,14 @@ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=Non suppress_coerce=(not type_check.is_coerce()), class_deserializer=class_deserializer, ) - env = jinja2.Environment(loader=jinja2.DictLoader({"iter": template})) - env.filters.update({"rvalue": renderer.render}) - env.filters.update({"arg": to_iter_arg}) fields = list(filter(renderable, defields(cls))) - res = env.get_template("iter").render( + res = jinja2_env.get_template("iter").render( func=FROM_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=fields, cls_type_vars=get_type_var_names(cls), + rvalue=renderer.render, + arg=to_iter_arg, ) if renderer.import_numpy: @@ -1124,34 +1194,6 @@ def render_from_dict( type_check: TypeCheck = strict, class_deserializer: Optional[ClassDeserializer] = None, ) -> str: - template = """ -def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None, - variable_type_args=None, reuse_instances=None): - if reuse_instances is None: - reuse_instances = {{serde_scope.reuse_instances_default}} - - maybe_generic_type_vars = maybe_generic_type_vars or {{cls_type_vars}} - - {% for f in fields %} - __{{f.name}} = {{f|arg(loop.index-1)|rvalue}} - {% endfor %} - - try: - return cls( - {% for f in fields %} - {% if f.kw_only %} - {{f.name}}=__{{f.name}}, - {% else %} - __{{f.name}}, - {% endif %} - {% endfor %} - ) - except BeartypeCallHintParamViolation as e: - raise SerdeError(e) - except Exception as e: - raise UserError(e) - """ - renderer = Renderer( FROM_DICT, cls=cls, @@ -1159,16 +1201,15 @@ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=Non suppress_coerce=(not type_check.is_coerce()), class_deserializer=class_deserializer, ) - env = jinja2.Environment(loader=jinja2.DictLoader({"dict": template})) - env.filters.update({"rvalue": renderer.render}) - env.filters.update({"arg": functools.partial(to_arg, rename_all=rename_all)}) fields = list(filter(renderable, defields(cls))) - res = env.get_template("dict").render( + res = jinja2_env.get_template("dict").render( func=FROM_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=fields, type_check=type_check, cls_type_vars=get_type_var_names(cls), + rvalue=renderer.render, + arg=functools.partial(to_arg, rename_all=rename_all), ) if renderer.import_numpy: @@ -1180,61 +1221,21 @@ def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=Non def render_union_func( cls: type[Any], union_args: Sequence[type[Any]], tagging: Tagging = DefaultTagging ) -> str: - template = """ -def {{func}}(cls=cls, maybe_generic=None, maybe_generic_type_vars=None, data=None, - variable_type_args=None, reuse_instances = {{serde_scope.reuse_instances_default}}): - errors = [] - {% for t in union_args %} - try: - # create fake dict so we can reuse the normal render function - {% if tagging.is_external() and is_taggable(t) %} - ensure("{{t|typename}}" in data , "'{{t|typename}}' key is not present") - fake_dict = {"fake_key": data["{{t|typename}}"]} - - {% elif tagging.is_internal() and is_taggable(t) %} - ensure("{{tagging.tag}}" in data , "'{{tagging.tag}}' key is not present") - ensure("{{t|typename}}" == data["{{tagging.tag}}"], "tag '{{t|typename}}' isn't found") - fake_dict = {"fake_key": data} - - {% elif tagging.is_adjacent() and is_taggable(t) %} - ensure("{{tagging.tag}}" in data , "'{{tagging.tag}}' key is not present") - ensure("{{tagging.content}}" in data , "'{{tagging.content}}' key is not present") - ensure("{{t|typename}}" == data["{{tagging.tag}}"], "tag '{{t|typename}}' isn't found") - fake_dict = {"fake_key": data["{{tagging.content}}"]} - - {% else %} - fake_dict = {"fake_key": data} - {% endif %} - - {% if t|is_primitive or t|is_none %} - if not isinstance(fake_dict["fake_key"], {{t|typename}}): - raise Exception("Not a type of {{t|typename}}") - {% endif %} - return {{t|arg|rvalue}} - except Exception as e: - errors.append(f' Failed to deserialize into {{t|typename}}: {e}') - {% endfor %} - raise SerdeError("Can not deserialize " + repr(data) + " of type " + \ - typename(type(data)) + " into {{union_name}}.\\nReasons:\\n" + "\\n".join(errors)) - """ union_name = f"Union[{', '.join([typename(a) for a in union_args])}]" renderer = Renderer(FROM_DICT, cls=cls, suppress_coerce=True) - env = jinja2.Environment(loader=jinja2.DictLoader({"dict": template})) - env.filters.update( - {"arg": lambda x: DeField(x, datavar="fake_dict", name="fake_key")} - ) # use custom to_arg for fake field - env.filters.update({"rvalue": renderer.render}) - env.filters.update({"is_primitive": is_primitive}) - env.filters.update({"is_none": is_none}) - env.filters.update({"typename": typename}) - return env.get_template("dict").render( + return jinja2_env.get_template("union").render( func=union_func_name(UNION_DE_PREFIX, union_args), serde_scope=getattr(cls, SERDE_SCOPE), union_args=union_args, union_name=union_name, tagging=tagging, is_taggable=Tagging.is_taggable, + arg=lambda x: DeField(x, datavar="fake_dict", name="fake_key"), + rvalue=renderer.render, + is_primitive=is_primitive, + is_none=is_none, + typename=typename, ) diff --git a/serde/se.py b/serde/se.py index 5962a0ab..44a7192e 100644 --- a/serde/se.py +++ b/serde/se.py @@ -67,7 +67,6 @@ coerce_object, disabled, strict, - conv, fields, is_instance, logger, @@ -493,14 +492,34 @@ def sefields(cls: type[Any], serialize_class_var: bool = False) -> Iterator[SeFi yield f -def render_to_tuple( - cls: type[Any], - legacy_class_serializer: Optional[SerializeFunc] = None, - type_check: TypeCheck = strict, - serialize_class_var: bool = False, - class_serializer: Optional[ClassSerializer] = None, -) -> str: - template = """ +jinja2_env = jinja2.Environment( + loader=jinja2.DictLoader( + { + "dict": """ +def {{func}}(obj, reuse_instances = None, convert_sets = None): + if reuse_instances is None: + reuse_instances = {{serde_scope.reuse_instances_default}} + if convert_sets is None: + convert_sets = {{serde_scope.convert_sets_default}} + if not is_dataclass(obj): + return copy.deepcopy(obj) + + res = {} + {% for f in fields -%} + {% if not f.skip -%} + {% if f.skip_if -%} + subres = {{rvalue(f)}} + if not {{f.skip_if.name}}(subres): + {{lvalue(f)}} = subres + {% else -%} + {{lvalue(f)}} = {{rvalue(f)}} + {% endif -%} + {% endif %} + + {% endfor -%} + return res +""", + "iter": """ def {{func}}(obj, reuse_instances=None, convert_sets=None): if reuse_instances is None: reuse_instances = {{serde_scope.reuse_instances_default}} @@ -512,12 +531,52 @@ def {{func}}(obj, reuse_instances=None, convert_sets=None): return ( {% for f in fields -%} {% if not f.skip|default(False) %} - {{f|rvalue()}}, + {{rvalue(f)}}, {% endif -%} {% endfor -%} ) - """ +""", + "union": """ +def {{func}}(obj, reuse_instances, convert_sets): + union_args = serde_scope.union_se_args['{{func}}'] + + {% for t in union_args %} + if is_instance(obj, union_args[{{loop.index0}}]): + {% if tagging.is_external() and is_taggable(t) %} + return {"{{typename(t)}}": {{rvalue(arg(t))}}} + + {% elif tagging.is_internal() and is_taggable(t) %} + res = {{rvalue(arg(t))}} + res["{{tagging.tag}}"] = "{{typename(t)}}" + return res + + {% elif tagging.is_adjacent() and is_taggable(t) %} + res = {"{{tagging.content}}": {{rvalue(arg(t))}}} + res["{{tagging.tag}}"] = "{{typename(t)}}" + return res + + {% else %} + return {{rvalue(arg(t))}} + {% endif %} + {% endfor %} + raise SerdeError("Can not serialize " + \ + repr(obj) + \ + " of type " + \ + typename(type(obj)) + \ + " for {{union_name}}") +""", + } + ) +) + +def render_to_tuple( + cls: type[Any], + legacy_class_serializer: Optional[SerializeFunc] = None, + type_check: TypeCheck = strict, + serialize_class_var: bool = False, + class_serializer: Optional[ClassSerializer] = None, +) -> str: renderer = Renderer( TO_ITER, legacy_class_serializer, @@ -525,13 +584,12 @@ def {{func}}(obj, reuse_instances=None, convert_sets=None): serialize_class_var=serialize_class_var, class_serializer=class_serializer, ) - env = jinja2.Environment(loader=jinja2.DictLoader({"iter": template})) - env.filters.update({"rvalue": renderer.render}) - return env.get_template("iter").render( + return jinja2_env.get_template("iter").render( func=TO_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=sefields(cls, serialize_class_var), type_check=type_check, + rvalue=renderer.render, ) @@ -543,30 +601,6 @@ def render_to_dict( serialize_class_var: bool = False, class_serializer: Optional[ClassSerializer] = None, ) -> str: - template = """ -def {{func}}(obj, reuse_instances = None, convert_sets = None): - if reuse_instances is None: - reuse_instances = {{serde_scope.reuse_instances_default}} - if convert_sets is None: - convert_sets = {{serde_scope.convert_sets_default}} - if not is_dataclass(obj): - return copy.deepcopy(obj) - - res = {} - {% for f in fields -%} - {% if not f.skip -%} - {% if f.skip_if -%} - subres = {{f|rvalue}} - if not {{f.skip_if.name}}(subres): - {{f|lvalue}} = subres - {% else -%} - {{f|lvalue}} = {{f|rvalue}} - {% endif -%} - {% endif %} - - {% endfor -%} - return res - """ renderer = Renderer( TO_DICT, legacy_class_serializer, @@ -574,15 +608,13 @@ def {{func}}(obj, reuse_instances = None, convert_sets = None): class_serializer=class_serializer, ) lrenderer = LRenderer(case, serialize_class_var) - env = jinja2.Environment(loader=jinja2.DictLoader({"dict": template})) - env.filters.update({"rvalue": renderer.render}) - env.filters.update({"lvalue": lrenderer.render}) - env.filters.update({"case": functools.partial(conv, case=case)}) - return env.get_template("dict").render( + return jinja2_env.get_template("dict").render( func=TO_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=sefields(cls, serialize_class_var), type_check=type_check, + lvalue=lrenderer.render, + rvalue=renderer.render, ) @@ -592,49 +624,18 @@ def render_union_func( """ Render function that serializes a field with union type. """ - template = """ -def {{func}}(obj, reuse_instances, convert_sets): - union_args = serde_scope.union_se_args['{{func}}'] - - {% for t in union_args %} - if is_instance(obj, union_args[{{loop.index0}}]): - {% if tagging.is_external() and is_taggable(t) %} - return {"{{t|typename}}": {{t|arg|rvalue}}} - - {% elif tagging.is_internal() and is_taggable(t) %} - res = {{t|arg|rvalue}} - res["{{tagging.tag}}"] = "{{t|typename}}" - return res - - {% elif tagging.is_adjacent() and is_taggable(t) %} - res = {"{{tagging.content}}": {{t|arg|rvalue}}} - res["{{tagging.tag}}"] = "{{t|typename}}" - return res - - {% else %} - return {{t|arg|rvalue}} - {% endif %} - {% endfor %} - raise SerdeError("Can not serialize " + \ - repr(obj) + \ - " of type " + \ - typename(type(obj)) + \ - " for {{union_name}}") - """ union_name = f"Union[{', '.join([typename(a) for a in union_args])}]" - renderer = Renderer(TO_DICT, suppress_coerce=True) - env = jinja2.Environment(loader=jinja2.DictLoader({"dict": template})) - env.filters.update({"arg": lambda x: SeField(x, "obj")}) - env.filters.update({"rvalue": renderer.render}) - env.filters.update({"typename": typename}) - return env.get_template("dict").render( + return jinja2_env.get_template("union").render( func=union_func_name(UNION_SE_PREFIX, union_args), serde_scope=getattr(cls, SERDE_SCOPE), union_args=union_args, union_name=union_name, tagging=tagging, is_taggable=Tagging.is_taggable, + arg=lambda x: SeField(x, "obj"), + rvalue=renderer.render, + typename=typename, )