Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement add_hist2d and add_kde2d with colors #20

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@ def test_cat(backend: str):
"y": rng.normal(size=30),
"label": np.repeat(["A", "B", "C"], 10),
}
canvas.cat(df, "x", "y").add_line()
canvas.cat(df, "x", "y").add_line(color="label")
canvas.cat(df, "x", "y").add_markers()
canvas.cat(df, "x", "y").add_markers(color="label")
canvas.cat(df, "x", "y").add_markers(hatch="label")
canvas.cat(df, "x", "y").add_hist2d(bins=5)
canvas.cat(df, "x", "y").add_hist2d(bins=(5, 4))
canvas.cat(df, "x", "y").add_hist2d(bins="auto")
canvas.cat(df, "x", "y").add_hist2d(bins=("auto", 5))
canvas.cat(df, "x", "y").along_x().add_hist(bins=5)
canvas.cat(df, "x", "y").along_x().add_hist(bins=5, color="label")
canvas.cat(df, "x", "y").along_y().add_hist(bins=6)
canvas.cat(df, "x", "y").along_y().add_hist(bins=6, color="label")
cplt = canvas.cat(df, "x", "y")
cplt.add_line()
cplt.add_line(color="label")
cplt.add_markers()
cplt.add_markers(color="label")
cplt.add_markers(hatch="label")
cplt.add_hist2d(bins=5)
cplt.add_hist2d(bins=(5, 4))
cplt.add_hist2d(bins="auto")
cplt.add_hist2d(bins=(5, 4), color="label")
cplt.add_hist2d(bins=("auto", 5))
cplt.add_kde2d()
cplt.add_kde2d(color="label")
cplt.along_x().add_hist(bins=5)
cplt.along_x().add_hist(bins=5, color="label")
cplt.along_y().add_hist(bins=6)
cplt.along_y().add_hist(bins=6, color="label")

@pytest.mark.parametrize("orient", ["v", "h"])
def test_cat_plots(backend: str, orient: str):
Expand Down
51 changes: 42 additions & 9 deletions whitecanvas/canvas/dataframe/_feature_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from whitecanvas.canvas.dataframe._base import BaseCatPlotter
from whitecanvas.layers import tabular as _lt
from whitecanvas.layers.tabular import _jitter
from whitecanvas.types import ColormapType, HistBinType, KdeBandWidthType, Orientation
from whitecanvas.types import HistBinType, KdeBandWidthType, Orientation

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -214,12 +214,12 @@ def add_markers(
def add_hist2d(
self,
*,
cmap: ColormapType = "inferno",
name: str | None = None,
color: str | None = None,
bins: HistBinType | tuple[HistBinType, HistBinType] = "auto",
rangex: tuple[float, float] | None = None,
rangey: tuple[float, float] | None = None,
density: bool = False,
cmap=None, # deprecated
) -> _lt.DFHeatmap[_DF]:
"""
Add 2-D histogram of given x/y columns.
Expand All @@ -240,19 +240,52 @@ def add_hist2d(
Range of x values in which histogram will be built.
rangey : (float, float), optional
Range of y values in which histogram will be built.
density : bool, default False
If True, the result is the value of the probability density function at the
bin, normalized such that the integral over the range is 1.

Returns
-------
DFHeatmap
Dataframe bound heatmap layer.
"""
canvas = self._canvas()
layer = _lt.DFHeatmap.build_hist(
self._df, self._get_x(), self._get_y(), cmap=cmap, name=name, bins=bins,
range=(rangex, rangey), density=density, backend=canvas._get_backend(),
layer = _lt.DFMultiHeatmap.build_hist(
self._df, self._get_x(), self._get_y(), color=color,name=name, bins=bins,
range=(rangex, rangey), palette=canvas._color_palette,
backend=canvas._get_backend(),
) # fmt: skip
return canvas.add_layer(layer)

def add_kde2d(
self,
*,
name: str | None = None,
color: str | None = None,
band_width: KdeBandWidthType = "scott",
) -> _lt.DFHeatmap[_DF]:
"""
Add 2-D kernel density estimation of given x/y columns.

>>> ### Use "tip" column as x-axis and "total_bill" column as y-axis
>>> canvas.cat(df, "tip", "total_bill").add_kde2d()

Parameters
----------
cmap : colormap-like, default "inferno"
Colormap to use for the heatmap.
name : str, optional
Name of the layer.
band_width : float, default None
Bandwidth of the kernel density estimation. If None, use Scott's rule.

Returns
-------
DFHeatmap
Dataframe bound heatmap layer.
"""
canvas = self._canvas()
layer = _lt.DFMultiHeatmap.build_kde(
self._df, self._get_x(), self._get_y(), color=color, name=name,
band_width=band_width, palette=canvas._color_palette,
backend=canvas._get_backend(),
) # fmt: skip
return canvas.add_layer(layer)

Expand Down
45 changes: 44 additions & 1 deletion whitecanvas/layers/_primitive/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from whitecanvas.backend import Backend
from whitecanvas.layers._base import DataBoundLayer, LayerEvents
from whitecanvas.protocols import ImageProtocol
from whitecanvas.types import ArrayLike1D, ColormapType, HistBinType, Origin, _Void
from whitecanvas.types import (
ArrayLike1D,
ColormapType,
HistBinType,
KdeBandWidthType,
Origin,
_Void,
)
from whitecanvas.utils.normalize import as_array_1d
from whitecanvas.utils.type_check import is_real_number

Expand Down Expand Up @@ -284,6 +291,42 @@ def build_hist(
self.origin = Origin.EDGE
return self

@classmethod
def build_kde(
cls,
x: ArrayLike1D,
y: ArrayLike1D,
shape: tuple[int, int] = (256, 256),
range=None,
band_width: KdeBandWidthType = "scott",
name: str | None = None,
cmap: ColormapType = "inferno",
backend: Backend | str | None = None,
) -> Image:
from whitecanvas.utils.kde import gaussian_kde

_x = as_array_1d(x)
_y = as_array_1d(y)
kde = gaussian_kde([_x, _y], bw_method=band_width)
if range is None:
xrange = yrange = None
else:
xrange, yrange = range
if xrange is None:
xrange = _x.min(), _x.max()
if yrange is None:
yrange = _y.min(), _y.max()
xedges = np.linspace(*xrange, shape[0])
yedges = np.linspace(*yrange, shape[1])
xx, yy = np.meshgrid(xedges, yedges)
positions = np.vstack([xx.ravel(), yy.ravel()])
val = np.reshape(kde(positions).T, xx.shape)
shift = (xedges[0], yedges[0])
scale = (xedges[1] - xedges[0], yedges[1] - yedges[0])
self = cls(val, name=name, cmap=cmap, shift=shift, scale=scale, backend=backend)
self.origin = Origin.EDGE
return self


def _normalize_image(image):
img = np.asarray(image)
Expand Down
2 changes: 2 additions & 0 deletions whitecanvas/layers/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DFHistograms,
DFKde,
DFLines,
DFMultiHeatmap,
DFPointPlot2D,
)
from whitecanvas.layers.tabular._df_compat import parse
Expand All @@ -27,6 +28,7 @@
"DFMarkerGroups",
"DFPointPlot",
"DFMarkers",
"DFMultiHeatmap",
"DFBars",
"DFBoxPlot",
"DFHeatmap",
Expand Down
149 changes: 128 additions & 21 deletions whitecanvas/layers/tabular/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from itertools import cycle
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -37,6 +38,7 @@
if TYPE_CHECKING:
from typing_extensions import Self


_DF = TypeVar("_DF")


Expand Down Expand Up @@ -230,43 +232,148 @@ def clim(self) -> tuple[float, float]:
def clim(self, clim: tuple[float, float]):
self._base_layer.clim = clim

@classmethod
def from_array(
cls,
src: DataFrameWrapper[_DF],
arr: np.ndarray,
name: str | None = None,
cmap: ColormapType = "gray",
clim: tuple[float | None, float | None] | None = None,
backend: Backend | str | None = None,
) -> DFHeatmap[_DF]:
return cls(_l.Image(arr, name=name, cmap=cmap, clim=clim, backend=backend), src)


class DFMultiHeatmap(
_shared.DataFrameLayerWrapper[_lg.LayerCollectionBase[_l.Image], _DF],
Generic[_DF],
):
def __init__(
self,
base: _lg.LayerCollectionBase[_l.Image],
source: DataFrameWrapper[_DF],
color_by: _p.ColorPlan,
categories: list[tuple],
):
self._color_by = color_by
self._categories = categories
super().__init__(base, source)

@classmethod
def build_hist(
cls,
df: _DF,
x: str,
y: str,
name: str | None = None,
cmap: ColormapType = "gray",
color: str | list[str] | None = None,
bins: HistBinType | tuple[HistBinType, HistBinType] = "auto",
range=None,
density: bool = False,
palette: ColormapType = "tab10",
backend: Backend | str | None = None,
) -> Self:
src = parse(df)
xdata = src[x]
ydata = src[y]
if xdata.dtype.kind not in "fiub":
raise ValueError(f"Column {x!r} is not numeric.")
if ydata.dtype.kind not in "fiub":
raise ValueError(f"Column {y!r} is not numeric.")
base = _l.Image.build_hist(
xdata, ydata, name=name, cmap=cmap, bins=bins, range=range,
density=density, backend=backend,
) # fmt: skip
return cls(base, src)
src, color = cls._norm_df_xy_color(df, x, y, color)
# normalize bins
if isinstance(bins, tuple):
xbins, ybins = bins
else:
xbins = ybins = bins
if range is None:
xrange = yrange = None
else:
xrange, yrange = range
_bins = (
np.histogram_bin_edges(src[x], xbins, xrange),
np.histogram_bin_edges(src[y], ybins, yrange),
)

color_by = _p.ColorPlan.from_palette(color, palette)
image_layers: list[_l.Image] = []
categories = []
color_iter = cycle(color_by.values)
for sl, sub in src.group_by(color):
categories.append(sl)
xdata, ydata = sub[x], sub[y]
next_color = next(color_iter)
next_background = Color([*next_color.rgba[:3], 0.0])
cmap = [next_background, next_color]
img = _l.Image.build_hist(
xdata, ydata, name=name, cmap=cmap, bins=_bins, density=True,
backend=backend,
) # fmt: skip
image_layers.append(img)
base = _lg.LayerCollectionBase(image_layers)
return cls(base, src, color_by, categories)

@classmethod
def from_array(
def build_kde(
cls,
src: DataFrameWrapper[_DF],
arr: np.ndarray,
df: _DF,
x: str,
y: str,
name: str | None = None,
cmap: ColormapType = "gray",
clim: tuple[float | None, float | None] | None = None,
color: str | list[str] | None = None,
band_width: KdeBandWidthType = "scott",
palette: ColormapType = "tab10",
backend: Backend | str | None = None,
) -> DFHeatmap[_DF]:
return cls(_l.Image(arr, name=name, cmap=cmap, clim=clim, backend=backend), src)
) -> Self:
src, color = cls._norm_df_xy_color(df, x, y, color)
color_by = _p.ColorPlan.from_palette(color, palette)
image_layers: list[_l.Image] = []
categories = []
color_iter = cycle(color_by.values)
xrange = src[x].min(), src[x].max()
yrange = src[y].min(), src[y].max()
for sl, sub in src.group_by(color):
categories.append(sl)
xdata, ydata = sub[x], sub[y]
next_color = next(color_iter)
next_background = Color([*next_color.rgba[:3], 0.0])
cmap = [next_background, next_color]
img = _l.Image.build_kde(
xdata,
ydata,
name=name,
cmap=cmap,
band_width=band_width,
range=(xrange, yrange),
backend=backend,
)
image_layers.append(img)
base = _lg.LayerCollectionBase(image_layers)
return cls(base, src, color_by, categories)

@staticmethod
def _norm_df_xy_color(df, x, y, color):
src = parse(df)
# dtype check
if src[x].dtype.kind not in "fiub":
raise ValueError(f"Column {x!r} is not numeric.")
if src[y].dtype.kind not in "fiub":
raise ValueError(f"Column {y!r} is not numeric.")

if isinstance(color, str):
color = (color,)
elif color is None:
color = ()
else:
color = tuple(color)
return src, color

def _as_legend_item(self) -> _legend.LegendItem:
if len(self._categories) == 1:
face = _legend.FaceInfo(self._color_by.values[0])
edge = _legend.EdgeInfo(self._color_by.values[0], width=0)
return _legend.BarLegendItem(face, edge)
df = _shared.list_to_df(self._categories, self._color_by.by)
colors = self._color_by.to_entries(df)
items = [(", ".join(self._color_by.by), _legend.TitleItem())]
for label, color in colors:
face = _legend.FaceInfo(color)
edge = _legend.EdgeInfo(color, width=0)
items.append((label, _legend.BarLegendItem(face, edge)))
return _legend.LegendItemCollection(items)


class DFPointPlot2D(_shared.DataFrameLayerWrapper[_lg.LabeledPlot, _DF], Generic[_DF]):
Expand Down
Loading