Skip to content

Commit

Permalink
Add size parameter to scatter artist
Browse files Browse the repository at this point in the history
  • Loading branch information
jo-mueller committed Jul 23, 2024
1 parent 65e2e9b commit 064248d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/artists_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
~Scatter.data
~Scatter.visible
~Scatter.color_indices
~Scatter.size
.. rubric:: Attributes Summary
Expand All @@ -77,6 +78,7 @@
.. autoattribute:: data
.. autoattribute:: visible
.. autoattribute:: color_indices
.. autoattribute:: size
.. rubric:: Attributes Documentation
Expand Down
18 changes: 18 additions & 0 deletions src/biaplotter/_tests/test_artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def on_color_indices_changed(color_indices):
assert np.all(colors[0] == scatter.categorical_colormap(0))
assert np.all(colors[50] == scatter.categorical_colormap(2))

# Test size property
scatter.size = 5.0
assert scatter.size == 5.0
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == 5.0)

scatter.size = np.linspace(1, 10, size)
assert np.all(scatter.size == np.linspace(1, 10, size))
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == np.linspace(1, 10, size))

# Test size reset when new data is set
new_data = np.random.rand(size, 2)
scatter.data = new_data
assert scatter.size == 1.0
sizes = scatter._scatter.get_sizes()
assert np.all(sizes == 1.0)


def test_histogram2d():
# Inputs
Expand Down
25 changes: 24 additions & 1 deletion src/biaplotter/artists.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, ax: plt.Axes = None, data: np.ndarray = None, categorical_col
super().__init__(ax, data, categorical_colormap, color_indices)
#: Stores the scatter plot matplotlib object
self._scatter = None
self._size = 1 # Default size
self.data = data
self.draw() # Initial draw of the scatter plot

Expand Down Expand Up @@ -160,7 +161,7 @@ def data(self, value: np.ndarray):
# emit signal
self.data_changed_signal.emit(self._data)
if self._scatter is None:
self._scatter = self.ax.scatter(value[:, 0], value[:, 1])
self._scatter = self.ax.scatter(value[:, 0], value[:, 1], s=self._size)
self.color_indices = 0 # Set default color index
else:
# If the scatter plot already exists, just update its data
Expand All @@ -176,6 +177,7 @@ def data(self, value: np.ndarray):
# fill with zeros where new data is larger
color_indices[color_indices_size:] = 0
self.color_indices = color_indices
self.size = 1 # Reset size to default
self.draw()

@property
Expand Down Expand Up @@ -236,6 +238,27 @@ def color_indices(self, indices: np.ndarray):
self.color_indices_changed_signal.emit(self._color_indices)
self.draw()

@property
def size(self) -> float | np.ndarray:
"""Gets or sets the size of the points in the scatter plot.
Triggers a draw idle command.
Returns
-------
size : float or (N,) np.ndarray[float]
size of the points in the scatter plot. Accepts a scalar or an array of floats.
"""
return self._size

@size.setter
def size(self, value: float | np.ndarray):
"""Sets the size of the points in the scatter plot."""
self._size = value
if self._scatter is not None:
self._scatter.set_sizes(np.full(len(self._data), value) if np.isscalar(value) else value)
self.draw()

def draw(self):
"""Draws or redraws the scatter plot."""
self.ax.figure.canvas.draw_idle()
Expand Down

0 comments on commit 064248d

Please sign in to comment.