Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement joint plot #17

Merged
merged 6 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/canvas/grid.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ The signature of the method differs between 1D and 2D grid.

``` python
#!name: canvas_grid_vertical
from whitecanvas import vgrid
from whitecanvas import new_col

grid = vgrid(3, backend="matplotlib")
grid = new_col(3, backend="matplotlib")

c0 = grid.add_canvas(0)
c0.add_text(0, 0, "Canvas 0")
Expand All @@ -27,9 +27,9 @@ grid.show()

``` python
#!name: canvas_grid_horizontal
from whitecanvas import hgrid
from whitecanvas import new_row

grid = hgrid(3, backend="matplotlib")
grid = new_row(3, backend="matplotlib")

c0 = grid.add_canvas(0)
c0.add_text(0, 0, "Canvas 0")
Expand All @@ -44,9 +44,9 @@ grid.show()

``` python
#!name: canvas_grid_2d
from whitecanvas import grid as grid2d
from whitecanvas import new_grid

grid = grid2d(2, 2, backend="matplotlib")
grid = new_grid(2, 2, backend="matplotlib")

for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
c = grid.add_canvas(i, j)
Expand Down
4 changes: 2 additions & 2 deletions docs/categorical/cat_num.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ canvas.show()
``` python
#!name: categorical_axis_many_plots
#!width: 700
from whitecanvas import hgrid
from whitecanvas import new_row

canvas = hgrid(ncols=3, size=(1600, 600), backend="matplotlib")
canvas = new_row(3, size=(1600, 600), backend="matplotlib")

c0 = canvas.add_canvas(0)
c0.cat_x(df, x="category", y="observation").add_boxplot()
Expand Down
4 changes: 2 additions & 2 deletions docs/layers/markers.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ colors the markers by the density of the points using kernel density estimation.
#!name: markers_layer_color_by_density
#!width: 500
import numpy as np
from whitecanvas import hgrid
from whitecanvas import new_row

rng = np.random.default_rng(999)
x = rng.normal(size=1000)
y = rng.normal(size=1000)

grid = hgrid(2, backend="matplotlib")
grid = new_row(2, backend="matplotlib")
(
grid
.add_canvas(0)
Expand Down
23 changes: 23 additions & 0 deletions examples/joint_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pandas as pd
from whitecanvas import new_jointgrid

def main():
url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv"
df = pd.read_csv(url).dropna()

joint = (
new_jointgrid("matplotlib:qt")
.with_hist_x(shape="step") # show histogram as the x-marginal distribution
.with_kde_y(width=2) # show kde as the y-marginal distribution
.with_rug(width=1) # show rug plot for both marginal distributions
)

layer = (
joint.cat(df, x="bill_length_mm", y="flipper_length_mm")
.add_markers(color="species")
)

joint.show(block=True)

if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/show_image_on_pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import numpy as np
from whitecanvas import hgrid
from whitecanvas import new_row

def make_images() -> np.ndarray:
# prepare sample image data
Expand All @@ -23,7 +23,7 @@ def main():
images = make_images()
means = np.mean(images, axis=(1, 2)) # calculate mean intensity to plot

g = hgrid(2, backend="matplotlib:qt")
g = new_row(2, backend="matplotlib:qt")

# markers to be picked
markers = (
Expand Down
26 changes: 22 additions & 4 deletions tests/test_canvas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from numpy.testing import assert_allclose

import pytest
import whitecanvas as wc
from whitecanvas import new_canvas

Expand Down Expand Up @@ -50,7 +52,7 @@ def test_namespace_pointing_at_different_objects():
assert_color_equal(c1.x.color, "blue")

def test_grid(backend: str):
cgrid = wc.grid(2, 2, backend=backend).link_x().link_y()
cgrid = wc.new_grid(2, 2, backend=backend).link_x().link_y()
c00 = cgrid.add_canvas(0, 0)
c01 = cgrid.add_canvas(0, 1)
c10 = cgrid.add_canvas(1, 0)
Expand All @@ -75,7 +77,7 @@ def test_grid(backend: str):


def test_grid_nonuniform(backend: str):
cgrid = wc.grid_nonuniform(
cgrid = wc.new_grid(
[2, 1], [2, 1], backend=backend
).link_x().link_y()
c00 = cgrid.add_canvas(0, 0)
Expand All @@ -101,7 +103,7 @@ def test_grid_nonuniform(backend: str):
assert len(c11.layers) == 1

def test_vgrid_hgrid(backend: str):
cgrid = wc.vgrid(2, backend=backend).link_x().link_y()
cgrid = wc.new_col(2, backend=backend).link_x().link_y()
c0 = cgrid.add_canvas(0)
c1 = cgrid.add_canvas(1)

Expand All @@ -114,7 +116,7 @@ def test_vgrid_hgrid(backend: str):
assert len(c0.layers) == 1
assert len(c1.layers) == 1

cgrid = wc.hgrid(2, backend=backend).link_x().link_y()
cgrid = wc.new_row(2, backend=backend).link_x().link_y()
c0 = cgrid.add_canvas(0)
c1 = cgrid.add_canvas(1)

Expand All @@ -126,3 +128,19 @@ def test_vgrid_hgrid(backend: str):

assert len(c0.layers) == 1
assert len(c1.layers) == 1

def test_unlink(backend: str):
grid = wc.new_row(2, backend=backend).fill()
linker = wc.link_axes(grid[0].x, grid[1].x)
grid[0].x.lim = (10, 11)
assert grid[0].x.lim == pytest.approx((10, 11))
assert grid[1].x.lim == pytest.approx((10, 11))
linker.unlink_all()
grid[0].x.lim = (20, 21)
assert grid[0].x.lim == pytest.approx((20, 21))
assert grid[1].x.lim == pytest.approx((10, 11))

def test_jointgrid(backend: str):
rng = np.random.default_rng(0)
joint = wc.new_jointgrid(backend=backend).with_hist().with_kde().with_rug()
joint.add_markers(rng.random(100), rng.random(100), color="red")
53 changes: 38 additions & 15 deletions whitecanvas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,51 @@
__version__ = "0.2.2.dev0"

from whitecanvas import theme
from whitecanvas.canvas import Canvas, CanvasGrid
from whitecanvas.canvas import link_axes
from whitecanvas.core import (
grid,
grid_nonuniform,
hgrid,
hgrid_nonuniform,
new_canvas,
vgrid,
vgrid_nonuniform,
new_col,
new_grid,
new_jointgrid,
new_row,
wrap_canvas,
)

__all__ = [
"Canvas",
"CanvasGrid",
"grid",
"grid_nonuniform",
"hgrid",
"hgrid_nonuniform",
"vgrid",
"vgrid_nonuniform",
"new_canvas",
"new_col",
"new_grid",
"new_row",
"new_jointgrid",
"wrap_canvas",
"theme",
"link_axes",
]


def __getattr__(name: str):
import warnings

if name in ("grid", "grid_nonuniform"):
warnings.warn(
f"{name!r} is deprecated. Use `new_grid` instead",
DeprecationWarning,
stacklevel=2,
)
return new_grid
elif name in ("vgrid", "vgrid_nonuniform"):
warnings.warn(
f"{name!r} is deprecated. Use `new_col` instead",
DeprecationWarning,
stacklevel=2,
)
return new_col
elif name in ("hgrid", "hgrid_nonuniform"):
warnings.warn(
f"{name!r} is deprecated. Use `new_row` instead",
DeprecationWarning,
stacklevel=2,
)
return new_row
else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
45 changes: 29 additions & 16 deletions whitecanvas/backend/bokeh/canvas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, Iterator

import numpy as np
from bokeh import events as bk_events
Expand All @@ -24,7 +24,7 @@
from whitecanvas.utils.normalize import arr_color, hex_color


def _prep_plot(width=400, height=300):
def _prep_plot(width=400, height=300) -> bk_plotting.figure:
plot = bk_plotting.figure(
width=width,
height=height,
Expand Down Expand Up @@ -246,23 +246,30 @@ def _translate_modifiers(mod: bk_events.KeyModifiers | None) -> tuple[Modifier,

@protocols.check_protocol(protocols.CanvasGridProtocol)
class CanvasGrid:
def __init__(self, heights: list[int], widths: list[int], app: str = "default"):
nr, nc = len(heights), len(widths)
def __init__(self, heights: list[float], widths: list[float], app: str = "default"):
hsum = sum(heights)
wsum = sum(widths)
children = []
for _ in range(nr):
for h in heights:
row = []
for _ in range(nc):
row.append(_prep_plot())
for w in widths:
p = _prep_plot(width=int(w / wsum * 600), height=int(h / hsum * 600))
p.visible = False
row.append(p)
children.append(row)
self._grid_plot: bk_layouts.GridPlot = bk_layouts.gridplot(
children, sizing_mode="fixed"
)
self._shape = (nr, nc)
self._widths = widths
self._heights = heights
self._width_total = wsum
self._height_total = hsum
self._app = app

def _plt_add_canvas(self, row: int, col: int, rowspan: int, colspan: int) -> Canvas:
for plot, r0, c0 in self._grid_plot.children:
for r0, c0, plot in self._iter_bokeh_subplots():
if r0 == row and c0 == col:
plot.visible = True
return Canvas(plot)
raise ValueError(f"Canvas at ({row}, {col}) not found")

Expand All @@ -280,12 +287,8 @@ def _plt_get_background_color(self):

def _plt_set_background_color(self, color):
color = hex_color(color)
for r in range(self._shape[0]):
for c in range(self._shape[1]):
child = self._grid_plot.children[r][c]
if not hasattr(child, "background_fill_color"):
continue
child.background_fill_color = color
for _, _, child in self._iter_bokeh_subplots():
child.background_fill_color = color

def _plt_screenshot(self):
import io
Expand All @@ -296,10 +299,20 @@ def _plt_screenshot(self):
export_png(self._grid_plot, filename=buff)
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = self._grid_plot.plot_width, self._grid_plot.plot_height
w, h = self._grid_plot.width, self._grid_plot.height
img = data.reshape((int(h), int(w), -1))
return img

def _plt_set_figsize(self, width: int, height: int):
for r, c, child in self._iter_bokeh_subplots():
child.height = int(self._heights[r] / self._height_total * width)
child.width = int(self._widths[c] / self._width_total * height)
self._grid_plot.width = width
self._grid_plot.height = height

def _plt_set_spacings(self, wspace: float, hspace: float):
self._grid_plot.spacing = (int(wspace), int(hspace))

def _iter_bokeh_subplots(self) -> Iterator[tuple[int, int, bk_plotting.figure]]:
for child, r, c in self._grid_plot.children:
yield r, c, child
21 changes: 15 additions & 6 deletions whitecanvas/backend/matplotlib/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from matplotlib.lines import Line2D

from whitecanvas import protocols
from whitecanvas.backend.matplotlib._base import MplLayer
from whitecanvas.backend.matplotlib._base import MplLayer, MplMouseEventsMixin
from whitecanvas.backend.matplotlib._labels import (
Title,
XAxis,
Expand Down Expand Up @@ -59,16 +59,18 @@ def __init__(self, ax: plt.Axes | None = None):
return
fig.canvas.mpl_connect("motion_notify_event", self._on_hover)
fig.canvas.mpl_connect("figure_leave_event", self._hide_tooltip)
self._hoverable_artists: list[Artist] = []
self._hoverable_artists: list[MplMouseEventsMixin] = []
self._last_hover = -1.0

def _on_hover(self, event):
def _on_hover(self, event: mplMouseEvent):
if default_timer() - self._last_hover < 0.1:
# avoid calling the tooltip too often
return
if event.button is not None:
return self._hide_tooltip()
self._last_hover = default_timer()
if event.inaxes is not self._axes:
return
return self._hide_tooltip()
for layer in reversed(self._hoverable_artists):
text = layer._on_hover(event)
if text:
Expand All @@ -77,7 +79,7 @@ def _on_hover(self, event):
return
self._hide_tooltip()

def _set_tooltip(self, pos, text: str):
def _set_tooltip(self, pos: tuple[float, float], text: str):
# determine in which direction to show the tooltip
x, y = pos
x0, x1 = self._axes.get_xlim()
Expand Down Expand Up @@ -290,7 +292,7 @@ def _plt_connect_ylim_changed(

@protocols.check_protocol(protocols.CanvasGridProtocol)
class CanvasGrid:
def __init__(self, heights: list[int], widths: list[int], app: str = "default"):
def __init__(self, heights: list[float], widths: list[float], app: str = "default"):
nr, nc = len(heights), len(widths)
self._gridspec = plt.GridSpec(
nr, nc, height_ratios=heights, width_ratios=widths
Expand Down Expand Up @@ -347,3 +349,10 @@ def _plt_set_figsize(self, width: int, height: int):
dpi = self._fig.get_dpi()
self._fig.set_size_inches(width / dpi, height / dpi)
self._fig.tight_layout()

def _plt_set_spacings(self, wspace: float, hspace: float):
dpi = self._fig.get_dpi()
nh, nw = self._gridspec.get_geometry()
w_avg = self._fig.get_figwidth() / nw * dpi
h_avg = self._fig.get_figheight() / nh * dpi
self._gridspec.update(hspace=hspace / h_avg, wspace=wspace / w_avg)
Loading
Loading