Skip to content

Commit

Permalink
Remove margin object from read_hats (#475)
Browse files Browse the repository at this point in the history
* Remove margin object from read_hats

* Improve test coverage
  • Loading branch information
camposandro authored Oct 29, 2024
1 parent 6888b52 commit 4fef428
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 39 deletions.
6 changes: 2 additions & 4 deletions src/lsdb/loaders/hats/hats_loading_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pandas.io._util import _arrow_dtype_mapping
from upath import UPath

from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.core.search.abstract_search import AbstractSearch


Expand All @@ -25,9 +24,8 @@ class HatsLoadingConfig:
columns: List[str] | None = None
"""Columns to load from the catalog. If not specified, all columns are loaded"""

margin_cache: MarginCatalog | str | Path | UPath | None = None
"""Margin cache for the catalog. It can be provided as a path for the margin on disk,
or as a margin object instance. By default, it is None."""
margin_cache: str | Path | UPath | None = None
"""Path to the margin cache catalog. Defaults to None."""

dtype_backend: str | None = "pyarrow"
"""The backend data type to apply to the catalog. It defaults to "pyarrow" and
Expand Down
31 changes: 20 additions & 11 deletions src/lsdb/loaders/hats/read_hats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyarrow as pa
from hats.catalog import CatalogType
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset
from hats.io import paths
from hats.io.file_io import file_io
from hats.pixel_math import HealpixPixel
from hats.pixel_math.healpix_pixel_function import get_pixel_argsort
Expand All @@ -28,7 +29,7 @@ def read_hats(
path: str | Path | UPath,
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | str | Path | UPath | None = None,
margin_cache: str | Path | UPath | None = None,
dtype_backend: str | None = "pyarrow",
**kwargs,
) -> CatalogTypeVar | None:
Expand All @@ -50,8 +51,7 @@ def read_hats(
path (UPath | Path): The path that locates the root of the HATS catalog
search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied.
columns (List[str]): Default `None`. The set of columns to filter the catalog on.
margin_cache (MarginCatalog or path-like): The margin cache for the main catalog,
provided as a path on disk or as an instance of the MarginCatalog object. Defaults to None.
margin_cache (path-like): Default `None`. The margin for the main catalog, provided as a path.
dtype_backend (str): Backend data type to apply to the catalog.
Defaults to "pyarrow". If None, no type conversion is performed.
**kwargs: Arguments to pass to the pandas parquet file reader
Expand Down Expand Up @@ -146,17 +146,26 @@ def _load_object_catalog(hc_catalog, config):
catalog = Catalog(dask_df, dask_df_pixel_map, hc_catalog)
if config.search_filter is not None:
catalog = catalog.search(config.search_filter)
if isinstance(config.margin_cache, MarginCatalog):
catalog.margin = config.margin_cache
if config.search_filter is not None:
# pylint: disable=protected-access
catalog.margin = catalog.margin.search(config.search_filter)
elif config.margin_cache is not None:
hc_catalog = hc.read_hats(config.margin_cache)
catalog.margin = _load_margin_catalog(hc_catalog, config)
if config.margin_cache is not None:
margin_hc_catalog = hc.read_hats(config.margin_cache)
margin = _load_margin_catalog(margin_hc_catalog, config)
_validate_margin_catalog(margin_hc_catalog, hc_catalog)
catalog.margin = margin
return catalog


def _validate_margin_catalog(margin_hc_catalog, hc_catalog):
"""Validate that the margin catalog and the main catalog are compatible"""
pixel_columns = [paths.PARTITION_ORDER, paths.PARTITION_DIR, paths.PARTITION_PIXEL]
margin_pixel_columns = pixel_columns + ["margin_" + column for column in pixel_columns]
catalog_schema = pa.schema([field for field in hc_catalog.schema if field.name not in pixel_columns])
margin_schema = pa.schema(
[field for field in margin_hc_catalog.schema if field.name not in margin_pixel_columns]
)
if not catalog_schema.equals(margin_schema):
raise ValueError("The margin catalog and the main catalog must have the same schema")


def _create_dask_meta_schema(schema: pa.Schema, config) -> npd.NestedFrame:
"""Creates the Dask meta DataFrame from the HATS catalog schema."""
dask_meta_schema = schema.empty_table().to_pandas(types_mapper=config.get_dtype_mapper())
Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/loaders/hats/read_hats.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_hats(
path: str | Path | UPath,
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | str | Path | UPath | None = None,
margin_cache: str | Path | UPath | None = None,
dtype_backend: str | None = "pyarrow",
**kwargs,
) -> Dataset | None: ...
Expand All @@ -38,7 +38,7 @@ def read_hats(
catalog_type: Type[CatalogTypeVar],
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | str | Path | UPath | None = None,
margin_cache: str | Path | UPath | None = None,
dtype_backend: str | None = "pyarrow",
**kwargs,
) -> CatalogTypeVar | None: ...
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def small_sky_xmatch_margin_catalog(small_sky_xmatch_margin_dir):


@pytest.fixture
def small_sky_xmatch_with_margin(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog):
return lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog)
def small_sky_xmatch_with_margin(small_sky_xmatch_dir, small_sky_xmatch_margin_dir):
return lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir)


@pytest.fixture
Expand All @@ -167,8 +167,8 @@ def small_sky_order1_catalog(small_sky_order1_dir):


@pytest.fixture
def small_sky_order1_source_with_margin(small_sky_order1_source_dir, small_sky_order1_source_margin_catalog):
return lsdb.read_hats(small_sky_order1_source_dir, margin_cache=small_sky_order1_source_margin_catalog)
def small_sky_order1_source_with_margin(small_sky_order1_source_dir, small_sky_order1_source_margin_dir):
return lsdb.read_hats(small_sky_order1_source_dir, margin_cache=small_sky_order1_source_margin_dir)


@pytest.fixture
Expand Down
12 changes: 6 additions & 6 deletions tests/lsdb/catalog/test_crossmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def test_kdtree_crossmatch_multiple_neighbors(

@staticmethod
def test_kdtree_crossmatch_multiple_neighbors_margin(
algo, small_sky_catalog, small_sky_xmatch_dir, small_sky_xmatch_margin_catalog, xmatch_correct_3n_2t
algo, small_sky_catalog, small_sky_xmatch_dir, small_sky_xmatch_margin_dir, xmatch_correct_3n_2t
):
small_sky_xmatch_catalog = lsdb.read_hats(
small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog
small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir
)
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog, n_neighbors=3, radius_arcsec=2 * 3600, algorithm=algo
Expand All @@ -98,11 +98,11 @@ def test_crossmatch_negative_margin(
algo,
small_sky_left_xmatch_catalog,
small_sky_xmatch_dir,
small_sky_xmatch_margin_catalog,
small_sky_xmatch_margin_dir,
xmatch_correct_3n_2t_negative,
):
small_sky_xmatch_catalog = lsdb.read_hats(
small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog
small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir
)
xmatched = small_sky_left_xmatch_catalog.crossmatch(
small_sky_xmatch_catalog, n_neighbors=3, radius_arcsec=2 * 3600, algorithm=algo
Expand Down Expand Up @@ -154,11 +154,11 @@ def test_kdtree_crossmatch_min_thresh_multiple_neighbors_margin(
algo,
small_sky_catalog,
small_sky_xmatch_dir,
small_sky_xmatch_margin_catalog,
small_sky_xmatch_margin_dir,
xmatch_correct_05_2_3n_margin,
):
small_sky_xmatch_catalog = lsdb.read_hats(
small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog
small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir
)
xmatched = small_sky_catalog.crossmatch(
small_sky_xmatch_catalog,
Expand Down
20 changes: 8 additions & 12 deletions tests/lsdb/loaders/hats/test_read_hats.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,7 @@ def test_read_hats_specify_catalog_type(small_sky_catalog, small_sky_dir):
assert isinstance(catalog.compute(), npd.NestedFrame)


def test_catalog_with_margin_object(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog):
catalog = lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog)
assert isinstance(catalog, lsdb.Catalog)
assert isinstance(catalog.margin, lsdb.MarginCatalog)
assert isinstance(catalog._ddf, nd.NestedFrame)
assert catalog.margin is small_sky_xmatch_margin_catalog
assert isinstance(catalog.margin._ddf, nd.NestedFrame)


def test_catalog_with_margin_path(
def test_catalog_with_margin(
small_sky_xmatch_dir, small_sky_xmatch_margin_dir, small_sky_xmatch_margin_catalog
):
assert isinstance(small_sky_xmatch_margin_dir, Path)
Expand All @@ -165,6 +156,11 @@ def test_catalog_without_margin_is_none(small_sky_xmatch_dir):
assert catalog.margin is None


def test_catalog_with_wrong_margin(small_sky_order1_dir, small_sky_order1_source_margin_dir):
with pytest.raises(ValueError, match="must have the same schema"):
lsdb.read_hats(small_sky_order1_dir, margin_cache=small_sky_order1_source_margin_dir)


def test_read_hats_subset_with_cone_search(small_sky_order1_dir, small_sky_order1_catalog):
cone_search = ConeSearch(ra=0, dec=-80, radius_arcsec=20 * 3600)
# Filtering using catalog's cone_search
Expand Down Expand Up @@ -237,7 +233,7 @@ def test_read_hats_subset_no_partitions(small_sky_order1_dir, small_sky_order1_i


def test_read_hats_with_margin_subset(
small_sky_order1_source_dir, small_sky_order1_source_with_margin, small_sky_order1_source_margin_catalog
small_sky_order1_source_dir, small_sky_order1_source_with_margin, small_sky_order1_source_margin_dir
):
cone_search = ConeSearch(ra=315, dec=-66, radius_arcsec=20)
# Filtering using catalog's cone_search
Expand All @@ -246,7 +242,7 @@ def test_read_hats_with_margin_subset(
cone_search_catalog_2 = lsdb.read_hats(
small_sky_order1_source_dir,
search_filter=cone_search,
margin_cache=small_sky_order1_source_margin_catalog,
margin_cache=small_sky_order1_source_margin_dir,
)
assert isinstance(cone_search_catalog_2, lsdb.Catalog)
# The partitions of the catalogs are equivalent
Expand Down

0 comments on commit 4fef428

Please sign in to comment.