Skip to content

Commit

Permalink
Allow x=/y= to control paired scales and limits (#3458)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom authored Sep 1, 2023
1 parent 5d530ac commit 74a8301
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
18 changes: 12 additions & 6 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,11 +1256,16 @@ def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
data.frame = res

def _get_scale(
self, spec: Plot, var: str, prop: Property, values: Series
self, p: Plot, var: str, prop: Property, values: Series
) -> Scale:

if var in spec._scales:
arg = spec._scales[var]
if re.match(r"[xy]\d+", var):
key = var if var in p._scales else var[0]
else:
key = var

if key in p._scales:
arg = p._scales[key]
if arg is None or isinstance(arg, Scale):
scale = arg
else:
Expand Down Expand Up @@ -1293,7 +1298,8 @@ def _get_subplot_data(self, df, var, view, share_state):
return seed_values

def _setup_scales(
self, p: Plot,
self,
p: Plot,
common: PlotData,
layers: list[Layer],
variables: list[str] | None = None,
Expand Down Expand Up @@ -1786,9 +1792,9 @@ def _finalize_figure(self, p: Plot) -> None:
axis_obj = getattr(ax, f"{axis}axis")

# Axis limits
if axis_key in p._limits:
if axis_key in p._limits or axis in p._limits:
convert_units = getattr(ax, f"{axis}axis").convert_units
a, b = p._limits[axis_key]
a, b = p._limits.get(axis_key) or p._limits[axis]
lo = a if a is None else convert_units(a)
hi = b if b is None else convert_units(b)
if isinstance(a, str):
Expand Down
18 changes: 14 additions & 4 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,16 @@ def test_paired_single_log_scale(self):
xfm_log = ax_log.xaxis.get_transform().transform
assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2])

def test_paired_with_common_fallback(self):

x0, x1 = [1, 2, 3], [1, 10, 100]
p = Plot().pair(x=[x0, x1]).scale(x="pow", x1="log").plot()
ax_pow, ax_log = p._figure.axes
xfm_pow = ax_pow.xaxis.get_transform().transform
assert_array_equal(xfm_pow([1, 2, 3]), [1, 4, 9])
xfm_log = ax_log.xaxis.get_transform().transform
assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2])

@pytest.mark.xfail(reason="Custom log scale needs log name for consistency")
def test_log_scale_name(self):

Expand Down Expand Up @@ -1734,10 +1744,10 @@ def test_two_variables_single_order_error(self, long_df):

def test_limits(self, long_df):

limit = (-2, 24)
p = Plot(long_df, y="y").pair(x=["x", "z"]).limit(x1=limit).plot()
ax1 = p._figure.axes[1]
assert ax1.get_xlim() == limit
lims = (-3, 10), (-2, 24)
p = Plot(long_df, y="y").pair(x=["x", "z"]).limit(x=lims[0], x1=lims[1]).plot()
for ax, lim in zip(p._figure.axes, lims):
assert ax.get_xlim() == lim

def test_labels(self, long_df):

Expand Down

0 comments on commit 74a8301

Please sign in to comment.