Skip to content

Commit

Permalink
fix: use PyCapsule Interface instead of Dataframe Interchange Protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 9, 2024
1 parent b4e5f8d commit 0bd8507
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 66 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"mypy",
"pandas-stubs",
"pre-commit",
"pyarrow",
"flit",
]
docs = [
Expand Down
66 changes: 31 additions & 35 deletions seaborn/_core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from collections.abc import Mapping, Sized
from typing import cast
import warnings

import pandas as pd
from pandas import DataFrame
Expand Down Expand Up @@ -269,9 +268,9 @@ def _assign_variables(

def handle_data_source(data: object) -> pd.DataFrame | Mapping | None:
"""Convert the data source object to a common union representation."""
if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"):
if isinstance(data, pd.DataFrame) or hasattr(data, "__arrow_c_stream__"):
# Check for pd.DataFrame inheritance could be removed once
# minimal pandas version supports dataframe interchange (1.5.0).
# minimal pandas version supports PyCapsule Interface (2.2).
data = convert_dataframe_to_pandas(data)
elif data is not None and not isinstance(data, Mapping):
err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}."
Expand All @@ -285,35 +284,32 @@ def convert_dataframe_to_pandas(data: object) -> pd.DataFrame:
if isinstance(data, pd.DataFrame):
return data

if not hasattr(pd.api, "interchange"):
msg = (
"Support for non-pandas DataFrame objects requires a version of pandas "
"that implements the DataFrame interchange protocol. Please upgrade "
"your pandas version or coerce your data to pandas before passing "
"it to seaborn."
)
raise TypeError(msg)

if _version_predates(pd, "2.0.2"):
msg = (
"DataFrame interchange with pandas<2.0.2 has some known issues. "
f"You are using pandas {pd.__version__}. "
"Continuing, but it is recommended to carefully inspect the results and to "
"consider upgrading."
)
warnings.warn(msg, stacklevel=2)

try:
# This is going to convert all columns in the input dataframe, even though
# we may only need one or two of them. It would be more efficient to select
# the columns that are going to be used in the plot prior to interchange.
# Solving that in general is a hard problem, especially with the objects
# interface where variables passed in Plot() may only be referenced later
# in Plot.add(). But noting here in case this seems to be a bottleneck.
return pd.api.interchange.from_dataframe(data)
except Exception as err:
msg = (
"Encountered an exception when converting data source "
"to a pandas DataFrame. See traceback above for details."
)
raise RuntimeError(msg) from err
if hasattr(data, '__arrow_c_stream__'):
try:
import pyarrow
except ImportError as err:
msg = "PyArrow is required for non-pandas Dataframe support."
raise RuntimeError(msg) from err
if _version_predates(pyarrow, '14.0.0'):
msg = "PyArrow>=14.0.0 is required for non-pandas Dataframe support."
raise RuntimeError(msg)
try:
# This is going to convert all columns in the input dataframe, even though
# we may only need one or two of them. It would be more efficient to select
# the columns that are going to be used in the plot prior to interchange.
# Solving that in general is a hard problem, especially with the objects
# interface where variables passed in Plot() may only be referenced later
# in Plot.add(). But noting here in case this seems to be a bottleneck.
return pyarrow.table(data).to_pandas()
except Exception as err:
msg = (
"Encountered an exception when converting data source "
"to a pandas DataFrame. See traceback above for details."
)
raise RuntimeError(msg) from err

msg = (
"Expected object which implements '__arrow_c_stream__' from the "
f"PyCapsule Interface, got: {type(data)}"
)
raise TypeError(msg)
2 changes: 1 addition & 1 deletion seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _resolve_positionals(

if (
isinstance(args[0], (abc.Mapping, pd.DataFrame))
or hasattr(args[0], "__dataframe__")
or hasattr(args[0], "__arrow_c_stream__")
):
if data is not None:
raise TypeError("`data` given by both name and position.")
Expand Down
4 changes: 2 additions & 2 deletions seaborn/_core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
VariableSpec = Union[ColumnName, Vector, None]
VariableSpecList = Union[List[VariableSpec], Index, None]

# A DataSource can be an object implementing __dataframe__, or a Mapping
# A DataSource can be an object implementing __arrow_c_stream__, or a Mapping
# (and is optional in all contexts where it is used).
# I don't think there's an abc for "has __dataframe__", so we type as object
# I don't think there's an abc for "has __arrow_c_stream__", so we type as object
# but keep the (slightly odd) Union alias for better user-facing annotations.
DataSource = Union[object, Mapping, None]

Expand Down
28 changes: 15 additions & 13 deletions tests/_core/test_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import numpy as np
import pandas as pd
from seaborn.external.version import Version

import pytest
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -404,11 +405,11 @@ def test_bad_type(self, flat_list):
with pytest.raises(TypeError, match=err):
PlotData(flat_list, {})

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

variables = {"x": "x", "y": "z", "color": "a"}
p = PlotData(mock_long_df, variables)
Expand All @@ -419,21 +420,22 @@ def test_data_interchange(self, mock_long_df, long_df):
for var, col in variables.items():
assert_vector_equal(p.frame[var], long_df[col])

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange_failure(self, mock_long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

mock_long_df._data = None # Break __dataframe__()
mock_long_df.__arrow_c_stream__ = lambda _x: 1 / 0 # Break __arrow_c_stream__()
with pytest.raises(RuntimeError, match="Encountered an exception"):
PlotData(mock_long_df, {"x": "x"})

@pytest.mark.skipif(
condition=hasattr(pd.api, "interchange"),
reason="Tests graceful failure without support for dataframe interchange"
)
def test_data_interchange_support_test(self, mock_long_df):
pyarrow = pytest.importorskip('pyarrow')
if Version(pyarrow.__version__) >= Version('14.0.0'):
pytest.skip(
reason="Tests graceful failure without support for PyCapsule Interface"
)

with pytest.raises(TypeError, match="Support for non-pandas DataFrame"):
PlotData(mock_long_df, {"x": "x"})
8 changes: 4 additions & 4 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def test_positional_x(self, long_df):
assert p._data.source_data is None
assert list(p._data.source_vars) == ["x"]

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_positional_interchangeable_dataframe(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

p = Plot(mock_long_df, x="x")
assert_frame_equal(p._data.source_data, long_df)
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,12 @@ class MockInterchangeableDataFrame:
def __init__(self, data):
self._data = data

def __dataframe__(self, *args, **kwargs):
return self._data.__dataframe__(*args, **kwargs)
def __arrow_c_stream__(self, *args, **kwargs):
return self._data.__arrow_c_stream__()


@pytest.fixture
def mock_long_df(long_df):
import pyarrow

return MockInterchangeableDataFrame(long_df)
return MockInterchangeableDataFrame(pyarrow.Table.from_pandas(long_df))
16 changes: 8 additions & 8 deletions tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,11 @@ def test_tick_params(self):
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
assert tick.get_pad() == pad

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

g = ag.FacetGrid(mock_long_df, col="a", row="b")
g.map(scatterplot, "x", "y")
Expand Down Expand Up @@ -1477,11 +1477,11 @@ def test_tick_params(self):
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
assert tick.get_pad() == pad

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

g = ag.PairGrid(mock_long_df, vars=["x", "y", "z"], hue="a")
g.map(scatterplot)
Expand Down

0 comments on commit 0bd8507

Please sign in to comment.