From 5cc2202b9576d9f627efa1c807d52575ac19de1b Mon Sep 17 00:00:00 2001 From: Sandro Campos Date: Tue, 29 Oct 2024 11:47:26 -0400 Subject: [PATCH] Remove margin object from read_hats --- src/lsdb/loaders/hats/hats_loading_config.py | 6 ++-- src/lsdb/loaders/hats/read_hats.py | 31 +++++++++++++------- src/lsdb/loaders/hats/read_hats.pyi | 4 +-- tests/conftest.py | 8 ++--- tests/lsdb/catalog/test_crossmatch.py | 12 ++++---- tests/lsdb/loaders/hats/test_read_hats.py | 13 ++------ 6 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/lsdb/loaders/hats/hats_loading_config.py b/src/lsdb/loaders/hats/hats_loading_config.py index 5962209e..f69416c2 100644 --- a/src/lsdb/loaders/hats/hats_loading_config.py +++ b/src/lsdb/loaders/hats/hats_loading_config.py @@ -8,7 +8,6 @@ from pandas.io._util import _arrow_dtype_mapping from upath import UPath -from lsdb.catalog.margin_catalog import MarginCatalog from lsdb.core.search.abstract_search import AbstractSearch @@ -25,9 +24,8 @@ class HatsLoadingConfig: columns: List[str] | None = None """Columns to load from the catalog. If not specified, all columns are loaded""" - margin_cache: MarginCatalog | str | Path | UPath | None = None - """Margin cache for the catalog. It can be provided as a path for the margin on disk, - or as a margin object instance. By default, it is None.""" + margin_cache: str | Path | UPath | None = None + """Path to the margin cache catalog. Defaults to None.""" dtype_backend: str | None = "pyarrow" """The backend data type to apply to the catalog. It defaults to "pyarrow" and diff --git a/src/lsdb/loaders/hats/read_hats.py b/src/lsdb/loaders/hats/read_hats.py index 07c38b58..2bf64b6a 100644 --- a/src/lsdb/loaders/hats/read_hats.py +++ b/src/lsdb/loaders/hats/read_hats.py @@ -10,6 +10,7 @@ import pyarrow as pa from hats.catalog import CatalogType from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset +from hats.io import paths from hats.io.file_io import file_io from hats.pixel_math import HealpixPixel from hats.pixel_math.healpix_pixel_function import get_pixel_argsort @@ -28,7 +29,7 @@ def read_hats( path: str | Path | UPath, search_filter: AbstractSearch | None = None, columns: List[str] | None = None, - margin_cache: MarginCatalog | str | Path | UPath | None = None, + margin_cache: str | Path | UPath | None = None, dtype_backend: str | None = "pyarrow", **kwargs, ) -> CatalogTypeVar | None: @@ -50,8 +51,7 @@ def read_hats( path (UPath | Path): The path that locates the root of the HATS catalog search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied. columns (List[str]): Default `None`. The set of columns to filter the catalog on. - margin_cache (MarginCatalog or path-like): The margin cache for the main catalog, - provided as a path on disk or as an instance of the MarginCatalog object. Defaults to None. + margin_cache (path-like): Default `None`. The margin for the main catalog, provided as a path. dtype_backend (str): Backend data type to apply to the catalog. Defaults to "pyarrow". If None, no type conversion is performed. **kwargs: Arguments to pass to the pandas parquet file reader @@ -146,17 +146,26 @@ def _load_object_catalog(hc_catalog, config): catalog = Catalog(dask_df, dask_df_pixel_map, hc_catalog) if config.search_filter is not None: catalog = catalog.search(config.search_filter) - if isinstance(config.margin_cache, MarginCatalog): - catalog.margin = config.margin_cache - if config.search_filter is not None: - # pylint: disable=protected-access - catalog.margin = catalog.margin.search(config.search_filter) - elif config.margin_cache is not None: - hc_catalog = hc.read_hats(config.margin_cache) - catalog.margin = _load_margin_catalog(hc_catalog, config) + if config.margin_cache is not None: + margin_hc_catalog = hc.read_hats(config.margin_cache) + margin = _load_margin_catalog(margin_hc_catalog, config) + _validate_margin_catalog(margin_hc_catalog, hc_catalog) + catalog.margin = margin return catalog +def _validate_margin_catalog(margin_hc_catalog, hc_catalog): + """Validate that the margin catalog and the main catalog are compatible""" + pixel_columns = [paths.PARTITION_ORDER, paths.PARTITION_DIR, paths.PARTITION_PIXEL] + margin_pixel_columns = pixel_columns + ["margin_" + column for column in pixel_columns] + catalog_schema = pa.schema([field for field in hc_catalog.schema if field.name not in pixel_columns]) + margin_schema = pa.schema( + [field for field in margin_hc_catalog.schema if field.name not in margin_pixel_columns] + ) + if not catalog_schema.equals(margin_schema): + raise ValueError("The margin catalog and the main catalog must have the same schema") + + def _create_dask_meta_schema(schema: pa.Schema, config) -> npd.NestedFrame: """Creates the Dask meta DataFrame from the HATS catalog schema.""" dask_meta_schema = schema.empty_table().to_pandas(types_mapper=config.get_dtype_mapper()) diff --git a/src/lsdb/loaders/hats/read_hats.pyi b/src/lsdb/loaders/hats/read_hats.pyi index 68cfb2b5..0bbdc7e4 100644 --- a/src/lsdb/loaders/hats/read_hats.pyi +++ b/src/lsdb/loaders/hats/read_hats.pyi @@ -28,7 +28,7 @@ def read_hats( path: str | Path | UPath, search_filter: AbstractSearch | None = None, columns: List[str] | None = None, - margin_cache: MarginCatalog | str | Path | UPath | None = None, + margin_cache: str | Path | UPath | None = None, dtype_backend: str | None = "pyarrow", **kwargs, ) -> Dataset | None: ... @@ -38,7 +38,7 @@ def read_hats( catalog_type: Type[CatalogTypeVar], search_filter: AbstractSearch | None = None, columns: List[str] | None = None, - margin_cache: MarginCatalog | str | Path | UPath | None = None, + margin_cache: str | Path | UPath | None = None, dtype_backend: str | None = "pyarrow", **kwargs, ) -> CatalogTypeVar | None: ... diff --git a/tests/conftest.py b/tests/conftest.py index 8f90bd6d..853ecc5e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -142,8 +142,8 @@ def small_sky_xmatch_margin_catalog(small_sky_xmatch_margin_dir): @pytest.fixture -def small_sky_xmatch_with_margin(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog): - return lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog) +def small_sky_xmatch_with_margin(small_sky_xmatch_dir, small_sky_xmatch_margin_dir): + return lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir) @pytest.fixture @@ -167,8 +167,8 @@ def small_sky_order1_catalog(small_sky_order1_dir): @pytest.fixture -def small_sky_order1_source_with_margin(small_sky_order1_source_dir, small_sky_order1_source_margin_catalog): - return lsdb.read_hats(small_sky_order1_source_dir, margin_cache=small_sky_order1_source_margin_catalog) +def small_sky_order1_source_with_margin(small_sky_order1_source_dir, small_sky_order1_source_margin_dir): + return lsdb.read_hats(small_sky_order1_source_dir, margin_cache=small_sky_order1_source_margin_dir) @pytest.fixture diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 5867c3ab..63d0d069 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -75,10 +75,10 @@ def test_kdtree_crossmatch_multiple_neighbors( @staticmethod def test_kdtree_crossmatch_multiple_neighbors_margin( - algo, small_sky_catalog, small_sky_xmatch_dir, small_sky_xmatch_margin_catalog, xmatch_correct_3n_2t + algo, small_sky_catalog, small_sky_xmatch_dir, small_sky_xmatch_margin_dir, xmatch_correct_3n_2t ): small_sky_xmatch_catalog = lsdb.read_hats( - small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog + small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir ) xmatched = small_sky_catalog.crossmatch( small_sky_xmatch_catalog, n_neighbors=3, radius_arcsec=2 * 3600, algorithm=algo @@ -98,11 +98,11 @@ def test_crossmatch_negative_margin( algo, small_sky_left_xmatch_catalog, small_sky_xmatch_dir, - small_sky_xmatch_margin_catalog, + small_sky_xmatch_margin_dir, xmatch_correct_3n_2t_negative, ): small_sky_xmatch_catalog = lsdb.read_hats( - small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog + small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir ) xmatched = small_sky_left_xmatch_catalog.crossmatch( small_sky_xmatch_catalog, n_neighbors=3, radius_arcsec=2 * 3600, algorithm=algo @@ -154,11 +154,11 @@ def test_kdtree_crossmatch_min_thresh_multiple_neighbors_margin( algo, small_sky_catalog, small_sky_xmatch_dir, - small_sky_xmatch_margin_catalog, + small_sky_xmatch_margin_dir, xmatch_correct_05_2_3n_margin, ): small_sky_xmatch_catalog = lsdb.read_hats( - small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog + small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir ) xmatched = small_sky_catalog.crossmatch( small_sky_xmatch_catalog, diff --git a/tests/lsdb/loaders/hats/test_read_hats.py b/tests/lsdb/loaders/hats/test_read_hats.py index b1c77072..751e920d 100644 --- a/tests/lsdb/loaders/hats/test_read_hats.py +++ b/tests/lsdb/loaders/hats/test_read_hats.py @@ -134,15 +134,6 @@ def test_read_hats_specify_catalog_type(small_sky_catalog, small_sky_dir): assert isinstance(catalog.compute(), npd.NestedFrame) -def test_catalog_with_margin_object(small_sky_xmatch_dir, small_sky_xmatch_margin_catalog): - catalog = lsdb.read_hats(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_catalog) - assert isinstance(catalog, lsdb.Catalog) - assert isinstance(catalog.margin, lsdb.MarginCatalog) - assert isinstance(catalog._ddf, nd.NestedFrame) - assert catalog.margin is small_sky_xmatch_margin_catalog - assert isinstance(catalog.margin._ddf, nd.NestedFrame) - - def test_catalog_with_margin_path( small_sky_xmatch_dir, small_sky_xmatch_margin_dir, small_sky_xmatch_margin_catalog ): @@ -237,7 +228,7 @@ def test_read_hats_subset_no_partitions(small_sky_order1_dir, small_sky_order1_i def test_read_hats_with_margin_subset( - small_sky_order1_source_dir, small_sky_order1_source_with_margin, small_sky_order1_source_margin_catalog + small_sky_order1_source_dir, small_sky_order1_source_with_margin, small_sky_order1_source_margin_dir ): cone_search = ConeSearch(ra=315, dec=-66, radius_arcsec=20) # Filtering using catalog's cone_search @@ -246,7 +237,7 @@ def test_read_hats_with_margin_subset( cone_search_catalog_2 = lsdb.read_hats( small_sky_order1_source_dir, search_filter=cone_search, - margin_cache=small_sky_order1_source_margin_catalog, + margin_cache=small_sky_order1_source_margin_dir, ) assert isinstance(cone_search_catalog_2, lsdb.Catalog) # The partitions of the catalogs are equivalent