",
@@ -208,6 +222,7 @@ def _repr_html_(self) -> str:
class DisplayConfig(TypedDict):
"""Configuration for IPython's rich display hooks."""
+
format: Literal["png", "svg"]
scaling: float
hidpi: bool
@@ -215,10 +230,10 @@ class DisplayConfig(TypedDict):
class PlotConfig:
"""Configuration for default behavior / appearance of class:`Plot` instances."""
- def __init__(self):
+ def __init__(self):
self._theme = ThemeConfig()
- self._display = {"format": "png", "scaling": .85, "hidpi": True}
+ self._display = {"format": "png", "scaling": 0.85, "hidpi": True}
@property
def theme(self) -> dict[str, Any]:
@@ -284,6 +299,7 @@ class Plot:
the plot without rendering it to access the lower-level representation.
"""
+
config = PlotConfig()
_data: PlotData
@@ -308,7 +324,6 @@ def __init__(
data: DataSource = None,
**variables: VariableSpec,
):
-
if args:
data, variables = self._resolve_positionals(args, data, variables)
@@ -347,9 +362,8 @@ def _resolve_positionals(
err = "Plot() accepts no more than 3 positional arguments (data, x, y)."
raise TypeError(err)
- if (
- isinstance(args[0], (abc.Mapping, pd.DataFrame))
- or hasattr(args[0], "__dataframe__")
+ if isinstance(args[0], (abc.Mapping, pd.DataFrame)) or hasattr(
+ args[0], "__dataframe__"
):
if data is not None:
raise TypeError("`data` given by both name and position.")
@@ -373,7 +387,6 @@ def _resolve_positionals(
return data, variables
def __add__(self, other):
-
if isinstance(other, Mark) or isinstance(other, Stat):
raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?")
@@ -381,13 +394,11 @@ def __add__(self, other):
raise TypeError(f"Unsupported operand type(s) for +: 'Plot' and '{other_type}")
def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
-
if Plot.config.display["format"] != "png":
return None
return self.plot()._repr_png_()
def _repr_svg_(self) -> str | None:
-
if Plot.config.display["format"] != "svg":
return None
return self.plot()._repr_svg_()
@@ -419,14 +430,12 @@ def _clone(self) -> Plot:
return new
def _theme_with_defaults(self) -> dict[str, Any]:
-
theme = self.config.theme.copy()
theme.update(self._theme)
return theme
@property
def _variables(self) -> list[str]:
-
variables = (
list(self._data.frame)
+ list(self._pair_spec.get("variables", []))
@@ -462,9 +471,7 @@ def on(self, target: Axes | SubFigure | Figure) -> Plot:
"""
accepted_types: tuple # Allow tuple of various length
- accepted_types = (
- mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure
- )
+ accepted_types = (mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure)
accepted_types_str = (
f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}"
)
@@ -552,25 +559,29 @@ def add(
error = len(move) != len(transforms)
if error:
- msg = " ".join([
- "Transforms must have at most one Stat type (in the first position),",
- "and all others must be a Move type. Given transform type(s):",
- ", ".join(str(type(t).__name__) for t in transforms) + "."
- ])
+ msg = " ".join(
+ [
+ "Transforms must have at most one Stat type (in the first position),",
+ "and all others must be a Move type. Given transform type(s):",
+ ", ".join(str(type(t).__name__) for t in transforms) + ".",
+ ]
+ )
raise TypeError(msg)
new = self._clone()
- new._layers.append({
- "mark": mark,
- "stat": stat,
- "move": move,
- # TODO it doesn't work to supply scalars to variables, but it should
- "vars": variables,
- "source": data,
- "legend": legend,
- "label": label,
- "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore
- })
+ new._layers.append(
+ {
+ "mark": mark,
+ "stat": stat,
+ "move": move,
+ # TODO it doesn't work to supply scalars to variables, but it should
+ "vars": variables,
+ "source": data,
+ "legend": legend,
+ "label": label,
+ "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore
+ }
+ )
return new
@@ -677,10 +688,12 @@ def facet(
structure[dim] = list(dim_order)
elif order is not None:
if col is not None and row is not None:
- err = " ".join([
- "When faceting on both col= and row=, passing `order` as a list"
- "is ambiguous. Use a dict with 'col' and/or 'row' keys instead."
- ])
+ err = " ".join(
+ [
+ "When faceting on both col= and row=, passing `order` as a list"
+ "is ambiguous. Use a dict with 'col' and/or 'row' keys instead."
+ ]
+ )
raise RuntimeError(err)
elif col is not None:
structure["col"] = list(order)
@@ -768,10 +781,11 @@ def limit(self, **limits: tuple[Any, Any]) -> Plot:
return new
def label(
- self, *,
+ self,
+ *,
title: str | None = None,
legend: str | None = None,
- **variables: str | Callable[[str], str]
+ **variables: str | Callable[[str], str],
) -> Plot:
"""
Control the labels and titles for axes, legends, and subplots.
@@ -932,7 +946,6 @@ def plot(self, pyplot: bool = False) -> Plotter:
return self._plot(pyplot)
def _plot(self, pyplot: bool = False) -> Plotter:
-
# TODO if we have _target object, pyplot should be determined by whether it
# is hooked into the pyplot state machine (how do we check?)
@@ -978,18 +991,22 @@ class Plotter:
This class is not intended to be instantiated directly by users.
"""
+
# TODO decide if we ever want these (Plot.plot(debug=True))?
_data: PlotData
_layers: list[Layer]
_figure: Figure
def __init__(self, pyplot: bool, theme: dict[str, Any]):
-
self._pyplot = pyplot
self._theme = theme
- self._legend_contents: list[tuple[
- tuple[str, str | int], list[Artist], list[str],
- ]] = []
+ self._legend_contents: list[
+ tuple[
+ tuple[str, str | int],
+ list[Artist],
+ list[str],
+ ]
+ ] = []
self._scales: dict[str, Scale] = {}
def save(self, loc, **kwargs) -> Plotter: # TODO type args
@@ -1012,6 +1029,7 @@ def show(self, **kwargs) -> None:
# TODO if we did not create the Plotter with pyplot, is it possible to do this?
# If not we should clearly raise.
import matplotlib.pyplot as plt
+
with theme_context(self._theme):
plt.show(**kwargs)
@@ -1019,7 +1037,6 @@ def show(self, **kwargs) -> None:
# TODO what else is useful in the public API for this class?
def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
-
# TODO use matplotlib backend directly instead of going through savefig?
# TODO perhaps have self.show() flip a switch to disable this, so that
@@ -1048,7 +1065,6 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
return data, metadata
def _repr_svg_(self) -> str | None:
-
if Plot.config.display["format"] != "svg":
return None
@@ -1069,11 +1085,8 @@ def _repr_svg_(self) -> str | None:
return out.getvalue().decode()
def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:
-
- common_data = (
- p._data
- .join(None, p._facet_spec.get("variables"))
- .join(None, p._pair_spec.get("variables"))
+ common_data = p._data.join(None, p._facet_spec.get("variables")).join(
+ None, p._pair_spec.get("variables")
)
layers: list[Layer] = []
@@ -1085,7 +1098,6 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:
return common_data, layers
def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str:
-
if re.match(r"[xy]\d+", var):
key = var if var in p._labels else var[0]
else:
@@ -1105,7 +1117,6 @@ def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str:
return label
def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
-
# --- Parsing the faceting/pairing parameterization to specify figure grid
subplot_spec = p._subplot_spec.copy()
@@ -1125,7 +1136,10 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
# --- Figure initialization
self._figure = subplots.init_figure(
- pair_spec, self._pyplot, p._figure_spec, p._target,
+ pair_spec,
+ self._pyplot,
+ p._figure_spec,
+ p._target,
)
# --- Figure annotation
@@ -1142,7 +1156,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
# something to be desired (in terms of how it defines 'centered').
names = [
common.names.get(axis_key),
- *(layer["data"].names.get(axis_key) for layer in layers)
+ *(layer["data"].names.get(axis_key) for layer in layers),
]
auto_label = next((name for name in names if name is not None), None)
label = self._resolve_label(p, axis_key, auto_label)
@@ -1165,12 +1179,9 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
)
axis_obj.get_label().set_visible(show_axis_label)
- show_tick_labels = (
- show_axis_label
- or subplot_spec.get(f"share{axis}") not in (
- True, "all", {"x": "col", "y": "row"}[axis]
- )
- )
+ show_tick_labels = show_axis_label or subplot_spec.get(
+ f"share{axis}"
+ ) not in (True, "all", {"x": "col", "y": "row"}[axis])
for group in ("major", "minor"):
side = {"x": "bottom", "y": "left"}[axis]
axis_obj.set_tick_params(**{f"label{side}": show_tick_labels})
@@ -1192,8 +1203,10 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
has_col = sub["col"] is not None
has_row = sub["row"] is not None
show_title = (
- has_col and has_row
- or (has_col or has_row) and p._facet_spec.get("wrap")
+ has_col
+ and has_row
+ or (has_col or has_row)
+ and p._facet_spec.get("wrap")
or (has_col and sub["top"])
# TODO or has_row and sub["right"] and
or has_row # TODO and not
@@ -1207,14 +1220,12 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
title_text = ax.set_title(title)
def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
-
grouping_vars = [v for v in PROPERTIES if v not in "xy"]
grouping_vars += ["col", "row", "group"]
pair_vars = spec._pair_spec.get("structure", {})
for layer in layers:
-
data = layer["data"]
mark = layer["mark"]
stat = layer["stat"]
@@ -1222,9 +1233,9 @@ def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
if stat is None:
continue
- iter_axes = itertools.product(*[
- pair_vars.get(axis, [axis]) for axis in "xy"
- ])
+ iter_axes = itertools.product(
+ *[pair_vars.get(axis, [axis]) for axis in "xy"]
+ )
old = data.frame
@@ -1233,7 +1244,6 @@ def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
data.frame = data.frame.iloc[:0] # TODO to simplify typing
for coord_vars in iter_axes:
-
pairings = "xy", coord_vars
df = old.copy()
@@ -1260,10 +1270,7 @@ def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
else:
data.frame = res
- def _get_scale(
- self, p: Plot, var: str, prop: Property, values: Series
- ) -> Scale:
-
+ def _get_scale(self, p: Plot, var: str, prop: Property, values: Series) -> Scale:
if re.match(r"[xy]\d+", var):
key = var if var in p._scales else var[0]
else:
@@ -1281,7 +1288,6 @@ def _get_scale(
return scale
def _get_subplot_data(self, df, var, view, share_state):
-
if share_state in [True, "all"]:
# The all-shared case is easiest, every subplot sees all the data
seed_values = df[var]
@@ -1309,7 +1315,6 @@ def _setup_scales(
layers: list[Layer],
variables: list[str] | None = None,
) -> None:
-
if variables is None:
# Add variables that have data but not a scale, which happens
# because this method can be called multiple time, to handle
@@ -1322,7 +1327,6 @@ def _setup_scales(
variables = [v for v in variables if v not in self._scales]
for var in variables:
-
# Determine whether this is a coordinate variable
# (i.e., x/y, paired x/y, or derivative such as xmax)
m = re.match(r"^(?P(?Px|y)\d*).*", var)
@@ -1390,7 +1394,6 @@ def _setup_scales(
transformed_data.append(empty_series)
for view in subplots:
-
axis_obj = getattr(view["ax"], f"{axis}axis")
seed_values = self._get_subplot_data(var_df, var, view, share_state)
view_scale = scale._setup(seed_values, prop, axis=axis_obj)
@@ -1415,7 +1418,6 @@ def _setup_scales(
layer_df[var] = pd.to_numeric(new_series)
def _plot_layer(self, p: Plot, layer: Layer) -> None:
-
data = layer["data"]
mark = layer["mark"]
move = layer["move"]
@@ -1426,7 +1428,6 @@ def _plot_layer(self, p: Plot, layer: Layer) -> None:
pair_variables = p._pair_spec.get("structure", {})
for subplots, df, scales in self._generate_pairings(data, pair_variables):
-
orient = layer["orient"] or mark._infer_orient(scales)
def get_order(var):
@@ -1449,9 +1450,11 @@ def get_order(var):
elif "width" in df:
view_width = view_df["width"]
else:
- view_width = 0.8 # TODO what default?
- spacing = scales[orient]._spacing(view_df.loc[view_idx, orient])
- width.loc[view_idx] = view_width * spacing
+ # TODO what default?
+ view_width = 0.8 * scales[orient]._spacing(
+ view_df.loc[view_idx, orient]
+ )
+ width.loc[view_idx] = view_width
df["width"] = width
if "baseline" in mark._mappable_props:
@@ -1495,13 +1498,15 @@ def get_order(var):
self._update_legend_contents(p, mark, data, scales, layer["label"])
def _unscale_coords(
- self, subplots: list[dict], df: DataFrame, orient: str,
+ self,
+ subplots: list[dict],
+ df: DataFrame,
+ orient: str,
) -> DataFrame:
# TODO do we still have numbers in the variable name at this point?
coord_cols = [c for c in df if re.match(r"^[xy]\D*$", str(c))]
out_df = (
- df
- .drop(coord_cols, axis=1)
+ df.drop(coord_cols, axis=1)
.reindex(df.columns, axis=1) # So unscaled columns retain their place
.copy(deep=False)
)
@@ -1510,7 +1515,6 @@ def _unscale_coords(
view_df = self._filter_subplot_data(df, view)
axes_df = view_df[coord_cols]
for var, values in axes_df.items():
-
axis = getattr(view["ax"], f"{str(var)[0]}axis")
# TODO see https://github.com/matplotlib/matplotlib/issues/22713
transform = axis.get_transform().inverted().transform
@@ -1520,18 +1524,17 @@ def _unscale_coords(
return out_df
def _generate_pairings(
- self, data: PlotData, pair_variables: dict,
- ) -> Generator[
- tuple[list[dict], DataFrame, dict[str, Scale]], None, None
- ]:
+ self,
+ data: PlotData,
+ pair_variables: dict,
+ ) -> Generator[tuple[list[dict], DataFrame, dict[str, Scale]], None, None]:
# TODO retype return with subplot_spec or similar
- iter_axes = itertools.product(*[
- pair_variables.get(axis, [axis]) for axis in "xy"
- ])
+ iter_axes = itertools.product(
+ *[pair_variables.get(axis, [axis]) for axis in "xy"]
+ )
for x, y in iter_axes:
-
subplots = []
for view in self._subplots:
if (view["x"] == x) and (view["y"] == y):
@@ -1562,7 +1565,6 @@ def _generate_pairings(
yield subplots, out_df, scales
def _get_subplot_index(self, df: DataFrame, subplot: dict) -> Index:
-
dims = df.columns.intersection(["col", "row"])
if dims.empty:
return df.index
@@ -1584,9 +1586,11 @@ def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame:
return df[keep_rows]
def _setup_split_generator(
- self, grouping_vars: list[str], df: DataFrame, subplots: list[dict[str, Any]],
+ self,
+ grouping_vars: list[str],
+ df: DataFrame,
+ subplots: list[dict[str, Any]],
) -> Callable[[], Generator]:
-
grouping_keys = []
grouping_vars = [
v for v in grouping_vars if v in df and v not in ["col", "row"]
@@ -1598,9 +1602,7 @@ def _setup_split_generator(
grouping_keys.append(order)
def split_generator(keep_na=False) -> Generator:
-
for view in subplots:
-
axes_df = self._filter_subplot_data(df, view)
axes_df_inf_as_nan = axes_df.copy()
@@ -1633,13 +1635,16 @@ def split_generator(keep_na=False) -> Generator:
continue
grouped_df = axes_df.groupby(
- grouping_vars, sort=False, as_index=False, observed=False,
+ grouping_vars,
+ sort=False,
+ as_index=False,
+ observed=False,
)
for key in itertools.product(*grouping_keys):
-
pd_key = (
- key[0] if len(key) == 1 and _version_predates(pd, "2.2.0")
+ key[0]
+ if len(key) == 1 and _version_predates(pd, "2.2.0")
else key
)
try:
@@ -1696,9 +1701,9 @@ def _update_legend_contents(
# Then handle the scale legends
# First pass: Identify the values that will be shown for each variable
- schema: list[tuple[
- tuple[str, str | int], list[str], tuple[list[Any], list[str]]
- ]] = []
+ schema: list[
+ tuple[tuple[str, str | int], list[str], tuple[list[Any], list[str]]]
+ ] = []
schema = []
for var in legend_vars:
var_legend = scales[var]._legend
@@ -1733,7 +1738,8 @@ def _make_legend(self, p: Plot) -> None:
# Input list has an entry for each distinct variable in each layer
# Output dict has an entry for each distinct variable
merged_contents: dict[
- tuple[str, str | int], tuple[list[tuple[Artist, ...]], list[str]],
+ tuple[str, str | int],
+ tuple[list[tuple[Artist, ...]], list[str]],
] = {}
for key, new_artists, labels in self._legend_contents:
# Key is (name, id); we need the id to resolve variable uniqueness,
@@ -1755,14 +1761,13 @@ def _make_legend(self, p: Plot) -> None:
base_legend = None
for (name, _), (handles, labels) in merged_contents.items():
-
legend = mpl.legend.Legend(
self._figure,
handles, # type: ignore # matplotlib/issues/26639
labels,
title=name,
loc=loc,
- bbox_to_anchor=(.98, .55),
+ bbox_to_anchor=(0.98, 0.55),
)
if base_legend:
@@ -1777,7 +1782,6 @@ def _make_legend(self, p: Plot) -> None:
self._figure.legends.append(legend)
def _finalize_figure(self, p: Plot) -> None:
-
for sub in self._subplots:
ax = sub["ax"]
for axis in "xy":
diff --git a/seaborn/_stats/counting.py b/seaborn/_stats/counting.py
index 0c2fb7d499..9e4f7e35fc 100644
--- a/seaborn/_stats/counting.py
+++ b/seaborn/_stats/counting.py
@@ -178,7 +178,7 @@ def _eval(self, data, orient, bin_kws):
width = np.diff(edges)
center = edges[:-1] + width / 2
- return pd.DataFrame({orient: center, "count": hist, "space": width})
+ return pd.DataFrame({orient: center, "count": hist, "width": width})
def _normalize(self, data):
@@ -188,11 +188,11 @@ def _normalize(self, data):
elif self.stat == "percent":
hist = hist.astype(float) / hist.sum() * 100
elif self.stat == "frequency":
- hist = hist.astype(float) / data["space"]
+ hist = hist.astype(float) / data["width"]
if self.cumulative:
if self.stat in ["density", "frequency"]:
- hist = (hist * data["space"]).cumsum()
+ hist = (hist * data["width"]).cumsum()
else:
hist = hist.cumsum()
diff --git a/seaborn/distributions.py b/seaborn/distributions.py
index f8ec166cf4..40d9115a9e 100644
--- a/seaborn/distributions.py
+++ b/seaborn/distributions.py
@@ -470,7 +470,7 @@ def plot_univariate_histogram(
bin_kws = estimator._define_bin_params(sub_data, orient, None)
res = estimator._normalize(estimator._eval(sub_data, orient, bin_kws))
heights = res[estimator.stat].to_numpy()
- widths = res["space"].to_numpy()
+ widths = res["width"].to_numpy()
edges = res[orient].to_numpy() - widths / 2
# Rescale the smoothed curve to match the histogram
diff --git a/tests/_stats/test_counting.py b/tests/_stats/test_counting.py
index 7656654492..bdfe2857ab 100644
--- a/tests/_stats/test_counting.py
+++ b/tests/_stats/test_counting.py
@@ -152,13 +152,13 @@ def test_density_stat(self, long_df, single_args):
h = Hist(stat="density")
out = h(long_df, *single_args)
- assert (out["y"] * out["space"]).sum() == 1
+ assert (out["y"] * out["width"]).sum() == 1
def test_frequency_stat(self, long_df, single_args):
h = Hist(stat="frequency")
out = h(long_df, *single_args)
- assert (out["y"] * out["space"]).sum() == len(long_df)
+ assert (out["y"] * out["width"]).sum() == len(long_df)
def test_invalid_stat(self):
@@ -248,7 +248,7 @@ def test_histogram_single(self, long_df, single_args):
out = h(long_df, *single_args)
hist, edges = np.histogram(long_df["x"], bins="auto")
assert_array_equal(out["y"], hist)
- assert_array_equal(out["space"], np.diff(edges))
+ assert_array_equal(out["width"], np.diff(edges))
def test_histogram_multiple(self, long_df, triple_args):
@@ -259,4 +259,4 @@ def test_histogram_multiple(self, long_df, triple_args):
x = long_df.loc[(long_df["a"] == a) & (long_df["s"] == s), "x"]
hist, edges = np.histogram(x, bins=bins)
assert_array_equal(out_part["y"], hist)
- assert_array_equal(out_part["space"], np.diff(edges))
+ assert_array_equal(out_part["width"], np.diff(edges))