Skip to content

Commit

Permalink
Merge pull request #28 from hanjinliu/numeric-cat
Browse files Browse the repository at this point in the history
Categorical plot with numeric categories
  • Loading branch information
hanjinliu authored Feb 25, 2024
2 parents 8981d4d + 87c0473 commit 94a6c1d
Show file tree
Hide file tree
Showing 19 changed files with 346 additions and 124 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ testing = [
"pytest",
"pytest-qt",
"pytest-cov",
"imageio",
"qtpy>=2.4.1",
"pyqt5>=5.15.4",
"ipywidgets>=8.0.0",
Expand Down Expand Up @@ -208,7 +209,7 @@ omit = [
]

[tool.coverage.paths]
whitecanvas = ["whitecanvas", "*/whitecanvas/whitecanvas"]
whitecanvas = ["*/whitecanvas/whitecanvas"]
tests = ["tests", "*/whitecanvas/tests"]

[tool.coverage.report]
Expand Down
13 changes: 13 additions & 0 deletions tests/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import contextmanager
import warnings
from cmap import Color

def assert_color_equal(a, b):
Expand All @@ -18,3 +20,14 @@ def assert_color_array_equal(arr, b):
ok = all([a == b for a, b in zip(cols, other)])
if not ok:
raise AssertionError(f"Color {arr} != {b}")

@contextmanager
def filter_warning(backend: str, choices: "str | list[str]"):
if isinstance(choices, str):
choices = [choices]
if backend in choices:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
yield
else:
yield
33 changes: 27 additions & 6 deletions tests/test_canvas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path
import tempfile
import numpy as np
from numpy.testing import assert_allclose

Expand Down Expand Up @@ -77,9 +79,7 @@ def test_grid(backend: str):


def test_grid_nonuniform(backend: str):
cgrid = wc.new_grid(
[2, 1], [2, 1], backend=backend
).link_x().link_y()
cgrid = wc.new_grid([2, 1], [2, 1], backend=backend, size=(100, 100)).link_x().link_y()
c00 = cgrid.add_canvas(0, 0)
c01 = cgrid.add_canvas(0, 1)
c10 = cgrid.add_canvas(1, 0)
Expand All @@ -103,7 +103,7 @@ def test_grid_nonuniform(backend: str):
assert len(c11.layers) == 1

def test_vgrid_hgrid(backend: str):
cgrid = wc.new_col(2, backend=backend).link_x().link_y()
cgrid = wc.new_col(2, backend=backend, size=(100, 100)).link_x().link_y()
c0 = cgrid.add_canvas(0)
c1 = cgrid.add_canvas(1)

Expand All @@ -116,7 +116,7 @@ def test_vgrid_hgrid(backend: str):
assert len(c0.layers) == 1
assert len(c1.layers) == 1

cgrid = wc.new_row(2, backend=backend).link_x().link_y()
cgrid = wc.new_row(2, backend=backend, size=(100, 100)).link_x().link_y()
c0 = cgrid.add_canvas(0)
c1 = cgrid.add_canvas(1)

Expand All @@ -142,7 +142,7 @@ def test_unlink(backend: str):

def test_jointgrid(backend: str):
rng = np.random.default_rng(0)
joint = wc.new_jointgrid(backend=backend).with_hist().with_kde().with_rug()
joint = wc.new_jointgrid(backend=backend, size=(100, 100)).with_hist().with_kde().with_rug()
joint.add_markers(rng.random(100), rng.random(100), color="red")

def test_legend(backend: str):
Expand All @@ -158,3 +158,24 @@ def test_legend(backend: str):
canvas.add_line([3, 4, 5], [4, 5, 4], name="plot+err").with_markers().with_xerr([1, 1, 1])
canvas.add_markers([3, 4, 5], [5, 6, 5], name="markers+err+err").with_stem()
canvas.add_legend(location="bottom_right")

def test_animation():
from whitecanvas.animation import Animation

canvas = new_canvas(backend="matplotlib")
anim = Animation(canvas)
x = np.linspace(0, 2 * np.pi, 100)
line = canvas.add_line(x, np.sin(x + 0), name="line")
for i in anim.iter_range(3):
line.set_data(x, np.sin(x + i * np.pi / 3))
with tempfile.TemporaryDirectory() as tmpdir:
anim.save(Path(tmpdir) / "test.gif")
assert anim.asarray().ndim == 4

def test_multidim():
canvas = new_canvas(backend="matplotlib")
x = np.arange(5)
ys = [x, x ** 2, x ** 3]
canvas.dims.add_line(x, ys)
img = np.zeros((3, 5, 5))
canvas.dims.add_image(img)
48 changes: 38 additions & 10 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from whitecanvas import new_canvas
from ._utils import assert_color_array_equal
from ._utils import assert_color_array_equal, filter_warning
import pytest

def test_cat(backend: str):
Expand Down Expand Up @@ -45,16 +45,17 @@ def test_cat_plots(backend: str, orient: str):
cat_plt = canvas.cat_y(df, "y", "label")
cat_plt.add_stripplot(color="c")
cat_plt.add_swarmplot(color="c")
cat_plt.add_boxplot(color="c")
cat_plt.add_boxplot(color="c").with_outliers(ratio=0.5)
with filter_warning(backend, "plotly"):
cat_plt.add_boxplot(color="c").as_edge_only()
cat_plt.add_violinplot(color="c").with_rug()
cat_plt.add_pointplot(color="c").err_by_se()
cat_plt.add_barplot(color="c")
if backend == "plotly":
# NOTE: plotly does not support multiple colors for rugplot
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cat_plt.add_rugplot(color="c").scale_by_density()
else:
cat_plt.add_violinplot(color="c").with_outliers(ratio=0.5)
cat_plt.add_violinplot(color="c").with_box()
cat_plt.add_violinplot(color="c").as_edge_only().with_strip()
cat_plt.add_violinplot(color="c").with_swarm()
cat_plt.add_pointplot(color="c").err_by_se().err_by_sd().err_by_quantile().est_by_mean().est_by_median()
cat_plt.add_barplot(color="c").err_by_se().err_by_sd().err_by_quantile().est_by_mean().est_by_median()
with filter_warning(backend, "plotly"):
cat_plt.add_rugplot(color="c").scale_by_density()

def test_markers(backend: str):
Expand Down Expand Up @@ -139,3 +140,30 @@ def test_catx_legend(backend: str):
_c.add_pointplot(color="label").err_by_se()
_c.add_barplot(color="label")
canvas.add_legend()

@pytest.mark.parametrize("orient", ["v", "h"])
def test_numeric_axis(backend: str, orient: str):
canvas = new_canvas(backend=backend)
df = {
"y": np.arange(30),
"label": np.repeat([2, 5, 6], 10),
"c": ["P", "Q"] * 15,
}
if orient == "v":
cat_plt = canvas.cat_x(df, "label", "y", numeric_axis=True)
else:
cat_plt = canvas.cat_y(df, "y", "label", numeric_axis=True)
cat_plt.add_stripplot(color="c")
cat_plt.add_swarmplot(color="c")
cat_plt.add_boxplot(color="c").with_outliers(ratio=0.5)
with filter_warning(backend, "plotly"):
cat_plt.add_boxplot(color="c").as_edge_only()
cat_plt.add_violinplot(color="c").with_rug()
cat_plt.add_violinplot(color="c").with_outliers(ratio=0.5)
cat_plt.add_violinplot(color="c").with_box()
cat_plt.add_violinplot(color="c").as_edge_only().with_strip()
cat_plt.add_violinplot(color="c").with_swarm()
cat_plt.add_pointplot(color="c").err_by_se().err_by_sd().err_by_quantile().est_by_mean().est_by_median()
cat_plt.add_barplot(color="c").err_by_se().err_by_sd().err_by_quantile().est_by_mean().est_by_median()
with filter_warning(backend, "plotly"):
cat_plt.add_rugplot(color="c").scale_by_density()
4 changes: 2 additions & 2 deletions whitecanvas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.4"
__version__ = "0.2.5"

from whitecanvas import theme
from whitecanvas.canvas import link_axes
Expand All @@ -23,7 +23,7 @@
]


def __getattr__(name: str):
def __getattr__(name: str): # pragma: no cover
import warnings

if name in ("grid", "grid_nonuniform"):
Expand Down
10 changes: 8 additions & 2 deletions whitecanvas/backend/bokeh/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def _plt_set_visible(self, visible: bool):
##### HasFace protocol #####

def _plt_get_face_color(self) -> NDArray[np.float32]:
return np.stack([arr_color(c) for c in self._data.data["face_color"]], axis=0)
colors = [arr_color(c) for c in self._data.data["face_color"]]
if len(colors) == 0:
return np.zeros((0, 4), dtype=np.float32)
return np.stack(colors, axis=0)

def _plt_set_face_color(self, color: NDArray[np.float32]):
if color.ndim == 1:
Expand Down Expand Up @@ -75,7 +78,10 @@ def _plt_set_edge_style(self, style: LineStyle | list[LineStyle]):
self._data.data["style"] = val

def _plt_get_edge_color(self) -> NDArray[np.float32]:
return np.stack([arr_color(c) for c in self._data.data["edge_color"]], axis=0)
colors = [arr_color(c) for c in self._data.data["edge_color"]]
if len(colors) == 0:
return np.zeros((0, 4), dtype=np.float32)
return np.stack(colors, axis=0)

def _plt_set_edge_color(self, color: NDArray[np.float32]):
if color.ndim == 1:
Expand Down
20 changes: 18 additions & 2 deletions whitecanvas/backend/bokeh/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,24 @@ def __init__(self, xdata, ydata):
def _plt_get_data(self):
return self._data.data["x"], self._data.data["y"]

def _plt_set_data(self, xdata, ydata):
self._data.data = {"x": xdata, "y": ydata}
def _plt_set_data(self, xdata: NDArray[np.number], ydata: NDArray[np.number]):
ndata = self._data.data["x"].size
cur_data = self._data.data.copy()
cur_data["x"] = xdata
cur_data["y"] = ydata
cols_to_update = [
"sizes", "face_color", "edge_color", "width", "pattern", "style",
"hovertexts"
] # fmt: skip
if xdata.size < ndata:
for key in cols_to_update:
cur_data[key] = cur_data[key][: xdata.size]
elif xdata.size > ndata:
for key in cols_to_update:
cur_data[key] = np.concatenate(
[cur_data[key], np.full(xdata.size - ndata, cur_data[key][-1])]
)
self._data.data = cur_data

def _plt_get_symbol(self) -> Symbol:
sym = self._model.marker
Expand Down
6 changes: 5 additions & 1 deletion whitecanvas/backend/matplotlib/canvas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from timeit import default_timer
from typing import Callable

Expand Down Expand Up @@ -408,7 +409,10 @@ def _plt_screenshot(self):
def _plt_set_figsize(self, width: int, height: int):
dpi = self._fig.get_dpi()
self._fig.set_size_inches(width / dpi, height / dpi)
self._fig.tight_layout()
with warnings.catch_warnings():
# if the size is small, tight_layout may raise a warning
warnings.simplefilter("ignore")
self._fig.tight_layout()

def _plt_set_spacings(self, wspace: float, hspace: float):
dpi = self._fig.get_dpi()
Expand Down
12 changes: 8 additions & 4 deletions whitecanvas/backend/pyqtgraph/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def _get_brush(self) -> list[QtGui.QBrush]:
return brushes

def _plt_get_face_color(self) -> NDArray[np.float32]:
brushes = self._get_brush()
if len(brushes) == 0:
return np.zeros((0, 4), dtype=np.float32)
return np.array(
[brush.color().getRgbF() for brush in self._get_brush()], dtype=np.float32
[brush.color().getRgbF() for brush in brushes], dtype=np.float32
)

def _plt_set_face_color(self, color: NDArray[np.float32]):
Expand Down Expand Up @@ -118,9 +121,10 @@ def _get_pen(self) -> list[QtGui.QPen]:
return pens

def _plt_get_edge_color(self) -> NDArray[np.float32]:
return np.array(
[pen.color().getRgbF() for pen in self._get_pen()], dtype=np.float32
)
pens = self._get_pen()
if len(pens) == 0:
return np.zeros((0, 4), dtype=np.float32)
return np.array([pen.color().getRgbF() for pen in pens], dtype=np.float32)

def _plt_set_edge_color(self, color: NDArray[np.float32]):
color = as_color_array(color, len(self.data["x"]))
Expand Down
10 changes: 10 additions & 0 deletions whitecanvas/backend/vispy/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _plt_set_data(self, xdata, ydata):

##### HasSymbol protocol #####
def _plt_get_symbol(self) -> Symbol:
if self._data["a_position"].shape[0] == 0:
return Symbol.CIRCLE
sym = self.symbol[0]
if sym == "clobber":
return Symbol.TRIANGLE_LEFT
Expand All @@ -76,6 +78,8 @@ def _plt_get_symbol_size(self) -> NDArray[np.floating]:
def _plt_set_symbol_size(self, size: float | NDArray[np.floating]):
if is_real_number(size):
size = np.full(self._plt_get_ndata(), size)
if size.shape[0] == 0:
return
self.set_data(
pos=self._data["a_position"],
size=size,
Expand All @@ -91,6 +95,8 @@ def _plt_get_face_color(self) -> NDArray[np.float32]:

def _plt_set_face_color(self, color: NDArray[np.float32]):
color = as_color_array(color, self._plt_get_ndata())
if color.shape[0] == 0:
return
self.set_data(
pos=self._data["a_position"],
size=self._plt_get_symbol_size(),
Expand All @@ -108,6 +114,8 @@ def _plt_get_edge_color(self) -> NDArray[np.float32]:

def _plt_set_edge_color(self, color: NDArray[np.float32]):
color = as_color_array(color, self._plt_get_ndata())
if color.shape[0] == 0:
return
self.set_data(
pos=self._data["a_position"],
size=self._plt_get_symbol_size(),
Expand All @@ -123,6 +131,8 @@ def _plt_get_edge_width(self) -> NDArray[np.floating]:
def _plt_set_edge_width(self, width: float):
if isinstance(width, float):
width = np.full(self._plt_get_ndata(), width)
if width.shape[0] == 0:
return
self.set_data(
pos=self._data["a_position"],
size=self._plt_get_symbol_size(),
Expand Down
Loading

0 comments on commit 94a6c1d

Please sign in to comment.