Skip to content

Commit

Permalink
Support polars and other data libraries via dataframe interchange (#3369
Browse files Browse the repository at this point in the history
)

* Support dataframe interchange in objects interface

* Add warning for dataframe interchange with pandas<2.0.2

* Fix typo

* Move more data source inference into standalone function

* Use PlotData for longform variable assignment in _base and fix simple test consequences

* Handle countplot edgecase

* Add a comment about what's going on here

* Support dataframe interchange in FacetGrid/PairGrid

* Add tests for data interchange in FacetGrid/PairGrid

* Add note about pre-selecting relevant input columns

* Skip interchange tests on pinned pandas
  • Loading branch information
mwaskom authored Aug 23, 2023
1 parent af613f1 commit 58cf628
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 147 deletions.
140 changes: 15 additions & 125 deletions seaborn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import pandas as pd
import matplotlib as mpl

from ._decorators import (
from seaborn._core.data import PlotData
from seaborn._decorators import (
share_init_params_with_map,
)
from .palettes import (
from seaborn.palettes import (
QUAL_PALETTES,
color_palette,
)
from .utils import (
from seaborn.utils import (
_check_argument,
desaturate,
get_color_cycle,
Expand Down Expand Up @@ -698,23 +699,26 @@ def assign_variables(self, data=None, variables={}):

if x is None and y is None:
self.input_format = "wide"
plot_data, variables = self._assign_variables_wideform(
frame, names = self._assign_variables_wideform(
data, **variables,
)
else:
# When dealing with long-form input, use the newer PlotData
# object (internal but introduced for the objects interface)
# to centralize / standardize data consumption logic.
self.input_format = "long"
plot_data, variables = self._assign_variables_longform(
data, **variables,
)
plot_data = PlotData(data, variables)
frame = plot_data.frame
names = plot_data.names

self.plot_data = plot_data
self.variables = variables
self.plot_data = frame
self.variables = names
self.var_types = {
v: variable_type(
plot_data[v],
frame[v],
boolean_type="numeric" if v in "xy" else "categorical"
)
for v in variables
for v in names
}

return self
Expand Down Expand Up @@ -861,120 +865,6 @@ def _assign_variables_wideform(self, data=None, **kwargs):

return plot_data, variables

def _assign_variables_longform(self, data=None, **kwargs):
"""Define plot variables given long-form data and/or vector inputs.
Parameters
----------
data : dict-like collection of vectors
Input data where variable names map to vector values.
kwargs : variable -> data mappings
Keys are seaborn variables (x, y, hue, ...) and values are vectors
in any format that can construct a :class:`pandas.DataFrame` or
names of columns or index levels in ``data``.
Returns
-------
plot_data : :class:`pandas.DataFrame`
Long-form data object mapping seaborn variables (x, y, hue, ...)
to data vectors.
variables : dict
Keys are defined seaborn variables; values are names inferred from
the inputs (or None when no name can be determined).
Raises
------
ValueError
When variables are strings that don't appear in ``data``.
"""
plot_data = {}
variables = {}

# Data is optional; all variables can be defined as vectors
if data is None:
data = {}

# TODO should we try a data.to_dict() or similar here to more
# generally accept objects with that interface?
# Note that dict(df) also works for pandas, and gives us what we
# want, whereas DataFrame.to_dict() gives a nested dict instead of
# a dict of series.

# Variables can also be extracted from the index attribute
# TODO is this the most general way to enable it?
# There is no index.to_dict on multiindex, unfortunately
try:
index = data.index.to_frame()
except AttributeError:
index = {}

# The caller will determine the order of variables in plot_data
for key, val in kwargs.items():

# First try to treat the argument as a key for the data collection.
# But be flexible about what can be used as a key.
# Usually it will be a string, but allow numbers or tuples too when
# taking from the main data object. Only allow strings to reference
# fields in the index, because otherwise there is too much ambiguity.
try:
val_as_data_key = (
val in data
or (isinstance(val, (str, bytes)) and val in index)
)
except (KeyError, TypeError):
val_as_data_key = False

if val_as_data_key:

# We know that __getitem__ will work

if val in data:
plot_data[key] = data[val]
elif val in index:
plot_data[key] = index[val]
variables[key] = val

elif isinstance(val, (str, bytes)):

# This looks like a column name but we don't know what it means!

err = f"Could not interpret value `{val}` for parameter `{key}`"
raise ValueError(err)

else:

# Otherwise, assume the value is itself data

# Raise when data object is present and a vector can't matched
if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series):
if np.ndim(val) and len(data) != len(val):
val_cls = val.__class__.__name__
err = (
f"Length of {val_cls} vectors must match length of `data`"
f" when both are used, but `data` has length {len(data)}"
f" and the vector passed to `{key}` has length {len(val)}."
)
raise ValueError(err)

plot_data[key] = val

# Try to infer the name of the variable
variables[key] = getattr(val, "name", None)

# Construct a tidy plot DataFrame. This will convert a number of
# types automatically, aligning on index in case of pandas objects
plot_data = pd.DataFrame(plot_data)

# Reduce the variables dictionary to fields with valid data
variables = {
var: name
for var, name in variables.items()
if plot_data[var].notnull().any()
}

return plot_data, variables

def iter_data(
self, grouping_vars=None, *,
reverse=False, from_comp_data=False,
Expand Down
77 changes: 68 additions & 9 deletions seaborn/_core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from collections.abc import Mapping, Sized
from typing import cast
import warnings

import pandas as pd
from pandas import DataFrame

from seaborn._core.typing import DataSource, VariableSpec, ColumnName
from seaborn.utils import _version_predates


class PlotData:
Expand Down Expand Up @@ -52,13 +54,19 @@ def __init__(
variables: dict[str, VariableSpec],
):

data = handle_data_source(data)
frame, names, ids = self._assign_variables(data, variables)

self.frame = frame
self.names = names
self.ids = ids

self.frames = {} # TODO this is a hack, remove
# The reason we possibly have a dictionary of frames is to support the
# Plot.pair operation, post scaling, where each x/y variable needs its
# own frame. This feels pretty clumsy and there are a bunch of places in
# the client code with awkard if frame / elif frames constructions.
# It would be great to have a cleaner abstraction here.
self.frames = {}

self.source_data = data
self.source_vars = variables
Expand All @@ -75,7 +83,7 @@ def join(
variables: dict[str, VariableSpec] | None,
) -> PlotData:
"""Add, replace, or drop variables and return as a new dataset."""
# Inherit the original source of the upsteam data by default
# Inherit the original source of the upstream data by default
if data is None:
data = self.source_data

Expand Down Expand Up @@ -118,7 +126,7 @@ def join(

def _assign_variables(
self,
data: DataSource,
data: DataFrame | Mapping | None,
variables: dict[str, VariableSpec],
) -> tuple[DataFrame, dict[str, str | None], dict[str, str | int]]:
"""
Expand Down Expand Up @@ -147,6 +155,8 @@ def _assign_variables(
Raises
------
TypeError
When data source is not a DataFrame or Mapping.
ValueError
When variables are strings that don't appear in `data`, or when they are
non-indexed vector datatypes that have a different length from `data`.
Expand All @@ -162,15 +172,12 @@ def _assign_variables(
ids = {}

given_data = data is not None
if data is not None:
source_data = data
else:
if data is None:
# Data is optional; all variables can be defined as vectors
# But simplify downstream code by always having a usable source data object
source_data = {}

# TODO Generally interested in accepting a generic DataFrame interface
# Track https://data-apis.org/ for development
else:
source_data = data

# Variables can also be extracted from the index of a DataFrame
if isinstance(source_data, pd.DataFrame):
Expand Down Expand Up @@ -258,3 +265,55 @@ def _assign_variables(
frame = pd.DataFrame(plot_data)

return frame, names, ids


def handle_data_source(data: object) -> pd.DataFrame | Mapping | None:
"""Convert the data source object to a common union representation."""
if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"):
# Check for pd.DataFrame inheritance could be removed once
# minimal pandas version supports dataframe interchange (1.5.0).
data = convert_dataframe_to_pandas(data)
elif data is not None and not isinstance(data, Mapping):
err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}."
raise TypeError(err)

return data


def convert_dataframe_to_pandas(data: object) -> pd.DataFrame:
"""Use the DataFrame exchange protocol, or fail gracefully."""
if isinstance(data, pd.DataFrame):
return data

if not hasattr(pd.api, "interchange"):
msg = (
"Support for non-pandas DataFrame objects requires a version of pandas "
"that implements the DataFrame interchange protocol. Please upgrade "
"your pandas version or coerce your data to pandas before passing "
"it to seaborn."
)
raise TypeError(msg)

if _version_predates(pd, "2.0.2"):
msg = (
"DataFrame interchange with pandas<2.0.2 has some known issues. "
f"You are using pandas {pd.__version__}. "
"Continuing, but it is recommended to carefully inspect the results and to "
"consider upgrading."
)
warnings.warn(msg, stacklevel=2)

try:
# This is going to convert all columns in the input dataframe, even though
# we may only need one or two of them. It would be more efficient to select
# the columns that are going to be used in the plot prior to interchange.
# Solving that in general is a hard problem, especially with the objects
# interface where variables passed in Plot() may only be referenced later
# in Plot.add(). But noting here in case this seems to be a bottleneck.
return pd.api.interchange.from_dataframe(data)
except Exception as err:
msg = (
"Encountered an exception when converting data source "
"to a pandas DataFrame. See traceback above for details."
)
raise RuntimeError(msg) from err
7 changes: 4 additions & 3 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,10 @@ def _resolve_positionals(
err = "Plot() accepts no more than 3 positional arguments (data, x, y)."
raise TypeError(err)

# TODO need some clearer way to differentiate data / vector here
# (There might be an abstract DataFrame class to use here?)
if isinstance(args[0], (abc.Mapping, pd.DataFrame)):
if (
isinstance(args[0], (abc.Mapping, pd.DataFrame))
or hasattr(args[0], "__dataframe__")
):
if data is not None:
raise TypeError("`data` given by both name and position.")
data, args = args[0], args[1:]
Expand Down
12 changes: 8 additions & 4 deletions seaborn/_core/typing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping
from datetime import date, datetime, timedelta
from typing import Any, Optional, Union, Mapping, Tuple, List, Dict
from collections.abc import Hashable, Iterable
from typing import Any, Optional, Union, Tuple, List, Dict

from numpy import ndarray # TODO use ArrayLike?
from pandas import DataFrame, Series, Index, Timestamp, Timedelta
from pandas import Series, Index, Timestamp, Timedelta
from matplotlib.colors import Colormap, Normalize


Expand All @@ -17,7 +17,11 @@
VariableSpec = Union[ColumnName, Vector, None]
VariableSpecList = Union[List[VariableSpec], Index, None]

DataSource = Union[DataFrame, Mapping[Hashable, Vector], None]
# A DataSource can be an object implementing __dataframe__, or a Mapping
# (and is optional in all contexts where it is used).
# I don't think there's an abc for "has __dataframe__", so we type as object
# but keep the (slightly odd) Union alias for better user-facing annotations.
DataSource = Union[object, Mapping, None]

OrderSpec = Union[Iterable, None] # TODO technically str is iterable
NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None]
Expand Down
3 changes: 3 additions & 0 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import matplotlib.pyplot as plt

from ._base import VectorPlotter, variable_type, categorical_order
from ._core.data import handle_data_source
from ._compat import share_axis, get_legend_handles
from . import utils
from .utils import (
Expand Down Expand Up @@ -374,6 +375,7 @@ def __init__(
):

super().__init__()
data = handle_data_source(data)

# Determine the hue facet layer information
hue_var = hue
Expand Down Expand Up @@ -1240,6 +1242,7 @@ def __init__(
"""

super().__init__()
data = handle_data_source(data)

# Sort out the variables that define the grid
numeric_cols = self._find_numeric_cols(data)
Expand Down
4 changes: 2 additions & 2 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,10 +2593,10 @@ def countplot(

if x is None and y is not None:
orient = "y"
x = 1
x = 1 if list(y) else None
elif x is not None and y is None:
orient = "x"
y = 1
y = 1 if list(x) else None
elif x is not None and y is not None:
raise TypeError("Cannot pass values for both `x` and `y`.")

Expand Down
1 change: 1 addition & 0 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,7 @@ def rugplot(
x = data
elif axis == "y":
y = data
data = None
msg = textwrap.dedent(f"""\n
The `axis` parameter has been deprecated; use the `{axis}` parameter instead.
Please update your code; this will become an error in seaborn v0.13.0.
Expand Down
Loading

0 comments on commit 58cf628

Please sign in to comment.