Skip to content

Commit

Permalink
refactor subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Aug 15, 2024
1 parent 9853e8e commit c46c6db
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 111 deletions.
29 changes: 27 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,39 @@ def test_bars(backend: str):

def test_step(backend: str):
canvas = new_canvas(backend=backend)
canvas.add_step(np.arange(10), np.zeros(10), where="pre")
layer = canvas.add_step(np.zeros(10))
layer = canvas.add_step(np.arange(10), np.arange(10), where="pre")

repr(layer)

assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))
layer.data = np.arange(10), np.arange(10)
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))

layer.where = "mid"
assert layer.where == "mid"
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))
layer.data = np.arange(10), np.arange(10)
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))

layer.where = "mid_t"
assert layer.where == "mid_t"
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))
layer.data = np.arange(10), np.arange(10)
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))

layer.where = "post"
assert layer.where == "post"
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))
layer.data = np.arange(10), np.arange(10)
assert_allclose(layer.data.x, np.arange(10))
assert_allclose(layer.data.y, np.arange(10))

canvas.add_step(np.arange(10), np.zeros(10), where="mid")
canvas.add_step(np.arange(10), np.zeros(10), where="post")
Expand Down
19 changes: 10 additions & 9 deletions whitecanvas/canvas/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,19 +938,20 @@ def add_markers(
@overload
def add_step(
self, ydata: ArrayLike1D, *, name: str | None = None,
where: Literal["pre", "post", "mid"] | StepStyle = "pre",
color: ColorType | None = None, width: float | None = None,
style: LineStyle | str | None = None, alpha: float = 1.0,
where: str | StepStyle = "pre", color: ColorType | None = None,
width: float | None = None, style: LineStyle | str | None = None,
alpha: float = 1.0, orient: OrientationLike = "horizontal",
antialias: bool = True,
) -> _l.LineStep: # fmt: skip
...

@overload
def add_step(
self, xdata: ArrayLike1D, ydata: ArrayLike1D, *, name: str | None = None,
where: Literal["pre", "post", "mid"] | StepStyle = "pre",
color: ColorType | None = None, width: float | None = None,
style: LineStyle | str | None = None, alpha: float = 1.0,
where: str | StepStyle = "pre", color: ColorType | None = None,
width: float | None = None, style: LineStyle | str | None = None,
alpha: float = 1.0, orient: OrientationLike = "horizontal",
antialias: bool = True,
) -> _l.LineStep: # fmt: skip
...

Expand All @@ -975,10 +976,10 @@ def add_step(
----------
name : str, optional
Name of the layer.
where : "pre", "post" or "mid", default "pre"
where : str or StepStyle, default "pre"
Where the step should be placed.
color : color-like, optional
Color of the bars.
Color of the steps.
width : float, optional
Line width. Use the theme default if not specified.
style : str or LineStyle, optional
Expand All @@ -1000,7 +1001,7 @@ def add_step(
style = theme._default("line.style", style)
layer = _l.LineStep(
xdata, ydata, name=name, color=color, width=width, style=style, where=where,
alpha=alpha, antialias=antialias, backend=self._get_backend(),
alpha=alpha, antialias=antialias, backend=self._get_backend()
) # fmt: skip
return self.add_layer(layer)

Expand Down
155 changes: 88 additions & 67 deletions whitecanvas/layers/_primitive/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,13 @@ def _as_legend_item(self) -> _legend.LineLegendItem:
return _legend.LineLegendItem(self.color, self.width, self.style)


class Line(LineMixin[LineProtocol], HoverableDataBoundLayer[LineProtocol, XYData]):
class _SingleLine(
LineMixin[LineProtocol], HoverableDataBoundLayer[LineProtocol, XYData]
):
_backend_class_name = "MonoLine"
events: LineLayerEvents
_events_class = LineLayerEvents

def __init__(
self,
xdata: ArrayLike1D,
ydata: ArrayLike1D,
*,
name: str | None = None,
color: ColorType = "blue",
width: float = 1,
alpha: float = 1.0,
style: LineStyle | str = LineStyle.SOLID,
antialias: bool = True,
backend: Backend | str | None = None,
):
xdata, ydata = normalize_xy(xdata, ydata)
super().__init__(name=name)
self._backend = self._create_backend(Backend(backend), xdata, ydata)
self.update(
color=color, width=width, style=style, alpha=alpha, antialias=antialias
)
self._x_hint, self._y_hint = xy_size_hint(xdata, ydata)
self._backend._plt_connect_pick_event(self.events.clicked.emit)

def _get_layer_data(self) -> XYData:
return XYData(*self._backend._plt_get_data())

def _norm_layer_data(self, data: Any) -> XYData:
if isinstance(data, np.ndarray):
if data.ndim != 2 or data.shape[1] != 2:
Expand All @@ -197,31 +174,30 @@ def _norm_layer_data(self, data: Any) -> XYData:
)
return XYData(xdata, ydata)

def _set_layer_data(self, data: XYData):
x0, y0 = data
self._backend._plt_set_data(x0, y0)
self._x_hint, self._y_hint = xy_size_hint(x0, y0)

def set_data(
self,
xdata: ArrayLike1D | None = None,
ydata: ArrayLike1D | None = None,
):
"""Set x and y data of the line."""
self.data = xdata, ydata

@property
def ndata(self) -> int:
"""Number of data points."""
return self.data.x.size

def _data_to_backend_data(self, data: XYData) -> XYData:
return data

Check warning on line 191 in whitecanvas/layers/_primitive/line.py

View check run for this annotation

Codecov / codecov/patch

whitecanvas/layers/_primitive/line.py#L191

Added line #L191 was not covered by tests

def with_markers(
self,
symbol: Symbol | str = Symbol.CIRCLE,
size: float | None = None,
color: ColorType | _Void = _void,
alpha: float = 1.0,
hatch: str | Hatch = Hatch.SOLID,
) -> _lg.Plot:
) -> _lg.Plot[Self]:
"""
Add markers at each data point.
Expand Down Expand Up @@ -268,7 +244,7 @@ def with_xerr(
style: str | _Void = _void,
antialias: bool | _Void = _void,
capsize: float = 0,
) -> _lg.LabeledLine:
) -> _lg.LabeledLine[Self]:
from whitecanvas.layers._primitive import Errorbars
from whitecanvas.layers.group import LabeledLine

Expand Down Expand Up @@ -302,7 +278,7 @@ def with_yerr(
style: str | _Void = _void,
antialias: bool | _Void = _void,
capsize: float = 0,
) -> _lg.LabeledLine:
) -> _lg.LabeledLine[Self]:
from whitecanvas.layers._primitive import Errorbars
from whitecanvas.layers.group import LabeledLine

Expand Down Expand Up @@ -334,7 +310,7 @@ def with_xband(
color: ColorType | _Void = _void,
alpha: float = 0.3,
hatch: str | Hatch = Hatch.SOLID,
) -> _lg.LineBand:
) -> _lg.LineBand[Self]:
from whitecanvas.layers._primitive import Band
from whitecanvas.layers.group import LineBand

Expand All @@ -343,10 +319,11 @@ def with_xband(
if color is _void:
color = self.color
data = self.data
_x, _y0 = self._data_to_backend_data(XYData(data.y, data.x - err))
_, _y1 = self._data_to_backend_data(XYData(data.y, data.x + err_high))
band = Band(
data.y, data.x - err, data.x + err_high, orient="horizontal",
color=color, alpha=alpha, hatch=hatch, name=f"xband-of-{self.name}",
backend=self._backend_name,
_x, _y0, _y1, orient=Orientation.HORIZONTAL, color=color, alpha=alpha,
hatch=hatch, name=f"xband-of-{self.name}", backend=self._backend_name,
) # fmt: skip
old_name = self.name
self.name = f"line-of-{self.name}"
Expand All @@ -360,7 +337,7 @@ def with_yband(
color: ColorType | _Void = _void,
alpha: float = 0.3,
hatch: str | Hatch = Hatch.SOLID,
) -> _lg.LineBand:
) -> _lg.LineBand[Self]:
from whitecanvas.layers._primitive import Band
from whitecanvas.layers.group import LineBand

Expand All @@ -369,10 +346,11 @@ def with_yband(
if color is _void:
color = self.color
data = self.data
_x, _y0 = self._data_to_backend_data(XYData(data.x, data.y - err))
_, _y1 = self._data_to_backend_data(XYData(data.x, data.y + err_high))
band = Band(
data.x, data.y - err, data.y + err_high, orient=Orientation.VERTICAL,
color=color, alpha=alpha, hatch=hatch, name=f"yband-of-{self.name}",
backend=self._backend_name,
_x, _y0, _y1, orient=Orientation.VERTICAL, color=color, alpha=alpha,
hatch=hatch, name=f"yband-of-{self.name}", backend=self._backend_name,
) # fmt: skip
old_name = self.name
self.name = f"line-of-{self.name}"
Expand All @@ -385,7 +363,7 @@ def with_xfill(
color: ColorType | _Void = _void,
alpha: float = 0.3,
hatch: str | Hatch = Hatch.SOLID,
) -> _lg.LineBand:
) -> _lg.LineBand[Self]:
from whitecanvas.layers._primitive import Band
from whitecanvas.layers.group import LineBand

Expand All @@ -408,7 +386,7 @@ def with_yfill(
color: ColorType | _Void = _void,
alpha: float = 0.3,
hatch: str | Hatch = Hatch.SOLID,
) -> _lg.LineBand:
) -> _lg.LineBand[Self]:
from whitecanvas.layers._primitive import Band
from whitecanvas.layers.group import LineBand

Expand All @@ -424,6 +402,45 @@ def with_yfill(
self.name = f"line-of-{self.name}"
return LineBand(self, band, name=old_name)


class Line(_SingleLine):
_backend_class_name = "MonoLine"
events: LineLayerEvents
_events_class = LineLayerEvents

def __init__(
self,
xdata: ArrayLike1D,
ydata: ArrayLike1D,
*,
name: str | None = None,
color: ColorType = "blue",
width: float = 1.0,
alpha: float = 1.0,
style: LineStyle | str = LineStyle.SOLID,
antialias: bool = True,
backend: Backend | str | None = None,
):
xdata, ydata = normalize_xy(xdata, ydata)
super().__init__(name=name)
self._backend = self._create_backend(Backend(backend), xdata, ydata)
self.update(
color=color, width=width, style=style, alpha=alpha, antialias=antialias
)
self._x_hint, self._y_hint = xy_size_hint(xdata, ydata)
self._backend._plt_connect_pick_event(self.events.clicked.emit)

def _data_to_backend_data(self, data: XYData) -> XYData:
return data

def _get_layer_data(self) -> XYData:
return XYData(*self._backend._plt_get_data())

def _set_layer_data(self, data: XYData):
x0, y0 = data
self._backend._plt_set_data(x0, y0)
self._x_hint, self._y_hint = xy_size_hint(x0, y0)

def with_text(
self,
strings: list[str],
Expand Down Expand Up @@ -502,7 +519,7 @@ def build_cdf(
) # fmt: skip


class LineStep(LineMixin[LineProtocol], HoverableDataBoundLayer[LineProtocol, XYData]):
class LineStep(_SingleLine):
_backend_class_name = "MonoLine"
events: LineLayerEvents
_events_class = LineLayerEvents
Expand Down Expand Up @@ -535,31 +552,46 @@ def __init__(
def _data_to_backend_data(self, data: XYData) -> XYData:
if data.x.size < 2:
return data

Check warning on line 554 in whitecanvas/layers/_primitive/line.py

View check run for this annotation

Codecov / codecov/patch

whitecanvas/layers/_primitive/line.py#L554

Added line #L554 was not covered by tests
if self._where is StepStyle.PRE:
if self._where is StepStyle.PRE or self._where is StepStyle.POST_T:
xdata = np.repeat(data.x, 2)[:-1]
ydata = np.repeat(data.y, 2)[1:]
elif self._where is StepStyle.POST:
elif self._where is StepStyle.POST or self._where is StepStyle.PRE_T:
xdata = np.repeat(data.x, 2)[1:]
ydata = np.repeat(data.y, 2)[:-1]
else:
elif self._where is StepStyle.MID:
xrep = np.repeat((data.x[1:] + data.x[:-1]) / 2, 2)
xdata = np.concatenate([data.x[:1], xrep, data.x[-1:]])
ydata = np.repeat(data.y, 2)
elif self._where is StepStyle.MID_T:
yrep = np.repeat((data.y[1:] + data.y[:-1]) / 2, 2)
xdata = np.repeat(data.x, 2)
ydata = np.concatenate([data.y[:1], yrep, data.y[-1:]])
else: # pragma: no cover
raise ValueError(f"Invalid step style: {self._where}")
return XYData(xdata, ydata)

def _get_layer_data(self) -> XYData:
xback, yback = self._backend._plt_get_data()
if self._where is StepStyle.PRE:
if (
self._where is StepStyle.PRE
or self._where is StepStyle.POST_T
or self._where is StepStyle.POST
or self._where is StepStyle.PRE_T
):
xdata = xback[::2]
ydata = yback[::2]
elif self._where is StepStyle.POST:
xdata = xback[1::2]
ydata = yback[1::2]
else:
elif self._where is StepStyle.MID:
xmids = xback[1:-1:2]
xmid = (xmids[2:] + xmids[:-2]) / 2
xmid = (xmids[1:] + xmids[:-1]) / 2
xdata = np.concatenate([xback[:1], xmid, xback[-1:]])
ydata = yback[1::2]
elif self._where is StepStyle.MID_T:
ymids = yback[1:-1:2]
ymid = (ymids[1:] + ymids[:-1]) / 2
xdata = xback[::2]
ydata = np.concatenate([yback[:1], ymid, yback[-1:]])
else: # pragma: no cover
raise ValueError(f"Invalid step style: {self._where}")
return XYData(xdata, ydata)

def _norm_layer_data(self, data: Any) -> XYData:
Expand Down Expand Up @@ -589,26 +621,15 @@ def _set_layer_data(self, data: XYData):
self._backend._plt_set_data(x0, y0)
self._x_hint, self._y_hint = xy_size_hint(x0, y0)

def set_data(
self,
xdata: ArrayLike1D | None = None,
ydata: ArrayLike1D | None = None,
):
self.data = xdata, ydata

@property
def ndata(self) -> int:
"""Number of data points."""
return self.data.x.size

@property
def where(self) -> StepStyle:
return self._where

@where.setter
def where(self, where: str | StepStyle):
data = self.data
self._where = StepStyle(where)
self._set_layer_data(self.data)
self._set_layer_data(data)


class MultiLineEvents(LayerEvents):
Expand Down
Loading

0 comments on commit c46c6db

Please sign in to comment.