diff --git a/seaborn/_base.py b/seaborn/_base.py index 8b890252d4..a3861aed86 100644 --- a/seaborn/_base.py +++ b/seaborn/_base.py @@ -11,14 +11,15 @@ import pandas as pd import matplotlib as mpl -from ._decorators import ( +from seaborn._core.data import PlotData +from seaborn._decorators import ( share_init_params_with_map, ) -from .palettes import ( +from seaborn.palettes import ( QUAL_PALETTES, color_palette, ) -from .utils import ( +from seaborn.utils import ( _check_argument, desaturate, get_color_cycle, @@ -698,23 +699,26 @@ def assign_variables(self, data=None, variables={}): if x is None and y is None: self.input_format = "wide" - plot_data, variables = self._assign_variables_wideform( + frame, names = self._assign_variables_wideform( data, **variables, ) else: + # When dealing with long-form input, use the newer PlotData + # object (internal but introduced for the objects interface) + # to centralize / standardize data consumption logic. self.input_format = "long" - plot_data, variables = self._assign_variables_longform( - data, **variables, - ) + plot_data = PlotData(data, variables) + frame = plot_data.frame + names = plot_data.names - self.plot_data = plot_data - self.variables = variables + self.plot_data = frame + self.variables = names self.var_types = { v: variable_type( - plot_data[v], + frame[v], boolean_type="numeric" if v in "xy" else "categorical" ) - for v in variables + for v in names } return self @@ -861,120 +865,6 @@ def _assign_variables_wideform(self, data=None, **kwargs): return plot_data, variables - def _assign_variables_longform(self, data=None, **kwargs): - """Define plot variables given long-form data and/or vector inputs. - - Parameters - ---------- - data : dict-like collection of vectors - Input data where variable names map to vector values. - kwargs : variable -> data mappings - Keys are seaborn variables (x, y, hue, ...) and values are vectors - in any format that can construct a :class:`pandas.DataFrame` or - names of columns or index levels in ``data``. - - Returns - ------- - plot_data : :class:`pandas.DataFrame` - Long-form data object mapping seaborn variables (x, y, hue, ...) - to data vectors. - variables : dict - Keys are defined seaborn variables; values are names inferred from - the inputs (or None when no name can be determined). - - Raises - ------ - ValueError - When variables are strings that don't appear in ``data``. - - """ - plot_data = {} - variables = {} - - # Data is optional; all variables can be defined as vectors - if data is None: - data = {} - - # TODO should we try a data.to_dict() or similar here to more - # generally accept objects with that interface? - # Note that dict(df) also works for pandas, and gives us what we - # want, whereas DataFrame.to_dict() gives a nested dict instead of - # a dict of series. - - # Variables can also be extracted from the index attribute - # TODO is this the most general way to enable it? - # There is no index.to_dict on multiindex, unfortunately - try: - index = data.index.to_frame() - except AttributeError: - index = {} - - # The caller will determine the order of variables in plot_data - for key, val in kwargs.items(): - - # First try to treat the argument as a key for the data collection. - # But be flexible about what can be used as a key. - # Usually it will be a string, but allow numbers or tuples too when - # taking from the main data object. Only allow strings to reference - # fields in the index, because otherwise there is too much ambiguity. - try: - val_as_data_key = ( - val in data - or (isinstance(val, (str, bytes)) and val in index) - ) - except (KeyError, TypeError): - val_as_data_key = False - - if val_as_data_key: - - # We know that __getitem__ will work - - if val in data: - plot_data[key] = data[val] - elif val in index: - plot_data[key] = index[val] - variables[key] = val - - elif isinstance(val, (str, bytes)): - - # This looks like a column name but we don't know what it means! - - err = f"Could not interpret value `{val}` for parameter `{key}`" - raise ValueError(err) - - else: - - # Otherwise, assume the value is itself data - - # Raise when data object is present and a vector can't matched - if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): - if np.ndim(val) and len(data) != len(val): - val_cls = val.__class__.__name__ - err = ( - f"Length of {val_cls} vectors must match length of `data`" - f" when both are used, but `data` has length {len(data)}" - f" and the vector passed to `{key}` has length {len(val)}." - ) - raise ValueError(err) - - plot_data[key] = val - - # Try to infer the name of the variable - variables[key] = getattr(val, "name", None) - - # Construct a tidy plot DataFrame. This will convert a number of - # types automatically, aligning on index in case of pandas objects - plot_data = pd.DataFrame(plot_data) - - # Reduce the variables dictionary to fields with valid data - variables = { - var: name - for var, name in variables.items() - if plot_data[var].notnull().any() - } - - return plot_data, variables - def iter_data( self, grouping_vars=None, *, reverse=False, from_comp_data=False, diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index 535fafe83f..c17bfe95c5 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -5,11 +5,13 @@ from collections.abc import Mapping, Sized from typing import cast +import warnings import pandas as pd from pandas import DataFrame from seaborn._core.typing import DataSource, VariableSpec, ColumnName +from seaborn.utils import _version_predates class PlotData: @@ -52,13 +54,19 @@ def __init__( variables: dict[str, VariableSpec], ): + data = handle_data_source(data) frame, names, ids = self._assign_variables(data, variables) self.frame = frame self.names = names self.ids = ids - self.frames = {} # TODO this is a hack, remove + # The reason we possibly have a dictionary of frames is to support the + # Plot.pair operation, post scaling, where each x/y variable needs its + # own frame. This feels pretty clumsy and there are a bunch of places in + # the client code with awkard if frame / elif frames constructions. + # It would be great to have a cleaner abstraction here. + self.frames = {} self.source_data = data self.source_vars = variables @@ -75,7 +83,7 @@ def join( variables: dict[str, VariableSpec] | None, ) -> PlotData: """Add, replace, or drop variables and return as a new dataset.""" - # Inherit the original source of the upsteam data by default + # Inherit the original source of the upstream data by default if data is None: data = self.source_data @@ -118,7 +126,7 @@ def join( def _assign_variables( self, - data: DataSource, + data: DataFrame | Mapping | None, variables: dict[str, VariableSpec], ) -> tuple[DataFrame, dict[str, str | None], dict[str, str | int]]: """ @@ -147,6 +155,8 @@ def _assign_variables( Raises ------ + TypeError + When data source is not a DataFrame or Mapping. ValueError When variables are strings that don't appear in `data`, or when they are non-indexed vector datatypes that have a different length from `data`. @@ -162,15 +172,12 @@ def _assign_variables( ids = {} given_data = data is not None - if data is not None: - source_data = data - else: + if data is None: # Data is optional; all variables can be defined as vectors # But simplify downstream code by always having a usable source data object source_data = {} - - # TODO Generally interested in accepting a generic DataFrame interface - # Track https://data-apis.org/ for development + else: + source_data = data # Variables can also be extracted from the index of a DataFrame if isinstance(source_data, pd.DataFrame): @@ -258,3 +265,55 @@ def _assign_variables( frame = pd.DataFrame(plot_data) return frame, names, ids + + +def handle_data_source(data: object) -> pd.DataFrame | Mapping | None: + """Convert the data source object to a common union representation.""" + if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"): + # Check for pd.DataFrame inheritance could be removed once + # minimal pandas version supports dataframe interchange (1.5.0). + data = convert_dataframe_to_pandas(data) + elif data is not None and not isinstance(data, Mapping): + err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}." + raise TypeError(err) + + return data + + +def convert_dataframe_to_pandas(data: object) -> pd.DataFrame: + """Use the DataFrame exchange protocol, or fail gracefully.""" + if isinstance(data, pd.DataFrame): + return data + + if not hasattr(pd.api, "interchange"): + msg = ( + "Support for non-pandas DataFrame objects requires a version of pandas " + "that implements the DataFrame interchange protocol. Please upgrade " + "your pandas version or coerce your data to pandas before passing " + "it to seaborn." + ) + raise TypeError(msg) + + if _version_predates(pd, "2.0.2"): + msg = ( + "DataFrame interchange with pandas<2.0.2 has some known issues. " + f"You are using pandas {pd.__version__}. " + "Continuing, but it is recommended to carefully inspect the results and to " + "consider upgrading." + ) + warnings.warn(msg, stacklevel=2) + + try: + # This is going to convert all columns in the input dataframe, even though + # we may only need one or two of them. It would be more efficient to select + # the columns that are going to be used in the plot prior to interchange. + # Solving that in general is a hard problem, especially with the objects + # interface where variables passed in Plot() may only be referenced later + # in Plot.add(). But noting here in case this seems to be a bottleneck. + return pd.api.interchange.from_dataframe(data) + except Exception as err: + msg = ( + "Encountered an exception when converting data source " + "to a pandas DataFrame. See traceback above for details." + ) + raise RuntimeError(msg) from err diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 2ec7b6ecb8..d07e99bc42 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -346,9 +346,10 @@ def _resolve_positionals( err = "Plot() accepts no more than 3 positional arguments (data, x, y)." raise TypeError(err) - # TODO need some clearer way to differentiate data / vector here - # (There might be an abstract DataFrame class to use here?) - if isinstance(args[0], (abc.Mapping, pd.DataFrame)): + if ( + isinstance(args[0], (abc.Mapping, pd.DataFrame)) + or hasattr(args[0], "__dataframe__") + ): if data is not None: raise TypeError("`data` given by both name and position.") data, args = args[0], args[1:] diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index cc3522d48d..9bdf8a6ef8 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -1,11 +1,11 @@ from __future__ import annotations +from collections.abc import Iterable, Mapping from datetime import date, datetime, timedelta -from typing import Any, Optional, Union, Mapping, Tuple, List, Dict -from collections.abc import Hashable, Iterable +from typing import Any, Optional, Union, Tuple, List, Dict from numpy import ndarray # TODO use ArrayLike? -from pandas import DataFrame, Series, Index, Timestamp, Timedelta +from pandas import Series, Index, Timestamp, Timedelta from matplotlib.colors import Colormap, Normalize @@ -17,7 +17,11 @@ VariableSpec = Union[ColumnName, Vector, None] VariableSpecList = Union[List[VariableSpec], Index, None] -DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] +# A DataSource can be an object implementing __dataframe__, or a Mapping +# (and is optional in all contexts where it is used). +# I don't think there's an abc for "has __dataframe__", so we type as object +# but keep the (slightly odd) Union alias for better user-facing annotations. +DataSource = Union[object, Mapping, None] OrderSpec = Union[Iterable, None] # TODO technically str is iterable NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 69a4e37ff9..7b0e51db06 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -10,6 +10,7 @@ import matplotlib.pyplot as plt from ._base import VectorPlotter, variable_type, categorical_order +from ._core.data import handle_data_source from ._compat import share_axis, get_legend_handles from . import utils from .utils import ( @@ -374,6 +375,7 @@ def __init__( ): super().__init__() + data = handle_data_source(data) # Determine the hue facet layer information hue_var = hue @@ -1240,6 +1242,7 @@ def __init__( """ super().__init__() + data = handle_data_source(data) # Sort out the variables that define the grid numeric_cols = self._find_numeric_cols(data) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index c18f3a109e..5e9d8c709c 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -2593,10 +2593,10 @@ def countplot( if x is None and y is not None: orient = "y" - x = 1 + x = 1 if list(y) else None elif x is not None and y is None: orient = "x" - y = 1 + y = 1 if list(x) else None elif x is not None and y is not None: raise TypeError("Cannot pass values for both `x` and `y`.") diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 8cf7f231f8..c8c6b5f057 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -2012,6 +2012,7 @@ def rugplot( x = data elif axis == "y": y = data + data = None msg = textwrap.dedent(f"""\n The `axis` parameter has been deprecated; use the `{axis}` parameter instead. Please update your code; this will become an error in seaborn v0.13.0. diff --git a/tests/_core/test_data.py b/tests/_core/test_data.py index aeffc9ed7d..0e67ed37b4 100644 --- a/tests/_core/test_data.py +++ b/tests/_core/test_data.py @@ -144,7 +144,7 @@ def test_frame_and_vector_mismatched_lengths(self, long_df): PlotData(long_df, {"x": "x", "y": vector}) @pytest.mark.parametrize( - "arg", [[], np.array([]), pd.DataFrame()], + "arg", [{}, pd.DataFrame()], ) def test_empty_data_input(self, arg): @@ -397,3 +397,43 @@ def test_join_multiple_inherits_from_orig(self, rng): p = PlotData(d1, {"x": "a"}).join(d2, {"y": "a"}).join(None, {"y": "a"}) assert_vector_equal(p.frame["x"], d1["a"]) assert_vector_equal(p.frame["y"], d1["a"]) + + def test_bad_type(self, flat_list): + + err = "Data source must be a DataFrame or Mapping" + with pytest.raises(TypeError, match=err): + PlotData(flat_list, {}) + + @pytest.mark.skipif( + condition=not hasattr(pd.api, "interchange"), + reason="Tests behavior assuming support for dataframe interchange" + ) + def test_data_interchange(self, mock_long_df, long_df): + + variables = {"x": "x", "y": "z", "color": "a"} + p = PlotData(mock_long_df, variables) + for var, col in variables.items(): + assert_vector_equal(p.frame[var], long_df[col]) + + p = PlotData(mock_long_df, {**variables, "color": long_df["a"]}) + for var, col in variables.items(): + assert_vector_equal(p.frame[var], long_df[col]) + + @pytest.mark.skipif( + condition=not hasattr(pd.api, "interchange"), + reason="Tests behavior assuming support for dataframe interchange" + ) + def test_data_interchange_failure(self, mock_long_df): + + mock_long_df._data = None # Break __dataframe__() + with pytest.raises(RuntimeError, match="Encountered an exception"): + PlotData(mock_long_df, {"x": "x"}) + + @pytest.mark.skipif( + condition=hasattr(pd.api, "interchange"), + reason="Tests graceful failure without support for dataframe interchange" + ) + def test_data_interchange_support_test(self, mock_long_df): + + with pytest.raises(TypeError, match="Support for non-pandas DataFrame"): + PlotData(mock_long_df, {"x": "x"}) diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index 355537f71e..bf69864cfb 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -170,6 +170,15 @@ def test_positional_x(self, long_df): assert p._data.source_data is None assert list(p._data.source_vars) == ["x"] + @pytest.mark.skipif( + condition=not hasattr(pd.api, "interchange"), + reason="Tests behavior assuming support for dataframe interchange" + ) + def test_positional_interchangeable_dataframe(self, mock_long_df, long_df): + + p = Plot(mock_long_df, x="x") + assert_frame_equal(p._data.source_data, long_df) + def test_positional_too_many(self, long_df): err = r"Plot\(\) accepts no more than 3 positional arguments \(data, x, y\)" diff --git a/tests/conftest.py b/tests/conftest.py index 01d93a4941..9f9eef49dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -178,3 +178,19 @@ def object_df(rng, long_df): def null_series(flat_series): return pd.Series(index=flat_series.index, dtype='float64') + + +class MockInterchangeableDataFrame: + # Mock object that is not a pandas.DataFrame but that can + # be converted to one via the DataFrame exchange protocol + def __init__(self, data): + self._data = data + + def __dataframe__(self, *args, **kwargs): + return self._data.__dataframe__(*args, **kwargs) + + +@pytest.fixture +def mock_long_df(long_df): + + return MockInterchangeableDataFrame(long_df) diff --git a/tests/test_axisgrid.py b/tests/test_axisgrid.py index 1390fc0522..30f930aead 100644 --- a/tests/test_axisgrid.py +++ b/tests/test_axisgrid.py @@ -707,6 +707,19 @@ def test_tick_params(self): assert mpl.colors.same_color(tick.tick2line.get_color(), color) assert tick.get_pad() == pad + @pytest.mark.skipif( + condition=not hasattr(pd.api, "interchange"), + reason="Tests behavior assuming support for dataframe interchange" + ) + def test_data_interchange(self, mock_long_df, long_df): + + g = ag.FacetGrid(mock_long_df, col="a", row="b") + g.map(scatterplot, "x", "y") + + assert g.axes.shape == (long_df["b"].nunique(), long_df["a"].nunique()) + for ax in g.axes.flat: + assert len(ax.collections) == 1 + class TestPairGrid: @@ -1462,6 +1475,18 @@ def test_tick_params(self): assert mpl.colors.same_color(tick.tick2line.get_color(), color) assert tick.get_pad() == pad + @pytest.mark.skipif( + condition=not hasattr(pd.api, "interchange"), + reason="Tests behavior assuming support for dataframe interchange" + ) + def test_data_interchange(self, mock_long_df, long_df): + + g = ag.PairGrid(mock_long_df, vars=["x", "y", "z"], hue="a") + g.map(scatterplot) + assert g.axes.shape == (3, 3) + for ax in g.axes.flat: + assert len(ax.collections) == long_df["a"].nunique() + 1 + class TestJointGrid: diff --git a/tests/test_base.py b/tests/test_base.py index c29e8a0631..e6cfd8c123 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -757,7 +757,7 @@ def test_long_numeric_name(self, long_df, name): p = VectorPlotter() p.assign_variables(data=long_df, variables={"x": name}) assert_array_equal(p.plot_data["x"], long_df[name]) - assert p.variables["x"] == name + assert p.variables["x"] == str(name) def test_long_hierarchical_index(self, rng): @@ -771,7 +771,7 @@ def test_long_hierarchical_index(self, rng): p = VectorPlotter() p.assign_variables(data=df, variables={var: name}) assert_array_equal(p.plot_data[var], df[name]) - assert p.variables[var] == name + assert p.variables[var] == str(name) def test_long_scalar_and_data(self, long_df): @@ -788,7 +788,7 @@ def test_wide_semantic_error(self, wide_df): def test_long_unknown_error(self, long_df): - err = "Could not interpret value `what` for parameter `hue`" + err = "Could not interpret value `what` for `hue`" with pytest.raises(ValueError, match=err): VectorPlotter(data=long_df, variables={"x": "x", "hue": "what"})