From 0923d228b9ac4e6c6a0c8110819a71a32b4fed1d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 15 Aug 2024 10:22:57 -0500 Subject: [PATCH] fix: generalize `Index.ptr` (#3206) * fix: generalize 'Index.ptr' * include a test that checks TypeTracer --- src/awkward/index.py | 10 ++++++++-- tests/test_3206_generalize_index_ptr.py | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 tests/test_3206_generalize_index_ptr.py diff --git a/src/awkward/index.py b/src/awkward/index.py index 3e6f879cfe..c6726e726e 100644 --- a/src/awkward/index.py +++ b/src/awkward/index.py @@ -136,10 +136,16 @@ def metadata(self) -> dict: @property def ptr(self): - if self._nplike == Numpy.instance(): + if isinstance(self._nplike, Numpy): return self._data.ctypes.data - elif self._nplike == Cupy.instance(): + elif isinstance(self._nplike, Cupy): return self._data.data.ptr + elif isinstance(self._nplike, TypeTracer): + return 0 + else: + raise NotImplementedError( + f"this function hasn't been implemented for the {type(self._nplike).__name__} backend" + ) @property def length(self) -> ShapeItem: diff --git a/tests/test_3206_generalize_index_ptr.py b/tests/test_3206_generalize_index_ptr.py new file mode 100644 index 0000000000..33796aa7ab --- /dev/null +++ b/tests/test_3206_generalize_index_ptr.py @@ -0,0 +1,22 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import pytest + +import awkward as ak + + +def test_1(): + arr = ak.Array([[1, 3, 4], 5]) + tarr = arr.layout.to_typetracer() + + with pytest.raises(ak.errors.AxisError, match="exceeds the depth of this array"): + ak.flatten(tarr) + + +def test_2(): + arr = ak.Array([[[1, 3, 4]], [5]]) + tarr = arr.layout.to_typetracer() + + assert ak.flatten(tarr).type == ak.flatten(arr).type