diff --git a/whitecanvas/tools/__init__.py b/whitecanvas/tools/__init__.py index a31a0f50..bc732c23 100644 --- a/whitecanvas/tools/__init__.py +++ b/whitecanvas/tools/__init__.py @@ -1,8 +1,10 @@ """Built-in tools.""" from whitecanvas.tools._selection import ( + SelectionManager, lasso_selector, line_selector, + point_selector, polygon_selector, rect_selector, xspan_selector, @@ -12,8 +14,10 @@ __all__ = [ "line_selector", "rect_selector", + "point_selector", "xspan_selector", "yspan_selector", "polygon_selector", "lasso_selector", + "SelectionManager", ] diff --git a/whitecanvas/tools/_selection.py b/whitecanvas/tools/_selection.py index 57efba53..07494e81 100644 --- a/whitecanvas/tools/_selection.py +++ b/whitecanvas/tools/_selection.py @@ -4,9 +4,9 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, + Any, Generic, Literal, - NamedTuple, Sequence, TypeVar, overload, @@ -17,17 +17,26 @@ from psygnal import Signal from whitecanvas.canvas import CanvasBase -from whitecanvas.layers import Layer, Line, Rects, Spans -from whitecanvas.tools._polygon_utils import is_in_polygon +from whitecanvas.layers import Layer, Line, Markers, Rects, Spans +from whitecanvas.tools._selection_types import ( + LineSelection, + PointSelection, + PolygonSelection, + RectSelection, + SelectionMode, + SpanSelection, + XSpanSelection, + YSpanSelection, +) from whitecanvas.types import ( ColorType, + Hatch, LineStyle, Modifier, MouseButton, MouseEvent, MouseEventType, Point, - Rect, XYData, _Void, ) @@ -75,7 +84,7 @@ def __exit__(self, *args): def _canvas(self) -> CanvasBase: canvas = self._canvas_ref() if canvas is None: - raise RuntimeError("Canvas has been deleted") + raise RuntimeError(f"Canvas is not connected to {self!r}.") return canvas @abstractmethod @@ -183,8 +192,8 @@ def _update_layer( self._layer.visible = True @property - def selection(self) -> Rect: - return self._layer.rects[0] + def selection(self) -> RectSelection: + return RectSelection(*self._layer.rects[0]) @property def face(self) -> ConstFace: @@ -202,29 +211,42 @@ def contains_point(self, point: tuple[float, float], /) -> bool: ... def contains_point(self, x: float, y: float, /) -> bool: ... def contains_point(self, *args) -> bool: - x, y = _point_to_xy(*args) - sel = self.selection - return sel.left <= x <= sel.right and sel.bottom <= y <= sel.top + return self.selection.contains_point(*args) def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: - points = _atleast_2d(points) - sel = self.selection - if points.ndim == 2 and points.shape[1] == 2: - xs = points[:, 0] - ys = points[:, 1] - return ( - (sel.left <= xs) - & (xs <= sel.right) - & (sel.bottom <= ys) - & (ys <= sel.top) - ) - else: - raise ValueError("points must be (2,) or (N, 2) array.") + return self.selection.contains_points(points) + +class PointSelectionTool(SelectionToolBase[Markers]): + def _create_layer(self) -> Markers: + return Markers(np.array([]), np.array([]), size=10, color="blue", alpha=0.4) -class LineSelection(NamedTuple): - start: Point - end: Point + def _update_layer( + self, + start: tuple[float, float], + now: tuple[float, float], + ): + x, y = now + self._layer.data = np.array([x]), np.array([y]) + + @property + def selection(self) -> Point: + xs, ys = self._layer.data + return PointSelection(xs[0], ys[0]) + + def update( + self, + *, + color: ColorType | _Void = _void, + symbol: str | _Void = _void, + size: float | _Void = _void, + alpha: float | _Void = _void, + hatch: str | Hatch | _Void = _void, + ): + self._layer.update( + color=color, symbol=symbol, size=size, alpha=alpha, hatch=hatch + ) + return self class LineSelectionTool(SelectionToolBase[Line]): @@ -254,6 +276,7 @@ def update( alpha: float | _Void = _void, ): self._layer.update(color=color, width=width, style=style, alpha=alpha) + return self @property def color(self) -> NDArray[np.float32]: @@ -292,16 +315,10 @@ def alpha(self, alpha: float): self._layer.alpha = alpha -class SpanSelection(NamedTuple): - start: float - end: float - - class _SpanSelectionTool(SelectionToolBase[Spans]): @property - def selection(self) -> SpanSelection: - span = self._layer.data[0] - return SpanSelection(*sorted(span)) + @abstractmethod + def selection(self) -> SpanSelection: ... @property def face(self) -> ConstFace: @@ -313,8 +330,24 @@ def edge(self) -> ConstEdge: """Edge color of the selection span.""" return self._layer.edge + @overload + def contains_point(self, point: tuple[float, float], /) -> bool: ... + @overload + def contains_point(self, x: float, y: float, /) -> bool: ... + + def contains_point(self, *args) -> bool: + return self.selection.contains_point(*args) + + def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: + return self.selection.contains_points(points) + class XSpanSelectionTool(_SpanSelectionTool): + @property + def selection(self) -> XSpanSelection: + span = self._layer.data[0] + return XSpanSelection(*sorted(span)) + def _create_layer(self) -> Spans: layer = Spans([[0, 1]], orient="vertical", color="red", alpha=0.25) layer.visible = False @@ -330,27 +363,13 @@ def _update_layer( self._layer.data = np.array([[x0, x1]], dtype=np.float32) self._layer.visible = True - @overload - def contains_point(self, point: tuple[float, float], /) -> bool: ... - @overload - def contains_point(self, x: float, y: float, /) -> bool: ... - - def contains_point(self, *args) -> bool: - x, _ = _point_to_xy(*args) - sel = self.selection - return sel.start <= x <= sel.end - - def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: - points = _atleast_2d(points) - sel = self.selection - if points.ndim == 2 and points.shape[1] == 2: - xs = points[:, 0] - return (sel.start <= xs) & (xs <= sel.end) - else: - raise ValueError("points must be (2,) or (N, 2) array.") - class YSpanSelectionTool(_SpanSelectionTool): + @property + def selection(self) -> YSpanSelection: + span = self._layer.data[0] + return YSpanSelection(*sorted(span)) + def _create_layer(self) -> Spans: layer = Spans([[0, 1]], orient="horizontal", color="red", alpha=0.25) layer.visible = False @@ -366,25 +385,6 @@ def _update_layer( self._layer.data = np.array([[y0, y1]], dtype=np.float32) self._layer.visible = True - @overload - def contains_point(self, point: tuple[float, float], /) -> bool: ... - @overload - def contains_point(self, x: float, y: float, /) -> bool: ... - - def contains_point(self, *args) -> bool: - _, y = _point_to_xy(*args) - sel = self.selection - return sel.start <= y <= sel.end - - def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: - points = _atleast_2d(points) - sel = self.selection - if points.ndim == 2 and points.shape[1] == 2: - ys = points[:, 1] - return (sel.start <= ys) & (ys <= sel.end) - else: - raise ValueError("points must be (2,) or (N, 2) array.") - class LassoSelectionTool(LineSelectionTool): def _update_layer( @@ -399,8 +399,8 @@ def _update_layer( self._layer.data = xs, ys @property - def selection(self) -> XYData: - return self._layer.data + def selection(self) -> PolygonSelection: + return PolygonSelection(*self._layer.data) def _on_press(self, start: tuple[float, float]): self._layer.data = np.array([start[0]]), np.array([start[1]]) @@ -429,20 +429,31 @@ def contains_point(self, point: tuple[float, float], /) -> bool: ... def contains_point(self, x: float, y: float, /) -> bool: ... def contains_point(self, *args) -> bool: - x, y = _point_to_xy(*args) - poly = self._layer.data - return is_in_polygon(np.array([[x, y]]), poly.stack())[0] + return self.selection.contains_point(*args) def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: - points = _atleast_2d(points) - poly = self._layer.data - if points.ndim == 2 and points.shape[1] == 2: - return is_in_polygon(points, poly.stack()) - else: - raise ValueError("points must be (2,) or (N, 2) array.") + return self.selection.contains_points(points) class PolygonSelectionTool(LassoSelectionTool): + def __init__( + self, + canvas: CanvasBase, + buttons: list[MouseButton], + modifiers: list[Modifier], + tracking: bool = False, + auto_close: bool = False, + ): + super().__init__(canvas, buttons, modifiers, tracking) + self._auto_close = auto_close + + def _redraw_layer(self, now: tuple[float, float]): + cur_data = self._layer.data + x0, y0 = now + xs = np.concatenate([cur_data.x[:-1], [x0]]) + ys = np.concatenate([cur_data.y[:-1], [y0]]) + self._layer.data = xs, ys + def callback(self, e: MouseEvent): """The callback function that is called when mouse is moved.""" if not self._enabled: @@ -459,19 +470,21 @@ def callback(self, e: MouseEvent): with canvas.autoscale_context(enabled=False): canvas.add_layer(self._layer) + while e.type is not MouseEventType.RELEASE: + yield # dragging + self._update_layer(pos_start, e.pos) while True: while e.type is not MouseEventType.RELEASE: yield # dragging - self._update_layer(pos_start, e.pos) + self._redraw_layer(e.pos) yield while e.button is MouseButton.NONE: - cur_data = self._layer.data - x0, y0 = e.pos - xs = np.concatenate([cur_data.x[:-1], [x0]]) - ys = np.concatenate([cur_data.y[:-1], [y0]]) - self._layer.data = xs, ys + self._redraw_layer(e.pos) yield if e.type is MouseEventType.DOUBLE_CLICK: + self._remove_duplicates() + if self._auto_close: + self.close_path(emit=True) break elif e.button in self._valid_buttons: if e.type is MouseEventType.PRESS: @@ -488,36 +501,40 @@ def callback(self, e: MouseEvent): with self.cleared.blocked(): self.clear_selection() + def _remove_duplicates(self): + xs, ys = self._layer.data + if xs.size >= 2 and xs[-2] == xs[-1] and ys[-2] == ys[-1]: + self._layer.data = xs[:-1], ys[:-1] + return + def _norm_input( buttons: _MouseButton | Sequence[_MouseButton] = "left", modifiers: _Modifier | Sequence[_Modifier] | None = None, ): + return _norm_button(buttons), _norm_modifier(modifiers) + + +def _norm_button( + buttons: _MouseButton | Sequence[_MouseButton] = "left", +) -> list[MouseButton]: if isinstance(buttons, (str, MouseButton)): buttons = [buttons] _buttons = [MouseButton(btn) for btn in buttons] if MouseButton.NONE in _buttons: raise ValueError("MouseButton.NONE is not allowed.") + return _buttons + + +def _norm_modifier( + modifiers: _Modifier | Sequence[_Modifier] | None = None, +) -> list[Modifier]: if modifiers is None: modifiers = [] elif isinstance(modifiers, (str, Modifier)): modifiers = [modifiers] _modifiers = [Modifier(mod) for mod in modifiers] - return _buttons, _modifiers - - -def _atleast_2d(points: NDArray[np.number]) -> NDArray[np.number]: - if isinstance(points, XYData): - return points.stack() - return np.atleast_2d(points) - - -def _point_to_xy(*args) -> tuple[float, float]: - if len(args) == 1: - x, y = args[0] - else: - x, y = args - return x, y + return _modifiers def line_selector( @@ -556,6 +573,37 @@ def line_selector( return LineSelectionTool(canvas, _buttons, _modifiers, tracking=tracking) +def point_selector( + canvas: CanvasBase, + buttons: _MouseButton | Sequence[_MouseButton] = "left", + modifiers: _Modifier | Sequence[_Modifier] | None = None, +) -> PointSelectionTool: + """ + Create a point selector tool with given settings. + + A point selector emits a Point object when a point is drawn. + A selection tool is constructed by specifying the canvas to attach the tool. + + >>> canvas = new_canvas("matplotlib:qt") + >>> tool = point_selector(canvas) + + Use `buttons` and `modifiers` to specify how to trigger the tool. + + >>> tool = point_selector(canvas, buttons="right", modifiers="ctrl") + + Parameters + ---------- + canvas : CanvasBase + The canvas to which the tool is attached. + buttons : MouseButton or Sequence[MouseButton], default "left" + The mouse buttons that can trigger the tool. + modifiers : Modifier or Sequence[Modifier], optional + The modifier keys that must be pressed to trigger the tool. + """ + _buttons, _modifiers = _norm_input(buttons, modifiers) + return PointSelectionTool(canvas, _buttons, _modifiers) + + def rect_selector( canvas: CanvasBase, buttons: _MouseButton | Sequence[_MouseButton] = "left", @@ -736,3 +784,217 @@ def polygon_selector( """ _buttons, _modifiers = _norm_input(buttons, modifiers) return PolygonSelectionTool(canvas, _buttons, _modifiers, tracking=tracking) + + +_S = TypeVar("_S", bound=SelectionToolBase) + + +class SelectionToolConstructor(ABC, Generic[_S]): + def __init__( + self, + buttons: _MouseButton | list[_MouseButton] = "left", + modifiers: _Modifier | list[_Modifier] | None = None, + tracking: bool = False, + ): + self._buttons = _norm_button(buttons) + self._modifiers = _norm_modifier(modifiers) + self._tracking = tracking + + @property + def buttons(self) -> list[MouseButton]: + return list(self._buttons) + + @buttons.setter + def buttons(self, value: list[MouseButton]): + self._buttons = _norm_button(value) + + @property + def modifiers(self) -> list[Modifier]: + return list(self._modifiers) + + @modifiers.setter + def modifiers(self, value: list[Modifier]): + self._modifiers = _norm_modifier(value) + + @property + def tracking(self) -> bool: + return self._tracking + + @tracking.setter + def tracking(self, value: bool): + if not isinstance(value, bool): + raise TypeError("tracking must be a boolean value.") + self._tracking = value + + def _prep_kwargs(self) -> dict[str, Any]: + return { + "buttons": self._buttons, + "modifiers": self._modifiers, + "tracking": self._tracking, + } + + def install(self, canvas: CanvasBase) -> _S: + return self._install(canvas, **self._prep_kwargs()) + + @abstractmethod + def _install(self, canvas: CanvasBase, **kwargs) -> _S: ... + + +class LineSelectorConstructor(SelectionToolConstructor[LineSelectionTool]): + def _install(self, canvas: CanvasBase, buttons, modifiers, tracking): + return LineSelectionTool(canvas, buttons, modifiers, tracking) + + +class PointSelectorConstructor(SelectionToolConstructor[PointSelectionTool]): + def _prep_kwargs(self) -> dict[str, Any]: + return { + "buttons": self._buttons, + "modifiers": self._modifiers, + } + + def _install(self, canvas: CanvasBase, buttons, modifiers): + return PointSelectionTool(canvas, buttons, modifiers) + + +class RectSelectorConstructor(SelectionToolConstructor[RectSelectionTool]): + def _install(self, canvas: CanvasBase, buttons, modifiers, tracking): + return RectSelectionTool(canvas, buttons, modifiers, tracking) + + +class XSpanSelectorConstructor(SelectionToolConstructor[XSpanSelectionTool]): + def _install(self, canvas: CanvasBase, buttons, modifiers, tracking): + return XSpanSelectionTool(canvas, buttons, modifiers, tracking) + + +class YSpanSelectorConstructor(SelectionToolConstructor[YSpanSelectionTool]): + def _install(self, canvas: CanvasBase, buttons, modifiers, tracking): + return YSpanSelectionTool(canvas, buttons, modifiers, tracking) + + +class LassoSelectorConstructor(SelectionToolConstructor[LassoSelectionTool]): + def _install(self, canvas: CanvasBase, buttons, modifiers, tracking): + return LassoSelectionTool(canvas, buttons, modifiers, tracking) + + +class PolygonSelectorConstructor(SelectionToolConstructor[PolygonSelectionTool]): + def __init__( + self, + buttons: _MouseButton | list[_MouseButton] = "left", + modifiers: _Modifier | list[_Modifier] | None = None, + tracking: bool = False, + auto_close: bool = False, + ): + super().__init__(buttons, modifiers, tracking) + self._auto_close = auto_close + + @property + def auto_close(self) -> bool: + return self._auto_close + + @auto_close.setter + def auto_close(self, value: bool): + if not isinstance(value, bool): + raise TypeError("auto_close must be a boolean value.") + self._auto_close = value + + def _prep_kwargs(self) -> dict[str, Any]: + kwargs = super()._prep_kwargs() + kwargs["auto_close"] = self._auto_close + return kwargs + + def _install(self, canvas: CanvasBase, buttons, modifiers, tracking, auto_close): + return PolygonSelectionTool(canvas, buttons, modifiers, tracking, auto_close) + + +class SelectionManager: + def __init__(self, canvas: CanvasBase): + self._canvas_ref = weakref.ref(canvas) + self._current_tool: SelectionToolBase | None = None + self._line_constructor = LineSelectorConstructor() + self._point_constructor = PointSelectorConstructor() + self._rect_constructor = RectSelectorConstructor() + self._xspan_constructor = XSpanSelectorConstructor() + self._yspan_constructor = YSpanSelectorConstructor() + self._lasso_constructor = LassoSelectorConstructor() + self._polygon_constructor = PolygonSelectorConstructor() + + @property + def line(self) -> LineSelectorConstructor: + return self._line_constructor + + @property + def points(self) -> PointSelectorConstructor: + return self._point_constructor + + @property + def rect(self) -> RectSelectorConstructor: + return self._rect_constructor + + @property + def xspan(self) -> XSpanSelectorConstructor: + return self._xspan_constructor + + @property + def yspan(self) -> YSpanSelectorConstructor: + return self._yspan_constructor + + @property + def lasso(self) -> LassoSelectorConstructor: + return self._lasso_constructor + + @property + def polygon(self) -> PolygonSelectorConstructor: + return self._polygon_constructor + + @property + def current_tool(self) -> SelectionToolBase | None: + return self._current_tool + + @property + def mode(self) -> SelectionMode: + if self._current_tool is None: + return SelectionMode.NONE + elif isinstance(self._current_tool, LineSelectionTool): + return SelectionMode.LINE + elif isinstance(self._current_tool, PointSelectionTool): + return SelectionMode.POINT + elif isinstance(self._current_tool, RectSelectionTool): + return SelectionMode.RECT + elif isinstance(self._current_tool, XSpanSelectionTool): + return SelectionMode.XSPAN + elif isinstance(self._current_tool, YSpanSelectionTool): + return SelectionMode.YSPAN + elif isinstance(self._current_tool, LassoSelectionTool): + return SelectionMode.LASSO + elif isinstance(self._current_tool, PolygonSelectionTool): + return SelectionMode.POLYGON + else: + raise RuntimeError("Invalid tool mode.") + + @mode.setter + def mode(self, value: str | SelectionMode): + value = SelectionMode(value) + if self._current_tool is not None: + self._current_tool.disconnect() + self._current_tool = None + canvas = self._canvas_ref() + if canvas is None: + raise RuntimeError("Canvas is not connected.") + if value is SelectionMode.NONE: + self._current_tool = None + elif value is SelectionMode.LINE: + self._current_tool = self._line_constructor.install(canvas) + elif value is SelectionMode.POINT: + self._current_tool = self._point_constructor.install(canvas) + elif value is SelectionMode.RECT: + self._current_tool = self._rect_constructor.install(canvas) + elif value is SelectionMode.XSPAN: + self._current_tool = self._xspan_constructor.install(canvas) + elif value is SelectionMode.YSPAN: + self._current_tool = self._yspan_constructor.install(canvas) + elif value is SelectionMode.LASSO: + self._current_tool = self._lasso_constructor.install(canvas) + elif value is SelectionMode.POLYGON: + self._current_tool = self._polygon_constructor.install(canvas) + else: + raise ValueError("Invalid mode value.") diff --git a/whitecanvas/tools/_selection_types.py b/whitecanvas/tools/_selection_types.py new file mode 100644 index 00000000..31e6898a --- /dev/null +++ b/whitecanvas/tools/_selection_types.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from enum import Enum +from typing import NamedTuple, overload + +import numpy as np +from numpy.typing import NDArray + +from whitecanvas.tools._polygon_utils import is_in_polygon +from whitecanvas.types import Point, Rect, XYData + + +class PointSelection(Point): + """A selection that defines a point.""" + + +class LineSelection(NamedTuple): + """A selection that defines a line.""" + + start: Point + end: Point + + +class SpanSelection(NamedTuple): + """A selection that defines a span.""" + + start: float + end: float + + @overload + def contains_point(self, point: tuple[float, float], /) -> bool: ... + @overload + def contains_point(self, x: float, y: float, /) -> bool: ... + + def contains_point(self, *args) -> bool: + x_or_y = _point_to_xy(*args)[self._slice_index()] + return self.start <= x_or_y <= self.end + + def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: + points = _atleast_2d(points) + if points.ndim == 2 and points.shape[1] == 2: + xs = points[:, self._slice_index()] + return (self.start <= xs) & (xs <= self.end) + else: + raise ValueError("points must be (2,) or (N, 2) array.") + + def _slice_index(self): + raise NotImplementedError() + + +class XSpanSelection(SpanSelection): + """A selection that defines a span along the x-axis.""" + + def _slice_index(self) -> int: + return 0 + + +class YSpanSelection(SpanSelection): + """A selection that defines a span along the y-axis.""" + + def _slice_index(self) -> int: + return 1 + + +class RectSelection(Rect): + """A selection that defines a rectangle.""" + + @overload + def contains_point(self, point: tuple[float, float], /) -> bool: ... + @overload + def contains_point(self, x: float, y: float, /) -> bool: ... + + def contains_point(self, *args) -> bool: + x, y = _point_to_xy(*args) + return self.left <= x <= self.right and self.bottom <= y <= self.top + + def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: + points = _atleast_2d(points) + if points.ndim == 2 and points.shape[1] == 2: + xs = points[:, 0] + ys = points[:, 1] + return ( + (self.left <= xs) + & (xs <= self.right) + & (self.bottom <= ys) + & (ys <= self.top) + ) + else: + raise ValueError("points must be (2,) or (N, 2) array.") + + +class PolygonSelection(XYData): + """A selection that defines a polygon.""" + + @overload + def contains_point(self, point: tuple[float, float], /) -> bool: ... + @overload + def contains_point(self, x: float, y: float, /) -> bool: ... + + def contains_point(self, *args) -> bool: + x, y = _point_to_xy(*args) + return is_in_polygon(np.array([[x, y]]), self.stack())[0] + + def contains_points(self, points: XYData | NDArray[np.number]) -> NDArray[np.bool_]: + points = _atleast_2d(points) + if points.ndim == 2 and points.shape[1] == 2: + return is_in_polygon(points, self.stack()) + else: + raise ValueError("points must be (2,) or (N, 2) array.") + + +class SelectionMode(Enum): + NONE = "none" + LINE = "line" + POINT = "point" + RECT = "rect" + XSPAN = "xspan" + YSPAN = "yspan" + LASSO = "lasso" + POLYGON = "polygon" + + +def _point_to_xy(*args) -> tuple[float, float]: + if len(args) == 1: + x, y = args[0] + else: + x, y = args + return x, y + + +def _atleast_2d(points: NDArray[np.number]) -> NDArray[np.number]: + if isinstance(points, XYData): + return points.stack() + return np.atleast_2d(points) diff --git a/whitecanvas/types/_tuples.py b/whitecanvas/types/_tuples.py index 6d358588..c35eec4a 100644 --- a/whitecanvas/types/_tuples.py +++ b/whitecanvas/types/_tuples.py @@ -20,6 +20,11 @@ def stack(self) -> NDArray[np.floating]: """Data as a stacked (N, 2) array.""" return np.stack([self.x, self.y], axis=1) + @property + def ndata(self) -> int: + """Number of data points.""" + return int(self.x.size) + @classmethod def from_dict(cls, data: dict[str, NDArray[np.floating]]) -> XYData: """Create XYData from a dictionary."""