From 1515f767990be9bd14d5ffd050815897fa46c607 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Thu, 25 Jul 2024 20:30:59 -0500 Subject: [PATCH 1/5] initial working jaxtyping serializing/deserializing --- examples/type_numpy_jaxtyping.py | 90 ++++++++++++++++++++++++++++++++ serde/de.py | 5 ++ serde/numpy.py | 16 ++++++ serde/se.py | 3 ++ 4 files changed, 114 insertions(+) create mode 100644 examples/type_numpy_jaxtyping.py diff --git a/examples/type_numpy_jaxtyping.py b/examples/type_numpy_jaxtyping.py new file mode 100644 index 00000000..e74bfd49 --- /dev/null +++ b/examples/type_numpy_jaxtyping.py @@ -0,0 +1,90 @@ +import numpy +from jaxtyping import ( + Float, + Float16, + Float32, + Float64, + Inexact, + Int, + Int8, + Int16, + Int32, + Int64, + Integer, + UInt, + UInt8, + UInt16, + UInt32, + UInt64, +) +from serde import serde +from serde.json import from_json, to_json + + +@serde +class Foo: + float_: Float[numpy.ndarray, "3 3"] + float16: Float16[numpy.ndarray, "3 3"] + float32: Float32[numpy.ndarray, "3 3"] + float64: Float64[numpy.ndarray, "3 3"] + inexact: Inexact[numpy.ndarray, "3 3"] + int_: Int[numpy.ndarray, "3 3"] + int8: Int8[numpy.ndarray, "3 3"] + int16: Int16[numpy.ndarray, "3 3"] + int32: Int32[numpy.ndarray, "3 3"] + int64: Int64[numpy.ndarray, "3 3"] + integer: Integer[numpy.ndarray, "3 3"] + uint: UInt[numpy.ndarray, "3 3"] + uint8: UInt8[numpy.ndarray, "3 3"] + uint16: UInt16[numpy.ndarray, "3 3"] + uint32: UInt32[numpy.ndarray, "3 3"] + uint64: UInt64[numpy.ndarray, "3 3"] + + +def main() -> None: + foo = Foo( + float_=numpy.zeros((3, 3), dtype=float), + float16=numpy.zeros((3, 3), dtype=numpy.float16), + float32=numpy.zeros((3, 3), dtype=numpy.float32), + float64=numpy.zeros((3, 3), dtype=numpy.float64), + inexact=numpy.zeros((3, 3), dtype=numpy.inexact), + int_=numpy.zeros((3, 3), dtype=int), + int8=numpy.zeros((3, 3), dtype=numpy.int8), + int16=numpy.zeros((3, 3), dtype=numpy.int16), + int32=numpy.zeros((3, 3), dtype=numpy.int32), + int64=numpy.zeros((3, 3), dtype=numpy.int64), + integer=numpy.zeros((3, 3), dtype=numpy.integer), + uint=numpy.zeros((3, 3), dtype=numpy.uint), + uint8=numpy.zeros((3, 3), dtype=numpy.uint8), + uint16=numpy.zeros((3, 3), dtype=numpy.uint16), + uint32=numpy.zeros((3, 3), dtype=numpy.uint32), + uint64=numpy.zeros((3, 3), dtype=numpy.uint64), + ) + + print(f"Into Json: {to_json(foo)}") + + s = """ + { + "float_": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + "float16": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + "float32": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + "float64": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + "inexact": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + "int_": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "int8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "int16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "int32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "int64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "integer": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "uint": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "uint8": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "uint16": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "uint32": [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + "uint64": [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + } + """ + print(f"From Json: {from_json(Foo, s)}") + + +if __name__ == "__main__": + main() diff --git a/serde/de.py b/serde/de.py index 047be4de..1634c551 100644 --- a/serde/de.py +++ b/serde/de.py @@ -89,7 +89,9 @@ deserialize_numpy_array, deserialize_numpy_scalar, deserialize_numpy_array_direct, + deserialize_numpy_jaxtyping_array, is_numpy_array, + is_numpy_jaxtyping, is_numpy_scalar, ) @@ -749,6 +751,9 @@ def render(self, arg: DeField[Any]) -> str: elif is_numpy_array(arg.type): self.import_numpy = True res = deserialize_numpy_array(arg) + elif is_numpy_jaxtyping(arg.type): + self.import_numpy = True + res = deserialize_numpy_jaxtyping_array(arg) elif is_union(arg.type): res = self.union_func(arg) elif is_str_serializable(arg.type): diff --git a/serde/numpy.py b/serde/numpy.py index a019a93e..c305c660 100644 --- a/serde/numpy.py +++ b/serde/numpy.py @@ -73,6 +73,12 @@ def is_numpy_array(typ) -> bool: typ = origin return typ is np.ndarray + def is_numpy_jaxtyping(typ) -> bool: + origin = get_origin(typ) + if origin is not None: + typ = origin + return issubclass(typ, np.ndarray) + def serialize_numpy_array(arg) -> str: return f"{arg.varname}.tolist()" @@ -86,6 +92,10 @@ def deserialize_numpy_array(arg) -> str: dtype = fullname(arg[1][0].type) return f"numpy.array({arg.data}, dtype={dtype})" + def deserialize_numpy_jaxtyping_array(arg) -> str: + dtype = f"numpy.{arg.type.dtypes[-1]}" + return f"numpy.array({arg.data}, dtype={dtype})" + def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any: if is_bare_numpy_array(typ): return np.array(arg) @@ -111,6 +121,9 @@ def deserialize_numpy_scalar(arg): def is_numpy_array(typ) -> bool: return False + def is_numpy_jaxtyping(typ) -> bool: + return False + def serialize_numpy_array(arg) -> str: return "" @@ -120,5 +133,8 @@ def serialize_numpy_datetime(arg) -> str: def deserialize_numpy_array(arg) -> str: return "" + def deserialize_numpy_jaxtyping_array(arg) -> str: + return "" + def deserialize_numpy_array_direct(typ: Any, arg: Any) -> Any: return arg diff --git a/serde/se.py b/serde/se.py index 1501a51b..fad3dbbf 100644 --- a/serde/se.py +++ b/serde/se.py @@ -76,6 +76,7 @@ ) from .numpy import ( is_numpy_array, + is_numpy_jaxtyping, is_numpy_datetime, is_numpy_scalar, serialize_numpy_array, @@ -751,6 +752,8 @@ def render(self, arg: SeField[Any]) -> str: res = serialize_numpy_scalar(arg) elif is_numpy_array(arg.type): res = serialize_numpy_array(arg) + elif is_numpy_jaxtyping(arg.type): + res = serialize_numpy_array(arg) elif is_primitive(arg.type): res = self.primitive(arg) elif is_union(arg.type): From 76d9c3ee783256d06e93f97d3cb42a4b0baf7ee7 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Sat, 27 Jul 2024 11:53:51 -0500 Subject: [PATCH 2/5] chore: Handle TypeError in is_numpy_jaxtyping The code changes in `serde/numpy.py` modify the `is_numpy_jaxtyping` function to handle a `TypeError` exception. This change ensures that if a `TypeError` occurs when trying to determine if a type is a numpy jaxtyping, the function will return `False` instead of raising an error. --- serde/numpy.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/serde/numpy.py b/serde/numpy.py index c305c660..f81bb90b 100644 --- a/serde/numpy.py +++ b/serde/numpy.py @@ -74,10 +74,13 @@ def is_numpy_array(typ) -> bool: return typ is np.ndarray def is_numpy_jaxtyping(typ) -> bool: - origin = get_origin(typ) - if origin is not None: - typ = origin - return issubclass(typ, np.ndarray) + try: + origin = get_origin(typ) + if origin is not None: + typ = origin + return issubclass(typ, np.ndarray) + except TypeError: + return False def serialize_numpy_array(arg) -> str: return f"{arg.varname}.tolist()" From edea8f890cc2980283895301b57785d0ccaf3639 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Sat, 27 Jul 2024 12:12:08 -0500 Subject: [PATCH 3/5] add jaxtyping as an extra dep --- pyproject.toml | 59 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 06afa4e4..1cc4b88a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,16 +11,16 @@ packages = [ { include = "serde" }, ] classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - ] + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] [tool.poetry.dependencies] python = "^3.9.0" @@ -33,10 +33,10 @@ tomli = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional tomli-w = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional = true } pyyaml = { version = "*", markers = "extra == 'yaml' or extra == 'all'", optional = true } numpy = [ - { version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true }, - { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true }, - { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true }, - { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true }, + { version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true }, + { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true }, + { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true }, + { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true }, ] orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true } plum-dispatch = ">=2,<2.3" @@ -49,10 +49,10 @@ tomli = { version = "*", markers = "python_version <= '3.11.0'" } tomli-w = "*" msgpack = "*" numpy = [ - { version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" }, - { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" }, - { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" }, - { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" }, + { version = ">1.21.0,<2.0.0", markers = "python_version ~= '3.9.0'" }, + { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.10'" }, + { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11'" }, + { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12'" }, ] mypy = "==1.10.1" pytest = "*" @@ -68,6 +68,7 @@ types-PyYAML = "^6.0.9" msgpack-types = "^0.3" envclasses = "^0.3.1" jedi = "*" +jaxtyping = "*" [tool.poetry.extras] msgpack = ["msgpack"] @@ -76,7 +77,8 @@ toml = ["tomli", "tomli-w"] yaml = ["pyyaml"] orjson = ["orjson"] sqlalchemy = ["sqlalchemy"] -all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy"] +jaxtyping = ["jaxtyping"] +all = ["msgpack", "tomli", "tomli-w", "pyyaml", "numpy", "orjson", "sqlalchemy", "jaxtyping"] [build-system] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning"] @@ -145,16 +147,25 @@ exclude = [ "tests/test_sqlalchemy.py", ] +[[tool.mypy.overrides]] +# to avoid complaints about generic type ndarray +module = "examples.type_numpy_jaxtyping" +ignore_errors = true + [tool.ruff] select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "C", # flake8-comprehensions - "B", # flake8-bugbear + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "C", # flake8-comprehensions + "B", # flake8-bugbear ] ignore = ["B904"] line-length = 100 [tool.ruff.lint.mccabe] max-complexity = 30 + +[tool.ruff.per-file-ignores] +# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error +"examples/type_numpy_jaxtyping.py" = ["F722"] From 86cfda93f6fc8fb47ba59c5122bb28207a48f51c Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Mon, 29 Jul 2024 11:33:59 -0500 Subject: [PATCH 4/5] Add jaxtyping as an optional dependency, make sure typ is jaxtyping and not numpy --- pyproject.toml | 1 + serde/numpy.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1cc4b88a..14e7f723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ numpy = [ { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.11' and (extra == 'numpy' or extra == 'all')", optional = true }, { version = ">1.22.0,<2.0.0", markers = "python_version ~= '3.12' and (extra == 'numpy' or extra == 'all')", optional = true }, ] +jaxtyping = { version = "*", markers = "extra == 'jaxtyping' or extra == 'all'", optional = true } orjson = { version = "*", markers = "extra == 'orjson' or extra == 'all'", optional = true } plum-dispatch = ">=2,<2.3" beartype = ">=0.18.4" diff --git a/serde/numpy.py b/serde/numpy.py index f81bb90b..4fa181e4 100644 --- a/serde/numpy.py +++ b/serde/numpy.py @@ -78,7 +78,7 @@ def is_numpy_jaxtyping(typ) -> bool: origin = get_origin(typ) if origin is not None: typ = origin - return issubclass(typ, np.ndarray) + return typ is not np.ndarray and issubclass(typ, np.ndarray) except TypeError: return False From 2a012c1f1ccae0822739e358ea06464f0ebab2d6 Mon Sep 17 00:00:00 2001 From: pablovela5620 Date: Wed, 31 Jul 2024 17:32:05 -0500 Subject: [PATCH 5/5] add jaxtyping test --- tests/test_numpy.py | 61 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_numpy.py b/tests/test_numpy.py index 7404e756..9d3596ed 100644 --- a/tests/test_numpy.py +++ b/tests/test_numpy.py @@ -3,6 +3,7 @@ import numpy as np import numpy.typing as npt +import jaxtyping import pytest import serde @@ -89,6 +90,66 @@ class NumpyDate: assert de(NumpyDate, se(date_test)) == date_test + @serde.serde(**opt) + class NumpyJaxtyping: + float_: jaxtyping.Float[np.ndarray, "2 2"] # noqa: F722 + float16: jaxtyping.Float16[np.ndarray, "2 2"] # noqa: F722 + float32: jaxtyping.Float32[np.ndarray, "2 2"] # noqa: F722 + float64: jaxtyping.Float64[np.ndarray, "2 2"] # noqa: F722 + inexact: jaxtyping.Inexact[np.ndarray, "2 2"] # noqa: F722 + int_: jaxtyping.Int[np.ndarray, "2 2"] # noqa: F722 + int8: jaxtyping.Int8[np.ndarray, "2 2"] # noqa: F722 + int16: jaxtyping.Int16[np.ndarray, "2 2"] # noqa: F722 + int32: jaxtyping.Int32[np.ndarray, "2 2"] # noqa: F722 + int64: jaxtyping.Int64[np.ndarray, "2 2"] # noqa: F722 + integer: jaxtyping.Integer[np.ndarray, "2 2"] # noqa: F722 + uint: jaxtyping.UInt[np.ndarray, "2 2"] # noqa: F722 + uint8: jaxtyping.UInt8[np.ndarray, "2 2"] # noqa: F722 + uint16: jaxtyping.UInt16[np.ndarray, "2 2"] # noqa: F722 + uint32: jaxtyping.UInt32[np.ndarray, "2 2"] # noqa: F722 + uint64: jaxtyping.UInt64[np.ndarray, "2 2"] # noqa: F722 + + def __eq__(self, other): + return ( + (self.float_ == other.float_).all() + and (self.float16 == other.float16).all() + and (self.float32 == other.float32).all() + and (self.float64 == other.float64).all() + and (self.inexact == other.inexact).all() + and (self.int_ == other.int_).all() + and (self.int8 == other.int8).all() + and (self.int16 == other.int16).all() + and (self.int32 == other.int32).all() + and (self.int64 == other.int64).all() + and (self.integer == other.integer).all() + and (self.uint == other.uint).all() + and (self.uint8 == other.uint8).all() + and (self.uint16 == other.uint16).all() + and (self.uint32 == other.uint32).all() + and (self.uint64 == other.uint64).all() + ) + + jaxtyping_test = NumpyJaxtyping( + float_=np.array([[1, 2], [3, 4]], dtype=np.float_), + float16=np.array([[5, 6], [7, 8]], dtype=np.float16), + float32=np.array([[9, 10], [11, 12]], dtype=np.float32), + float64=np.array([[13, 14], [15, 16]], dtype=np.float64), + inexact=np.array([[17, 18], [19, 20]], dtype=np.float_), + int_=np.array([[21, 22], [23, 24]], dtype=np.int_), + int8=np.array([[25, 26], [27, 28]], dtype=np.int8), + int16=np.array([[29, 30], [31, 32]], dtype=np.int16), + int32=np.array([[33, 34], [35, 36]], dtype=np.int32), + int64=np.array([[37, 38], [39, 40]], dtype=np.int64), + integer=np.array([[41, 42], [43, 44]], dtype=np.int_), + uint=np.array([[45, 46], [47, 48]], dtype=np.uint), + uint8=np.array([[49, 50], [51, 52]], dtype=np.uint8), + uint16=np.array([[53, 54], [55, 56]], dtype=np.uint16), + uint32=np.array([[57, 58], [59, 60]], dtype=np.uint32), + uint64=np.array([[61, 62], [63, 64]], dtype=np.uint64), + ) + + assert de(NumpyJaxtyping, se(jaxtyping_test)) == jaxtyping_test + @pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids()) @pytest.mark.parametrize("se,de", format_json + format_msgpack)