From b458523e68aafc7d2810c81bb39b84ff8744995d Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Wed, 26 Jun 2024 22:28:25 +0530 Subject: [PATCH] feat: copy_behaviors to make sub-classing easy (#3137) * feat: copy_behaviors to make sub-classing easy * pylint errors and tests * make 'copy_behaviors' safer by making it immutable --------- Co-authored-by: Jim Pivarski Co-authored-by: Jim Pivarski --- src/awkward/_util.py | 16 ++++ tests/test_2433_copy_behaviors.py | 125 ++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 tests/test_2433_copy_behaviors.py diff --git a/src/awkward/_util.py b/src/awkward/_util.py index 9c68aaf89b..89ea936ae8 100644 --- a/src/awkward/_util.py +++ b/src/awkward/_util.py @@ -6,6 +6,7 @@ import os import struct import sys +import typing from collections.abc import Collection import numpy as np # noqa: TID251 @@ -102,3 +103,18 @@ def unique_list(items: Collection[T]) -> list[T]: seen.add(item) result.append(item) return result + + +def copy_behaviors(existing_class: typing.Any, new_class: typing.Any, behavior: dict): + output = {} + + oldname = existing_class.__name__ + newname = new_class.__name__ + + for key, value in behavior.items(): + if oldname in key: + if not isinstance(key, str) and "*" not in key: + new_tuple = tuple(newname if k == oldname else k for k in key) + output[new_tuple] = value + + return output diff --git a/tests/test_2433_copy_behaviors.py b/tests/test_2433_copy_behaviors.py new file mode 100644 index 0000000000..c52decc09e --- /dev/null +++ b/tests/test_2433_copy_behaviors.py @@ -0,0 +1,125 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import numpy +import pytest + +import awkward as ak + + +def test(): + class SuperVector: + def add(self, other): + """Add two vectors together elementwise using `x` and `y` components""" + return ak.zip( + {"x": self.x + other.x, "y": self.y + other.y}, + with_name="VectorTwoD", + behavior=self.behavior, + ) + + # first sub-class + @ak.mixin_class(ak.behavior) + class VectorTwoD(SuperVector): + def __eq__(self, other): + return ak.all(self.x == other.x) and ak.all(self.y == other.y) + + v = ak.Array( + [ + [{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}], + [], + [{"x": 3, "y": 3.3}], + [ + {"x": 4, "y": 4.4}, + {"x": 5, "y": 5.5}, + {"x": 6, "y": 6.6}, + ], + ], + with_name="VectorTwoD", + behavior=ak.behavior, + ) + v_added = ak.Array( + [ + [{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}], + [], + [{"x": 6, "y": 6.6}], + [ + {"x": 8, "y": 8.8}, + {"x": 10, "y": 11}, + {"x": 12, "y": 13.2}, + ], + ], + with_name="VectorTwoD", + behavior=ak.behavior, + ) + + # add method works but the binary operator does not + assert v.add(v) == v_added + with pytest.raises(TypeError): + v + v + + # registering the operator makes everything work + ak.behavior[numpy.add, "VectorTwoD", "VectorTwoD"] = lambda v1, v2: v1.add(v2) + assert v + v == v_added + + # second sub-class + @ak.mixin_class(ak.behavior) + class VectorTwoDAgain(VectorTwoD): + pass + + v = ak.Array( + [ + [{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}], + [], + [{"x": 3, "y": 3.3}], + [ + {"x": 4, "y": 4.4}, + {"x": 5, "y": 5.5}, + {"x": 6, "y": 6.6}, + ], + ], + with_name="VectorTwoDAgain", + behavior=ak.behavior, + ) + # add method works but the binary operator does not + assert v.add(v) == v_added + with pytest.raises(TypeError): + v + v + + # instead of registering every operator again, just copy the behaviors of + # another class to this class + ak.behavior.update( + ak._util.copy_behaviors(VectorTwoD, VectorTwoDAgain, ak.behavior) + ) + assert v + v == v_added + + # third sub-class + @ak.mixin_class(ak.behavior) + class VectorTwoDAgainAgain(VectorTwoDAgain): + pass + + v = ak.Array( + [ + [{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}], + [], + [{"x": 3, "y": 3.3}], + [ + {"x": 4, "y": 4.4}, + {"x": 5, "y": 5.5}, + {"x": 6, "y": 6.6}, + ], + ], + with_name="VectorTwoDAgainAgain", + behavior=ak.behavior, + ) + # add method works but the binary operator does not + assert v.add(v) == v_added + with pytest.raises(TypeError): + v + v + + # instead of registering every operator again, just copy the behaviors of + # another class to this class + ak.behavior.update( + ak._util.copy_behaviors(VectorTwoDAgain, VectorTwoDAgainAgain, ak.behavior) + ) + assert v + v == v_added