Skip to content

Commit

Permalink
Avoid a matplotlib warning for strip/swarmplot with unfilled marker
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Sep 26, 2023
1 parent 6b9e2fb commit a0c73b2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
8 changes: 8 additions & 0 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_scatter_legend_artist,
_version_predates,
)
from seaborn._compat import MarkerStyle
from seaborn._statistics import EstimateAggregator, LetterValues
from seaborn.palettes import light_palette
from seaborn.axisgrid import FacetGrid, _facet_docs
Expand Down Expand Up @@ -481,6 +482,9 @@ def plot_strips(
ax = self.ax
dodge_move = jitter_move = 0

if "marker" in plot_kws and not MarkerStyle(plot_kws["marker"]).is_filled():
plot_kws.pop("edgecolor", None)

for sub_vars, sub_data in self.iter_data(iter_vars,
from_comp_data=True,
allow_empty=True):
Expand Down Expand Up @@ -521,6 +525,9 @@ def plot_swarms(
point_collections = {}
dodge_move = 0

if "marker" in plot_kws and not MarkerStyle(plot_kws["marker"]).is_filled():
plot_kws.pop("edgecolor", None)

for sub_vars, sub_data in self.iter_data(iter_vars,
from_comp_data=True,
allow_empty=True):
Expand All @@ -534,6 +541,7 @@ def plot_swarms(
sub_data[self.orient] = sub_data[self.orient] + dodge_move

self._invert_scale(ax, sub_data)

points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws)
if "hue" in self.variables:
points.set_facecolors(self._hue_map(sub_data["hue"]))
Expand Down
10 changes: 10 additions & 0 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from functools import partial
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -256,6 +257,15 @@ def test_supplied_color_array(self, long_df):
_draw_figure(ax.figure)
assert_array_equal(ax.collections[0].get_facecolors(), colors)

def test_unfilled_marker(self, long_df):

with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ax = self.func(long_df, x="y", y="a", marker="x", color="r")
for points in ax.collections:
assert same_color(points.get_facecolors().squeeze(), "r")
assert same_color(points.get_edgecolors().squeeze(), "r")

@pytest.mark.parametrize(
"orient,data_type", [
("h", "dataframe"), ("h", "dict"),
Expand Down

0 comments on commit a0c73b2

Please sign in to comment.