Skip to content

Commit

Permalink
Generalize internal scaling operation (#3440)
Browse files Browse the repository at this point in the history
* Small renaming and reorganization

* Convert data using transform attached to axis in _base

* Replace simple inverse scaling patterns

* Fix generalized scaling in boxplot/stripplot

* Address final remaining scaling steps

* Fix boxplot stat unscaling
  • Loading branch information
mwaskom authored Aug 20, 2023
1 parent 2386036 commit af613f1
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 138 deletions.
52 changes: 25 additions & 27 deletions seaborn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,7 @@ def iter_data(
)

# Reduce to the semantics used in this plot
grouping_vars = [
var for var in grouping_vars if var in self.variables
]
grouping_vars = [var for var in grouping_vars if var in self.variables]

if from_comp_data:
data = self.comp_data
Expand All @@ -1040,22 +1038,21 @@ def iter_data(
levels = self.var_levels.copy()
if from_comp_data:
for axis in {"x", "y"} & set(grouping_vars):
converter = self.converters[axis].iloc[0]
if self.var_types[axis] == "categorical":
if self._var_ordered[axis]:
# If the axis is ordered, then the axes in a possible
# facet grid are by definition "shared", or there is a
# single axis with a unique cat -> idx mapping.
# So we can just take the first converter object.
converter = self.converters[axis].iloc[0]
levels[axis] = converter.convert_units(levels[axis])
else:
# Otherwise, the mappings may not be unique, but we can
# use the unique set of index values in comp_data.
levels[axis] = np.sort(data[axis].unique())
elif self.var_types[axis] == "datetime":
levels[axis] = mpl.dates.date2num(levels[axis])
elif self.var_types[axis] == "numeric" and self._log_scaled(axis):
levels[axis] = np.log10(levels[axis])
else:
transform = converter.get_transform().transform
levels[axis] = transform(converter.convert_units(levels[axis]))

if grouping_vars:

Expand Down Expand Up @@ -1129,9 +1126,8 @@ def comp_data(self):
# supporting `order` in categorical plots is tricky
orig = orig[orig.isin(self.var_levels[var])]
comp = pd.to_numeric(converter.convert_units(orig)).astype(float)
if converter.get_scale() == "log":
comp = np.log10(comp)
parts.append(pd.Series(comp, orig.index, name=orig.name))
transform = converter.get_transform().transform
parts.append(pd.Series(transform(comp), orig.index, name=orig.name))
if parts:
comp_col = pd.concat(parts)
else:
Expand Down Expand Up @@ -1300,25 +1296,27 @@ def _attach(

# TODO -- Add axes labels

def _log_scaled(self, axis):
"""Return True if specified axis is log scaled on all attached axes."""
if not hasattr(self, "ax"):
return False

def _get_scale_transforms(self, axis):
"""Return a function implementing the scale transform (or its inverse)."""
if self.ax is None:
axes_list = self.facets.axes.flatten()
axis_list = [getattr(ax, f"{axis}axis") for ax in self.facets.axes.flat]
scales = {axis.get_scale() for axis in axis_list}
if len(scales) > 1:
# It is a simplifying assumption that faceted axes will always have
# the same scale (even if they are unshared and have distinct limits).
# Nothing in the seaborn API allows you to create a FacetGrid with
# a mixture of scales, although it's possible via matplotlib.
# This is constraining, but no more so than previous behavior that
# only (properly) handled log scales, and there are some places where
# it would be much too complicated to use axes-specific transforms.
err = "Cannot determine transform with mixed scales on faceted axes."
raise RuntimeError(err)
transform_obj = axis_list[0].get_transform()
else:
axes_list = [self.ax]

log_scaled = []
for ax in axes_list:
data_axis = getattr(ax, f"{axis}axis")
log_scaled.append(data_axis.get_scale() == "log")

if any(log_scaled) and not all(log_scaled):
raise RuntimeError("Axis scaling is not consistent")
# This case is more straightforward
transform_obj = getattr(self.ax, f"{axis}axis").get_transform()

return any(log_scaled)
return transform_obj.transform, transform_obj.inverted().transform

def _add_axis_labels(self, ax, default_x="", default_y=""):
"""Add axis labels if not present, set visibility to match ticklabels."""
Expand Down
79 changes: 37 additions & 42 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_check_argument,
_draw_figure,
_default_color,
_get_transform_functions,
_normalize_kwargs,
_version_predates,
)
Expand Down Expand Up @@ -371,7 +372,7 @@ def _dodge(self, keys, data):
def _invert_scale(self, ax, data, vars=("x", "y")):
"""Undo scaling after computation so data are plotted correctly."""
for var in vars:
_, inv = utils._get_transform_functions(ax, var[0])
_, inv = _get_transform_functions(ax, var[0])
if var == self.orient and "width" in data:
hw = data["width"] / 2
data["edge"] = inv(data[var] - hw)
Expand Down Expand Up @@ -528,9 +529,7 @@ def plot_swarms(
if not sub_data.empty:
point_collections[(ax, sub_data[self.orient].iloc[0])] = points

beeswarm = Beeswarm(
width=width, orient=self.orient, warn_thresh=warn_thresh,
)
beeswarm = Beeswarm(width=width, orient=self.orient, warn_thresh=warn_thresh)
for (ax, center), points in point_collections.items():
if points.get_offsets().shape[0] > 1:

Expand Down Expand Up @@ -627,6 +626,12 @@ def get_props(element, artist=mpl.lines.Line2D):
capwidth = plot_kws.get("capwidths", 0.5 * data["width"])

self._invert_scale(ax, data)
_, inv = _get_transform_functions(ax, value_var)
for stat in ["mean", "med", "q1", "q3", "cilo", "cihi", "whislo", "whishi"]:
stats[stat] = inv(stats[stat])
stats["fliers"] = stats["fliers"].map(inv)

linear_orient_scale = getattr(ax, f"get_{self.orient}scale")() == "linear"

maincolor = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color
if fill:
Expand All @@ -651,8 +656,8 @@ def get_props(element, artist=mpl.lines.Line2D):
default_kws = dict(
bxpstats=stats.to_dict("records"),
positions=data[self.orient],
# Set width to 0 with log scaled orient axis to avoid going < 0
widths=0 if self._log_scaled(self.orient) else data["width"],
# Set width to 0 to avoid going out of domain
widths=data["width"] if linear_orient_scale else 0,
patch_artist=fill,
vert=self.orient == "x",
manage_ticks=False,
Expand All @@ -673,7 +678,8 @@ def get_props(element, artist=mpl.lines.Line2D):

# Reset artist widths after adding so everything stays positive
ori_idx = ["x", "y"].index(self.orient)
if self._log_scaled(self.orient):

if not linear_orient_scale:
for i, box in enumerate(data.to_dict("records")):
p0 = box["edge"]
p1 = box["edge"] + box["width"]
Expand Down Expand Up @@ -702,9 +708,10 @@ def get_props(element, artist=mpl.lines.Line2D):
artists["medians"][i].set_data(verts)

if artists["caps"]:
f_fwd, f_inv = _get_transform_functions(ax, self.orient)
for line in artists["caps"][2 * i:2 * i + 2]:
p0 = 10 ** (np.log10(box[self.orient]) - capwidth[i] / 2)
p1 = 10 ** (np.log10(box[self.orient]) + capwidth[i] / 2)
p0 = f_inv(f_fwd(box[self.orient]) - capwidth[i] / 2)
p1 = f_inv(f_fwd(box[self.orient]) + capwidth[i] / 2)
verts = line.get_xydata().T
verts[ori_idx][:] = p0, p1
line.set_data(verts)
Expand Down Expand Up @@ -769,8 +776,8 @@ def plot_boxens(
allow_empty=False):

ax = self._get_axes(sub_vars)
_, inv_ori = utils._get_transform_functions(ax, self.orient)
_, inv_val = utils._get_transform_functions(ax, value_var)
_, inv_ori = _get_transform_functions(ax, self.orient)
_, inv_val = _get_transform_functions(ax, value_var)

# Statistics
lv_data = estimator(sub_data[value_var])
Expand Down Expand Up @@ -1010,8 +1017,8 @@ def vars_to_key(sub_vars):
offsets = span, span

ax = violin["ax"]
_, invx = utils._get_transform_functions(ax, "x")
_, invy = utils._get_transform_functions(ax, "y")
_, invx = _get_transform_functions(ax, "x")
_, invy = _get_transform_functions(ax, "y")
inv_pos = {"x": invx, "y": invy}[self.orient]
inv_val = {"x": invx, "y": invy}[value_var]

Expand Down Expand Up @@ -1168,17 +1175,11 @@ def plot_points(
markers = self._map_prop_with_hue("marker", markers, "o", plot_kws)
linestyles = self._map_prop_with_hue("linestyle", linestyles, "-", plot_kws)

positions = self.var_levels[self.orient]
base_positions = self.var_levels[self.orient]
if self.var_types[self.orient] == "categorical":
min_cat_val = int(self.comp_data[self.orient].min())
max_cat_val = int(self.comp_data[self.orient].max())
positions = [i for i in range(min_cat_val, max_cat_val + 1)]
else:
if self._log_scaled(self.orient):
positions = np.log10(positions)
if self.var_types[self.orient] == "datetime":
positions = mpl.dates.date2num(positions)
positions = pd.Index(positions, name=self.orient)
base_positions = [i for i in range(min_cat_val, max_cat_val + 1)]

n_hue_levels = 0 if self._hue_map.levels is None else len(self._hue_map.levels)
if dodge is True:
Expand All @@ -1192,11 +1193,14 @@ def plot_points(

ax = self._get_axes(sub_vars)

ori_axis = getattr(ax, f"{self.orient}axis")
transform, _ = _get_transform_functions(ax, self.orient)
positions = transform(ori_axis.convert_units(base_positions))
agg_data = sub_data if sub_data.empty else (
sub_data
.groupby(self.orient)
.apply(aggregator, agg_var)
.reindex(positions)
.reindex(pd.Index(positions, name=self.orient))
.reset_index()
)

Expand Down Expand Up @@ -1316,14 +1320,12 @@ def plot_errorbars(self, ax, data, capsize, err_kws):
pos = np.array([row[self.orient], row[self.orient]])
val = np.array([row[f"{var}min"], row[f"{var}max"]])

cw = capsize * self._native_width / 2
if self._log_scaled(self.orient):
log_pos = np.log10(pos)
cap = 10 ** (log_pos[0] - cw), 10 ** (log_pos[1] + cw)
else:
cap = pos[0] - cw, pos[1] + cw

if capsize:

cw = capsize * self._native_width / 2
scl, inv = _get_transform_functions(ax, self.orient)
cap = inv(scl(pos[0]) - cw), inv(scl(pos[1]) + cw)

pos = np.concatenate([
[*cap, np.nan], pos, [np.nan, *cap]
])
Expand Down Expand Up @@ -3220,13 +3222,12 @@ def __call__(self, points, center):
new_xy = new_xyr[:, :2]
new_x_data, new_y_data = ax.transData.inverted().transform(new_xy).T

log_scale = getattr(ax, f"get_{self.orient}scale")() == "log"

# Add gutters
t_fwd, t_inv = _get_transform_functions(ax, self.orient)
if self.orient == "y":
self.add_gutters(new_y_data, center, log_scale=log_scale)
self.add_gutters(new_y_data, center, t_fwd, t_inv)
else:
self.add_gutters(new_x_data, center, log_scale=log_scale)
self.add_gutters(new_x_data, center, t_fwd, t_inv)

# Reposition the points so they do not overlap
if self.orient == "y":
Expand Down Expand Up @@ -3330,20 +3331,14 @@ def first_non_overlapping_candidate(self, candidates, neighbors):
"No non-overlapping candidates found. This should not happen."
)

def add_gutters(self, points, center, log_scale=False):
def add_gutters(self, points, center, trans_fwd, trans_inv):
"""Stop points from extending beyond their territory."""
half_width = self.width / 2
if log_scale:
low_gutter = 10 ** (np.log10(center) - half_width)
else:
low_gutter = center - half_width
low_gutter = trans_inv(trans_fwd(center) - half_width)
off_low = points < low_gutter
if off_low.any():
points[off_low] = low_gutter
if log_scale:
high_gutter = 10 ** (np.log10(center) + half_width)
else:
high_gutter = center + half_width
high_gutter = trans_inv(trans_fwd(center) + half_width)
off_high = points > high_gutter
if off_high.any():
points[off_high] = high_gutter
Expand Down
Loading

0 comments on commit af613f1

Please sign in to comment.