Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support interchange protocol #3340

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
142086a
support interchange protocol
Apr 19, 2023
55df47b
raise if trying to interchange before pd 2.0.2
May 4, 2023
194a564
Merge remote-tracking branch 'upstream/master' into support-interchan…
May 7, 2023
088cb08
revert temporary change
May 7, 2023
03c1717
simplify
May 7, 2023
7dd9ff6
fixup
May 7, 2023
bbfdadb
Merge remote-tracking branch 'upstream/master' into support-interchan…
MarcoGorelli May 19, 2023
ad48c8a
try adding polars workflow
MarcoGorelli May 19, 2023
a0bd3f7
3.10
MarcoGorelli May 19, 2023
8edcf14
try fixup;
MarcoGorelli May 19, 2023
22df733
include pyarrow install
MarcoGorelli May 19, 2023
fa37b56
pandas nightly
MarcoGorelli May 19, 2023
b5c4ff8
wip
MarcoGorelli May 19, 2023
064b2e6
fixup
MarcoGorelli May 19, 2023
3f32596
reduce dependency to pandas 2.0.1
MarcoGorelli May 19, 2023
f4f3317
test that all load_dataset examples can actually interchange
MarcoGorelli May 19, 2023
1103aa9
better msg
MarcoGorelli May 20, 2023
338e119
coverage
MarcoGorelli May 20, 2023
5a44bed
pyarrow
MarcoGorelli May 20, 2023
73f35d5
fix deps
MarcoGorelli May 20, 2023
63c21ee
gotta remember pyarrow
MarcoGorelli May 20, 2023
0e9586f
wip
MarcoGorelli May 20, 2023
9f1927f
wip
MarcoGorelli May 20, 2023
9c4aedb
wip
MarcoGorelli May 20, 2023
e7e84f5
wip
MarcoGorelli May 20, 2023
30e1002
wip
MarcoGorelli May 20, 2023
1985028
increase test coverage even more
MarcoGorelli May 20, 2023
b8584ee
pre-commit run -a
MarcoGorelli May 20, 2023
5b2532e
skip estimateaggregator tests for the polars fixtures
MarcoGorelli May 20, 2023
4897344
simplify
MarcoGorelli May 21, 2023
8117fe6
convert as soon as possible
MarcoGorelli May 21, 2023
3812e5f
try convert in facetgrid
MarcoGorelli May 21, 2023
6494ef4
convert in pairgrid
MarcoGorelli May 21, 2023
991c343
remove separate workflow;
MarcoGorelli May 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ci/deps_pinned.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
numpy~=1.20.0
pandas~=1.2.0
polars~=0.17.0
pyarrow~=12.0.0
matplotlib~=3.3.0
scipy~=1.7.0
statsmodels~=0.12.0
3 changes: 2 additions & 1 deletion seaborn/_core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pandas import DataFrame

from seaborn._core.typing import DataSource, VariableSpec, ColumnName
from seaborn.utils import try_convert_to_pandas


class PlotData:
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
data: DataSource,
variables: dict[str, VariableSpec],
):

data = try_convert_to_pandas(data)
frame, names, ids = self._assign_variables(data, variables)

self.frame = frame
Expand Down
8 changes: 6 additions & 2 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from seaborn._compat import set_scale_obj, set_layout_engine
from seaborn.rcmod import axes_style, plotting_context
from seaborn.palettes import color_palette
from seaborn.utils import _version_predates
from seaborn.utils import _version_predates, try_convert_to_pandas

from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
Expand Down Expand Up @@ -309,6 +309,7 @@ def __init__(

if args:
data, variables = self._resolve_positionals(args, data, variables)
data = try_convert_to_pandas(data)

unknown = [x for x in variables if x not in PROPERTIES]
if unknown:
Expand Down Expand Up @@ -347,7 +348,10 @@ def _resolve_positionals(

# TODO need some clearer way to differentiate data / vector here
# (There might be an abstract DataFrame class to use here?)
if isinstance(args[0], (abc.Mapping, pd.DataFrame)):
if (
isinstance(args[0], (abc.Mapping, pd.DataFrame))
or hasattr(args[0], '__dataframe__')
):
if data is not None:
raise TypeError("`data` given by both name and position.")
data, args = args[0], args[1:]
Expand Down
4 changes: 2 additions & 2 deletions seaborn/_oldcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _assign_variables_wideform(self, data=None, **kwargs):
# (Could be accomplished with a more general to_series() interface)
flat_data = pd.Series(data).copy()
names = {
"@values": flat_data.name,
"@values": getattr(data, 'name', None),
"@index": flat_data.index.name
}

Expand Down Expand Up @@ -922,7 +922,7 @@ def _assign_variables_longform(self, data=None, **kwargs):
val in data
or (isinstance(val, (str, bytes)) and val in index)
)
except (KeyError, TypeError):
except (KeyError, TypeError, ValueError):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if val is a non-pandas Series, then checking val in data will throw ValueError

val_as_data_key = False

if val_as_data_key:
Expand Down
6 changes: 4 additions & 2 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def __init__(
margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
gridspec_kws=None,
):

data = utils.try_convert_to_pandas(data)
super().__init__()

# Determine the hue facet layer information
Expand Down Expand Up @@ -1238,7 +1238,7 @@ def __init__(
.. include:: ../docstrings/PairGrid.rst

"""

data = utils.try_convert_to_pandas(data)
super().__init__()

# Sort out the variables that define the grid
Expand Down Expand Up @@ -2087,6 +2087,8 @@ def pairplot(
# Avoid circular import
from .distributions import histplot, kdeplot

data = utils.try_convert_to_pandas(data)

# Handle deprecations
if size is not None:
height = size
Expand Down
3 changes: 2 additions & 1 deletion seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,7 +2490,6 @@ def stripplot(
hue_norm=None, native_scale=False, formatter=None, legend="auto",
ax=None, **kwargs
):

p = _CategoricalPlotterNew(
data=data,
variables=_CategoricalPlotterNew.get_semantics(locals()),
Expand Down Expand Up @@ -2618,6 +2617,7 @@ def swarmplot(
ax=None, **kwargs
):

data = utils.try_convert_to_pandas(data)
p = _CategoricalPlotterNew(
data=data,
variables=_CategoricalPlotterNew.get_semantics(locals()),
Expand Down Expand Up @@ -3162,6 +3162,7 @@ def catplot(
margin_titles=False, facet_kws=None, ci="deprecated",
**kwargs
):
data = utils.try_convert_to_pandas(data)

# Determine the plotting function
try:
Expand Down
3 changes: 3 additions & 0 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_check_argument,
_assign_default_kwargs,
_default_color,
try_convert_to_pandas,
)
from .palettes import color_palette
from .external import husl
Expand Down Expand Up @@ -1392,6 +1393,7 @@ def histplot(
# Other appearance keywords
**kwargs,
):
data = try_convert_to_pandas(data)

p = _DistributionPlotter(
data=data,
Expand Down Expand Up @@ -2123,6 +2125,7 @@ def displot(
**kwargs,
):

data = try_convert_to_pandas(data)
p = _DistributionFacetPlotter(
data=data,
variables=_DistributionFacetPlotter.get_semantics(locals())
Expand Down
1 change: 1 addition & 0 deletions seaborn/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def lmplot(
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
line_kws=None, facet_kws=None,
):
data = utils.try_convert_to_pandas(data)

if facet_kws is None:
facet_kws = {}
Expand Down
3 changes: 2 additions & 1 deletion seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
adjust_legend_subtitles,
_default_color,
_deprecate_ci,
try_convert_to_pandas,
)
from ._statistics import EstimateAggregator
from .axisgrid import FacetGrid, _facet_docs
Expand Down Expand Up @@ -799,7 +800,7 @@ def relplot(
legend="auto", kind="scatter", height=5, aspect=1, facet_kws=None,
**kwargs
):

data = try_convert_to_pandas(data)
if kind == "scatter":

plotter = _ScatterPlotter
Expand Down
18 changes: 18 additions & 0 deletions seaborn/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Utility functions, mostly for internal use."""
from __future__ import annotations

import os
import inspect
import warnings
Expand Down Expand Up @@ -894,3 +896,19 @@ def _disable_autolayout():
def _version_predates(lib: ModuleType, version: str) -> bool:
"""Helper function for checking version compatibility."""
return Version(lib.__version__) < Version(version)


def try_convert_to_pandas(data: object | None) -> pd.DataFrame:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why object | None and not Any?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just because Any turns off the type checker so I tend to avoid using it unless I have to, whereas as object prevents me from making assumptions about what properties the variable might have

if data is None:
return None
elif isinstance(data, pd.DataFrame):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this guard important? will passing a pandas.DataFrame to pd.api.interchange.from_dataframe be costly? I'd expect it to be a no-op without looking closer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, thanks, it is in-fact a no-op

https://github.com/pandas-dev/pandas/blob/360bf218d68c703911731aec58a52b6501b2f4ce/pandas/core/interchange/from_dataframe.py#L48-L49

However, seaborn still supports versions of pandas older than those which support the interchange protocol, so I introduced this to keep it a no-op in such cases

return data
elif hasattr(data, "__dataframe__") and _version_predates(pd, "2.0.2"):
msg = (
"Interchanging to pandas requires at least pandas version '2.0.2'. "
"Please upgrade pandas to at least version '2.0.2'."
)
raise RuntimeError(msg)
elif hasattr(data, "__dataframe__"):
return pd.api.interchange.from_dataframe(data)
return data
12 changes: 12 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,18 @@ def test_move_legend_input_checks():
def check_load_dataset(name):
ds = load_dataset(name, cache=False)
assert isinstance(ds, pd.DataFrame)
# Check that the example datasets can actually be interchanged.
try:
import polars as pl
except ModuleNotFoundError:
pass
else:
if _version_predates(pd, '2.0.2'):
with pytest.raises(RuntimeError, match='Please upgrade pandas'):
utils.try_convert_to_pandas(pl.from_pandas(ds))
else:
ds = utils.try_convert_to_pandas(pl.from_pandas(ds))
assert isinstance(ds, pd.DataFrame)


def check_load_cached_dataset(name):
Expand Down