Skip to content

Commit

Permalink
More optimizations, change API
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Sep 15, 2022
1 parent 792d38e commit 6dd089f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 77 deletions.
14 changes: 9 additions & 5 deletions examples/vector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"from geocube.rasterize import rasterize_image\n",
"from rasterio.enums import MergeAlg\n",
"import geopandas as gpd\n",
"from ipyleaflet import LayersControl, Map, WidgetControl, basemaps\n",
"from ipyleaflet import LocalTileLayer, LayersControl, Map, WidgetControl, basemaps\n",
"from ipywidgets import FloatSlider\n",
"import xarray_leaflet\n",
"import matplotlib.pyplot as plt"
Expand All @@ -32,8 +35,7 @@
"metadata": {},
"outputs": [],
"source": [
"df = gpd.read_file(\"bldg_footprints.shp\")\n",
"df[\"mask\"] = 1"
"df = gpd.read_file(\"bldg_footprints.shp\")"
]
},
{
Expand All @@ -54,7 +56,9 @@
"metadata": {},
"outputs": [],
"source": [
"l = df.leaflet.plot(m, measurement=\"mask\", colormap=plt.cm.inferno)"
"rasterize_function = partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=False)\n",
"layer = partial(LocalTileLayer, max_zoom=20)\n",
"l = df.leaflet.plot(m, measurement=\"Height\", layer=layer, dynamic=False, rasterize_function=rasterize_function, colormap=plt.cm.viridis)"
]
},
{
Expand Down Expand Up @@ -94,7 +98,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.5"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ install_requires =
jupyter_server >=0.2.0
rioxarray >=0.0.30
ipyleaflet >=0.13.1
ipywidgets >=7.7.2
pillow >=7
matplotlib >=3
affine >=2
mercantile >=1
ipyspin >=0.1.6
ipyurl >=0.1.3
jupyterlab-widgets >=1.0.0,<2
geocube <1.0.0
pygeos >=0.12,<1.0.0
zarr >=2.0.0,<3.0.0
Expand Down
73 changes: 23 additions & 50 deletions xarray_leaflet/vector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import math
from functools import partial
from pathlib import Path
from typing import Optional
from typing import Callable, Optional

import mercantile
import numpy as np
import pyproj
import xarray as xr
import zarr
from geocube.api.core import make_geocube
Expand All @@ -23,13 +23,15 @@ def __init__(
self,
df: GeoDataFrame,
measurement: str,
rasterize_function: Optional[Callable],
width: int,
height: int,
root_path: str = "",
):
# reproject to Web Mercator
self.df = df.to_crs(epsg=3857)
self.measurement = measurement
self.rasterize_function = rasterize_function or partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=True)
self.width = width
self.height = height
self.zzarr = Zzarr(root_path, width, height)
Expand All @@ -56,9 +58,7 @@ def get_da_tile(self, tile: mercantile.Tile) -> Optional[xr.DataArray]:
vector_data=df_tile,
resolution=(-dy, dx),
measurements=[self.measurement],
rasterize_function=partial(
rasterize_image, merge_alg=MergeAlg.add, all_touched=True
),
rasterize_function=self.rasterize_function,
fill=0,
geom=geom,
)
Expand All @@ -82,15 +82,10 @@ def get_da_llbbox(
self.tiles.append(tile)
if all_none:
return None
project = pyproj.Transformer.from_crs(
pyproj.CRS("EPSG:4326"), pyproj.CRS("EPSG:3857"), always_xy=True
).transform
b = box(*bbox)
polygon = transform(project, b)
left, bottom, right, top = polygon.bounds
return self.zzarr.get_ds(z)["da"].sel(
x=slice(left, right), y=slice(top, bottom)
)
da = self.get_da(z)
y0, x0 = deg2idx(bbox.north, bbox.west, z, self.height, self.width, math.floor)
y1, x1 = deg2idx(bbox.south, bbox.east, z, self.height, self.width, math.ceil)
return da[y0:y1, x0:x1]

def get_da(self, z: int) -> xr.DataArray:
return self.zzarr.get_ds(z)["da"]
Expand All @@ -101,7 +96,7 @@ def __init__(self, root_path: str, width: int, height: int):
self.root_path = Path(root_path)
self.width = width
self.height = height
self.ds = {}
self.z = None

def open_zarr(self, mode: str, z: int) -> zarr.Array:
path = self.root_path / str(z)
Expand All @@ -114,32 +109,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array:
)
if mode == "w":
# write Dataset to zarr
mi, ma = mercantile.minmax(z)
ul = mercantile.xy_bounds(mi, mi, z)
lr = mercantile.xy_bounds(ma, ma, z)
bbox = mercantile.Bbox(ul.left, lr.bottom, lr.right, ul.top)
x = zarr.open(
path / "x",
mode="w",
shape=(2**z * self.width,),
chunks=(2**z * self.width,),
dtype="<f8",
)
x[:] = np.linspace(bbox.left, bbox.right, 2**z * self.width)
x_zattrs = dict(_ARRAY_DIMENSIONS=["x"])
(path / "x" / ".zattrs").write_text(json.dumps(x_zattrs))
y = zarr.open(
path / "y",
mode="w",
shape=(2**z * self.height,),
chunks=(2**z * self.height,),
dtype="<f8",
)
y[:] = np.linspace(bbox.top, bbox.bottom, 2**z * self.height)
x_zarray = json.loads((path / "x" / ".zarray").read_text())
y_zarray = json.loads((path / "y" / ".zarray").read_text())
y_zattrs = dict(_ARRAY_DIMENSIONS=["y"])
(path / "y" / ".zattrs").write_text(json.dumps(y_zattrs))
(path / ".zattrs").write_text(json.dumps(dict()))
zarray = json.loads((path / "da" / ".zarray").read_text())
zattrs = dict(_ARRAY_DIMENSIONS=["y", "x"])
Expand All @@ -153,10 +122,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array:
".zgroup": zgroup,
"da/.zarray": zarray,
"da/.zattrs": zattrs,
"x/.zarray": x_zarray,
"x/.zattrs": x_zattrs,
"y/.zarray": y_zarray,
"y/.zattrs": y_zattrs,
},
zarr_consolidated_format=1,
)
Expand All @@ -172,14 +137,22 @@ def write_to_zarr(self, tile: mercantile.Tile, data: np.ndarray):
mode = "a"
else:
mode = "w"
self.array = self.open_zarr(mode, z)
self.array[
array = self.open_zarr(mode, z)
array[
y * self.height : (y + 1) * self.height, # noqa
x * self.width : (x + 1) * self.width, # noqa
] = data

def get_ds(self, z: int) -> xr.Dataset:
path = self.root_path / str(z)
if z not in self.ds:
self.ds[z] = xr.open_zarr(path)
return self.ds[z]
if z != self.z:
self.ds_z = xr.open_zarr(path)
self.z = z
return self.ds_z

def deg2idx(lat_deg, lon_deg, zoom, height, width, round_fun):
lat_rad = math.radians(lat_deg)
n = 2 ** zoom
xtile = round_fun(((lon_deg + 180) % 360) / 360 * n * width)
ytile = round_fun((1 - math.asinh(math.tan(lat_rad)) / math.pi) / 2 * n * height)
return ytile, xtile
68 changes: 47 additions & 21 deletions xarray_leaflet/xarray_leaflet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def _map_ready_changed(self, change):
def plot(
self,
m,
*,
# raster or vector options:
get_base_url: Optional[Callable] = None,
dynamic: Optional[bool] = None,
persist: bool = True,
tile_dir=None,
tile_height: int = 256,
tile_width: int = 256,
layer: Callable = LocalTileLayer,
# raster-only options:
x_dim="x",
y_dim="y",
fit_bounds=True,
Expand All @@ -54,15 +64,11 @@ def plot(
transform3=passthrough,
colormap=None,
colorbar_position="topright",
persist=True,
dynamic=False,
tile_dir=None,
tile_height=256,
tile_width=256,
resampling=Resampling.nearest,
get_base_url=None,
# vector-only options:
measurement: Optional[str] = None,
visible_callback: Optional[Callable] = None,
rasterize_function: Optional[Callable] = None,
):
"""Display an array as an interactive map.
Expand Down Expand Up @@ -122,30 +128,39 @@ def plot(
- the mercantile.LngLatBbox of the visible region
and returning True if the layer should be shown, False otherwise.
rasterize_function: callable, optional
A callable passed to make_geocube. Defaults to:
partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=True)
Returns
-------
layer : ipyleaflet.LocalTileLayer
A handler to the layer that is added to the map.
"""

self.layer = LocalTileLayer()
self.layer = layer()

if self.is_vector:
# source is a GeoDataFrame (vector)
self.visible_callback = visible_callback
if measurement is None:
raise RuntimeError("You must provide a 'measurement'.")
if dynamic is None:
dynamic = True
if not dynamic:
self.vmin = self._df[measurement].min()
self.vmax = self._df[measurement].max()
self.measurement = measurement
dynamic = True
zarr_temp_dir = tempfile.TemporaryDirectory(prefix="xarray_leaflet_zarr_")
self.zvect = Zvect(
self._df, measurement, tile_width, tile_height, zarr_temp_dir.name
self._df, measurement, rasterize_function, tile_width, tile_height, zarr_temp_dir.name
)
if colormap is None:
colormap = plt.cm.viridis
else:
# source is a DataArray (raster)
if dynamic is None:
dynamic = False
if "proj4def" in m.crs:
# it's a custom projection
if dynamic:
Expand Down Expand Up @@ -363,6 +378,7 @@ def _get_vector_tiles(self, change=None):
tiles = mercantile.tiles(west, south, east, north, z)

if self.dynamic:
# get DataArray for the visible map
llbbox = mercantile.LngLatBbox(west, south, east, north)
da_visible = self.zvect.get_da_llbbox(llbbox, z)
# check if we must show the layer
Expand All @@ -372,32 +388,42 @@ def _get_vector_tiles(self, change=None):
self.m.remove_control(self.spinner_control)
return
if da_visible is None:
self.max_value = 0
vmin = vmax = 0
else:
self.max_value = da_visible.max()
vmin = da_visible.min()
vmax = da_visible.max()
else:
vmin = self.vmin
vmax = self.vmax
da_visible_computed = False

for tile in tiles:
x, y, z = tile
path = f"{self.tile_path}/{z}/{x}/{y}.png"
if self.dynamic or not os.path.exists(path):
xy_bbox = mercantile.xy_bounds(tile)
if self.dynamic:
if da_visible is not None:
da_tile = self.zvect.get_da(z).sel(
y=slice(xy_bbox.top, xy_bbox.bottom),
x=slice(xy_bbox.left, xy_bbox.right),
)
else:
da_tile = None
if not self.dynamic and not da_visible_computed:
# get DataArray for the visible map
llbbox = mercantile.LngLatBbox(west, south, east, north)
da_visible = self.zvect.get_da_llbbox(llbbox, z)
da_visible_computed = True
if self.dynamic and da_visible is None:
da_tile = None
else:
da_tile = self.zvect.get_da(z)[
y * self.tile_height : (y + 1) * self.tile_height,
x * self.tile_width : (x + 1) * self.tile_width,
]
if da_tile is None:
write_image(path, None)
else:
da_tile /= self.max_value
# normalize
da_tile = (da_tile - vmin) / (vmax - vmin)
da_tile = self.colormap(da_tile)
write_image(path, da_tile * 255)

if self.dynamic:
self.layer.redraw()

self.m.remove_control(self.spinner_control)

def _get_raster_tiles(self, change=None):
Expand Down

0 comments on commit 6dd089f

Please sign in to comment.