From 10f471f1210a68ecf0f24638f7ad0b8448e1bb51 Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Thu, 15 Feb 2024 15:07:24 +0900 Subject: [PATCH 1/6] fix bokeh subplots --- whitecanvas/backend/bokeh/canvas.py | 40 +++++++++++++++--------- whitecanvas/canvas/_grid.py | 4 ++- whitecanvas/layers/tabular/_dataframe.py | 36 ++++----------------- 3 files changed, 34 insertions(+), 46 deletions(-) diff --git a/whitecanvas/backend/bokeh/canvas.py b/whitecanvas/backend/bokeh/canvas.py index 145a7507..52d0ec4c 100644 --- a/whitecanvas/backend/bokeh/canvas.py +++ b/whitecanvas/backend/bokeh/canvas.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Iterator import numpy as np from bokeh import events as bk_events @@ -24,7 +24,7 @@ from whitecanvas.utils.normalize import arr_color, hex_color -def _prep_plot(width=400, height=300): +def _prep_plot(width=400, height=300) -> bk_plotting.figure: plot = bk_plotting.figure( width=width, height=height, @@ -247,22 +247,29 @@ def _translate_modifiers(mod: bk_events.KeyModifiers | None) -> tuple[Modifier, @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: def __init__(self, heights: list[int], widths: list[int], app: str = "default"): - nr, nc = len(heights), len(widths) + hsum = sum(heights) + wsum = sum(widths) children = [] - for _ in range(nr): + for h in heights: row = [] - for _ in range(nc): - row.append(_prep_plot()) + for w in widths: + p = _prep_plot(width=int(w / wsum * 600), height=int(h / hsum * 600)) + p.visible = False + row.append(p) children.append(row) self._grid_plot: bk_layouts.GridPlot = bk_layouts.gridplot( children, sizing_mode="fixed" ) - self._shape = (nr, nc) + self._widths = widths + self._heights = heights + self._width_total = wsum + self._height_total = hsum self._app = app def _plt_add_canvas(self, row: int, col: int, rowspan: int, colspan: int) -> Canvas: - for plot, r0, c0 in self._grid_plot.children: + for r0, c0, plot in self._iter_bokeh_subplots(): if r0 == row and c0 == col: + plot.visible = True return Canvas(plot) raise ValueError(f"Canvas at ({row}, {col}) not found") @@ -280,12 +287,8 @@ def _plt_get_background_color(self): def _plt_set_background_color(self, color): color = hex_color(color) - for r in range(self._shape[0]): - for c in range(self._shape[1]): - child = self._grid_plot.children[r][c] - if not hasattr(child, "background_fill_color"): - continue - child.background_fill_color = color + for _, _, child in self._iter_bokeh_subplots(): + child.background_fill_color = color def _plt_screenshot(self): import io @@ -296,10 +299,17 @@ def _plt_screenshot(self): export_png(self._grid_plot, filename=buff) buff.seek(0) data = np.frombuffer(buff.getvalue(), dtype=np.uint8) - w, h = self._grid_plot.plot_width, self._grid_plot.plot_height + w, h = self._grid_plot.width, self._grid_plot.height img = data.reshape((int(h), int(w), -1)) return img def _plt_set_figsize(self, width: int, height: int): + for r, c, child in self._iter_bokeh_subplots(): + child.height = int(self._heights[r] / self._height_total * width) + child.width = int(self._widths[c] / self._width_total * height) self._grid_plot.width = width self._grid_plot.height = height + + def _iter_bokeh_subplots(self) -> Iterator[tuple[int, int, bk_plotting.figure]]: + for child, r, c in self._grid_plot.children: + yield r, c, child diff --git a/whitecanvas/canvas/_grid.py b/whitecanvas/canvas/_grid.py index 565ba735..2576dddc 100644 --- a/whitecanvas/canvas/_grid.py +++ b/whitecanvas/canvas/_grid.py @@ -167,7 +167,9 @@ def __repr__(self) -> str: def __getitem__(self, key: tuple[int, int]) -> Canvas: canvas = self._canvas_array[key] if canvas is None: - raise IndexError(f"Canvas at {key} is not set") + raise ValueError(f"Canvas at {key} is not set") + elif isinstance(canvas, np.ndarray): + raise ValueError(f"Cannot index by {key}.") return canvas def _create_backend(self) -> protocols.CanvasGridProtocol: diff --git a/whitecanvas/layers/tabular/_dataframe.py b/whitecanvas/layers/tabular/_dataframe.py index 9abd1a15..1f3e06c9 100644 --- a/whitecanvas/layers/tabular/_dataframe.py +++ b/whitecanvas/layers/tabular/_dataframe.py @@ -8,7 +8,6 @@ Generic, Iterable, TypeVar, - Union, overload, ) @@ -28,7 +27,6 @@ KdeBandWidthType, LineStyle, Orientation, - _Void, ) from whitecanvas.utils.hist import histograms @@ -36,8 +34,6 @@ from typing_extensions import Self _DF = TypeVar("_DF") -_Cols = Union[str, "tuple[str, ...]"] -_void = _Void() class DFLines(_shared.DataFrameLayerWrapper[_lg.LineCollection, _DF], Generic[_DF]): @@ -46,9 +42,9 @@ def __init__( source: DataFrameWrapper[_DF], segs: list[np.ndarray], labels: list[tuple[Any, ...]], - color: _Cols | None = None, + color: str | tuple[str, ...] | None = None, width: float = 1.0, - style: _Cols | None = None, + style: str | tuple[str, ...] | None = None, name: str | None = None, backend: str | Backend | None = None, ): @@ -250,9 +246,9 @@ def __init__( source: DataFrameWrapper[_DF], base: _lg.LayerCollectionBase[_lg.Histogram], labels: list[tuple[Any, ...]], - color: _Cols | None = None, + color: str | tuple[str, ...] | None = None, width: str | None = None, - style: _Cols | None = None, + style: str | tuple[str, ...] | None = None, ): splitby = _shared.join_columns(color, style, source=source) self._color_by = _p.ColorPlan.default() @@ -356,9 +352,9 @@ def __init__( source: DataFrameWrapper[_DF], base: _lg.LayerCollectionBase[_lg.Kde], labels: list[tuple[Any, ...]], - color: _Cols | None = None, + color: str | tuple[str, ...] | None = None, width: str | None = None, - style: _Cols | None = None, + style: str | tuple[str, ...] | None = None, ): splitby = _shared.join_columns(color, style, source=source) self._color_by = _p.ColorPlan.default() @@ -445,23 +441,3 @@ def update_style(self, by: str | Iterable[str], styles=None) -> Self: self._base_layer[i].line.style = st self._style_by = style_by return self - - -def default_template(it: Iterable[tuple[str, np.ndarray]], max_rows: int = 10) -> str: - """ - Default template string for markers - - This template can only be used for those plot that has one tooltip for each data - point, which includes markers, bars and rugs. - """ - fmt_list = list[str]() - for ikey, (key, value) in enumerate(it): - if not key: - continue - if ikey >= max_rows: - break - if value.dtype.kind == "f": - fmt_list.append(f"{key}: {{{key}:.4g}}") - else: - fmt_list.append(f"{key}: {{{key}!r}}") - return "\n".join(fmt_list) From f3813fbd1fa9432d5fff3d8409e1bd27d4ca9e57 Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Thu, 15 Feb 2024 20:53:10 +0900 Subject: [PATCH 2/6] implement linker, jointplot --- docs/canvas/grid.md | 12 +- docs/categorical/cat_num.md | 4 +- docs/layers/markers.md | 4 +- examples/show_image_on_pick.py | 4 +- tests/test_canvas.py | 8 +- whitecanvas/__init__.py | 48 ++++-- whitecanvas/canvas/__init__.py | 9 +- whitecanvas/canvas/_grid.py | 208 ++++++++++++++------------ whitecanvas/canvas/_linker.py | 97 +++++++++++++ whitecanvas/canvas/_namespaces.py | 6 +- whitecanvas/core.py | 233 ++++++++++-------------------- whitecanvas/plot/_canvases.py | 4 +- 12 files changed, 359 insertions(+), 278 deletions(-) create mode 100644 whitecanvas/canvas/_linker.py diff --git a/docs/canvas/grid.md b/docs/canvas/grid.md index 6eec605f..bacb3f41 100644 --- a/docs/canvas/grid.md +++ b/docs/canvas/grid.md @@ -11,9 +11,9 @@ The signature of the method differs between 1D and 2D grid. ``` python #!name: canvas_grid_vertical -from whitecanvas import vgrid +from whitecanvas import new_col -grid = vgrid(3, backend="matplotlib") +grid = new_col(3, backend="matplotlib") c0 = grid.add_canvas(0) c0.add_text(0, 0, "Canvas 0") @@ -27,9 +27,9 @@ grid.show() ``` python #!name: canvas_grid_horizontal -from whitecanvas import hgrid +from whitecanvas import new_row -grid = hgrid(3, backend="matplotlib") +grid = new_row(3, backend="matplotlib") c0 = grid.add_canvas(0) c0.add_text(0, 0, "Canvas 0") @@ -44,9 +44,9 @@ grid.show() ``` python #!name: canvas_grid_2d -from whitecanvas import grid as grid2d +from whitecanvas import new_grid -grid = grid2d(2, 2, backend="matplotlib") +grid = new_grid(2, 2, backend="matplotlib") for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]: c = grid.add_canvas(i, j) diff --git a/docs/categorical/cat_num.md b/docs/categorical/cat_num.md index 0e557a46..91818dab 100644 --- a/docs/categorical/cat_num.md +++ b/docs/categorical/cat_num.md @@ -153,9 +153,9 @@ canvas.show() ``` python #!name: categorical_axis_many_plots #!width: 700 -from whitecanvas import hgrid +from whitecanvas import new_row -canvas = hgrid(ncols=3, size=(1600, 600), backend="matplotlib") +canvas = new_row(3, size=(1600, 600), backend="matplotlib") c0 = canvas.add_canvas(0) c0.cat_x(df, x="category", y="observation").add_boxplot() diff --git a/docs/layers/markers.md b/docs/layers/markers.md index 5a29be96..8cb42307 100644 --- a/docs/layers/markers.md +++ b/docs/layers/markers.md @@ -174,13 +174,13 @@ colors the markers by the density of the points using kernel density estimation. #!name: markers_layer_color_by_density #!width: 500 import numpy as np -from whitecanvas import hgrid +from whitecanvas import new_row rng = np.random.default_rng(999) x = rng.normal(size=1000) y = rng.normal(size=1000) -grid = hgrid(2, backend="matplotlib") +grid = new_row(2, backend="matplotlib") ( grid .add_canvas(0) diff --git a/examples/show_image_on_pick.py b/examples/show_image_on_pick.py index 69b781e2..49ff4870 100644 --- a/examples/show_image_on_pick.py +++ b/examples/show_image_on_pick.py @@ -3,7 +3,7 @@ from __future__ import annotations import numpy as np -from whitecanvas import hgrid +from whitecanvas import new_row def make_images() -> np.ndarray: # prepare sample image data @@ -23,7 +23,7 @@ def main(): images = make_images() means = np.mean(images, axis=(1, 2)) # calculate mean intensity to plot - g = hgrid(2, backend="matplotlib:qt") + g = new_row(2, backend="matplotlib:qt") # markers to be picked markers = ( diff --git a/tests/test_canvas.py b/tests/test_canvas.py index dae6ca34..0b17c758 100644 --- a/tests/test_canvas.py +++ b/tests/test_canvas.py @@ -50,7 +50,7 @@ def test_namespace_pointing_at_different_objects(): assert_color_equal(c1.x.color, "blue") def test_grid(backend: str): - cgrid = wc.grid(2, 2, backend=backend).link_x().link_y() + cgrid = wc.new_grid(2, 2, backend=backend).link_x().link_y() c00 = cgrid.add_canvas(0, 0) c01 = cgrid.add_canvas(0, 1) c10 = cgrid.add_canvas(1, 0) @@ -75,7 +75,7 @@ def test_grid(backend: str): def test_grid_nonuniform(backend: str): - cgrid = wc.grid_nonuniform( + cgrid = wc.new_grid( [2, 1], [2, 1], backend=backend ).link_x().link_y() c00 = cgrid.add_canvas(0, 0) @@ -101,7 +101,7 @@ def test_grid_nonuniform(backend: str): assert len(c11.layers) == 1 def test_vgrid_hgrid(backend: str): - cgrid = wc.vgrid(2, backend=backend).link_x().link_y() + cgrid = wc.new_col(2, backend=backend).link_x().link_y() c0 = cgrid.add_canvas(0) c1 = cgrid.add_canvas(1) @@ -114,7 +114,7 @@ def test_vgrid_hgrid(backend: str): assert len(c0.layers) == 1 assert len(c1.layers) == 1 - cgrid = wc.hgrid(2, backend=backend).link_x().link_y() + cgrid = wc.new_row(2, backend=backend).link_x().link_y() c0 = cgrid.add_canvas(0) c1 = cgrid.add_canvas(1) diff --git a/whitecanvas/__init__.py b/whitecanvas/__init__.py index 34a8b71d..46948467 100644 --- a/whitecanvas/__init__.py +++ b/whitecanvas/__init__.py @@ -3,26 +3,50 @@ from whitecanvas import theme from whitecanvas.canvas import Canvas, CanvasGrid from whitecanvas.core import ( - grid, - grid_nonuniform, - hgrid, - hgrid_nonuniform, new_canvas, - vgrid, - vgrid_nonuniform, + new_col, + new_grid, + new_jointcanvas, + new_row, wrap_canvas, ) __all__ = [ "Canvas", "CanvasGrid", - "grid", - "grid_nonuniform", - "hgrid", - "hgrid_nonuniform", - "vgrid", - "vgrid_nonuniform", "new_canvas", + "new_col", + "new_grid", + "new_row", + "new_jointcanvas", "wrap_canvas", "theme", ] + + +def __getattr__(name: str): + import warnings + + if name in ("grid", "grid_nonuniform"): + warnings.warn( + f"{name!r} is deprecated. Use `new_grid` instead", + DeprecationWarning, + stacklevel=2, + ) + return new_grid + elif name in ("vgrid", "vgrid_nonuniform"): + warnings.warn( + f"{name!r} is deprecated. Use `new_col` instead", + DeprecationWarning, + stacklevel=2, + ) + return new_col + elif name in ("hgrid", "hgrid_nonuniform"): + warnings.warn( + f"{name!r} is deprecated. Use `new_row` instead", + DeprecationWarning, + stacklevel=2, + ) + return new_row + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/whitecanvas/canvas/__init__.py b/whitecanvas/canvas/__init__.py index 78f5a03f..c0651697 100644 --- a/whitecanvas/canvas/__init__.py +++ b/whitecanvas/canvas/__init__.py @@ -1,5 +1,11 @@ from whitecanvas.canvas._base import Canvas, CanvasBase -from whitecanvas.canvas._grid import CanvasGrid, CanvasHGrid, CanvasVGrid, SingleCanvas +from whitecanvas.canvas._grid import ( + CanvasGrid, + CanvasHGrid, + CanvasVGrid, + JointCanvas, + SingleCanvas, +) __all__ = [ "CanvasBase", @@ -7,5 +13,6 @@ "CanvasGrid", "CanvasHGrid", "CanvasVGrid", + "JointCanvas", "SingleCanvas", ] diff --git a/whitecanvas/canvas/_grid.py b/whitecanvas/canvas/_grid.py index 2576dddc..bac42d94 100644 --- a/whitecanvas/canvas/_grid.py +++ b/whitecanvas/canvas/_grid.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator +from typing import TYPE_CHECKING, Any, Iterator, Literal import numpy as np from numpy.typing import NDArray @@ -11,6 +11,7 @@ from whitecanvas import protocols from whitecanvas.backend import Backend from whitecanvas.canvas import Canvas, CanvasBase +from whitecanvas.canvas._linker import link_axes from whitecanvas.theme import get_theme from whitecanvas.utils.normalize import arr_color @@ -33,8 +34,6 @@ def __init__( heights: list[int], widths: list[int], *, - link_x: bool = False, - link_y: bool = False, backend: Backend | str | None = None, ) -> None: self._heights = heights @@ -45,8 +44,10 @@ def __init__( self._canvas_array.fill(None) # link axes - self._x_linked = link_x - self._y_linked = link_y + self._x_linked = False + self._y_linked = False + self._x_linker_ref = None + self._y_linker_ref = None # update settings theme = get_theme() @@ -82,30 +83,12 @@ def shape(self) -> tuple[int, int]: """The (row, col) shape of the grid""" return self._canvas_array.shape - @property - def x_linked(self) -> bool: - """Whether the x-axis of all canvases are linked.""" - return self._x_linked - - @x_linked.setter - def x_linked(self, value: bool): - self.link_x() if value else self.unlink_x() - - @property - def y_linked(self) -> bool: - """Whether the y-axis of all canvases are linked.""" - return self._y_linked - - @y_linked.setter - def y_linked(self, value: bool): - self.link_y() if value else self.unlink_y() - - def link_x(self, future: bool = True) -> Self: + def link_x(self, *, future: bool = True, hide_ticks: bool = True) -> Self: """ Link all the x-axes of the canvases in the grid. - >>> from whitecanvas import grid - >>> g = grid(2, 2).link_x() # link x-axes of all canvases + >>> from whitecanvas import new_grid + >>> g = new_grid(2, 2).link_x() # link x-axes of all canvases Parameters ---------- @@ -113,19 +96,24 @@ def link_x(self, future: bool = True) -> Self: If Ture, all the canvases added in the future will also be linked. Only link the existing canvases if False. """ - if not self._x_linked: - for _, canvas in self.iter_canvas(): - canvas.x.events.lim.connect(self._align_xlims, unique=True) - if future: - self._x_linked = True + if self._x_linker_ref is not None: + self._x_linker_ref.unlink_all() # initialize linker + to_link = [] + for (_r, _), _canvas in self.iter_canvas(): + to_link.append(_canvas) + if hide_ticks and _r != self.shape[0] - 1: + _canvas.x.ticks.visible = False + self._x_linker_ref = link_axes(to_link) + if future: + self._x_linked = True return self - def link_y(self, future: bool = True) -> Self: + def link_y(self, *, future: bool = True, hide_ticks: bool = True) -> Self: """ Link all the y-axes of the canvases in the grid. - >>> from whitecanvas import grid - >>> g = grid(2, 2).link_y() # link y-axes of all canvases + >>> from whitecanvas import new_grid + >>> g = new_grid(2, 2).link_y() # link y-axes of all canvases Parameters ---------- @@ -133,29 +121,16 @@ def link_y(self, future: bool = True) -> Self: If Ture, all the canvases added in the future will also be linked. Only link the existing canvases if False. """ - if not self._y_linked: - for _, canvas in self.iter_canvas(): - canvas.y.events.lim.connect(self._align_ylims, unique=True) - if future: - self._y_linked = True - return self - - def unlink_x(self, future: bool = True) -> Self: - """Unlink all the x-axes of the canvases in the grid.""" - if self._x_linked: - for _, canvas in self.iter_canvas(): - canvas.x.events.lim.disconnect(self._align_xlims) - if future: - self._x_linked = False - return self - - def unlink_y(self, future: bool = True) -> Self: - """Unlink all the y-axes of the canvases in the grid.""" - if self._y_linked: - for _, canvas in self.iter_canvas(): - canvas.y.events.lim.disconnect(self._align_ylims) - if future: - self._y_linked = False + if self._y_linker_ref is not None: + self._y_linker_ref.unlink_all() + to_link = [] + for (_, _c), _canvas in self.iter_canvas(): + to_link.append(_canvas) + if hide_ticks and _c != self.shape[1] - 1: + _canvas.y.ticks.visible = False + self._y_linker_ref = link_axes(to_link) + if future: + self._y_linked = True return self def __repr__(self) -> str: @@ -177,16 +152,6 @@ def _create_backend(self) -> protocols.CanvasGridProtocol: self._heights, self._widths, self._backend._app ) - def _align_xlims(self, lim: tuple[float, float]): - for _, canvas in self.iter_canvas(): - with canvas.x.events.lim.blocked(): - canvas.x.lim = lim - - def _align_ylims(self, lim: tuple[float, float]): - for _, canvas in self.iter_canvas(): - with canvas.y.events.lim.blocked(): - canvas.y.lim = lim - def fill(self, palette: ColormapType | None = None) -> Self: """Fill the grid with canvases.""" for _ in self.iter_add_canvas(palette=palette): @@ -199,6 +164,7 @@ def add_canvas( col: int, rowspan: int = 1, colspan: int = 1, + *, palette: str | None = None, ) -> Canvas: """Add a canvas to the grid at the given position""" @@ -217,10 +183,10 @@ def add_canvas( canvas._install_mouse_events() # link axes if needed - if self.x_linked: - canvas.x.events.lim.connect(self._align_xlims, unique=True) - if self.y_linked: - canvas.y.events.lim.connect(self._align_ylims, unique=True) + if self._x_linked: + self._x_linker_ref.link(canvas.x) + if self._y_linked: + self._y_linker_ref.link(canvas.y) canvas.events.drawn.connect(self.events.drawn.emit, unique=True) return canvas @@ -337,7 +303,7 @@ def __init__( def __getitem__(self, key: int) -> Canvas: canvas = self._canvas_array[key, 0] if canvas is None: - raise IndexError(f"Canvas at {key} is not set") + raise ValueError(f"Canvas at {key} is not set") return canvas @override @@ -374,7 +340,7 @@ def __init__( def __getitem__(self, key: int) -> Canvas: canvas = self._canvas_array[0, key] if canvas is None: - raise IndexError(f"Canvas at {key} is not set") + raise ValueError(f"Canvas at {key} is not set") return canvas @override @@ -397,23 +363,11 @@ def iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: yield self.add_canvas(col, **kwargs) -class SingleCanvas(CanvasBase): - def __init__(self, grid: CanvasGrid): - if grid.shape != (1, 1): - raise ValueError(f"Grid shape must be (1, 1), got {grid.shape}") - self._grid = grid - _it = grid.iter_canvas() - _, canvas = next(_it) - if next(_it, None) is not None: - raise ValueError("Grid must have only one canvas") +class _CanvasWithGrid(CanvasBase): + def __init__(self, canvas: Canvas, grid: CanvasGrid): self._main_canvas = canvas - super().__init__(palette=self._main_canvas._color_palette) - - # NOTE: events, dims etc are not shared between the main canvas and the - # SingleCanvas instance. To avoid confusion, the first and the only canvas - # should be replaces with the SingleCanvas instance. - grid._canvas_array[0, 0] = self - self.events.drawn.connect(self._main_canvas.events.drawn.emit, unique=True) + self._grid = grid + super().__init__(palette=canvas._color_palette) def _get_backend(self) -> Backend: """Return the backend.""" @@ -469,3 +423,79 @@ def _repr_html_(self, *args: Any, **kwargs: Any) -> str: def to_html(self, file: str | None = None) -> str: """Return HTML representation of the canvas.""" return self._grid.to_html(file=file) + + +class SingleCanvas(_CanvasWithGrid): + """ + A canvas without other subplots. + + This class is the simplest form of canvas. In `matplotlib` terms, it is a figure + with a single axes. + """ + + def __init__(self, grid: CanvasGrid): + if grid.shape != (1, 1): + raise ValueError(f"Grid shape must be (1, 1), got {grid.shape}") + self._grid = grid + _it = grid.iter_canvas() + _, canvas = next(_it) + if next(_it, None) is not None: + raise ValueError("Grid must have only one canvas") + self._main_canvas = canvas + super().__init__(canvas, grid) + + # NOTE: events, dims etc are not shared between the main canvas and the + # SingleCanvas instance. To avoid confusion, the first and the only canvas + # should be replaces with the SingleCanvas instance. + grid._canvas_array[0, 0] = self + self.events.drawn.connect(self._main_canvas.events.drawn.emit, unique=True) + + +_0or1 = Literal[0, 1] + + +class JointCanvas(_CanvasWithGrid): + """ + Grid with a main (joint) canvas and two marginal canvases. + + The marginal canvases shares the x-axis and y-axis with the main canvas. + """ + + def __init__( + self, + loc: tuple[_0or1, _0or1] = (1, 0), + palette: str | ColormapType | None = None, + backend: Backend | str | None = None, + ): + widths = [1, 1] + heights = [1, 1] + rloc, cloc = loc + if rloc not in (0, 1) or cloc not in (0, 1): + raise ValueError(f"Invalid location {loc!r}.") + widths[rloc] = heights[cloc] = 3 + grid = CanvasGrid(widths, heights, backend=Backend(backend)) + canvas = grid.add_canvas(rloc, cloc, palette=palette) + self._x_canvas = grid.add_canvas(1 - rloc, cloc) + self._y_canvas = grid.add_canvas(rloc, 1 - cloc) + + super().__init__(canvas, grid) + + # NOTE: events, dims etc are not shared between the main canvas and the + # JointCanvas instance. To avoid confusion, the main canvas should be replaces + # with the JointCanvas instance. + grid._canvas_array[rloc, cloc] = self + self.events.drawn.connect(canvas.events.drawn.emit, unique=True) + + # link axes + self._x_linker = link_axes([self._main_canvas.x, self._x_canvas.x]) + self._y_linker = link_axes([self._main_canvas.y, self._y_canvas.y]) + + @property + def x_canvas(self) -> Canvas: + """The canvas at the x-axis.""" + return self._x_canvas + + @property + def y_canvas(self) -> Canvas: + """The canvas at the y-axis.""" + return self._y_canvas diff --git a/whitecanvas/canvas/_linker.py b/whitecanvas/canvas/_linker.py new file mode 100644 index 00000000..ccab4c2c --- /dev/null +++ b/whitecanvas/canvas/_linker.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, ClassVar +from weakref import WeakSet + +if TYPE_CHECKING: + from whitecanvas.canvas._namespaces import AxisNamespace + + +class AxisLinker: + _GLOBAL_LINKERS: ClassVar[set[AxisLinker]] = set() + + def __new__(cls): + self = super().__new__(cls) + cls._GLOBAL_LINKERS.add(self) + return self + + def __init__(self): + self._axis_set = WeakSet["AxisNamespace"]() + self._updating = False + + def link(self, axis: AxisNamespace): + """Link an axis.""" + axis._get_canvas() # raise error if the parent canvas is deleted. + if axis in self._axis_set: + warnings.warn(f"Axis {axis} already linked", RuntimeWarning, stacklevel=2) + return + self._axis_set.add(axis) + axis.events.lim.connect(self.set_limits) + + def unlink(self, axis: AxisNamespace): + """Unlink an axis.""" + if axis not in self._axis_set: + warnings.warn(f"Axis {axis} was not linked", RuntimeWarning, stacklevel=2) + self._axis_set.discard(axis) + axis.events.lim.disconnect(self.set_limits) + + def unlink_all(self) -> None: + """Unlink all axes.""" + for axis in self._axis_set: + self.unlink(axis) + self.__class__._GLOBAL_LINKERS.discard(self) + + def is_alive(self) -> bool: + """Check if the linker is still alive.""" + return self in self.__class__._GLOBAL_LINKERS + + def set_limits(self, limits: tuple[float, float]): + if self._updating: + return + self._updating = True + try: + for axis in self._axis_set: + axis.lim = limits + finally: + self._updating = False + + @classmethod + def link_axes(cls, *axes: AxisNamespace): + """Link multiple axes.""" + self = cls() + if len(axes) == 1 and hasattr(axes[0], "__iter__"): + axes = axes[0] + for axis in axes: + self.link(axis) + return self + + +class AxisLinkerRef: + def __init__(self, linker: AxisLinker): + self._linker = linker + + def _get_linker(self): + if self._linker.is_alive(): + return self._linker + raise RuntimeError("Linker has been deleted") + + def link(self, axis: AxisNamespace): + """Link an axis.""" + self._get_linker().link(axis) + return self + + def unlink(self, axis: AxisNamespace): + """Unlink an axis.""" + self._get_linker().unlink(axis) + return self + + def unlink_all(self) -> None: + """Unlink all axes.""" + self._get_linker().unlink_all() + + +def link_axes(*axes: AxisNamespace): + """Link multiple axes.""" + linker = AxisLinker.link_axes(*axes) + return AxisLinkerRef(linker) diff --git a/whitecanvas/canvas/_namespaces.py b/whitecanvas/canvas/_namespaces.py index f398d4bd..7d1403b7 100644 --- a/whitecanvas/canvas/_namespaces.py +++ b/whitecanvas/canvas/_namespaces.py @@ -262,7 +262,7 @@ def _get_object(self): return self._get_canvas()._plt_get_ylabel() -class _AxisNamespace(Namespace): +class AxisNamespace(Namespace): events: AxisSignals def __init__(self, canvas: CanvasBase | None = None): @@ -330,7 +330,7 @@ def set_gridlines( self._get_object()._plt_set_grid_state(visible, color, width, style) -class XAxisNamespace(_AxisNamespace): +class XAxisNamespace(AxisNamespace): label = XLabelNamespace() ticks = XTickNamespace() @@ -338,7 +338,7 @@ def _get_object(self): return self._get_canvas()._plt_get_xaxis() -class YAxisNamespace(_AxisNamespace): +class YAxisNamespace(AxisNamespace): label = YLabelNamespace() ticks = YTickNamespace() diff --git a/whitecanvas/core.py b/whitecanvas/core.py index 666eb53b..4dda7433 100644 --- a/whitecanvas/core.py +++ b/whitecanvas/core.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import Any +from typing import TYPE_CHECKING, Any, Sequence from whitecanvas.backend import Backend from whitecanvas.canvas import ( @@ -9,216 +9,125 @@ CanvasGrid, CanvasHGrid, CanvasVGrid, + JointCanvas, SingleCanvas, ) from whitecanvas.types import ColormapType +if TYPE_CHECKING: + from typing import Literal -def grid( - nrows: int = 1, - ncols: int = 1, - *, - size: tuple[int, int] | None = None, - backend: Backend | str | None = None, -) -> CanvasGrid: - """ - Create a canvas grid with uniform cell sizes. - - Parameters - ---------- - nrows : int, default 1 - Number of rows. - ncols : int, default 1 - Number of columns. - size : (int, int), optional - Displaying size of the grid (in pixels). - backend : Backend or str, optional - Backend name. - - Returns - ------- - CanvasGrid - Grid of empty canvases. - """ - g = CanvasGrid.uniform(nrows, ncols, backend=backend) - if size is not None: - g.size = size - return g + _0_or_1 = Literal[0, 1] -def grid_nonuniform( - heights: list[int], - widths: list[int], +def new_canvas( + backend: Backend | str | None = None, *, size: tuple[int, int] | None = None, - backend: Backend | str | None = None, -) -> CanvasGrid: + palette: str | ColormapType | None = None, +) -> SingleCanvas: """ - Create a canvas grid with non-uniform cell sizes. + Create a new canvas with a single cell. Parameters ---------- - heights : list of int - Height ratio of the rows. - widths : list of int - Width ratio the columns. - size : (int, int), optional - Displaying size of the grid (in pixels). backend : Backend or str, optional Backend name. - - Returns - ------- - CanvasGrid - Grid of empty canvases. + size : (int, int), optional + Displaying size of the canvas (in pixels). + palette : str or ColormapType, optional + Color palette of the canvas. This color palette will be used to generate colors + for the plots. """ - g = CanvasGrid(heights, widths, backend=backend) + _grid = CanvasGrid([1], [1], backend=backend) + _grid.add_canvas(0, 0, palette=palette) + cvs = SingleCanvas(_grid) if size is not None: - g.size = size - return g + cvs.size = size + return cvs -def vgrid( - nrows: int = 1, +def new_grid( + rows: int | Sequence[int] = 1, + cols: int | Sequence[int] = 1, *, size: tuple[int, int] | None = None, backend: Backend | str | None = None, -) -> CanvasVGrid: - """ - Create a vertical canvas grid with uniform cell sizes. - - Parameters - ---------- - nrows : int, default 1 - Number of rows. - size : (int, int), optional - Displaying size of the grid (in pixels). - backend : Backend or str, optional - Backend name. - - Returns - ------- - CanvasVGrid - 1D Grid of empty canvases. +) -> CanvasGrid: """ - g = CanvasVGrid.uniform(nrows, backend=backend) - if size is not None: - g.size = size - return g + Create a new canvas grid with uniform or non-uniform cell sizes. + >>> grid = new_grid(2, 3) # 2x3 grid + >>> grid = new_grid(2, 3, size=(800, 600)) # 2x3 grid with size 800x600 + >>> grid = new_grid([1, 2], [2, 1]) # 2x2 grid with non-uniform sizes -def vgrid_nonuniform( - heights: list[int], - *, - size: tuple[int, int] | None = None, - backend: Backend | str | None = None, -) -> CanvasVGrid: - """ - Create a vertical canvas grid with non-uniform cell sizes. + If you want to create a 1D grid, use `new_row` or `new_col` instead. Parameters ---------- - heights : list of int - Height ratios of rows. + rows : int or sequence of int, default 1 + Number of rows (if an integer is given) or height ratio of the rows (if a + sequence of intergers is given). + cols : int or sequence of int, default 1 + Number of columns (if an integer is given) or width ratio of the columns (if a + sequence of intergers is given). size : (int, int), optional Displaying size of the grid (in pixels). backend : Backend or str, optional - Backend name. + Backend name, such as "matplotlib:qt". Returns ------- - CanvasVGrid - 1D Grid of empty canvases. + CanvasGrid + Grid of empty canvases. """ - g = CanvasVGrid(heights, backend=backend) + heights = _norm_ratio(rows) + widths = _norm_ratio(cols) + grid = CanvasGrid(heights, widths, backend=backend) if size is not None: - g.size = size - return g + grid.size = size + return grid -def hgrid( - ncols: int = 1, +def new_row( + cols: int | Sequence[int] = 1, *, size: tuple[int, int] | None = None, backend: Backend | str | None = None, ) -> CanvasHGrid: - """ - Create a horizontal canvas grid with uniform cell sizes. - - Parameters - ---------- - ncols : int, default 1 - Number of columns. - size : (int, int), optional - Displaying size of the grid (in pixels). - backend : Backend or str, optional - Backend name. - - Returns - ------- - CanvasHGrid - 1D Grid of empty canvases. - """ - g = CanvasHGrid.uniform(ncols, backend=backend) + """Create a new horizontal canvas grid with uniform or non-uniform cell sizes.""" + widths = _norm_ratio(cols) + grid = CanvasHGrid(widths, backend=backend) if size is not None: - g.size = size - return g + grid.size = size + return grid -def hgrid_nonuniform( - widths: list[int], +def new_col( + rows: int | Sequence[int] = 1, *, size: tuple[int, int] | None = None, backend: Backend | str | None = None, -) -> CanvasHGrid: - """ - Create a horizontal canvas grid with non-uniform cell sizes. - - Parameters - ---------- - widths : list of int - Width ratios of columns. - size : (int, int), optional - Displaying size of the grid (in pixels). - backend : Backend or str, optional - Backend name. - - Returns - ------- - CanvasHGrid - 1D Grid of empty canvases. - """ - g = CanvasHGrid(widths, backend=backend) +) -> CanvasVGrid: + """Create a new vertical canvas grid with uniform or non-uniform cell sizes.""" + heights = _norm_ratio(rows) + grid = CanvasVGrid(heights, backend=backend) if size is not None: - g.size = size - return g + grid.size = size + return grid -def new_canvas( +def new_jointcanvas( backend: Backend | str | None = None, *, + loc: tuple[_0_or_1, _0_or_1] = (1, 0), size: tuple[int, int] | None = None, palette: str | ColormapType | None = None, -) -> SingleCanvas: - """ - Create a new canvas with a single cell. - - Parameters - ---------- - backend : Backend or str, optional - Backend name. - size : (int, int), optional - Displaying size of the canvas (in pixels). - palette : str or ColormapType, optional - Color palette of the canvas. This color palette will be used to generate colors - for the plots. - """ - _grid = grid(backend=backend) - _grid.add_canvas(0, 0, palette=palette) - cvs = SingleCanvas(_grid) +) -> JointCanvas: + joint = JointCanvas(loc, palette=palette, backend=backend) if size is not None: - cvs.size = size - return cvs + joint.size = size + return joint def wrap_canvas(obj: Any, palette=None) -> Canvas: @@ -278,3 +187,17 @@ def wrap_canvas(obj: Any, palette=None) -> Canvas: def _is_in_module(typ_str: str, mod_name: str, cls_name: str) -> bool: return mod_name in sys.modules and typ_str.split(".")[-1] == cls_name + + +def _norm_ratio(r: int | Sequence[int]) -> list[int]: + if hasattr(r, "__int__"): + out = [1] * int(r) + else: + out: list[int] = [] + for x in r: + if not hasattr(x, "__int__"): + raise ValueError(f"Invalid value for size ratio: {r!r}.") + out.append(int(x)) + if len(out) == 0: + raise ValueError("Size ratio must not be empty.") + return out diff --git a/whitecanvas/plot/_canvases.py b/whitecanvas/plot/_canvases.py index 71b313c5..47e3bd32 100644 --- a/whitecanvas/plot/_canvases.py +++ b/whitecanvas/plot/_canvases.py @@ -2,7 +2,7 @@ from whitecanvas.backend import Backend from whitecanvas.canvas import Canvas, CanvasGrid -from whitecanvas.core import grid, new_canvas +from whitecanvas.core import new_canvas, new_grid def current_grid() -> CanvasGrid: @@ -33,4 +33,4 @@ def subplots( backend: Backend | str | None = None, ) -> CanvasGrid: """Create a new grid of subplots.""" - return grid(nrows, ncols, backend=backend).fill() + return new_grid(nrows, ncols, backend=backend).fill() From 63aab794eb6bdc2b3301038cdac12cae2b33ea6b Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Fri, 16 Feb 2024 01:00:39 +0900 Subject: [PATCH 3/6] methods for jointgrid --- tests/test_canvas.py | 12 + whitecanvas/__init__.py | 9 +- whitecanvas/backend/bokeh/canvas.py | 2 +- whitecanvas/backend/matplotlib/canvas.py | 12 +- whitecanvas/backend/mock/canvas.py | 2 +- whitecanvas/backend/plotly/canvas.py | 2 +- whitecanvas/backend/pyqtgraph/canvas.py | 2 +- whitecanvas/backend/vispy/canvas.py | 4 +- whitecanvas/canvas/__init__.py | 6 +- whitecanvas/canvas/_base.py | 2 +- whitecanvas/canvas/_grid.py | 110 +--- whitecanvas/canvas/_joint.py | 522 +++++++++++++++++++ whitecanvas/canvas/_linker.py | 3 +- whitecanvas/canvas/dataframe/__init__.py | 2 + whitecanvas/canvas/dataframe/_base.py | 3 +- whitecanvas/canvas/dataframe/_both_cat.py | 6 +- whitecanvas/canvas/dataframe/_feature_cat.py | 10 +- whitecanvas/canvas/dataframe/_joint_cat.py | 141 +++++ whitecanvas/canvas/dataframe/_one_cat.py | 8 +- whitecanvas/core.py | 27 +- whitecanvas/layers/group/line_fill.py | 30 +- whitecanvas/layers/tabular/_dataframe.py | 45 +- whitecanvas/types/__init__.py | 4 + whitecanvas/types/_enums.py | 14 + 24 files changed, 816 insertions(+), 162 deletions(-) create mode 100644 whitecanvas/canvas/_joint.py create mode 100644 whitecanvas/canvas/dataframe/_joint_cat.py diff --git a/tests/test_canvas.py b/tests/test_canvas.py index 0b17c758..5b83b0e2 100644 --- a/tests/test_canvas.py +++ b/tests/test_canvas.py @@ -1,5 +1,6 @@ from numpy.testing import assert_allclose +import pytest import whitecanvas as wc from whitecanvas import new_canvas @@ -126,3 +127,14 @@ def test_vgrid_hgrid(backend: str): assert len(c0.layers) == 1 assert len(c1.layers) == 1 + +def test_unlink(backend: str): + grid = wc.new_row(2, backend=backend).fill() + linker = wc.link_axes(grid[0].x, grid[1].x) + grid[0].x.lim = (10, 11) + assert grid[0].x.lim == pytest.approx((10, 11)) + assert grid[1].x.lim == pytest.approx((10, 11)) + linker.unlink_all() + grid[0].x.lim = (20, 21) + assert grid[0].x.lim == pytest.approx((20, 21)) + assert grid[1].x.lim == pytest.approx((10, 11)) diff --git a/whitecanvas/__init__.py b/whitecanvas/__init__.py index 46948467..7b501040 100644 --- a/whitecanvas/__init__.py +++ b/whitecanvas/__init__.py @@ -1,26 +1,25 @@ __version__ = "0.2.2.dev0" from whitecanvas import theme -from whitecanvas.canvas import Canvas, CanvasGrid +from whitecanvas.canvas import link_axes from whitecanvas.core import ( new_canvas, new_col, new_grid, - new_jointcanvas, + new_jointgrid, new_row, wrap_canvas, ) __all__ = [ - "Canvas", - "CanvasGrid", "new_canvas", "new_col", "new_grid", "new_row", - "new_jointcanvas", + "new_jointgrid", "wrap_canvas", "theme", + "link_axes", ] diff --git a/whitecanvas/backend/bokeh/canvas.py b/whitecanvas/backend/bokeh/canvas.py index 52d0ec4c..c2f7a8fb 100644 --- a/whitecanvas/backend/bokeh/canvas.py +++ b/whitecanvas/backend/bokeh/canvas.py @@ -246,7 +246,7 @@ def _translate_modifiers(mod: bk_events.KeyModifiers | None) -> tuple[Modifier, @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: - def __init__(self, heights: list[int], widths: list[int], app: str = "default"): + def __init__(self, heights: list[float], widths: list[float], app: str = "default"): hsum = sum(heights) wsum = sum(widths) children = [] diff --git a/whitecanvas/backend/matplotlib/canvas.py b/whitecanvas/backend/matplotlib/canvas.py index 5422b568..f6184fa0 100644 --- a/whitecanvas/backend/matplotlib/canvas.py +++ b/whitecanvas/backend/matplotlib/canvas.py @@ -17,7 +17,7 @@ from matplotlib.lines import Line2D from whitecanvas import protocols -from whitecanvas.backend.matplotlib._base import MplLayer +from whitecanvas.backend.matplotlib._base import MplLayer, MplMouseEventsMixin from whitecanvas.backend.matplotlib._labels import ( Title, XAxis, @@ -59,16 +59,18 @@ def __init__(self, ax: plt.Axes | None = None): return fig.canvas.mpl_connect("motion_notify_event", self._on_hover) fig.canvas.mpl_connect("figure_leave_event", self._hide_tooltip) - self._hoverable_artists: list[Artist] = [] + self._hoverable_artists: list[MplMouseEventsMixin] = [] self._last_hover = -1.0 - def _on_hover(self, event): + def _on_hover(self, event: mplMouseEvent): if default_timer() - self._last_hover < 0.1: # avoid calling the tooltip too often return + if event.button is not None: + return self._hide_tooltip() self._last_hover = default_timer() if event.inaxes is not self._axes: - return + return self._hide_tooltip() for layer in reversed(self._hoverable_artists): text = layer._on_hover(event) if text: @@ -290,7 +292,7 @@ def _plt_connect_ylim_changed( @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: - def __init__(self, heights: list[int], widths: list[int], app: str = "default"): + def __init__(self, heights: list[float], widths: list[float], app: str = "default"): nr, nc = len(heights), len(widths) self._gridspec = plt.GridSpec( nr, nc, height_ratios=heights, width_ratios=widths diff --git a/whitecanvas/backend/mock/canvas.py b/whitecanvas/backend/mock/canvas.py index fcd83239..c1b11d08 100644 --- a/whitecanvas/backend/mock/canvas.py +++ b/whitecanvas/backend/mock/canvas.py @@ -109,7 +109,7 @@ def _plt_connect_ylim_changed( @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: - def __init__(self, heights: list[int], widths: list[int], app: str = "default"): + def __init__(self, heights: list[float], widths: list[float], app: str = "default"): self._background_color = np.array([1, 1, 1, 1], dtype=np.float32) self._figsize = (100, 100) diff --git a/whitecanvas/backend/plotly/canvas.py b/whitecanvas/backend/plotly/canvas.py index 245c17a5..45cdd4d1 100644 --- a/whitecanvas/backend/plotly/canvas.py +++ b/whitecanvas/backend/plotly/canvas.py @@ -200,7 +200,7 @@ def _repr_mimebundle_(self, *args, **kwargs): @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: - def __init__(self, heights: list[int], widths: list[int], app: str = "default"): + def __init__(self, heights: list[float], widths: list[float], app: str = "default"): from plotly.subplots import make_subplots if app == "notebook": diff --git a/whitecanvas/backend/pyqtgraph/canvas.py b/whitecanvas/backend/pyqtgraph/canvas.py index 10792022..9e70cc4c 100644 --- a/whitecanvas/backend/pyqtgraph/canvas.py +++ b/whitecanvas/backend/pyqtgraph/canvas.py @@ -242,7 +242,7 @@ def _translate_mouse_event( @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: - def __init__(self, heights: list[int], widths: list[int], app: str = "default"): + def __init__(self, heights: list[float], widths: list[float], app: str = "default"): if app == "notebook": from pyqtgraph.jupyter import GraphicsLayoutWidget elif app in ("default", "qt"): diff --git a/whitecanvas/backend/vispy/canvas.py b/whitecanvas/backend/vispy/canvas.py index 96bc3ee0..86ddcccf 100644 --- a/whitecanvas/backend/vispy/canvas.py +++ b/whitecanvas/backend/vispy/canvas.py @@ -193,13 +193,13 @@ def _plt_draw(self): @protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: - def __init__(self, heights: list[int], widths: list[int], app: str = "default"): + def __init__(self, heights: list[float], widths: list[float], app: str = "default"): if app != "default": vispy_use(_APP_NAMES.get(app, app)) self._scene = SceneCanvasExt(keys="interactive") self._grid: Grid = self._scene.central_widget.add_grid() self._scene.create_native() - self._heights = heights # TODO: not used + self._heights = heights self._widths = widths def _plt_add_canvas(self, row: int, col: int, rowspan: int, colspan: int): diff --git a/whitecanvas/canvas/__init__.py b/whitecanvas/canvas/__init__.py index c0651697..0d50f596 100644 --- a/whitecanvas/canvas/__init__.py +++ b/whitecanvas/canvas/__init__.py @@ -3,9 +3,10 @@ CanvasGrid, CanvasHGrid, CanvasVGrid, - JointCanvas, SingleCanvas, ) +from whitecanvas.canvas._joint import JointGrid +from whitecanvas.canvas._linker import link_axes __all__ = [ "CanvasBase", @@ -13,6 +14,7 @@ "CanvasGrid", "CanvasHGrid", "CanvasVGrid", - "JointCanvas", + "JointGrid", "SingleCanvas", + "link_axes", ] diff --git a/whitecanvas/canvas/_base.py b/whitecanvas/canvas/_base.py index d29ae2c6..5bc7a1ec 100644 --- a/whitecanvas/canvas/_base.py +++ b/whitecanvas/canvas/_base.py @@ -411,7 +411,7 @@ def cat( CatPlotter Plotter object. """ - plotter = _df.CatPlotter(self, data, x, y, update_label=update_labels) + plotter = _df.CatPlotter(self, data, x, y, update_labels=update_labels) return plotter def cat_x( diff --git a/whitecanvas/canvas/_grid.py b/whitecanvas/canvas/_grid.py index bac42d94..58587394 100644 --- a/whitecanvas/canvas/_grid.py +++ b/whitecanvas/canvas/_grid.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Literal +from typing import TYPE_CHECKING, Any, Iterator import numpy as np from numpy.typing import NDArray @@ -56,28 +56,6 @@ def __init__( self.events = GridEvents() self.__class__._CURRENT_INSTANCE = self - @classmethod - def uniform( - cls, - nrows: int = 1, - ncols: int = 1, - *, - backend: Backend | str | None = None, - ) -> CanvasGrid: - """ - Create a canvas grid with uniform row and column sizes. - - Parameters - ---------- - nrows : int - The number of rows in the grid. - ncols : int - The number of columns in the grid. - backend : backend-like, optional - The backend to use for the grid. - """ - return CanvasGrid([10] * nrows, [10] * ncols, backend=backend) - @property def shape(self) -> tuple[int, int]: """The (row, col) shape of the grid""" @@ -99,7 +77,7 @@ def link_x(self, *, future: bool = True, hide_ticks: bool = True) -> Self: if self._x_linker_ref is not None: self._x_linker_ref.unlink_all() # initialize linker to_link = [] - for (_r, _), _canvas in self.iter_canvas(): + for (_r, _), _canvas in self._iter_canvas(): to_link.append(_canvas) if hide_ticks and _r != self.shape[0] - 1: _canvas.x.ticks.visible = False @@ -124,7 +102,7 @@ def link_y(self, *, future: bool = True, hide_ticks: bool = True) -> Self: if self._y_linker_ref is not None: self._y_linker_ref.unlink_all() to_link = [] - for (_, _c), _canvas in self.iter_canvas(): + for (_, _c), _canvas in self._iter_canvas(): to_link.append(_canvas) if hide_ticks and _c != self.shape[1] - 1: _canvas.y.ticks.visible = False @@ -154,7 +132,7 @@ def _create_backend(self) -> protocols.CanvasGridProtocol: def fill(self, palette: ColormapType | None = None) -> Self: """Fill the grid with canvases.""" - for _ in self.iter_add_canvas(palette=palette): + for _ in self._iter_add_canvas(palette=palette): pass return self @@ -190,12 +168,12 @@ def add_canvas( canvas.events.drawn.connect(self.events.drawn.emit, unique=True) return canvas - def iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: + def _iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: for row in range(len(self._heights)): for col in range(len(self._widths)): yield self.add_canvas(row, col, **kwargs) - def iter_canvas(self) -> Iterator[tuple[tuple[int, int], Canvas]]: + def _iter_canvas(self) -> Iterator[tuple[tuple[int, int], Canvas]]: yielded: set[int] = set() for idx, canvas in np.ndenumerate(self._canvas_array): _id = id(canvas) @@ -306,22 +284,12 @@ def __getitem__(self, key: int) -> Canvas: raise ValueError(f"Canvas at {key} is not set") return canvas - @override - @classmethod - def uniform( - cls, - nrows: int = 1, - *, - backend: Backend | str | None = None, - ) -> CanvasVGrid: - return CanvasVGrid([1] * nrows, backend=backend) - @override def add_canvas(self, row: int, **kwargs) -> Canvas: return super().add_canvas(row, 0, **kwargs) @override - def iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: + def _iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: for row in range(len(self._heights)): yield self.add_canvas(row, **kwargs) @@ -343,22 +311,12 @@ def __getitem__(self, key: int) -> Canvas: raise ValueError(f"Canvas at {key} is not set") return canvas - @override - @classmethod - def uniform( - cls, - ncols: int = 1, - *, - backend: Backend | str | None = None, - ) -> CanvasHGrid: - return CanvasHGrid([1] * ncols, backend=backend) - @override def add_canvas(self, col: int, **kwargs) -> Canvas: return super().add_canvas(0, col, **kwargs) @override - def iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: + def _iter_add_canvas(self, **kwargs) -> Iterator[Canvas]: for col in range(len(self._widths)): yield self.add_canvas(col, **kwargs) @@ -437,7 +395,7 @@ def __init__(self, grid: CanvasGrid): if grid.shape != (1, 1): raise ValueError(f"Grid shape must be (1, 1), got {grid.shape}") self._grid = grid - _it = grid.iter_canvas() + _it = grid._iter_canvas() _, canvas = next(_it) if next(_it, None) is not None: raise ValueError("Grid must have only one canvas") @@ -449,53 +407,3 @@ def __init__(self, grid: CanvasGrid): # should be replaces with the SingleCanvas instance. grid._canvas_array[0, 0] = self self.events.drawn.connect(self._main_canvas.events.drawn.emit, unique=True) - - -_0or1 = Literal[0, 1] - - -class JointCanvas(_CanvasWithGrid): - """ - Grid with a main (joint) canvas and two marginal canvases. - - The marginal canvases shares the x-axis and y-axis with the main canvas. - """ - - def __init__( - self, - loc: tuple[_0or1, _0or1] = (1, 0), - palette: str | ColormapType | None = None, - backend: Backend | str | None = None, - ): - widths = [1, 1] - heights = [1, 1] - rloc, cloc = loc - if rloc not in (0, 1) or cloc not in (0, 1): - raise ValueError(f"Invalid location {loc!r}.") - widths[rloc] = heights[cloc] = 3 - grid = CanvasGrid(widths, heights, backend=Backend(backend)) - canvas = grid.add_canvas(rloc, cloc, palette=palette) - self._x_canvas = grid.add_canvas(1 - rloc, cloc) - self._y_canvas = grid.add_canvas(rloc, 1 - cloc) - - super().__init__(canvas, grid) - - # NOTE: events, dims etc are not shared between the main canvas and the - # JointCanvas instance. To avoid confusion, the main canvas should be replaces - # with the JointCanvas instance. - grid._canvas_array[rloc, cloc] = self - self.events.drawn.connect(canvas.events.drawn.emit, unique=True) - - # link axes - self._x_linker = link_axes([self._main_canvas.x, self._x_canvas.x]) - self._y_linker = link_axes([self._main_canvas.y, self._y_canvas.y]) - - @property - def x_canvas(self) -> Canvas: - """The canvas at the x-axis.""" - return self._x_canvas - - @property - def y_canvas(self) -> Canvas: - """The canvas at the y-axis.""" - return self._y_canvas diff --git a/whitecanvas/canvas/_joint.py b/whitecanvas/canvas/_joint.py new file mode 100644 index 00000000..84087eeb --- /dev/null +++ b/whitecanvas/canvas/_joint.py @@ -0,0 +1,522 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Iterator, + Literal, + Sequence, + TypeVar, +) + +from whitecanvas import layers as _l +from whitecanvas import theme +from whitecanvas.backend import Backend +from whitecanvas.canvas._grid import CanvasGrid +from whitecanvas.canvas._linker import link_axes +from whitecanvas.layers import group as _lg +from whitecanvas.layers import tabular as _lt +from whitecanvas.types import ( + ArrayLike1D, + ColormapType, + ColorType, + Hatch, + HistBinType, + HistogramKind, + HistogramShape, + KdeBandWidthType, + Orientation, + Symbol, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + from whitecanvas.canvas import Canvas + from whitecanvas.canvas import _namespaces as _ns + from whitecanvas.canvas.dataframe import JointCatPlotter + from whitecanvas.layers import _mixin + from whitecanvas.layers.tabular._dataframe import DataFrameWrapper + + NStr = str | Sequence[str] + +_C = TypeVar("_C", bound="JointGrid") +_DF = TypeVar("_DF") + + +_0_or_1 = Literal[0, 1] + + +class JointGrid(CanvasGrid): + """ + Grid with a main (joint) canvas and two marginal canvases. + + The marginal canvases shares the x-axis and y-axis with the main canvas. + """ + + def __init__( + self, + loc: tuple[_0_or_1, _0_or_1] = (1, 0), + palette: str | ColormapType | None = None, + ratio: int = 4, + backend: Backend | str | None = None, + ): + widths = [1, 1] + heights = [1, 1] + rloc, cloc = loc + if rloc not in (0, 1) or cloc not in (0, 1): + raise ValueError(f"Invalid location {loc!r}.") + widths[rloc] = heights[cloc] = ratio + super().__init__(widths, heights, backend=Backend(backend)) + self._main_canvas = self.add_canvas(rloc, cloc, palette=palette) + self._x_canvas = self.add_canvas(1 - rloc, cloc) + self._y_canvas = self.add_canvas(rloc, 1 - cloc) + + # flip the axes if needed + if rloc == 0: + self._x_canvas.y.flipped = True + self._x_namespace_canvas = self._x_canvas + self._main_canvas.x.ticks.visible = False + self._title_namespace_canvas = self._main_canvas + else: + self._x_namespace_canvas = self._main_canvas + self._x_canvas.x.ticks.visible = False + self._title_namespace_canvas = self._x_canvas + if cloc == 0: + self._ynamespace_canvas = self._main_canvas + self._y_canvas.y.ticks.visible = False + else: + self._y_canvas.x.flipped = True + self._ynamespace_canvas = self._y_canvas + self._main_canvas.y.ticks.visible = False + + # link axes + self._x_linker = link_axes([self._main_canvas.x, self._x_canvas.x]) + self._y_linker = link_axes([self._main_canvas.y, self._y_canvas.y]) + + # joint plotter + self._x_plotters = [] + self._y_plotters = [] + + def _iter_x_plotters(self) -> Iterator[MarginalPlotter]: + if len(self._x_plotters) == 0: + yield MarginalHistPlotter(Orientation.VERTICAL) + else: + yield from self._x_plotters + + def _iter_y_plotters(self) -> Iterator[MarginalPlotter]: + if len(self._y_plotters) == 0: + yield MarginalHistPlotter(Orientation.HORIZONTAL) + else: + yield from self._y_plotters + + @property + def x_canvas(self) -> Canvas: + """The canvas at the x-axis.""" + return self._x_canvas + + @property + def y_canvas(self) -> Canvas: + """The canvas at the y-axis.""" + return self._y_canvas + + @property + def main_canvas(self) -> Canvas: + """The main (joint) canvas.""" + return self._main_canvas + + @property + def x(self) -> _ns.XAxisNamespace: + """The x-axis namespace of the joint grid.""" + return self._x_namespace_canvas.x + + @property + def y(self) -> _ns.YAxisNamespace: + """The y-axis namespace of the joint grid.""" + return self._ynamespace_canvas.y + + @property + def title(self) -> _ns.TitleNamespace: + """Title namespace of the joint grid.""" + return self._title_namespace_canvas.title + + def cat( + self, + data: _DF, + x: str | None = None, + y: str | None = None, + *, + update_labels: bool = True, + ) -> JointCatPlotter[Self, _DF]: + """Create a joint categorical canvas.""" + from whitecanvas.canvas.dataframe import JointCatPlotter + + return JointCatPlotter(self, data, x, y, update_labels=update_labels) + + def add_markers( + self, + xdata: ArrayLike1D, + ydata: ArrayLike1D, + *, + name: str | None = None, + symbol: Symbol | str | None = None, + size: float | None = None, + color: ColorType | None = None, + alpha: float = 1.0, + hatch: str | Hatch | None = None, + ) -> _l.Markers[_mixin.ConstFace, _mixin.ConstEdge, float]: + out = self._main_canvas.add_markers( + xdata, + ydata, + name=name, + symbol=symbol, + size=size, + color=color, + alpha=alpha, + hatch=hatch, + ) + for _x_plt in self._iter_x_plotters(): + xlayer = _x_plt.add_layer_for_markers( + xdata, color, hatch, backend=self._backend + ) + self.x_canvas.add_layer(xlayer) + for _y_plt in self._iter_y_plotters(): + ylayer = _y_plt.add_layer_for_markers( + ydata, color, hatch, backend=self._backend + ) + self.y_canvas.add_layer(ylayer) + return out + + def with_hist_x( + self, + *, + bins: HistBinType = "auto", + limits: tuple[float, float] | None = None, + kind: str | HistogramKind = HistogramKind.density, + shape: str | HistogramShape = HistogramShape.bars, + ) -> Self: + self._x_plotters.append( + MarginalHistPlotter( + Orientation.VERTICAL, bins=bins, limits=limits, kind=kind, shape=shape + ) + ) + return self + + def with_hist_y( + self, + *, + bins: HistBinType = "auto", + limits: tuple[float, float] | None = None, + kind: str | HistogramKind = HistogramKind.density, + shape: str | HistogramShape = HistogramShape.bars, + ) -> Self: + self._y_plotters.append( + MarginalHistPlotter( + Orientation.HORIZONTAL, bins=bins, limits=limits, kind=kind, shape=shape + ) + ) + return self + + def with_hist( + self, + *, + bins: HistBinType = "auto", + limits: tuple[float, float] | None = None, + kind: str | HistogramKind = HistogramKind.density, + shape: str | HistogramShape = HistogramShape.bars, + ) -> Self: + self.with_hist_x(bins=bins, limits=limits, kind=kind, shape=shape) + self.with_hist_y(bins=bins, limits=limits, kind=kind, shape=shape) + return self + + def with_kde_x( + self, + *, + width: float | None = None, + band_width: KdeBandWidthType = "scott", + fill_alpha: float = 0.2, + ) -> Self: + width = theme._default("line.width", width) + self._x_plotters.append( + MarginalKdePlotter( + Orientation.VERTICAL, + width=width, + band_width=band_width, + fill_alpha=fill_alpha, + ) + ) + return self + + def with_kde_y( + self, + *, + width: float | None = None, + band_width: KdeBandWidthType = "scott", + fill_alpha: float = 0.2, + ) -> Self: + width = theme._default("line.width", width) + self._y_plotters.append( + MarginalKdePlotter( + Orientation.HORIZONTAL, + width=width, + band_width=band_width, + fill_alpha=fill_alpha, + ) + ) + return self + + def with_kde( + self, + *, + width: float | None = None, + band_width: KdeBandWidthType = "scott", + fill_alpha: float = 0.2, + ) -> Self: + self.with_kde_x(width=width, band_width=band_width, fill_alpha=fill_alpha) + self.with_kde_y(width=width, band_width=band_width, fill_alpha=fill_alpha) + return self + + def with_rug_x(self, *, width: float | None = None) -> Self: + width = theme._default("line.width", width) + self._x_plotters.append(MarginalRugPlotter(Orientation.VERTICAL, width=width)) + return self + + def with_rug_y(self, *, width: float | None = None) -> Self: + width = theme._default("line.width", width) + self._y_plotters.append(MarginalRugPlotter(Orientation.HORIZONTAL, width=width)) + return self + + def with_rug(self, *, width: float | None = None) -> Self: + self.with_rug_x(width=width) + self.with_rug_y(width=width) + return self + + +class MarginalPlotter(ABC): + def __init__(self, orient: str | Orientation): + self._orient = Orientation.parse(orient) + + @abstractmethod + def add_layer_for_markers( + self, + data: ArrayLike1D, + color: ColorType, + hatch: Hatch = Hatch.SOLID, + backend: str | Backend | None = None, + ) -> _l.Layer: + ... + + @abstractmethod + def add_layer_for_cat_markers( + self, + df: DataFrameWrapper[_DF], + value: str, + color: NStr | None = None, + hatch: NStr | None = None, + backend: str | Backend | None = None, + ) -> _l.Layer: + ... + + @abstractmethod + def add_layer_for_cat_hist2d( + self, + df: DataFrameWrapper[_DF], + value: str, + color: str | None = None, + bins: HistBinType | tuple[HistBinType, HistBinType] = "auto", + limits: tuple[float, float] | None = None, + backend: str | Backend | None = None, + ) -> _l.Layer: + ... + + +class MarginalHistPlotter(MarginalPlotter): + def __init__( + self, + orient: str | Orientation, + bins: HistBinType = "auto", + limits: tuple[float, float] | None = None, + kind: str | HistogramKind = "density", + shape: str | HistogramShape = "bars", + ): + super().__init__(orient) + self._bins = bins + self._limits = limits + self._kind = HistogramKind(kind) + self._shape = HistogramShape(shape) + + def add_layer_for_markers( + self, + data: ArrayLike1D, + color: ColorType, + hatch: Hatch = Hatch.SOLID, + backend: str | Backend | None = None, + ) -> _lg.Histogram: + return _lg.Histogram.from_array( + data, + shape=self._shape, + kind=self._kind, + color=color, + orient=self._orient, + bins=self._bins, + limits=self._limits, + backend=backend, + ) + + def add_layer_for_cat_markers( + self, + df: DataFrameWrapper[_DF], + value: str, + color: NStr | None = None, + hatch: NStr | None = None, + backend: str | Backend | None = None, + ) -> _lt.DFHistograms[_DF]: + return _lt.DFHistograms.from_table( + df, value, orient=self._orient, color=color, hatch=hatch, + bins=self._bins, limits=self._limits, kind=self._kind, shape=self._shape, + backend=backend, + ) # fmt: skip + + def add_layer_for_cat_hist2d( + self, + df: DataFrameWrapper[_DF], + value: str, + color: str | None = None, + bins: HistBinType | tuple[HistBinType, HistBinType] = "auto", + limits: tuple[float, float] | None = None, + backend: str | Backend | None = None, + ) -> _lt.DFHistograms[_DF]: + if self._bins != "auto": + bins = self._bins + if self._limits is not None: + limits = self._limits + return _lt.DFHistograms.from_table( + df, value, orient=self._orient, color=color, bins=bins, limits=limits, + kind=self._kind, shape=self._shape, backend=backend, + ) # fmt: skip + + +class MarginalKdePlotter(MarginalPlotter): + def __init__( + self, + orient: str | Orientation, + width: float = 1.0, + band_width: KdeBandWidthType = "scott", + fill_alpha: float = 0.2, + ): + super().__init__(orient) + self._width = width + self._band_width = band_width + self._fill_alpha = fill_alpha + + def add_layer_for_markers( + self, + data: ArrayLike1D, + color: ColorType, + hatch: Hatch = Hatch.SOLID, + backend: str | Backend | None = None, + ) -> _lg.Kde: + out = _lg.Kde.from_array( + data, color=color, orient=self._orient, band_width=self._band_width, + width=self._width, backend=backend, + ) # fmt: skip + out.fill_alpha = self._fill_alpha + return out + + def add_layer_for_cat_markers( + self, + df: DataFrameWrapper[_DF], + value: str, + color: NStr | None = None, + hatch: NStr | None = None, + backend: str | Backend | None = None, + ) -> _lt.DFKde[_DF]: + out = _lt.DFKde.from_table( + df, value, orient=self._orient, color=color, hatch=hatch, + width=self._width, backend=backend, + ) # fmt: skip + for layer in out.base: + layer.fill_alpha = self._fill_alpha + return out + + def add_layer_for_cat_hist2d( + self, + df: DataFrameWrapper[_DF], + value: str, + color: str | None = None, + bins: HistBinType | tuple[HistBinType, HistBinType] = "auto", + limits: tuple[float, float] | None = None, + backend: str | Backend | None = None, + ) -> _lt.DFKde[_DF]: + out = _lt.DFKde.from_table( + df, value, orient=self._orient, color=color, band_width=self._band_width, + width=self._width, backend=backend, + ) # fmt: skip + for layer in out.base: + layer.fill_alpha = self._fill_alpha + return out + + +class MarginalRugPlotter(MarginalPlotter): + def __init__( + self, + orient: str | Orientation, + width: float = 1.0, + length: float = 0.1, + ): + super().__init__(orient) + self._width = width + self._length = length + + def add_layer_for_markers( + self, + data: ArrayLike1D, + color: ColorType, + hatch: Hatch = Hatch.SOLID, + backend: str | Backend | None = None, + ) -> _l.Rug: + return _l.Rug( + data, high=self._length, color=color, orient=self._orient, + width=self._width, backend=backend + ) # fmt: skip + + def add_layer_for_cat_markers( + self, + df: DataFrameWrapper[_DF], + value: str, + color: NStr | None = None, + hatch: NStr | None = None, + backend: str | Backend | None = None, + ) -> _lt.DFRug[_DF]: + return _lt.DFRug.from_table( + df, + value, + high=self._length, + orient=self._orient, + color=color, + width=self._width, + backend=backend, + ) + + def add_layer_for_cat_hist2d( + self, + df: DataFrameWrapper[_DF], + value: str, + color: str | None = None, + bins: HistBinType | tuple[HistBinType, HistBinType] = "auto", + limits: tuple[float, float] | None = None, + backend: str | Backend | None = None, + ) -> _lt.DFRug[_DF]: + return _lt.DFRug.from_table( + df, + value, + high=self._length, + orient=self._orient, + color=color, + width=self._width, + backend=backend, + ) + + +# class MarginalBoxPlotter +# class MarginalViolinPlotter diff --git a/whitecanvas/canvas/_linker.py b/whitecanvas/canvas/_linker.py index ccab4c2c..ec5f4a90 100644 --- a/whitecanvas/canvas/_linker.py +++ b/whitecanvas/canvas/_linker.py @@ -38,7 +38,8 @@ def unlink(self, axis: AxisNamespace): def unlink_all(self) -> None: """Unlink all axes.""" - for axis in self._axis_set: + axes = list(self._axis_set) # avoid the size changing during iteration + for axis in axes: self.unlink(axis) self.__class__._GLOBAL_LINKERS.discard(self) diff --git a/whitecanvas/canvas/dataframe/__init__.py b/whitecanvas/canvas/dataframe/__init__.py index f8365657..d3a629a2 100644 --- a/whitecanvas/canvas/dataframe/__init__.py +++ b/whitecanvas/canvas/dataframe/__init__.py @@ -1,5 +1,6 @@ from whitecanvas.canvas.dataframe._both_cat import XYCatPlotter from whitecanvas.canvas.dataframe._feature_cat import CatPlotter +from whitecanvas.canvas.dataframe._joint_cat import JointCatPlotter from whitecanvas.canvas.dataframe._one_cat import XCatPlotter, YCatPlotter __all__ = [ @@ -7,4 +8,5 @@ "XCatPlotter", "YCatPlotter", "XYCatPlotter", + "JointCatPlotter", ] diff --git a/whitecanvas/canvas/dataframe/_base.py b/whitecanvas/canvas/dataframe/_base.py index e362ab39..127fbf8e 100644 --- a/whitecanvas/canvas/dataframe/_base.py +++ b/whitecanvas/canvas/dataframe/_base.py @@ -20,10 +20,9 @@ if TYPE_CHECKING: from typing_extensions import Self - from whitecanvas.canvas._base import CanvasBase from whitecanvas.layers.tabular._dataframe import DataFrameWrapper -_C = TypeVar("_C", bound="CanvasBase") +_C = TypeVar("_C") # NOTE: don't have to be a canvas _DF = TypeVar("_DF") NStr = Union[str, Sequence[str]] AggMethods = Literal["min", "max", "mean", "median", "sum", "std"] diff --git a/whitecanvas/canvas/dataframe/_both_cat.py b/whitecanvas/canvas/dataframe/_both_cat.py index b0f95d1c..fd6f66fe 100644 --- a/whitecanvas/canvas/dataframe/_both_cat.py +++ b/whitecanvas/canvas/dataframe/_both_cat.py @@ -60,7 +60,7 @@ def __init__( df: _DF, x: str | tuple[str, ...], y: str | tuple[str, ...], - update_label: bool = False, + update_labels: bool = False, ): super().__init__(canvas, df) if isinstance(x, str): @@ -69,10 +69,10 @@ def __init__( y = (y,) self._x: tuple[str, ...] = x self._y: tuple[str, ...] = y - self._update_label = update_label + self._update_label = update_labels self._cat_iter_x = CatIterator(self._df, x) self._cat_iter_y = CatIterator(self._df, y) - if update_label: + if update_labels: self._update_xy_label(x, y) self._update_axis_labels() diff --git a/whitecanvas/canvas/dataframe/_feature_cat.py b/whitecanvas/canvas/dataframe/_feature_cat.py index af295523..4dfc50b2 100644 --- a/whitecanvas/canvas/dataframe/_feature_cat.py +++ b/whitecanvas/canvas/dataframe/_feature_cat.py @@ -24,10 +24,6 @@ _DF = TypeVar("_DF") -_C = TypeVar("_C", bound="CanvasBase") -_DF = TypeVar("_DF") - - class _Aggregator(Generic[_C, _DF]): def __init__(self, method: str, plotter: CatPlotter[_C, _DF] = None): self._method = method @@ -69,13 +65,13 @@ def __init__( df: _DF, x: str | None, y: str | None, - update_label: bool = False, + update_labels: bool = False, ): super().__init__(canvas, df) self._x = x self._y = y - self._update_label = update_label - if update_label: + self._update_label = update_labels + if update_labels: self._update_xy_label(x, y) def _get_x(self) -> str: diff --git a/whitecanvas/canvas/dataframe/_joint_cat.py b/whitecanvas/canvas/dataframe/_joint_cat.py new file mode 100644 index 00000000..8384ceb0 --- /dev/null +++ b/whitecanvas/canvas/dataframe/_joint_cat.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Sequence, + TypeVar, +) + +from whitecanvas.canvas.dataframe._feature_cat import CatPlotter +from whitecanvas.layers import tabular as _lt +from whitecanvas.layers.tabular import _jitter +from whitecanvas.types import ColormapType, HistBinType + +if TYPE_CHECKING: + from whitecanvas.canvas import JointGrid + + NStr = str | Sequence[str] + +_C = TypeVar("_C", bound="JointGrid") +_DF = TypeVar("_DF") + + +class JointCatPlotter(CatPlotter[_C, _DF]): + def __init__( + self, + canvas: _C, + df: _DF, + x: str | None, + y: str | None, + update_labels: bool = False, + ): + super().__init__(canvas, df, x, y, update_labels=update_labels) + + def add_markers( + self, + *, + name: str | None = None, + color: NStr | None = None, + hatch: NStr | None = None, + size: str | None = None, + symbol: NStr | None = None, + ) -> _lt.DFMarkers[_DF]: + """ + Add a categorical marker plot. + + Parameters + ---------- + name : str, optional + Name of the layer. + color : str or sequence of str, optional + Column name(s) for coloring the lines. Must be categorical. + hatch : str or sequence of str, optional + Column name(s) for hatches. Must be categorical. + size : str, optional + Column name for marker size. Must be numerical. + symbol : str or sequence of str, optional + Column name(s) for symbols. Must be categorical. + + Returns + ------- + DFMarkers + Marker collection layer. + """ + grid = self._canvas() + main = grid.main_canvas + xj = _jitter.IdentityJitter(self._get_x()) + yj = _jitter.IdentityJitter(self._get_y()) + layer = _lt.DFMarkers( + self._df, xj, yj, name=name, color=color, hatch=hatch, + size=size, symbol=symbol, backend=grid._backend, + ) # fmt: skip + if color is not None and not layer._color_by.is_const(): + layer.update_color(layer._color_by.by, palette=main._color_palette) + elif color is None: + layer.update_color(main._color_palette.next()) + main.add_layer(layer) + for _x_plt in grid._iter_x_plotters(): + xlayer = _x_plt.add_layer_for_cat_markers( + self._df, self._get_x(), color=color, hatch=hatch, backend=grid._backend + ) # fmt: skip + grid.x_canvas.add_layer(xlayer) + for _y_plt in grid._iter_y_plotters(): + ylayer = _y_plt.add_layer_for_cat_markers( + self._df, self._get_y(), color=color, hatch=hatch, backend=grid._backend + ) + grid.y_canvas.add_layer(ylayer) + return layer + + def add_hist2d( + self, + *, + cmap: ColormapType = "inferno", + name: str | None = None, + bins: HistBinType | tuple[HistBinType, HistBinType] = "auto", + rangex: tuple[float, float] | None = None, + rangey: tuple[float, float] | None = None, + density: bool = False, + ) -> _lt.DFHeatmap[_DF]: + """ + Add 2-D histogram of given x/y columns. + + Parameters + ---------- + cmap : colormap-like, default "inferno" + Colormap to use for the heatmap. + name : str, optional + Name of the layer. + bins : int, array, str or tuple of them, default "auto" + If int, the number of bins for both x and y. If tuple, the number of bins + for x and y respectively. + rangex : (float, float), optional + Range of x values in which histogram will be built. + rangey : (float, float), optional + Range of y values in which histogram will be built. + density : bool, default False + If True, the result is the value of the probability density function at the + bin, normalized such that the integral over the range is 1. + + Returns + ------- + DFHeatmap + Dataframe bound heatmap layer. + """ + grid = self._canvas() + main = grid.main_canvas + layer = _lt.DFHeatmap.build_hist( + self._df, self._get_x(), self._get_y(), cmap=cmap, name=name, bins=bins, + range=(rangex, rangey), density=density, backend=grid._backend, + ) # fmt: skip + main.add_layer(layer) + for _x_plt in grid._iter_x_plotters(): + xlayer = _x_plt.add_layer_for_cat_hist2d( + self._df, self._get_x(), bins=bins, limits=rangex, backend=grid._backend + ) # fmt: skip + grid.x_canvas.add_layer(xlayer) + for _y_plt in grid._iter_y_plotters(): + ylayer = _y_plt.add_layer_for_cat_hist2d( + self._df, self._get_y(), bins=bins, limits=rangey, backend=grid._backend + ) # fmt: skip + grid.y_canvas.add_layer(ylayer) + return layer diff --git a/whitecanvas/canvas/dataframe/_one_cat.py b/whitecanvas/canvas/dataframe/_one_cat.py index cb80c591..983b3579 100644 --- a/whitecanvas/canvas/dataframe/_one_cat.py +++ b/whitecanvas/canvas/dataframe/_one_cat.py @@ -78,7 +78,7 @@ def __call__(self, by: str | tuple[str, ...]) -> OneAxisCatPlotter[_C, _DF]: plotter._df.agg_by((*plotter._offset, *by), [plotter._value], self._method), offset=plotter._offset, value=plotter._value, - update_label=plotter._update_label, + update_labels=plotter._update_labels, ) @@ -91,7 +91,7 @@ def __init__( df: _DF, offset: str | tuple[str, ...] | None, value: str | None, - update_label: bool = False, + update_labels: bool = False, ): super().__init__(canvas, df) if isinstance(offset, str): @@ -107,8 +107,8 @@ def __init__( self._offset: tuple[str, ...] = offset self._cat_iter = CatIterator(self._df, offset) self._value = value - self._update_label = update_label - if update_label: + self._update_labels = update_labels + if update_labels: if value is not None: self._update_axis_labels(value) pos, label = self._cat_iter.axis_ticks() diff --git a/whitecanvas/core.py b/whitecanvas/core.py index 4dda7433..4bad02b4 100644 --- a/whitecanvas/core.py +++ b/whitecanvas/core.py @@ -9,7 +9,7 @@ CanvasGrid, CanvasHGrid, CanvasVGrid, - JointCanvas, + JointGrid, SingleCanvas, ) from whitecanvas.types import ColormapType @@ -117,14 +117,33 @@ def new_col( return grid -def new_jointcanvas( +def new_jointgrid( backend: Backend | str | None = None, *, loc: tuple[_0_or_1, _0_or_1] = (1, 0), size: tuple[int, int] | None = None, palette: str | ColormapType | None = None, -) -> JointCanvas: - joint = JointCanvas(loc, palette=palette, backend=backend) +) -> JointGrid: + """ + Create a new joint grid. + + Parameters + ---------- + backend : Backend or str, optional + Backend of the canvas. + loc : (int, int), default (1, 0) + Location of the main canvas. Each integer must be 0 or 1. + size : (int, int), optional + Size of the canvas in pixel. + palette : colormap type, optional + Color palette used for the canvases. + + Returns + ------- + JointGrid + Joint grid object. + """ + joint = JointGrid(loc, palette=palette, backend=backend) if size is not None: joint.size = size return joint diff --git a/whitecanvas/layers/group/line_fill.py b/whitecanvas/layers/group/line_fill.py index 2757cb2d..d18ea32e 100644 --- a/whitecanvas/layers/group/line_fill.py +++ b/whitecanvas/layers/group/line_fill.py @@ -1,6 +1,5 @@ from __future__ import annotations -from enum import Enum from typing import overload import numpy as np @@ -13,6 +12,8 @@ ArrayLike1D, ColorType, HistBinType, + HistogramKind, + HistogramShape, KdeBandWidthType, LineStyle, Orientation, @@ -21,23 +22,10 @@ from whitecanvas.utils.normalize import as_array_1d -class HistogramShape(Enum): - step = "step" - polygon = "polygon" - bars = "bars" - - -class HistogramKind(Enum): - count = "count" - density = "density" - probability = "probability" - frequency = "frequency" - percent = "percent" - - class LineFillBase(LayerContainer): def __init__(self, line: Line, fill: Band, name: str | None = None): super().__init__([line, fill], name=name) + self._fill_alpha = 0.2 @property def line(self) -> Line: @@ -62,9 +50,19 @@ def color(self) -> NDArray[np.float32]: @color.setter def color(self, color: ColorType): self.line.color = color - self.fill.face.update(color=color, alpha=0.2) + self.fill.face.update(color=color, alpha=self._fill_alpha) self.fill.edge.width = 0.0 + @property + def fill_alpha(self) -> float: + """The alpha value applied to the fill region compared to the line.""" + return self._fill_alpha + + @fill_alpha.setter + def fill_alpha(self, alpha: float): + self._fill_alpha = alpha + self.fill.face.alpha = alpha + class Histogram(LineFillBase): def __init__( diff --git a/whitecanvas/layers/tabular/_dataframe.py b/whitecanvas/layers/tabular/_dataframe.py index 1f3e06c9..7320ef76 100644 --- a/whitecanvas/layers/tabular/_dataframe.py +++ b/whitecanvas/layers/tabular/_dataframe.py @@ -249,11 +249,13 @@ def __init__( color: str | tuple[str, ...] | None = None, width: str | None = None, style: str | tuple[str, ...] | None = None, + hatch: str | tuple[str, ...] | None = None, ): splitby = _shared.join_columns(color, style, source=source) self._color_by = _p.ColorPlan.default() self._width_by = _p.WidthPlan.default() self._style_by = _p.StylePlan.default() + self._hatch_by = _p.HatchPlan.default() self._labels = labels self._splitby = splitby super().__init__(base, source) @@ -263,6 +265,8 @@ def __init__( self.update_width(width) if style is not None: self.update_style(style) + if hatch is not None: + self.update_hatch(hatch) @classmethod def from_table( @@ -276,6 +280,7 @@ def from_table( color: str | None = None, width: float = 1.0, style: str | None = None, + hatch: str | None = None, name: str | None = None, orient: str | Orientation = Orientation.VERTICAL, backend: str | Backend | None = None, @@ -297,7 +302,7 @@ def from_table( ) # fmt: skip layers.append(each_layer) base = _lg.LayerCollectionBase(layers, name=name) - return cls(df, base, labels, color=color, width=width, style=style) + return cls(df, base, labels, color=color, width=width, style=style, hatch=hatch) @overload def update_color(self, value: ColorType) -> Self: @@ -329,12 +334,12 @@ def update_width(self, value: float) -> Self: hist.line.width = value return self - def update_style(self, by: str | Iterable[str], styles=None) -> Self: + def update_style(self, by: str | Iterable[str], palette=None) -> Self: cov = _shared.ColumnOrValue(by, self._source) if cov.is_column: if set(cov.columns) > set(self._splitby): raise ValueError(f"Cannot style by a column other than {self._splitby}") - style_by = _p.StylePlan.new(cov.columns, values=styles) + style_by = _p.StylePlan.new(cov.columns, values=palette) else: style_by = _p.StylePlan.from_const(LineStyle(cov.value)) for i, st in enumerate(style_by.generate(self._labels, self._splitby)): @@ -342,6 +347,19 @@ def update_style(self, by: str | Iterable[str], styles=None) -> Self: self._style_by = style_by return self + def update_hatch(self, by: str | Iterable[str], styles=None) -> Self: + cov = _shared.ColumnOrValue(by, self._source) + if cov.is_column: + if set(cov.columns) > set(self._splitby): + raise ValueError(f"Cannot hatch by a column other than {self._splitby}") + hatch_by = _p.HatchPlan.new(cov.columns, values=styles) + else: + hatch_by = _p.HatchPlan.from_const(cov.value) + for i, st in enumerate(hatch_by.generate(self._labels, self._splitby)): + self._base_layer[i].fill.face.hatch = st + self._hatch_by = hatch_by + return self + class DFKde( _shared.DataFrameLayerWrapper[_lg.LayerCollectionBase[_lg.Kde], _DF], @@ -355,6 +373,7 @@ def __init__( color: str | tuple[str, ...] | None = None, width: str | None = None, style: str | tuple[str, ...] | None = None, + hatch: str | tuple[str, ...] | None = None, ): splitby = _shared.join_columns(color, style, source=source) self._color_by = _p.ColorPlan.default() @@ -369,6 +388,8 @@ def __init__( self.update_width(width) if style is not None: self.update_style(style) + if hatch is not None: + self.update_hatch(hatch) @classmethod def from_table( @@ -379,10 +400,11 @@ def from_table( color: str | None = None, width: float = 1.0, style: str | None = None, + hatch: str | None = None, name: str | None = None, orient: str | Orientation = Orientation.VERTICAL, backend: str | Backend | None = None, - ) -> DFHistograms[_DF]: + ) -> DFKde[_DF]: splitby = _shared.join_columns(color, style, source=df) ori = Orientation.parse(orient) arrays: list[np.ndarray] = [] @@ -397,7 +419,7 @@ def from_table( ) # fmt: skip layers.append(each_layer) base = _lg.LayerCollectionBase(layers, name=name) - return cls(df, base, labels, color=color, width=width, style=style) + return cls(df, base, labels, color=color, width=width, style=style, hatch=hatch) @overload def update_color(self, value: ColorType) -> Self: @@ -441,3 +463,16 @@ def update_style(self, by: str | Iterable[str], styles=None) -> Self: self._base_layer[i].line.style = st self._style_by = style_by return self + + def update_hatch(self, by: str | Iterable[str], styles=None) -> Self: + cov = _shared.ColumnOrValue(by, self._source) + if cov.is_column: + if set(cov.columns) > set(self._splitby): + raise ValueError(f"Cannot hatch by a column other than {self._splitby}") + hatch_by = _p.HatchPlan.new(cov.columns, values=styles) + else: + hatch_by = _p.HatchPlan.from_const(cov.value) + for i, st in enumerate(hatch_by.generate(self._labels, self._splitby)): + self._base_layer[i].fill.face.hatch = st + self._hatch_by = hatch_by + return self diff --git a/whitecanvas/types/__init__.py b/whitecanvas/types/__init__.py index 4553ec96..a3350fe7 100644 --- a/whitecanvas/types/__init__.py +++ b/whitecanvas/types/__init__.py @@ -9,6 +9,8 @@ from whitecanvas.types._enums import ( Alignment, Hatch, + HistogramKind, + HistogramShape, LineStyle, Modifier, MouseButton, @@ -28,6 +30,8 @@ "Symbol", "Hatch", "HistBinType", + "HistogramKind", + "HistogramShape", "KdeBandWidthType", "Orientation", "Origin", diff --git a/whitecanvas/types/_enums.py b/whitecanvas/types/_enums.py index 5e8bc321..e7ef6e35 100644 --- a/whitecanvas/types/_enums.py +++ b/whitecanvas/types/_enums.py @@ -186,3 +186,17 @@ class Origin(_StrEnum): CORNER = "corner" EDGE = "edge" CENTER = "center" + + +class HistogramShape(_StrEnum): + step = "step" + polygon = "polygon" + bars = "bars" + + +class HistogramKind(_StrEnum): + count = "count" + density = "density" + probability = "probability" + frequency = "frequency" + percent = "percent" From 2a1ef94f4ff49d15736d4412d297f9c1e161d380 Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Fri, 16 Feb 2024 15:12:08 +0900 Subject: [PATCH 4/6] minor bug fixes in desing --- tests/test_canvas.py | 6 ++ whitecanvas/backend/matplotlib/canvas.py | 2 +- whitecanvas/canvas/_grid.py | 6 +- whitecanvas/canvas/_joint.py | 86 ++++++++++++++++++---- whitecanvas/canvas/dataframe/_joint_cat.py | 4 + whitecanvas/layers/_mixin.py | 9 ++- whitecanvas/layers/_primitive/image.py | 6 +- whitecanvas/layers/group/line_fill.py | 3 +- 8 files changed, 100 insertions(+), 22 deletions(-) diff --git a/tests/test_canvas.py b/tests/test_canvas.py index 5b83b0e2..2f6540d0 100644 --- a/tests/test_canvas.py +++ b/tests/test_canvas.py @@ -1,3 +1,4 @@ +import numpy as np from numpy.testing import assert_allclose import pytest @@ -138,3 +139,8 @@ def test_unlink(backend: str): grid[0].x.lim = (20, 21) assert grid[0].x.lim == pytest.approx((20, 21)) assert grid[1].x.lim == pytest.approx((10, 11)) + +def test_jointgrid(backend: str): + rng = np.random.default_rng(0) + joint = wc.new_jointgrid(backend=backend).with_hist().with_kde().with_rug() + joint.add_markers(rng.random(100), rng.random(100), color="red") diff --git a/whitecanvas/backend/matplotlib/canvas.py b/whitecanvas/backend/matplotlib/canvas.py index f6184fa0..812cdbce 100644 --- a/whitecanvas/backend/matplotlib/canvas.py +++ b/whitecanvas/backend/matplotlib/canvas.py @@ -79,7 +79,7 @@ def _on_hover(self, event: mplMouseEvent): return self._hide_tooltip() - def _set_tooltip(self, pos, text: str): + def _set_tooltip(self, pos: tuple[float, float], text: str): # determine in which direction to show the tooltip x, y = pos x0, x1 = self._axes.get_xlim() diff --git a/whitecanvas/canvas/_grid.py b/whitecanvas/canvas/_grid.py index 58587394..808532e5 100644 --- a/whitecanvas/canvas/_grid.py +++ b/whitecanvas/canvas/_grid.py @@ -78,7 +78,7 @@ def link_x(self, *, future: bool = True, hide_ticks: bool = True) -> Self: self._x_linker_ref.unlink_all() # initialize linker to_link = [] for (_r, _), _canvas in self._iter_canvas(): - to_link.append(_canvas) + to_link.append(_canvas.x) if hide_ticks and _r != self.shape[0] - 1: _canvas.x.ticks.visible = False self._x_linker_ref = link_axes(to_link) @@ -103,8 +103,8 @@ def link_y(self, *, future: bool = True, hide_ticks: bool = True) -> Self: self._y_linker_ref.unlink_all() to_link = [] for (_, _c), _canvas in self._iter_canvas(): - to_link.append(_canvas) - if hide_ticks and _c != self.shape[1] - 1: + to_link.append(_canvas.y) + if hide_ticks and _c != 0: _canvas.y.ticks.visible = False self._y_linker_ref = link_axes(to_link) if future: diff --git a/whitecanvas/canvas/_joint.py b/whitecanvas/canvas/_joint.py index 84087eeb..13b0eed4 100644 --- a/whitecanvas/canvas/_joint.py +++ b/whitecanvas/canvas/_joint.py @@ -40,7 +40,6 @@ NStr = str | Sequence[str] -_C = TypeVar("_C", bound="JointGrid") _DF = TypeVar("_DF") @@ -153,6 +152,10 @@ def cat( return JointCatPlotter(self, data, x, y, update_labels=update_labels) + def _link_marginal_to_main(self, layer: _l.Layer, main: _l.Layer) -> None: + # TODO: this is not the only thing to be done + main.events.visible.connect_setattr(layer, "visible") + def add_markers( self, xdata: ArrayLike1D, @@ -166,25 +169,21 @@ def add_markers( hatch: str | Hatch | None = None, ) -> _l.Markers[_mixin.ConstFace, _mixin.ConstEdge, float]: out = self._main_canvas.add_markers( - xdata, - ydata, - name=name, - symbol=symbol, - size=size, - color=color, - alpha=alpha, - hatch=hatch, - ) + xdata, ydata, name=name, symbol=symbol, size=size, color=color, + alpha=alpha, hatch=hatch, + ) # fmt: skip for _x_plt in self._iter_x_plotters(): xlayer = _x_plt.add_layer_for_markers( xdata, color, hatch, backend=self._backend ) self.x_canvas.add_layer(xlayer) + self._link_marginal_to_main(xlayer, out) for _y_plt in self._iter_y_plotters(): ylayer = _y_plt.add_layer_for_markers( ydata, color, hatch, backend=self._backend ) self.y_canvas.add_layer(ylayer) + self._link_marginal_to_main(ylayer, out) return out def with_hist_x( @@ -195,6 +194,25 @@ def with_hist_x( kind: str | HistogramKind = HistogramKind.density, shape: str | HistogramShape = HistogramShape.bars, ) -> Self: + """ + Configure the x-marginal canvas to have a histogram. + + Parameters + ---------- + bins : int or 1D array-like, default "auto" + Bins of the histogram. This parameter will directly be passed + to `np.histogram`. + limits : (float, float), optional + Limits in which histogram will be built. This parameter will equivalent to + the `range` paraneter of `np.histogram`. + name : str, optional + Name of the layer. + shape : {"step", "polygon", "bars"}, default "bars" + Shape of the histogram. This parameter defines how to convert the data into + the line nodes. + kind : {"count", "density", "probability", "frequency", "percent"}, optional + Kind of the histogram. + """ self._x_plotters.append( MarginalHistPlotter( Orientation.VERTICAL, bins=bins, limits=limits, kind=kind, shape=shape @@ -210,6 +228,25 @@ def with_hist_y( kind: str | HistogramKind = HistogramKind.density, shape: str | HistogramShape = HistogramShape.bars, ) -> Self: + """ + Configure the y-marginal canvas to have a histogram. + + Parameters + ---------- + bins : int or 1D array-like, default "auto" + Bins of the histogram. This parameter will directly be passed + to `np.histogram`. + limits : (float, float), optional + Limits in which histogram will be built. This parameter will equivalent to + the `range` paraneter of `np.histogram`. + name : str, optional + Name of the layer. + shape : {"step", "polygon", "bars"}, default "bars" + Shape of the histogram. This parameter defines how to convert the data into + the line nodes. + kind : {"count", "density", "probability", "frequency", "percent"}, optional + Kind of the histogram. + """ self._y_plotters.append( MarginalHistPlotter( Orientation.HORIZONTAL, bins=bins, limits=limits, kind=kind, shape=shape @@ -220,13 +257,36 @@ def with_hist_y( def with_hist( self, *, - bins: HistBinType = "auto", + bins: HistBinType | tuple[HistBinType, HistBinType] = "auto", limits: tuple[float, float] | None = None, kind: str | HistogramKind = HistogramKind.density, shape: str | HistogramShape = HistogramShape.bars, ) -> Self: - self.with_hist_x(bins=bins, limits=limits, kind=kind, shape=shape) - self.with_hist_y(bins=bins, limits=limits, kind=kind, shape=shape) + """ + Configure both of the marginal canvases to have histograms. + + Parameters + ---------- + bins : int or 1D array-like, default "auto" + Bins of the histogram. This parameter will directly be passed + to `np.histogram`. + limits : (float, float), optional + Limits in which histogram will be built. This parameter will equivalent to + the `range` paraneter of `np.histogram`. + name : str, optional + Name of the layer. + shape : {"step", "polygon", "bars"}, default "bars" + Shape of the histogram. This parameter defines how to convert the data into + the line nodes. + kind : {"count", "density", "probability", "frequency", "percent"}, optional + Kind of the histogram. + """ + if isinstance(bins, tuple): + bins_x, bins_y = bins + else: + bins_x = bins_y = bins + self.with_hist_x(bins=bins_x, limits=limits, kind=kind, shape=shape) + self.with_hist_y(bins=bins_y, limits=limits, kind=kind, shape=shape) return self def with_kde_x( diff --git a/whitecanvas/canvas/dataframe/_joint_cat.py b/whitecanvas/canvas/dataframe/_joint_cat.py index 8384ceb0..13b0fb7a 100644 --- a/whitecanvas/canvas/dataframe/_joint_cat.py +++ b/whitecanvas/canvas/dataframe/_joint_cat.py @@ -79,11 +79,13 @@ def add_markers( self._df, self._get_x(), color=color, hatch=hatch, backend=grid._backend ) # fmt: skip grid.x_canvas.add_layer(xlayer) + grid._link_marginal_to_main(xlayer, layer) for _y_plt in grid._iter_y_plotters(): ylayer = _y_plt.add_layer_for_cat_markers( self._df, self._get_y(), color=color, hatch=hatch, backend=grid._backend ) grid.y_canvas.add_layer(ylayer) + grid._link_marginal_to_main(ylayer, layer) return layer def add_hist2d( @@ -133,9 +135,11 @@ def add_hist2d( self._df, self._get_x(), bins=bins, limits=rangex, backend=grid._backend ) # fmt: skip grid.x_canvas.add_layer(xlayer) + grid._link_marginal_to_main(xlayer, layer) for _y_plt in grid._iter_y_plotters(): ylayer = _y_plt.add_layer_for_cat_hist2d( self._df, self._get_y(), bins=bins, limits=rangey, backend=grid._backend ) # fmt: skip grid.y_canvas.add_layer(ylayer) + grid._link_marginal_to_main(ylayer, layer) return layer diff --git a/whitecanvas/layers/_mixin.py b/whitecanvas/layers/_mixin.py index 4126b325..ed04b7cd 100644 --- a/whitecanvas/layers/_mixin.py +++ b/whitecanvas/layers/_mixin.py @@ -596,7 +596,7 @@ def __init__(self): super().__init__(MonoFace(self), MonoEdge(self)) def _make_sure_hatch_visible(self): - if self.edge.width == 0: + if self.face.hatch is not Hatch.SOLID and self.edge.width == 0: self.edge.width = 1 self.edge.color = get_theme().foreground_color @@ -669,6 +669,13 @@ def with_edge_multi( return self def _make_sure_hatch_visible(self): + # TODO: following lines are needed, but it might be slow. + # if isinstance(self.face, MonoFace): + # if self.face.hatch is Hatch.SOLID: + # return + # else: + # if np.all(self.face.hatch == Hatch.SOLID): + # return _is_no_width = self.edge.width == 0 if isinstance(self._edge_namespace, MultiEdge): if np.any(_is_no_width): diff --git a/whitecanvas/layers/_primitive/image.py b/whitecanvas/layers/_primitive/image.py index 6f4fa321..b3ed1118 100644 --- a/whitecanvas/layers/_primitive/image.py +++ b/whitecanvas/layers/_primitive/image.py @@ -260,10 +260,10 @@ def build_hist( _y = as_array_1d(y) if _x.size != _y.size: raise ValueError("x and y must have the same size.") - if isinstance(bins, (int, np.number, str)): - xbins = ybins = bins - else: + if isinstance(bins, tuple): xbins, ybins = bins + else: + xbins = ybins = bins if range is None: xrange = yrange = None else: diff --git a/whitecanvas/layers/group/line_fill.py b/whitecanvas/layers/group/line_fill.py index d18ea32e..59c8bb5e 100644 --- a/whitecanvas/layers/group/line_fill.py +++ b/whitecanvas/layers/group/line_fill.py @@ -51,7 +51,7 @@ def color(self) -> NDArray[np.float32]: def color(self, color: ColorType): self.line.color = color self.fill.face.update(color=color, alpha=self._fill_alpha) - self.fill.edge.width = 0.0 + self.fill.edge.update(color=color, alpha=self._fill_alpha) @property def fill_alpha(self) -> float: @@ -62,6 +62,7 @@ def fill_alpha(self) -> float: def fill_alpha(self, alpha: float): self._fill_alpha = alpha self.fill.face.alpha = alpha + self.fill.edge.alpha = alpha class Histogram(LineFillBase): From 84a925bf6802d6846fca622e46755e00f3bbeeb2 Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Fri, 16 Feb 2024 17:10:34 +0900 Subject: [PATCH 5/6] new example, better autoscaling --- examples/joint_grid.py | 23 ++++++++++++++++ whitecanvas/backend/bokeh/canvas.py | 3 +++ whitecanvas/backend/matplotlib/canvas.py | 7 +++++ whitecanvas/backend/mock/canvas.py | 3 +++ whitecanvas/backend/plotly/canvas.py | 6 ++++- whitecanvas/backend/pyqtgraph/canvas.py | 4 +++ whitecanvas/backend/vispy/canvas.py | 3 +++ whitecanvas/canvas/_base.py | 15 +++-------- whitecanvas/canvas/_grid.py | 4 +++ whitecanvas/canvas/_joint.py | 13 +++++++++ whitecanvas/canvas/dataframe/_joint_cat.py | 31 +++++++++++++++++++--- whitecanvas/layers/_base.py | 22 ++++----------- whitecanvas/layers/_primitive/bars.py | 1 + whitecanvas/layers/_primitive/line.py | 3 ++- whitecanvas/layers/_primitive/rug.py | 1 + whitecanvas/layers/group/line_fill.py | 2 ++ whitecanvas/layers/group/stemplot.py | 2 ++ whitecanvas/layers/tabular/_marker_like.py | 31 ++++++++++++++++++++++ whitecanvas/protocols/canvas_protocol.py | 3 +++ 19 files changed, 144 insertions(+), 33 deletions(-) create mode 100644 examples/joint_grid.py diff --git a/examples/joint_grid.py b/examples/joint_grid.py new file mode 100644 index 00000000..42ab4574 --- /dev/null +++ b/examples/joint_grid.py @@ -0,0 +1,23 @@ +import pandas as pd +from whitecanvas import new_jointgrid + +def main(): + url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv" + df = pd.read_csv(url).dropna() + + joint = ( + new_jointgrid("matplotlib:qt") + .with_hist_x(shape="step") # show histogram as the x-marginal distribution + .with_kde_y(width=2) # show kde as the y-marginal distribution + .with_rug(width=1) # show rug plot for both marginal distributions + ) + + layer = ( + joint.cat(df, x="bill_length_mm", y="flipper_length_mm") + .add_markers(color="species") + ) + + joint.show(block=True) + +if __name__ == "__main__": + main() diff --git a/whitecanvas/backend/bokeh/canvas.py b/whitecanvas/backend/bokeh/canvas.py index c2f7a8fb..a975b4ed 100644 --- a/whitecanvas/backend/bokeh/canvas.py +++ b/whitecanvas/backend/bokeh/canvas.py @@ -310,6 +310,9 @@ def _plt_set_figsize(self, width: int, height: int): self._grid_plot.width = width self._grid_plot.height = height + def _plt_set_spacings(self, wspace: float, hspace: float): + self._grid_plot.spacing = (int(wspace), int(hspace)) + def _iter_bokeh_subplots(self) -> Iterator[tuple[int, int, bk_plotting.figure]]: for child, r, c in self._grid_plot.children: yield r, c, child diff --git a/whitecanvas/backend/matplotlib/canvas.py b/whitecanvas/backend/matplotlib/canvas.py index 812cdbce..785eaa30 100644 --- a/whitecanvas/backend/matplotlib/canvas.py +++ b/whitecanvas/backend/matplotlib/canvas.py @@ -349,3 +349,10 @@ def _plt_set_figsize(self, width: int, height: int): dpi = self._fig.get_dpi() self._fig.set_size_inches(width / dpi, height / dpi) self._fig.tight_layout() + + def _plt_set_spacings(self, wspace: float, hspace: float): + dpi = self._fig.get_dpi() + nh, nw = self._gridspec.get_geometry() + w_avg = self._fig.get_figwidth() / nw * dpi + h_avg = self._fig.get_figheight() / nh * dpi + self._gridspec.update(hspace=hspace / h_avg, wspace=wspace / w_avg) diff --git a/whitecanvas/backend/mock/canvas.py b/whitecanvas/backend/mock/canvas.py index c1b11d08..9e159eb0 100644 --- a/whitecanvas/backend/mock/canvas.py +++ b/whitecanvas/backend/mock/canvas.py @@ -131,6 +131,9 @@ def _plt_screenshot(self): def _plt_set_figsize(self, width: int, height: int): self._figsize = (width, height) + def _plt_set_spacings(self, wspace: float, hspace: float): + pass + class _SupportsText: def __init__(self): diff --git a/whitecanvas/backend/plotly/canvas.py b/whitecanvas/backend/plotly/canvas.py index 45cdd4d1..552442ae 100644 --- a/whitecanvas/backend/plotly/canvas.py +++ b/whitecanvas/backend/plotly/canvas.py @@ -215,7 +215,7 @@ def __init__(self, heights: list[float], widths: list[float], app: str = "defaul column_widths=widths, ) ) - self._figs.update_layout(margin={"l": 6, "r": 6, "t": 6, "b": 6}) + self._figs.update_layout(margin={"l": 6, "r": 6, "t": 30, "b": 6}) self._app = app self._heights = heights self._widths = widths @@ -256,3 +256,7 @@ def _plt_screenshot(self): def _plt_set_figsize(self, width: int, height: int): self._figs.layout.width = width self._figs.layout.height = height + + def _plt_set_spacings(self, wspace: float, hspace: float): + # plotly does not have a flexible way to set spacings + pass diff --git a/whitecanvas/backend/pyqtgraph/canvas.py b/whitecanvas/backend/pyqtgraph/canvas.py index 9e70cc4c..50b4dc6e 100644 --- a/whitecanvas/backend/pyqtgraph/canvas.py +++ b/whitecanvas/backend/pyqtgraph/canvas.py @@ -294,6 +294,10 @@ def _plt_screenshot(self): def _plt_set_figsize(self, width: int, height: int): self._layoutwidget.resize(width, height) + def _plt_set_spacings(self, wspace: float, hspace: float): + self._layoutwidget.ci.layout.setHorizontalSpacing(wspace) + self._layoutwidget.ci.layout.setVerticalSpacing(hspace) + class SignalListener(pg.GraphicsObject): # Mouse events in pyqtgraph is very complicated. diff --git a/whitecanvas/backend/vispy/canvas.py b/whitecanvas/backend/vispy/canvas.py index 86ddcccf..344b666b 100644 --- a/whitecanvas/backend/vispy/canvas.py +++ b/whitecanvas/backend/vispy/canvas.py @@ -227,6 +227,9 @@ def _plt_show(self): def _plt_set_figsize(self, width: int, height: int): self._scene.size = (width, height) + def _plt_set_spacings(self, wspace: float, hspace: float): + self._grid.spacing = wspace + _APP_NAMES = { "qt4": "pyqt4", diff --git a/whitecanvas/canvas/_base.py b/whitecanvas/canvas/_base.py index 5bc7a1ec..089c8a5a 100644 --- a/whitecanvas/canvas/_base.py +++ b/whitecanvas/canvas/_base.py @@ -66,13 +66,6 @@ _L0 = TypeVar("_L0", _l.Bars, _l.Band) _void = _Void() -_ATTACH_TO_AXIS = ( - _l.Bars, - _lg.Histogram, - _lg.Kde, - _lg.StemPlot, -) - class CanvasEvents(SignalGroup): lims = Signal(Rect) @@ -1658,8 +1651,8 @@ def _autoscale_for_layer( dx = (xmax - xmin) * pad_rel if ( xmin != 0 - or not isinstance(layer, _ATTACH_TO_AXIS) - or layer.orient.is_vertical + or not layer._ATTACH_TO_AXIS + or getattr(layer, "orient", None) is not Orientation.HORIZONTAL ): xmin -= dx xmax += dx @@ -1672,8 +1665,8 @@ def _autoscale_for_layer( dy = (ymax - ymin) * pad_rel if ( ymin != 0 - or not isinstance(layer, _ATTACH_TO_AXIS) - or layer.orient.is_horizontal + or not layer._ATTACH_TO_AXIS + or getattr(layer, "orient", None) is not Orientation.VERTICAL ): ymin -= dy ymax += dy diff --git a/whitecanvas/canvas/_grid.py b/whitecanvas/canvas/_grid.py index 808532e5..5e12c19d 100644 --- a/whitecanvas/canvas/_grid.py +++ b/whitecanvas/canvas/_grid.py @@ -84,6 +84,8 @@ def link_x(self, *, future: bool = True, hide_ticks: bool = True) -> Self: self._x_linker_ref = link_axes(to_link) if future: self._x_linked = True + if hide_ticks: + self._backend_object._plt_set_spacings(6, 6) return self def link_y(self, *, future: bool = True, hide_ticks: bool = True) -> Self: @@ -109,6 +111,8 @@ def link_y(self, *, future: bool = True, hide_ticks: bool = True) -> Self: self._y_linker_ref = link_axes(to_link) if future: self._y_linked = True + if hide_ticks: + self._backend_object._plt_set_spacings(6, 6) return self def __repr__(self) -> str: diff --git a/whitecanvas/canvas/_joint.py b/whitecanvas/canvas/_joint.py index 13b0eed4..1b88d16a 100644 --- a/whitecanvas/canvas/_joint.py +++ b/whitecanvas/canvas/_joint.py @@ -89,6 +89,8 @@ def __init__( self._ynamespace_canvas = self._y_canvas self._main_canvas.y.ticks.visible = False + self._backend_object._plt_set_spacings(10, 10) + # link axes self._x_linker = link_axes([self._main_canvas.x, self._x_canvas.x]) self._y_linker = link_axes([self._main_canvas.y, self._y_canvas.y]) @@ -184,6 +186,7 @@ def add_markers( ) self.y_canvas.add_layer(ylayer) self._link_marginal_to_main(ylayer, out) + self._autoscale_layers() return out def with_hist_x( @@ -351,6 +354,16 @@ def with_rug(self, *, width: float | None = None) -> Self: self.with_rug_y(width=width) return self + def _autoscale_layers(self): + for layer in self.x_canvas.layers: + if isinstance(layer, (_l.Rug, _lt.DFRug)): + ylow, yhigh = self.x_canvas.y.lim + layer.update_length((yhigh - ylow) * 0.1) + for layer in self.y_canvas.layers: + if isinstance(layer, (_l.Rug, _lt.DFRug)): + xlow, xhigh = self.y_canvas.x.lim + layer.update_length((xhigh - xlow) * 0.1) + class MarginalPlotter(ABC): def __init__(self, orient: str | Orientation): diff --git a/whitecanvas/canvas/dataframe/_joint_cat.py b/whitecanvas/canvas/dataframe/_joint_cat.py index 13b0fb7a..341d8905 100644 --- a/whitecanvas/canvas/dataframe/_joint_cat.py +++ b/whitecanvas/canvas/dataframe/_joint_cat.py @@ -6,7 +6,7 @@ TypeVar, ) -from whitecanvas.canvas.dataframe._feature_cat import CatPlotter +from whitecanvas.canvas.dataframe._base import BaseCatPlotter from whitecanvas.layers import tabular as _lt from whitecanvas.layers.tabular import _jitter from whitecanvas.types import ColormapType, HistBinType @@ -20,7 +20,7 @@ _DF = TypeVar("_DF") -class JointCatPlotter(CatPlotter[_C, _DF]): +class JointCatPlotter(BaseCatPlotter[_C, _DF]): def __init__( self, canvas: _C, @@ -29,7 +29,30 @@ def __init__( y: str | None, update_labels: bool = False, ): - super().__init__(canvas, df, x, y, update_labels=update_labels) + super().__init__(canvas, df) + self._x = x + self._y = y + self._update_label = update_labels + if update_labels: + self._update_xy_label(x, y) + + def _get_x(self) -> str: + if self._x is None: + raise ValueError("Column for x-axis is not set") + return self._x + + def _get_y(self) -> str: + if self._y is None: + raise ValueError("Column for y-axis is not set") + return self._y + + def _update_xy_label(self, x: str | None, y: str | None) -> None: + """Update the x and y labels using the column names""" + canvas = self._canvas() + if isinstance(x, str): + canvas.x.label.text = x + if isinstance(y, str): + canvas.y.label.text = y def add_markers( self, @@ -86,6 +109,7 @@ def add_markers( ) grid.y_canvas.add_layer(ylayer) grid._link_marginal_to_main(ylayer, layer) + grid._autoscale_layers() return layer def add_hist2d( @@ -142,4 +166,5 @@ def add_hist2d( ) # fmt: skip grid.y_canvas.add_layer(ylayer) grid._link_marginal_to_main(ylayer, layer) + grid._autoscale_layers() return layer diff --git a/whitecanvas/layers/_base.py b/whitecanvas/layers/_base.py index c929a47e..0cc5078c 100644 --- a/whitecanvas/layers/_base.py +++ b/whitecanvas/layers/_base.py @@ -32,6 +32,7 @@ class LayerEvents(SignalGroup): class Layer(ABC): events: LayerEvents _events_class: type[LayerEvents] + _ATTACH_TO_AXIS = False def __init__(self, name: str | None = None): if not hasattr(self.__class__, "_events_class"): @@ -62,23 +63,6 @@ def name(self, name: str): """Set the name of this layer.""" self._name = str(name) - def expect(self, layer_type: _L, /) -> _L: - """ - A type guard for layers. - - >>> canvas.layers["scatter-layer-name"].expect(Line).color - """ - if not isinstance(layer_type, type) or issubclass(layer_type, PrimitiveLayer): - raise TypeError( - "Argument of `expect` must be a layer class, " - f"got {layer_type!r} (type: {type(layer_type).__name__}))" - ) - if not isinstance(self, layer_type): - raise TypeError( - f"Expected {layer_type.__name__}, got {type(self).__name__}" - ) - return self - def __repr__(self): return f"{self.__class__.__name__}<{self.name!r}>" @@ -320,6 +304,10 @@ def _disconnect_canvas(self, canvas: CanvasBase): self._base_layer._disconnect_canvas(canvas) return super()._disconnect_canvas(canvas) + @property + def _ATTACH_TO_AXIS(self) -> bool: + return self._base_layer._ATTACH_TO_AXIS + # deprecated, new _DEPRECATED = [ diff --git a/whitecanvas/layers/_primitive/bars.py b/whitecanvas/layers/_primitive/bars.py index 1a1094c5..240adbcb 100644 --- a/whitecanvas/layers/_primitive/bars.py +++ b/whitecanvas/layers/_primitive/bars.py @@ -65,6 +65,7 @@ class Bars( Edge properties of the bars. """ + _ATTACH_TO_AXIS = True events: BarEvents _events_class = BarEvents diff --git a/whitecanvas/layers/_primitive/line.py b/whitecanvas/layers/_primitive/line.py index de39e2d4..458ed314 100644 --- a/whitecanvas/layers/_primitive/line.py +++ b/whitecanvas/layers/_primitive/line.py @@ -488,6 +488,7 @@ def build_cdf( alpha: float = 1.0, width: float = 1.0, style: LineStyle | str = LineStyle.SOLID, + antialias: bool = True, backend: Backend | str | None = None, ): """Construct a line from a cumulative histogram.""" @@ -498,7 +499,7 @@ def build_cdf( if not Orientation.parse(orient).is_vertical: xdata, ydata = ydata, xdata return Line( - xdata, ydata, name=name, color=color, alpha=alpha, + xdata, ydata, name=name, color=color, alpha=alpha, antialias=antialias, width=width, style=style, backend=backend, ) # fmt: skip diff --git a/whitecanvas/layers/_primitive/rug.py b/whitecanvas/layers/_primitive/rug.py index aa55cd07..7f3e1be0 100644 --- a/whitecanvas/layers/_primitive/rug.py +++ b/whitecanvas/layers/_primitive/rug.py @@ -37,6 +37,7 @@ class Rug(MultiLine, HoverableDataBoundLayer[MultiLineProtocol, NDArray[np.numbe ──┴─┴┴──┴───┴──> """ + _ATTACH_TO_AXIS = True events: MultiLineEvents _events_class = MultiLineEvents diff --git a/whitecanvas/layers/group/line_fill.py b/whitecanvas/layers/group/line_fill.py index 59c8bb5e..9f978255 100644 --- a/whitecanvas/layers/group/line_fill.py +++ b/whitecanvas/layers/group/line_fill.py @@ -23,6 +23,8 @@ class LineFillBase(LayerContainer): + _ATTACH_TO_AXIS = True + def __init__(self, line: Line, fill: Band, name: str | None = None): super().__init__([line, fill], name=name) self._fill_alpha = 0.2 diff --git a/whitecanvas/layers/group/stemplot.py b/whitecanvas/layers/group/stemplot.py index ab976d2f..6375585c 100644 --- a/whitecanvas/layers/group/stemplot.py +++ b/whitecanvas/layers/group/stemplot.py @@ -11,6 +11,8 @@ class StemPlot(LayerContainer): + _ATTACH_TO_AXIS = True + def __init__( self, markers: Markers, diff --git a/whitecanvas/layers/tabular/_marker_like.py b/whitecanvas/layers/tabular/_marker_like.py index 610e0f64..edd0cde1 100644 --- a/whitecanvas/layers/tabular/_marker_like.py +++ b/whitecanvas/layers/tabular/_marker_like.py @@ -11,6 +11,7 @@ import numpy as np from cmap import Color, Colormap +from numpy.typing import NDArray from whitecanvas import layers as _l from whitecanvas import theme @@ -543,6 +544,36 @@ def _apply_style(self, style): # def update_scale(self, by: str | float, align: str = "low") -> Self: # ... + def update_length( + self, + lengths: float | NDArray[np.number], + *, + offset: float | None = None, + align: str = "low", + ) -> Self: + """ + Update the length of the rug lines. + + Parameters + ---------- + lengths : float or array-like + Length of the rug lines. If a scalar, all the lines have the same length. + If an array, each line has a different length. + offset : float, optional + Offset of the lines. If not given, the mean of the lower and upper bounds is + used. + align : {'low', 'high', 'center'}, optional + How to align the rug lines around the offset. This parameter is defined as + follows. + + ``` + "low" "high" "center" + ──┴─┴── ──┬─┬── ──┼─┼── + ``` + """ + self.base.update_length(lengths=lengths, offset=offset, align=align) + return self + def with_hover_template(self, template: str) -> Self: """Set the hover tooltip template for the layer.""" extra = dict(self._source.iter_items()) diff --git a/whitecanvas/protocols/canvas_protocol.py b/whitecanvas/protocols/canvas_protocol.py index e553bb12..ebbd5520 100644 --- a/whitecanvas/protocols/canvas_protocol.py +++ b/whitecanvas/protocols/canvas_protocol.py @@ -214,3 +214,6 @@ def _plt_show(self): def _plt_set_figsize(self, width: int, height: int): """Set size of canvas in pixels.""" + + def _plt_set_spacings(self, wspace: float, hspace: float): + """Set spacing between subplots""" From 976a3f2d7cf724d41e9be7f377a3bac5bc5c7d6b Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Fri, 16 Feb 2024 17:56:52 +0900 Subject: [PATCH 6/6] fix docs --- whitecanvas/canvas/_joint.py | 66 ++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/whitecanvas/canvas/_joint.py b/whitecanvas/canvas/_joint.py index 1b88d16a..0cad8646 100644 --- a/whitecanvas/canvas/_joint.py +++ b/whitecanvas/canvas/_joint.py @@ -208,8 +208,6 @@ def with_hist_x( limits : (float, float), optional Limits in which histogram will be built. This parameter will equivalent to the `range` paraneter of `np.histogram`. - name : str, optional - Name of the layer. shape : {"step", "polygon", "bars"}, default "bars" Shape of the histogram. This parameter defines how to convert the data into the line nodes. @@ -242,8 +240,6 @@ def with_hist_y( limits : (float, float), optional Limits in which histogram will be built. This parameter will equivalent to the `range` paraneter of `np.histogram`. - name : str, optional - Name of the layer. shape : {"step", "polygon", "bars"}, default "bars" Shape of the histogram. This parameter defines how to convert the data into the line nodes. @@ -276,8 +272,6 @@ def with_hist( limits : (float, float), optional Limits in which histogram will be built. This parameter will equivalent to the `range` paraneter of `np.histogram`. - name : str, optional - Name of the layer. shape : {"step", "polygon", "bars"}, default "bars" Shape of the histogram. This parameter defines how to convert the data into the line nodes. @@ -299,6 +293,18 @@ def with_kde_x( band_width: KdeBandWidthType = "scott", fill_alpha: float = 0.2, ) -> Self: + """ + Configure the x-marginal canvas to have a kernel density estimate (KDE) plot. + + Parameters + ---------- + width : float, optional + Width of the line. Use theme default if not specified. + band_width : "scott", "silverman" or float, default "scott" + Bandwidth of the kernel. + fill_alpha : float, default 0.2 + Alpha value of the fill color. + """ width = theme._default("line.width", width) self._x_plotters.append( MarginalKdePlotter( @@ -317,6 +323,18 @@ def with_kde_y( band_width: KdeBandWidthType = "scott", fill_alpha: float = 0.2, ) -> Self: + """ + Configure the y-marginal canvas to have a kernel density estimate (KDE) plot. + + Parameters + ---------- + width : float, optional + Width of the line. Use theme default if not specified. + band_width : "scott", "silverman" or float, default "scott" + Bandwidth of the kernel. + fill_alpha : float, default 0.2 + Alpha value of the fill color. + """ width = theme._default("line.width", width) self._y_plotters.append( MarginalKdePlotter( @@ -335,21 +353,57 @@ def with_kde( band_width: KdeBandWidthType = "scott", fill_alpha: float = 0.2, ) -> Self: + """ + Configure both of the marginal canvases to have KDE plots. + + Parameters + ---------- + width : float, optional + Width of the line. Use theme default if not specified. + band_width : "scott", "silverman" or float, default "scott" + Bandwidth of the kernel. + fill_alpha : float, default 0.2 + Alpha value of the fill color. + """ self.with_kde_x(width=width, band_width=band_width, fill_alpha=fill_alpha) self.with_kde_y(width=width, band_width=band_width, fill_alpha=fill_alpha) return self def with_rug_x(self, *, width: float | None = None) -> Self: + """ + Configure the x-marginal canvas to have a rug plot. + + Parameters + ---------- + width : float, optional + Width of the line. Use theme default if not specified. + """ width = theme._default("line.width", width) self._x_plotters.append(MarginalRugPlotter(Orientation.VERTICAL, width=width)) return self def with_rug_y(self, *, width: float | None = None) -> Self: + """ + Configure the y-marginal canvas to have a rug plot. + + Parameters + ---------- + width : float, optional + Width of the line. Use theme default if not specified. + """ width = theme._default("line.width", width) self._y_plotters.append(MarginalRugPlotter(Orientation.HORIZONTAL, width=width)) return self def with_rug(self, *, width: float | None = None) -> Self: + """ + Configure both of the marginal canvases to have rug plots. + + Parameters + ---------- + width : float, optional + Width of the line. Use theme default if not specified. + """ self.with_rug_x(width=width) self.with_rug_y(width=width) return self