Skip to content

Commit

Permalink
More tests for NestedExtensionArray
Browse files Browse the repository at this point in the history
  • Loading branch information
hombit committed May 2, 2024
1 parent 6f35f44 commit 3e90469
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 17 deletions.
34 changes: 18 additions & 16 deletions src/nested_pandas/series/ext_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
from copy import deepcopy
from typing import Any, Callable, cast

import numpy as np
Expand Down Expand Up @@ -76,8 +75,22 @@ class NestedExtensionArray(ExtensionArray):

@classmethod
def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self: # type: ignore[name-defined] # noqa: F821
"""Construct a NestedExtensionArray from a sequence of scalars.
Parameters
----------
scalars : Sequence
The sequence of scalars: disctionaries, DataFrames, None, pd.NA, pa.Array or anything convertible
to PyArrow scalars.
dtype : dtype or None
dtype of the resulting array
copy : bool
Ignored, because PyArrow arrays are immutable.
"""
del copy

pa_type = to_pyarrow_dtype(dtype)
pa_array = cls._box_pa_array(scalars, pa_type=pa_type, copy=copy)
pa_array = cls._box_pa_array(scalars, pa_type=pa_type)
return cls(pa_array)

# Tricky to implement, but reqquired by things like pd.read_csv
Expand Down Expand Up @@ -155,7 +168,7 @@ def __setitem__(self, key, value) -> None:
scalar = self._box_pa_scalar(value, pa_type=self._pyarrow_dtype)
except (ValueError, TypeError):
# Copy will happen later in replace_with_mask() anyway
value = self._box_pa_array(value, pa_type=self._pyarrow_dtype, copy=False)
value = self._box_pa_array(value, pa_type=self._pyarrow_dtype)
else:
# Our replace_with_mask implementation doesm't work with scalars
value = pa.array([scalar] * pa.compute.sum(pa_mask).as_py())
Expand Down Expand Up @@ -408,13 +421,7 @@ def __getstate__(self):

# Adopted from ArrowExtensionArray
def __setstate__(self, state):
# We keep the same implementation as ArrowExtensionArray, so ignoring linter which propsoses
# to rewrite it as a ternary expression
if "_data" in state: # noqa: SIM108
data = state.pop("_data")
else:
data = state["_pa_array"]
state["_pa_array"] = pa.chunked_array(data)
state["_pa_array"] = pa.chunked_array(state["_pa_array"])
self.__dict__.update(state)

# End of Additional magic methods #
Expand All @@ -433,9 +440,7 @@ def _box_pa_scalar(cls, value, *, pa_type: pa.DataType | None) -> pa.Scalar:
return pa.scalar(value, type=pa_type, from_pandas=True)

@classmethod
def _box_pa_array(
cls, value, *, pa_type: pa.DataType | None, copy: bool = False
) -> pa.Array | pa.ChunkedArray:
def _box_pa_array(cls, value, *, pa_type: pa.DataType | None) -> pa.Array | pa.ChunkedArray:
"""Convert a value to a PyArrow array with the specified type."""
if isinstance(value, cls):
pa_array = value._pa_array
Expand All @@ -460,15 +465,12 @@ def _box_pa_array(
scalars = [s.cast(pa_type) for s in scalars]
pa_array = pa.array(scalars)
# We already copied the data into scalars
copy = False

# We always cast - even if the type is the same, it does not hurt
# If the type is different the result may still be a view, so we do not set copy=False
if pa_type is not None:
pa_array = pa_array.cast(pa_type)

if copy:
pa_array = deepcopy(pa_array)
return pa_array

@classmethod
Expand Down
67 changes: 66 additions & 1 deletion tests/nested_pandas/series/test_ext_array.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import numpy as np
import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -798,6 +800,31 @@ def test_take(allow_fill, fill_value, desired_sequence):
assert result.equals(desired)


def test_take_raises_for_empty_array_and_non_empty_index():
"""Tests that .take([i1, i2, i3]) raises for empty array"""
ext_array = NestedExtensionArray._from_sequence([], dtype=NestedDtype.from_fields({"a": pa.int64()}))
with pytest.raises(IndexError):
_result = ext_array.take([0, 1, 2])


@pytest.mark.parametrize(
"indices",
[
[100],
[1 << 65],
[-100],
[0] * 100 + [100],
],
)
def test_take_raises_for_out_of_bounds_index(indices):
"""Tests that .take([i1, i2, i3]) raises for out of bounds index."""
ext_array = NestedExtensionArray._from_sequence(
[None, None], dtype=NestedDtype.from_fields({"a": pa.int64()})
)
with pytest.raises(IndexError):
ext_array.take(indices)


def test__formatter_unboxed():
"""Tests formatting of array values, when displayed alone."""
formatter = NestedExtensionArray._from_sequence(
Expand All @@ -817,6 +844,43 @@ def test__formatter_boxed():
assert formatter(df) == str(d)


def test__formetter_boxed_na():
"""Tests formatting of NA array value, when displayed in a DataFrame or Series"""
formatter = NestedExtensionArray._from_sequence(
[], dtype=NestedDtype.from_fields({"a": pa.int64(), "b": pa.float64()})
)._formatter(boxed=True)
assert formatter(pd.NA) == str(pd.NA)


def test_nbytes():
"""Test that the nbytes property is correct."""
struct_array = pa.StructArray.from_arrays(
arrays=[
pa.array([np.array([1, 2, 3]), np.array([1, 2, 1])], type=pa.list_(pa.uint32())),
pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])], pa.list_(pa.float64())),
],
names=["a", "b"],
)
ext_array = NestedExtensionArray(struct_array)

# Assume a typical 64-bit platform
a_data_size = 6 * 4
a_validity_buffer = 8 # cannot be smaller than 8 bytes because of alignment
b_data_size = 6 * 8
b_validity_buffer = 8 # cannot be smaller than 8 bytes because of alignment

assert ext_array.nbytes == a_data_size + a_validity_buffer + b_data_size + b_validity_buffer


def test_pickability():
"""Test that the extension array can be dumped and loaded back with pickle."""
ext_array = NestedExtensionArray._from_sequence(
[{"a": [1, None, 3], "b": [-4.0, -5.0, None]}, None, {"a": [100] * 10_000, "b": [-7.0] * 10_000}]
)
pickled = pickle.loads(pickle.dumps(ext_array))
assert ext_array.equals(pickled)


def test__concat_same_type():
"""Test concatenating of three NestedExtensionArrays with the same dtype."""
dtype = NestedDtype.from_fields({"a": pa.int64(), "b": pa.float64()})
Expand Down Expand Up @@ -984,6 +1048,7 @@ def test___array__():
type=pa.struct([pa.field("a", pa.list_(pa.string())), pa.field("b", pa.list_(pa.float64()))]),
),
),
(pa.scalar(None), None, pa.scalar(None)),
],
)
def test__box_pa_scalar(value, pa_type, desired):
Expand Down Expand Up @@ -1036,7 +1101,7 @@ def test__box_pa_scalar(value, pa_type, desired):
)
def test__box_pa_array(value, pa_type, desired):
"""Tests _box_pa_array"""
actual = NestedExtensionArray._box_pa_array(value, pa_type=pa_type, copy=False)
actual = NestedExtensionArray._box_pa_array(value, pa_type=pa_type)
assert actual == desired


Expand Down

0 comments on commit 3e90469

Please sign in to comment.