Skip to content

Commit

Permalink
More unit tests for NestedExtendedArray
Browse files Browse the repository at this point in the history
100% coverage!
  • Loading branch information
hombit committed May 3, 2024
1 parent 3e90469 commit e39977b
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions tests/nested_pandas/series/test_ext_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,13 @@ def test_equals():
assert array1.equals(array2)


def test_equals_when_other_is_different_type():
"""Test that equals() raises for different dtypes."""
ext_array = NestedExtensionArray._from_sequence([{"a": [1, None, 3], "b": [-4.0, -5.0, None]}, None])
other = ext_array.to_arrow_ext_array()
assert not ext_array.equals(other)


def test_dropna():
"""Test .dropna()"""
dtype = NestedDtype.from_fields({"a": pa.int64(), "b": pa.float64()})
Expand Down Expand Up @@ -1364,6 +1371,38 @@ def test_set_list_field_replace_field():
assert_series_equal(pd.Series(ext_array), pd.Series(desired))


def test_set_list_field_raises_for_non_list_array():
"""Tests that we raise an error when trying to set a field with a non-list array."""
struct_array = pa.StructArray.from_arrays(
arrays=[
pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0, 2.0])]),
pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]),
],
names=["a", "b"],
)
ext_array = NestedExtensionArray(struct_array)

with pytest.raises(ValueError):
ext_array.set_list_field("b", [1.0, 2.0])


def test_set_list_field_raises_for_wrong_length():
"""Tests that we raise an error when trying to set a field with an array-like of the wrong length."""
struct_array = pa.StructArray.from_arrays(
arrays=[
pa.array([np.array([1.0, 2.0, 3.0])]),
pa.array([-np.array([4.0, 5.0, 6.0])]),
],
names=["a", "b"],
)
ext_array = NestedExtensionArray(struct_array)

longer_array = np.array([[1.0, 2.0, 3.0], [1.0, 2.0]], dtype=object)

with pytest.raises(ValueError):
ext_array.set_list_field("b", longer_array)


def test_pop_field():
"""Tests that we can pop a field from the extension array."""
struct_array = pa.StructArray.from_arrays(
Expand All @@ -1390,6 +1429,21 @@ def test_pop_field():
assert_series_equal(pd.Series(ext_array), pd.Series(desired))


def test_pop_field_raises_for_invalid_field():
"""Tests that we raise an error when trying to pop a field that does not exist."""
struct_array = pa.StructArray.from_arrays(
arrays=[
pa.array([np.array([1.0, 2.0, 3.0])]),
pa.array([-np.array([4.0, 5.0, 6.0])]),
],
names=["a", "b"],
)
ext_array = NestedExtensionArray(struct_array)

with pytest.raises(ValueError):
ext_array.pop_field("c")


def test_delete_last_field_raises():
"""Tests that we raise an error when trying to delete the last field left."""
struct_array = pa.StructArray.from_arrays(
Expand Down

0 comments on commit e39977b

Please sign in to comment.