Skip to content

Commit

Permalink
more coverage, bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Feb 27, 2024
1 parent 22a8b73 commit 3808f3a
Show file tree
Hide file tree
Showing 16 changed files with 152 additions and 127 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ testing = [
"plotly>=5.3.1",
"vispy>=0.14.1",
"bokeh>=3.3.1",
"pandas>=1.3.3",
"polars>=0.20.10",
]

docs = [
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import matplotlib.pyplot as plt
from whitecanvas.canvas import Canvas
import pytest

ALL_BACKENDS = ["mock", "matplotlib", "pyqtgraph", "plotly", "bokeh", "vispy"]
Expand All @@ -19,3 +20,6 @@ def backend(request: pytest.FixtureRequest):
# TODO: how to skip tests if failed in mock backend?
if request.param == "matplotlib":
plt.close("all")
elif request.param == "pyqtgraph":
if Canvas._CURRENT_INSTANCE is not None:
Canvas._CURRENT_INSTANCE.native.update()
49 changes: 48 additions & 1 deletion tests/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
import whitecanvas as wc
from whitecanvas import new_canvas
from whitecanvas import new_canvas, wrap_canvas

from ._utils import assert_color_equal

Expand All @@ -20,6 +20,10 @@ def test_namespaces(backend: str):
assert canvas.title.size == 20
canvas.title.family = "Arial"
assert canvas.title.family == "Arial"
canvas.title.visible = False
assert not canvas.title.visible
canvas.title.visible = True
assert canvas.title.visible

canvas.x.label.text = "X-Label-0"
assert canvas.x.label.text == "X-Label-0"
Expand All @@ -29,6 +33,10 @@ def test_namespaces(backend: str):
assert canvas.x.label.size == 20
canvas.x.label.family = "Arial"
assert canvas.x.label.family == "Arial"
canvas.x.label.visible = False
assert not canvas.x.label.visible
canvas.x.label.visible = True
assert canvas.x.label.visible

canvas.y.label.text = "Y-Label-0"
assert canvas.y.label.text == "Y-Label-0"
Expand All @@ -39,6 +47,24 @@ def test_namespaces(backend: str):
canvas.y.label.family = "Arial"
assert canvas.y.label.family == "Arial"

if backend != "pyqtgraph": # not implemented in pyqtgraph
canvas.x.ticks.rotation = 45
assert canvas.x.ticks.rotation == pytest.approx(45)

canvas.x.ticks.visible = False
assert not canvas.x.ticks.visible
canvas.x.ticks.visible = True
assert canvas.x.ticks.visible
canvas.x.ticks.set_labels([0, 1, 2], ["a", "b", "c"])

# get tick positions and labels are still hard to implement
if backend in ("mock", "pyqtgraph", "plotly"):
assert_allclose(canvas.x.ticks.pos, [0, 1, 2])
assert canvas.x.ticks.labels == ["a", "b", "c"]
canvas.x.ticks.reset_labels()

canvas.x.set_gridlines()

def test_namespace_pointing_at_different_objects():
c0 = new_canvas(backend="matplotlib")
c1 = new_canvas(backend="matplotlib")
Expand All @@ -53,12 +79,29 @@ def test_namespace_pointing_at_different_objects():
assert_color_equal(c0.x.color, "red")
assert_color_equal(c1.x.color, "blue")

def test_update_methods():
canvas = new_canvas(backend="mock")
canvas.update_axes(visible=False)
canvas.update_labels(title="Title", x="X", y="Y")
canvas.update_font(size=24, color="red", family="Arial")

def test_native_and_wrapping(backend: str):
canvas = new_canvas(backend=backend)
assert canvas.native is not None
if backend == "mock":
return
new = wrap_canvas(canvas.native)
repr(new._get_backend())
assert new._get_backend() == canvas._get_backend()

def test_grid(backend: str):
cgrid = wc.new_grid(2, 2, backend=backend).link_x().link_y()
c00 = cgrid.add_canvas(0, 0)
c01 = cgrid.add_canvas(0, 1)
c10 = cgrid.add_canvas(1, 0)
c11 = cgrid.add_canvas(1, 1)
assert cgrid[0, 0] is not cgrid[1, 0]
assert cgrid[0, 1] is not cgrid[1, 1]

c00.add_line([0, 1, 2], [0, 1, 2])
c01.add_hist([0, 1, 2, 3, 4, 3, 2, 1])
Expand Down Expand Up @@ -106,6 +149,7 @@ def test_vgrid_hgrid(backend: str):
cgrid = wc.new_col(2, backend=backend, size=(100, 100)).link_x().link_y()
c0 = cgrid.add_canvas(0)
c1 = cgrid.add_canvas(1)
assert cgrid[0] is not cgrid[1]

c0.add_line([0, 1, 2], [0, 1, 2])
c1.add_hist([0, 1, 2, 3, 4, 3, 2, 1])
Expand All @@ -119,6 +163,7 @@ def test_vgrid_hgrid(backend: str):
cgrid = wc.new_row(2, backend=backend, size=(100, 100)).link_x().link_y()
c0 = cgrid.add_canvas(0)
c1 = cgrid.add_canvas(1)
assert cgrid[0] is not cgrid[1]

c0.add_line([0, 1, 2], [0, 1, 2])
c1.add_hist([0, 1, 2, 3, 4, 3, 2, 1])
Expand All @@ -144,6 +189,8 @@ def test_jointgrid(backend: str):
rng = np.random.default_rng(0)
joint = wc.new_jointgrid(backend=backend, size=(100, 100)).with_hist().with_kde().with_rug()
joint.add_markers(rng.random(100), rng.random(100), color="red")
if backend != "vispy":
joint.add_legend()

def test_legend(backend: str):
if backend == "vispy":
Expand Down
39 changes: 38 additions & 1 deletion tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
import numpy as np

from whitecanvas import new_canvas
from whitecanvas.core import new_jointgrid
from ._utils import assert_color_array_equal, filter_warning
import pytest

Expand Down Expand Up @@ -33,6 +33,9 @@ def test_cat(backend: str):
hist = cplt.along_y().add_hist(bins=6, color="label")
cplt.along_x().add_kde()
kde = cplt.along_x().add_kde(color="label")
cplt.along_x().add_rug()
with filter_warning(backend, "plotly"):
cplt.along_x().add_rug(color="label")
hist.update_color("black")
kde.update_color("black")
hist.update_width(1.5)
Expand Down Expand Up @@ -105,6 +108,7 @@ def test_markers(backend: str):
assert_color_array_equal(out._base_layer.face.color, "black")

out = _c.add_markers(color="transparent").update_edge_colormap("size")
_c.mean_for_each("label0").add_markers(symbol="D")

def test_heatmap(backend: str):
canvas = new_canvas(backend=backend)
Expand Down Expand Up @@ -245,3 +249,36 @@ def test_stack(backend: str):
cat_plt = canvas.cat_x(df, "label", "y", numeric_axis=True)
cat_plt.stack("c").add_bars(color="c")
cat_plt.stack("c").add_area(hatch="c")

def test_joint_cat(backend: str):
joint = new_jointgrid(backend=backend, loc=(0, 0), size=(180, 180))
df = {
"x": np.arange(30),
"y": np.arange(30),
"c": np.repeat(["A", "B", "C"], 10),
}
joint.cat(df, "x", "y").add_hist2d()
joint.cat(df, "x", "y").add_markers(color="c")

def test_pandas_and_polars():
import pandas as pd
import polars as pl

canvas = new_canvas(backend="mock")
_dict = {
"y": np.arange(30),
"label": np.repeat(["A", "B", "C"], 10),
"c": ["P", "Q"] * 15,
}
df_pd = pd.DataFrame(_dict)
df_pl = pl.DataFrame(_dict)

cat_pd = canvas.cat_x(df_pd, "label", "y")
cat_pl = canvas.cat_x(df_pl, "label", "y")
cat_pd.add_swarmplot(color="c")
cat_pd.mean().add_markers(color="c")
cat_pd.first().add_markers(color="c")

cat_pl.add_swarmplot(color="c")
cat_pl.mean().add_markers(color="c")
cat_pl.first().add_markers(color="c")
18 changes: 17 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_line(backend: str):
_test_visibility(layer)
layer.with_hover_template("x={x:.2f}, y={y:.2f}")
canvas.add_cdf(np.sqrt(np.arange(20)))
canvas.autoscale()

def test_markers(backend: str):
canvas = new_canvas(backend=backend)
Expand Down Expand Up @@ -70,6 +71,7 @@ def test_markers(backend: str):
layer.symbol = sym
assert layer.symbol == sym
_test_visibility(layer)
canvas.autoscale()

def test_bars(backend: str):
canvas = new_canvas(backend=backend)
Expand Down Expand Up @@ -101,6 +103,7 @@ def test_bars(backend: str):
layer.bar_width = 0.5
assert layer.bar_width == 0.5
_test_visibility(layer)
canvas.autoscale()

def test_infcurve(backend: str):
canvas = new_canvas(backend=backend)
Expand Down Expand Up @@ -129,6 +132,9 @@ def test_infcurve(backend: str):
canvas.x.lim = (-4, 4)
layer.angle = 90
canvas.x.lim = (-4, 4)
canvas.autoscale()
canvas.add_hline(1)
canvas.add_vline(1)

def test_band(backend: str):
canvas = new_canvas(backend=backend)
Expand All @@ -155,11 +161,13 @@ def test_band(backend: str):
layer.edge.width = 2
assert layer.edge.width == 2
_test_visibility(layer)
canvas.autoscale()

def test_image(backend: str):
canvas = new_canvas(backend=backend)

layer = canvas.add_image(np.random.random((10, 10)) * 2)
rng = np.random.default_rng(0)
layer = canvas.add_image(rng.random((10, 10)) * 2)

layer.cmap = "viridis"
assert layer.cmap == "viridis"
Expand All @@ -173,6 +181,8 @@ def test_image(backend: str):
layer.origin = "edge"
layer.shift = (-1, -1)
layer.fit_to(2, 2, 5, 5)
canvas.autoscale()
canvas.add_heatmap(rng.random((10, 10)))

def test_errorbars(backend: str):
canvas = new_canvas(backend=backend)
Expand All @@ -196,6 +206,7 @@ def test_errorbars(backend: str):
layer.width = 3
assert all(w == 3 for w in layer.width)
_test_visibility(layer)
canvas.autoscale()

def test_texts(backend: str):
canvas = new_canvas(backend=backend)
Expand Down Expand Up @@ -246,6 +257,8 @@ def test_texts(backend: str):
assert layer.rotation == 10
layer.color = "red"
layer.family = "Arial"
canvas.autoscale()
canvas.add_text(0, 0, "Hello, World!")


def test_with_text(backend: str):
Expand All @@ -270,6 +283,7 @@ def test_with_text(backend: str):
canvas.add_bars(x, y).with_yerr(y/4).with_text([f"{i}" for i in range(10)])
canvas.add_bars(x, y).with_xerr(y/4).with_text("{x:1f}, {y:1f},")
canvas.add_bars(x, y).with_yerr(y/4).with_text("{x:1f}, {y:1f}")
canvas.autoscale()

def test_rug(backend: str):
canvas = new_canvas(backend=backend)
Expand All @@ -289,6 +303,7 @@ def test_rug(backend: str):
layer.high = 1.5
assert np.allclose(layer.low, 0.5)
assert np.allclose(layer.high, 1.5)
canvas.autoscale(xpad=(0.01, 0.02), ypad=(0.01, 0.02))


def test_spans(backend: str):
Expand All @@ -308,3 +323,4 @@ def test_spans(backend: str):

if backend != "vispy":
canvas.add_legend()
canvas.autoscale(xpad=0.01, ypad=0.01)
5 changes: 5 additions & 0 deletions whitecanvas/backend/_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def __init__(self, name: Backend | str | None = None) -> None:
def __repr__(self) -> str:
return f"<Backend {self._name!r} (app: {self._app!r})>"

def __eq__(self, other) -> bool:
if not isinstance(other, Backend):
return False
return self._name == other._name and self._app == other._app

@property
def name(self) -> str:
"""Name of the backend."""
Expand Down
4 changes: 2 additions & 2 deletions whitecanvas/backend/bokeh/_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ def _plt_get_axis(self) -> BokehAxis:
raise NotImplementedError

def _plt_get_tick_labels(self) -> tuple[list[float], list[str]]:
return tuple(zip(*self._plt_get_axis().ticker))
return tuple(zip(*self._plt_get_axis().ticker.ticks))

def _plt_override_labels(self, pos: list[float], labels: list[str]):
self._plt_get_axis().ticker = pos
self._plt_get_axis().major_label_overrides = dict(zip(pos, labels))

def _plt_reset_override(self):
self._plt_get_axis().ticker = []
self._plt_get_axis().major_label_overrides = None
self._plt_get_axis().major_label_overrides = {}

def _plt_get_visible(self) -> bool:
return self._visible
Expand Down
10 changes: 6 additions & 4 deletions whitecanvas/backend/mock/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def __init__(self):
self._xaxis = Axis()
self._yaxis = Axis()
self._title = Title()
self._xlabel = Label()
self._ylabel = Label()
self._xticks = Ticks()
self._yticks = Ticks()
self._aspect_ratio = None
Expand All @@ -40,10 +38,10 @@ def _plt_get_yaxis(self):
return self._yaxis

def _plt_get_xlabel(self):
return self._xlabel
return self._xaxis._label

def _plt_get_ylabel(self):
return self._ylabel
return self._yaxis._label

def _plt_get_xticks(self):
return self._xticks
Expand Down Expand Up @@ -191,6 +189,7 @@ def __init__(self):
self._limits = (0, 1)
self._flipped = False
self._color = np.array([0, 0, 0, 1], dtype=np.float32)
self._label = Label()

def _plt_get_visible(self) -> bool:
return self._visible
Expand All @@ -216,6 +215,9 @@ def _plt_get_limits(self) -> tuple[float, float]:
def _plt_set_limits(self, limits: tuple[float, float]):
self._limits = limits

def _plt_set_grid_state(self, *args, **kwargs):
pass


class Ticks(_SupportsText):
def __init__(self):
Expand Down
9 changes: 9 additions & 0 deletions whitecanvas/backend/plotly/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def asdict(self):
out["secondary_y"] = True
return out

def asdictn(self):
out = {}
if self.row > 1 or self.col > 1:
out["rows"] = self.row
out["cols"] = self.col
if self.secondary_y:
out["secondary_y"] = True
return out


_LINE_STYLES = {
"solid": LineStyle.SOLID,
Expand Down
2 changes: 1 addition & 1 deletion whitecanvas/backend/plotly/_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _plt_get_axis(self):
def _plt_get_tick_labels(self) -> tuple[list[float], list[str]]:
return (
self._plt_get_axis().tickvals,
self._plt_get_axis().ticktext,
list(self._plt_get_axis().ticktext),
)

def _plt_override_labels(self, pos: list[float], labels: list[str]):
Expand Down
2 changes: 1 addition & 1 deletion whitecanvas/backend/plotly/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _plt_make_legend(
sample.name = label
plotly_traces.append(sample)
legend_kwargs = _LEGEND_KWARGS[anchor]
self._fig.add_traces(plotly_traces, **self._loc.asdict())
self._fig.add_traces(plotly_traces, **self._loc.asdictn())
self._fig.update_layout(showlegend=True, legend=legend_kwargs, overwrite=True)

def _repr_mimebundle_(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit 3808f3a

Please sign in to comment.