diff --git a/ci/deps_pinned.txt b/ci/deps_pinned.txt index 27a51b4043..f1c681018a 100644 --- a/ci/deps_pinned.txt +++ b/ci/deps_pinned.txt @@ -1,5 +1,7 @@ numpy~=1.20.0 pandas~=1.2.0 +polars~=0.17.0 +pyarrow~=12.0.0 matplotlib~=3.3.0 scipy~=1.7.0 statsmodels~=0.12.0 diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index 535fafe83f..4c0e11beb2 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -10,6 +10,7 @@ from pandas import DataFrame from seaborn._core.typing import DataSource, VariableSpec, ColumnName +from seaborn.utils import try_convert_to_pandas class PlotData: @@ -51,7 +52,7 @@ def __init__( data: DataSource, variables: dict[str, VariableSpec], ): - + data = try_convert_to_pandas(data) frame, names, ids = self._assign_variables(data, variables) self.frame = frame diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index a73af9bfd8..0b529ba787 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -42,7 +42,7 @@ from seaborn._compat import set_scale_obj, set_layout_engine from seaborn.rcmod import axes_style, plotting_context from seaborn.palettes import color_palette -from seaborn.utils import _version_predates +from seaborn.utils import _version_predates, try_convert_to_pandas from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: @@ -309,6 +309,7 @@ def __init__( if args: data, variables = self._resolve_positionals(args, data, variables) + data = try_convert_to_pandas(data) unknown = [x for x in variables if x not in PROPERTIES] if unknown: @@ -347,7 +348,10 @@ def _resolve_positionals( # TODO need some clearer way to differentiate data / vector here # (There might be an abstract DataFrame class to use here?) - if isinstance(args[0], (abc.Mapping, pd.DataFrame)): + if ( + isinstance(args[0], (abc.Mapping, pd.DataFrame)) + or hasattr(args[0], '__dataframe__') + ): if data is not None: raise TypeError("`data` given by both name and position.") data, args = args[0], args[1:] diff --git a/seaborn/_oldcore.py b/seaborn/_oldcore.py index 9bfebccc20..f427bd3f4b 100644 --- a/seaborn/_oldcore.py +++ b/seaborn/_oldcore.py @@ -775,7 +775,7 @@ def _assign_variables_wideform(self, data=None, **kwargs): # (Could be accomplished with a more general to_series() interface) flat_data = pd.Series(data).copy() names = { - "@values": flat_data.name, + "@values": getattr(data, 'name', None), "@index": flat_data.index.name } @@ -922,7 +922,7 @@ def _assign_variables_longform(self, data=None, **kwargs): val in data or (isinstance(val, (str, bytes)) and val in index) ) - except (KeyError, TypeError): + except (KeyError, TypeError, ValueError): val_as_data_key = False if val_as_data_key: diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 7534909920..9e58641699 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -372,7 +372,7 @@ def __init__( margin_titles=False, xlim=None, ylim=None, subplot_kws=None, gridspec_kws=None, ): - + data = utils.try_convert_to_pandas(data) super().__init__() # Determine the hue facet layer information @@ -1238,7 +1238,7 @@ def __init__( .. include:: ../docstrings/PairGrid.rst """ - + data = utils.try_convert_to_pandas(data) super().__init__() # Sort out the variables that define the grid @@ -2087,6 +2087,8 @@ def pairplot( # Avoid circular import from .distributions import histplot, kdeplot + data = utils.try_convert_to_pandas(data) + # Handle deprecations if size is not None: height = size diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 95a7243f43..11094a5823 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -2490,7 +2490,6 @@ def stripplot( hue_norm=None, native_scale=False, formatter=None, legend="auto", ax=None, **kwargs ): - p = _CategoricalPlotterNew( data=data, variables=_CategoricalPlotterNew.get_semantics(locals()), @@ -2618,6 +2617,7 @@ def swarmplot( ax=None, **kwargs ): + data = utils.try_convert_to_pandas(data) p = _CategoricalPlotterNew( data=data, variables=_CategoricalPlotterNew.get_semantics(locals()), @@ -3162,6 +3162,7 @@ def catplot( margin_titles=False, facet_kws=None, ci="deprecated", **kwargs ): + data = utils.try_convert_to_pandas(data) # Determine the plotting function try: diff --git a/seaborn/distributions.py b/seaborn/distributions.py index b1a4c9da37..febf859624 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -33,6 +33,7 @@ _check_argument, _assign_default_kwargs, _default_color, + try_convert_to_pandas, ) from .palettes import color_palette from .external import husl @@ -1392,6 +1393,7 @@ def histplot( # Other appearance keywords **kwargs, ): + data = try_convert_to_pandas(data) p = _DistributionPlotter( data=data, @@ -2123,6 +2125,7 @@ def displot( **kwargs, ): + data = try_convert_to_pandas(data) p = _DistributionFacetPlotter( data=data, variables=_DistributionFacetPlotter.get_semantics(locals()) diff --git a/seaborn/regression.py b/seaborn/regression.py index c6b81a1727..8b03c1c9bb 100644 --- a/seaborn/regression.py +++ b/seaborn/regression.py @@ -575,6 +575,7 @@ def lmplot( truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None, line_kws=None, facet_kws=None, ): + data = utils.try_convert_to_pandas(data) if facet_kws is None: facet_kws = {} diff --git a/seaborn/relational.py b/seaborn/relational.py index de3cf68348..5a2f83626d 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -13,6 +13,7 @@ adjust_legend_subtitles, _default_color, _deprecate_ci, + try_convert_to_pandas, ) from ._statistics import EstimateAggregator from .axisgrid import FacetGrid, _facet_docs @@ -799,7 +800,7 @@ def relplot( legend="auto", kind="scatter", height=5, aspect=1, facet_kws=None, **kwargs ): - + data = try_convert_to_pandas(data) if kind == "scatter": plotter = _ScatterPlotter diff --git a/seaborn/utils.py b/seaborn/utils.py index c5acc5e28f..3877fff18b 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -1,4 +1,6 @@ """Utility functions, mostly for internal use.""" +from __future__ import annotations + import os import inspect import warnings @@ -894,3 +896,19 @@ def _disable_autolayout(): def _version_predates(lib: ModuleType, version: str) -> bool: """Helper function for checking version compatibility.""" return Version(lib.__version__) < Version(version) + + +def try_convert_to_pandas(data: object | None) -> pd.DataFrame: + if data is None: + return None + elif isinstance(data, pd.DataFrame): + return data + elif hasattr(data, "__dataframe__") and _version_predates(pd, "2.0.2"): + msg = ( + "Interchanging to pandas requires at least pandas version '2.0.2'. " + "Please upgrade pandas to at least version '2.0.2'." + ) + raise RuntimeError(msg) + elif hasattr(data, "__dataframe__"): + return pd.api.interchange.from_dataframe(data) + return data diff --git a/tests/test_utils.py b/tests/test_utils.py index 28c836e999..ae74ecf4b0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -433,6 +433,18 @@ def test_move_legend_input_checks(): def check_load_dataset(name): ds = load_dataset(name, cache=False) assert isinstance(ds, pd.DataFrame) + # Check that the example datasets can actually be interchanged. + try: + import polars as pl + except ModuleNotFoundError: + pass + else: + if _version_predates(pd, '2.0.2'): + with pytest.raises(RuntimeError, match='Please upgrade pandas'): + utils.try_convert_to_pandas(pl.from_pandas(ds)) + else: + ds = utils.try_convert_to_pandas(pl.from_pandas(ds)) + assert isinstance(ds, pd.DataFrame) def check_load_cached_dataset(name):