Skip to content

Commit

Permalink
Merge pull request #44 from hanjinliu/update-vispy
Browse files Browse the repository at this point in the history
Vispy grid lines
  • Loading branch information
hanjinliu authored Mar 12, 2024
2 parents 246c74b + 337bf04 commit 2214c5f
Show file tree
Hide file tree
Showing 18 changed files with 169 additions and 72 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
dependencies = [
"typing_extensions>=4.5.0",
"numpy>=1.23.2",
"psygnal>=0.9.4,<0.10.0",
"psygnal>=0.9.4,!=0.10.0",
"cmap>=0.1.2",
]

Expand Down
2 changes: 1 addition & 1 deletion whitecanvas/backend/_window/_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, canvas: CanvasGrid):
dock = QtW.QDockWidget("Dimensions", self)
dock.setWidget(sl)
self.addDockWidget(QtCore.Qt.DockWidgetArea.BottomDockWidgetArea, dock)
canvas.events.drawn.connect(self._widget._update_widget_state)
canvas.events.drawn.connect(self._widget._update_widget_state, max_args=0)
self.__class__._instance = self


Expand Down
4 changes: 2 additions & 2 deletions whitecanvas/backend/_window/_tk.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, canvas: CanvasGrid):
sl = TkDimSliders(self)
sl.pack(fill="both", expand=True)
self.__class__._instance = self
canvas.events.drawn.connect(tkcanvas._update_imagetk)
canvas.events.drawn.connect(tkcanvas._update_imagetk, max_args=0)


class TkSlider(ttk.Scale):
Expand Down Expand Up @@ -98,7 +98,7 @@ def set_axes(self, axes: list[DimAxis]):
self._widgets[axis.name] = slider

def connect_changed(self, callback: Callable[[dict[str, object]], None]):
self.changed.connect(callback)
self.changed.connect(callback, max_args=1)

def _emit_changed(self):
values = {}
Expand Down
4 changes: 2 additions & 2 deletions whitecanvas/backend/mock/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def _plt_draw(self):
def _plt_connect_xlim_changed(
self, callback: Callable[[tuple[float, float]], None]
):
self._xaxis.lim_changed.connect(callback)
self._xaxis.lim_changed.connect(callback, max_args=1)

def _plt_connect_ylim_changed(
self, callback: Callable[[tuple[float, float]], None]
):
self._yaxis.lim_changed.connect(callback)
self._yaxis.lim_changed.connect(callback, max_args=1)

def _plt_make_legend(self, *args, **kwargs):
pass
Expand Down
78 changes: 78 additions & 0 deletions whitecanvas/backend/vispy/_gridlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import warnings
import weakref
from typing import TYPE_CHECKING

import numpy as np
from vispy.scene import visuals

from whitecanvas.backend.vispy.line import _make_connection, _safe_concat

if TYPE_CHECKING:
from whitecanvas.backend.vispy.canvas import Canvas


class LineProperties:
def __init__(self):
self.visible = False
self.color = np.ones(4, dtype=np.float32)
self.width = 1.0
self.style = "solid"


class GridLines(visuals.Compound):
def __init__(self, canvas: Canvas):
self._xscale = 0
self._yscale = 0
self._canvas_ref = weakref.ref(canvas)
self._xlines = visuals.Line()
self._ylines = visuals.Line()
self._xprops = LineProperties()
self._yprops = LineProperties()
super().__init__([self._xlines, self._ylines])
canvas._viewbox.add(self)
self.order = -99999

def set_x_grid_lines(self, visible: bool, color, width: float, style):
self._xprops.visible = visible
self._xprops.color = color
self._xprops.width = width
self._xprops.style = style
self.update()

def set_y_grid_lines(self, visible: bool, color, width: float, style):
self._yprops.visible = visible
self._yprops.color = color
self._yprops.width = width
self._yprops.style = style
self.update()

def _prepare_draw(self, view):
if not (self._xprops.visible or self._yprops.visible):
return super()._prepare_draw(view)
if canvas := self._canvas_ref():
rect = canvas._camera.rect
xmin, xmax = rect.left, rect.right
ymin, ymax = rect.bottom, rect.top
if self._xprops.visible:
xmajor_pos = canvas._xticks._get_ticker()._get_tick_frac_labels()[0]
xmajor = xmajor_pos * (xmax - xmin) + xmin
xdata = [np.array([[x, ymin], [x, ymax]]) for x in xmajor]
self._xlines.set_data(
pos=_safe_concat(xdata),
color=self._xprops.color,
width=self._xprops.width,
connect=_make_connection(xdata),
)
if self._yprops.visible:
ymajor_pos = canvas._yticks._get_ticker()._get_tick_frac_labels()[0]
ymajor = ymajor_pos * (ymax - ymin) + ymin
ydata = [np.array([[xmin, y], [xmax, y]]) for y in ymajor]
self._ylines.set_data(
pos=_safe_concat(ydata),
color=self._yprops.color,
width=self._yprops.width,
connect=_make_connection(ydata),
)
return super()._prepare_draw(view)
52 changes: 29 additions & 23 deletions whitecanvas/backend/vispy/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from whitecanvas.backend.vispy.canvas import Camera, Canvas

FONT_SIZE_FACTOR = 2.0


class TextLabel(scene.Label):
_text_visual: TextVisual
Expand All @@ -39,10 +37,10 @@ def _plt_set_color(self, color):
self._text_visual.color = color

def _plt_get_size(self) -> int:
return self._text_visual.font_size * FONT_SIZE_FACTOR
return self._text_visual.font_size

def _plt_set_size(self, size: int):
self._text_visual.font_size = size / FONT_SIZE_FACTOR
self._text_visual.font_size = size

def _plt_get_fontfamily(self) -> str:
return self._text_visual.face
Expand Down Expand Up @@ -105,12 +103,11 @@ def _plt_flip(self) -> None:
camera.flip = tuple(flipped)

def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineStyle):
# if visible:
# self._canvas()._gridlines.visible = True
# self._canvas()._gridlines._grid_color_fn['color'] = color
# else:
# self._canvas()._gridlines.visible = False
pass # TODO: implement this
grid_lines = self._canvas_ref()._grid_lines
if self._dim == 0: # y
grid_lines.set_y_grid_lines(visible, color, width, style)
else:
grid_lines.set_x_grid_lines(visible, color, width, style)


class Ticks:
Expand Down Expand Up @@ -152,12 +149,17 @@ def _plt_get_visible(self) -> bool:

def _plt_set_visible(self, visible: bool):
self._get_ticker().visible = visible
# axis = self._axis()
# if axis._dim == 0: # y
# axis.width_min = axis.width_max = 40 if visible else 0
# else:
# axis.height_min = axis.height_max = 50 if visible else 0

def _plt_get_size(self) -> float:
return self._text.font_size * FONT_SIZE_FACTOR
return self._text.font_size

def _plt_set_size(self, size: float):
self._text.font_size = size / FONT_SIZE_FACTOR
self._text.font_size = size

def _plt_get_fontfamily(self) -> str:
return self._text.face
Expand Down Expand Up @@ -185,17 +187,21 @@ def __init__(self, axis: AxisVisual):

def _get_tick_frac_labels(self):
if not self._visible:
return np.zeros(0), np.zeros(0), np.zeros(0)
if self._categorical_labels is None:
return super()._get_tick_frac_labels()
pos, labels = self._categorical_labels
domain = self.axis.domain
scale = domain[1] - domain[0]
major_tick_fractions = (np.asarray(pos) - domain[0]) / scale
minor_tick_fractions = np.zeros(0)
ok = (0 <= major_tick_fractions) & (major_tick_fractions <= 1)
tick_labels = np.asarray(labels)[ok]
return major_tick_fractions[ok], minor_tick_fractions, tick_labels
major, minor, labels = np.zeros(0), np.zeros(0), np.zeros(0)
elif self._categorical_labels is None:
major, minor, labels = super()._get_tick_frac_labels()
else:
pos, labels = self._categorical_labels
domain = self.axis.domain
scale = domain[1] - domain[0]
major_tick_fractions = (np.asarray(pos) - domain[0]) / scale
minor_tick_fractions = np.zeros(0)
ok = (0 <= major_tick_fractions) & (major_tick_fractions <= 1)
tick_labels = np.asarray(labels)[ok]
major = major_tick_fractions[ok]
minor = minor_tick_fractions
labels = tick_labels
return major, minor, labels

@property
def visible(self):
Expand Down
3 changes: 2 additions & 1 deletion whitecanvas/backend/vispy/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vispy.util import keys

from whitecanvas import protocols
from whitecanvas.backend.vispy._gridlines import GridLines
from whitecanvas.backend.vispy._label import Axis, TextLabel, Ticks
from whitecanvas.types import Modifier, MouseButton, MouseEvent, MouseEventType

Expand Down Expand Up @@ -72,11 +73,11 @@ def __init__(self, viewbox: ViewBox):
y_axis.stretch = (0.1, 1)
grid.add_widget(y_axis, row=1, col=0)
y_axis.link_view(self._viewbox)
self._grid_lines = GridLines(self)
self._xaxis = x_axis
self._yaxis = y_axis
self._xticks = Ticks(x_axis)
self._yticks = Ticks(y_axis)
self._title = TextLabel("")
self._xlabel = TextLabel("")
self._ylabel = TextLabel("")
self._grid = grid
Expand Down
10 changes: 4 additions & 6 deletions whitecanvas/backend/vispy/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from whitecanvas.utils.normalize import as_color_array
from whitecanvas.utils.type_check import is_real_number

FONT_SIZE_FACTOR = 2.0


@check_protocol(TextProtocol)
class Texts(visuals.Text):
Expand Down Expand Up @@ -65,15 +63,15 @@ def _plt_set_text_color(self, color):
self.color = col

def _plt_get_text_size(self) -> NDArray[np.floating]:
return np.full(self._plt_get_ndata(), self.font_size * FONT_SIZE_FACTOR)
return np.full(self._plt_get_ndata(), self.font_size)

def _plt_set_text_size(self, size: float | NDArray[np.floating]):
if is_real_number(size):
self.font_size = size / FONT_SIZE_FACTOR
self.font_size = size
else:
candidates = np.unique(size)
if candidates.size == 1:
self.font_size = candidates[0] / FONT_SIZE_FACTOR
self.font_size = candidates[0]
elif candidates.size == 0:
pass
else:
Expand All @@ -83,7 +81,7 @@ def _plt_set_text_size(self, size: float | NDArray[np.floating]):
UserWarning,
stacklevel=4,
)
self.font_size = np.mean(size) / FONT_SIZE_FACTOR
self.font_size = np.mean(size)

def _plt_get_text_position(
self,
Expand Down
24 changes: 16 additions & 8 deletions whitecanvas/canvas/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,22 @@ def _init_canvas(self):
self.y.ticks.update(family=_ft.family, color=_ft.color, size=_ft.size)

# connect layer events
self.layers.events.inserted.connect(self._cb_inserted, unique=True)
self.layers.events.removed.connect(self._cb_removed, unique=True)
self.layers.events.reordered.connect(self._cb_reordered, unique=True)
self.layers.events.connect(self._draw_canvas, unique=True)

self.overlays.events.inserted.connect(self._cb_overlay_inserted, unique=True)
self.overlays.events.removed.connect(self._cb_removed, unique=True)
self.overlays.events.connect(self._draw_canvas, unique=True)
self.layers.events.inserted.connect(
self._cb_inserted, unique=True, max_args=None
)
self.layers.events.removed.connect(self._cb_removed, unique=True, max_args=None)
self.layers.events.reordered.connect(
self._cb_reordered, unique=True, max_args=None
)
self.layers.events.connect(self._draw_canvas, unique=True, max_args=None)

self.overlays.events.inserted.connect(
self._cb_overlay_inserted, unique=True, max_args=None
)
self.overlays.events.removed.connect(
self._cb_removed, unique=True, max_args=None
)
self.overlays.events.connect(self._draw_canvas, unique=True, max_args=None)

canvas = self._canvas()
canvas._plt_connect_xlim_changed(self._emit_xlim_changed)
Expand Down
2 changes: 1 addition & 1 deletion whitecanvas/canvas/_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, canvas: CanvasBase | None = None):
self.events = DimsEvents()
if canvas is not None:
self._canvas_ref = weakref.ref(canvas)
self.events.indices.connect(canvas._draw_canvas, unique=True)
self.events.indices.connect(canvas._draw_canvas, unique=True, max_args=0)
else:
self._canvas_ref = lambda: None

Expand Down
6 changes: 4 additions & 2 deletions whitecanvas/canvas/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def add_canvas(
self._x_linker_ref.link(canvas.x)
if self._y_linked:
self._y_linker_ref.link(canvas.y)
canvas.events.drawn.connect(self.events.drawn.emit, unique=True)
canvas.events.drawn.connect(self.events.drawn.emit, unique=True, max_args=None)
return canvas

def _iter_add_canvas(self, **kwargs) -> Iterator[Canvas]:
Expand Down Expand Up @@ -410,4 +410,6 @@ def __init__(self, grid: CanvasGrid):
# SingleCanvas instance. To avoid confusion, the first and the only canvas
# should be replaces with the SingleCanvas instance.
grid._canvas_array[0, 0] = self
self.events.drawn.connect(self._main_canvas.events.drawn.emit, unique=True)
self.events.drawn.connect(
self._main_canvas.events.drawn.emit, unique=True, max_args=None
)
2 changes: 1 addition & 1 deletion whitecanvas/canvas/_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def link(self, axis: AxisNamespace):
warnings.warn(f"Axis {axis} already linked", RuntimeWarning, stacklevel=2)
return
self._axis_set.add(axis)
axis.events.lim.connect(self.set_limits)
axis.events.lim.connect(self.set_limits, max_args=1)

def unlink(self, axis: AxisNamespace):
"""Unlink an axis."""
Expand Down
6 changes: 4 additions & 2 deletions whitecanvas/layers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def __repr__(self):

def _connect_canvas(self, canvas: CanvasBase):
"""If needed, do something when layer is added to a canvas."""
self.events._layer_grouped.connect(canvas._cb_layer_grouped, unique=True)
self.events.connect(canvas._draw_canvas, unique=True)
self.events._layer_grouped.connect(
canvas._cb_layer_grouped, unique=True, max_args=1
)
self.events.connect(canvas._draw_canvas, unique=True, max_args=0)
self._canvas_ref = weakref.ref(canvas)

def _disconnect_canvas(self, canvas: CanvasBase):
Expand Down
Loading

0 comments on commit 2214c5f

Please sign in to comment.