Skip to content

Commit

Permalink
use xarray.open_zarr and make aiohttp and s3fs optional (#1016)
Browse files Browse the repository at this point in the history
* use xarray.open_zarr and make aiohttp and s3fs optional

* add support for references

* tests prefixed protocol

* use tmp_dir for reference

* add parquet support

* remove kerchunk support
  • Loading branch information
vincentsarago authored Nov 4, 2024
1 parent 691eeed commit d0804ec
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Test titiler.xarray
run: |
python -m pip install -e src/titiler/xarray["test"]
python -m pip install -e src/titiler/xarray["test,all"]
python -m pytest src/titiler/xarray --cov=titiler.xarray --cov-report=xml --cov-append --cov-report=term-missing
- name: Test titiler.mosaic
Expand Down
20 changes: 13 additions & 7 deletions src/titiler/xarray/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,25 @@ classifiers = [
dynamic = ["version"]
dependencies = [
"titiler.core==0.19.0.dev",
"cftime",
"h5netcdf",
"xarray",
"rioxarray",
"zarr",
"fsspec",
"s3fs",
"aiohttp",
"pandas",
"httpx",
"zarr",
"h5netcdf",
"cftime",
]

[project.optional-dependencies]
s3 = [
"s3fs",
]
http = [
"aiohttp",
]
all = [
"s3fs",
"aiohttp",
]
test = [
"pytest",
"pytest-cov",
Expand Down
24 changes: 24 additions & 0 deletions src/titiler/xarray/tests/fixtures/generate_fixtures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,30 @@
" ds.to_zarr(store=f\"pyramid.zarr\", mode=\"w\", group=ix)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import fsspec\n",
"from kerchunk.hdf import SingleHdf5ToZarr\n",
"\n",
"with fsspec.open(\"dataset_3d.nc\", mode=\"rb\", anon=True) as infile:\n",
" h5chunks = SingleHdf5ToZarr(infile, \"dataset_3d.nc\", inline_threshold=100)\n",
"\n",
" with open(\"reference.json\", 'w') as f:\n",
" f.write(json.dumps(h5chunks.translate()));\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
15 changes: 11 additions & 4 deletions src/titiler/xarray/tests/test_io_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,19 @@ def test_get_variable():


@pytest.mark.parametrize(
"filename",
["dataset_2d.nc", "dataset_3d.nc", "dataset_3d.zarr"],
"protocol,filename",
[
("file://", "dataset_2d.nc"),
("file://", "dataset_3d.nc"),
("file://", "dataset_3d.zarr"),
("", "dataset_2d.nc"),
("", "dataset_3d.nc"),
("", "dataset_3d.zarr"),
],
)
def test_reader(filename):
def test_reader(protocol, filename):
"""test reader."""
src_path = os.path.join(prefix, filename)
src_path = protocol + os.path.join(protocol, prefix, filename)
assert Reader.list_variables(src_path) == ["dataset"]

with Reader(src_path, variable="dataset") as src:
Expand Down
16 changes: 0 additions & 16 deletions src/titiler/xarray/titiler/xarray/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ class XarrayIOParams(DefaultDependency):
),
] = None

reference: Annotated[
Optional[bool],
Query(
title="reference",
description="Whether the dataset is a kerchunk reference",
),
] = None

decode_times: Annotated[
Optional[bool],
Query(
Expand All @@ -38,14 +30,6 @@ class XarrayIOParams(DefaultDependency):
),
] = None

consolidated: Annotated[
Optional[bool],
Query(
title="consolidated",
description="Whether to expect and open zarr store with consolidated metadata",
),
] = None

# cache_client


Expand Down
83 changes: 42 additions & 41 deletions src/titiler/xarray/titiler/xarray/io.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
"""titiler.xarray.io"""

import pickle
import re
from typing import Any, Callable, Dict, List, Optional, Protocol
from urllib.parse import urlparse

import attr
import fsspec
import numpy
import s3fs
import xarray
from morecantile import TileMatrixSet
from rio_tiler.constants import WEB_MERCATOR_TMS
from rio_tiler.io.xarray import XarrayReader

try:
import s3fs
except ImportError: # pragma: nocover
s3fs = None # type: ignore

try:
import aiohttp
except ImportError: # pragma: nocover
aiohttp = None # type: ignore


class CacheClient(Protocol):
"""CacheClient Protocol."""
Expand All @@ -26,27 +35,19 @@ def set(self, key: str, body: bytes) -> None:
...


def parse_protocol(src_path: str, reference: Optional[bool] = False) -> str:
def parse_protocol(src_path: str) -> str:
"""Parse protocol from path."""
match = re.match(r"^(s3|https|http)", src_path)
protocol = "file"
if match:
protocol = match.group(0)

# override protocol if reference
if reference:
protocol = "reference"

return protocol
parsed = urlparse(src_path)
return parsed.scheme or "file"


def xarray_engine(src_path: str) -> str:
"""Parse xarray engine from path."""
# ".hdf", ".hdf5", ".h5" will be supported once we have tests + expand the type permitted for the group parameter
if any(src_path.lower().endswith(ext) for ext in [".nc", ".nc4"]):
return "h5netcdf"
else:
return "zarr"

return "zarr"


def get_filesystem(
Expand All @@ -59,18 +60,21 @@ def get_filesystem(
Get the filesystem for the given source path.
"""
if protocol == "s3":
assert s3fs is not None, "s3fs must be installed to support S3:// url"

s3_filesystem = s3fs.S3FileSystem()
return (
s3_filesystem.open(src_path)
if xr_engine == "h5netcdf"
else s3fs.S3Map(root=src_path, s3=s3_filesystem)
)

elif protocol == "reference":
reference_args = {"fo": src_path, "remote_options": {"anon": anon}}
return fsspec.filesystem("reference", **reference_args).get_mapper("")

elif protocol in ["https", "http", "file"]:
if protocol.startswith("http"):
assert (
aiohttp is not None
), "aiohttp must be installed to support HTTP:// url"

filesystem = fsspec.filesystem(protocol) # type: ignore
return (
filesystem.open(src_path)
Expand All @@ -85,9 +89,7 @@ def get_filesystem(
def xarray_open_dataset(
src_path: str,
group: Optional[Any] = None,
reference: Optional[bool] = False,
decode_times: Optional[bool] = True,
consolidated: Optional[bool] = True,
cache_client: Optional[CacheClient] = None,
) -> xarray.Dataset:
"""Open dataset."""
Expand All @@ -98,7 +100,7 @@ def xarray_open_dataset(
if data_bytes:
return pickle.loads(data_bytes)

protocol = parse_protocol(src_path, reference=reference)
protocol = parse_protocol(src_path)
xr_engine = xarray_engine(src_path)
file_handler = get_filesystem(src_path, protocol, xr_engine)

Expand All @@ -115,19 +117,26 @@ def xarray_open_dataset(

# NetCDF arguments
if xr_engine == "h5netcdf":
xr_open_args["engine"] = "h5netcdf"
xr_open_args["lock"] = False
else:
# Zarr arguments
xr_open_args["engine"] = "zarr"
xr_open_args["consolidated"] = consolidated
xr_open_args.update(
{
"engine": "h5netcdf",
"lock": False,
}
)

# Additional arguments when dealing with a reference file.
if reference:
xr_open_args["consolidated"] = False
xr_open_args["backend_kwargs"] = {"consolidated": False}
ds = xarray.open_dataset(file_handler, **xr_open_args)

# Fallback to Zarr
else:
if protocol == "reference":
xr_open_args.update(
{
"consolidated": False,
"backend_kwargs": {"consolidated": False},
}
)

ds = xarray.open_dataset(file_handler, **xr_open_args)
ds = xarray.open_zarr(file_handler, **xr_open_args)

if cache_client:
# Serialize the dataset to bytes using pickle
Expand Down Expand Up @@ -245,9 +254,7 @@ class Reader(XarrayReader):
opener: Callable[..., xarray.Dataset] = attr.ib(default=xarray_open_dataset)

group: Optional[Any] = attr.ib(default=None)
reference: bool = attr.ib(default=False)
decode_times: bool = attr.ib(default=False)
consolidated: Optional[bool] = attr.ib(default=True)
cache_client: Optional[CacheClient] = attr.ib(default=None)

# xarray.DataArray options
Expand All @@ -266,9 +273,7 @@ def __attrs_post_init__(self):
self.ds = self.opener(
self.src_path,
group=self.group,
reference=self.reference,
decode_times=self.decode_times,
consolidated=self.consolidated,
cache_client=self.cache_client,
)

Expand All @@ -293,14 +298,10 @@ def list_variables(
cls,
src_path: str,
group: Optional[Any] = None,
reference: Optional[bool] = False,
consolidated: Optional[bool] = True,
) -> List[str]:
"""List available variable in a dataset."""
with xarray_open_dataset(
src_path,
group=group,
reference=reference,
consolidated=consolidated,
) as ds:
return list(ds.data_vars) # type: ignore

0 comments on commit d0804ec

Please sign in to comment.