diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index c58c312c..edce7b8f 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -614,21 +614,17 @@ def nest_lists( recommend setting the following dask config setting to prevent this: `dask.config.set({"dataframe.convert-string":False})` """ - new_ddf = super().nest_lists( + catalog = super().nest_lists( base_columns=base_columns, list_columns=list_columns, name=name, ) - - catalog = Catalog(new_ddf._ddf, self._ddf_pixel_map, self.hc_structure) - if self.margin is not None: catalog.margin = self.margin.nest_lists( base_columns=base_columns, list_columns=list_columns, name=name, ) - return catalog def dropna( @@ -708,6 +704,53 @@ def dropna( return catalog def reduce(self, func, *args, meta=None, **kwargs) -> Catalog: + """ + Takes a function and applies it to each top-level row of the Catalog. + + docstring copied from nested-pandas + + The user may specify which columns the function is applied to, with + columns from the 'base' layer being passsed to the function as + scalars and columns from the nested layers being passed as numpy arrays. + + Parameters + ---------- + func : callable + Function to apply to each nested dataframe. The first arguments to `func` should be which + columns to apply the function to. See the Notes for recommendations + on writing func outputs. + args : positional arguments + Positional arguments to pass to the function, the first *args should be the names of the + columns to apply the function to. + meta : dataframe or series-like, optional + The dask meta of the output. If append_columns is True, the meta should specify just the + additional columns output by func. + append_columns : bool + If the output columns should be appended to the orignal dataframe. + kwargs : keyword arguments, optional + Keyword arguments to pass to the function. + + Returns + ------- + `HealpixDataset` + `HealpixDataset` with the results of the function applied to the columns of the frame. + + Notes + ----- + By default, `reduce` will produce a `NestedFrame` with enumerated + column names for each returned value of the function. For more useful + naming, it's recommended to have `func` return a dictionary where each + key is an output column of the dataframe returned by `reduce`. + + Example User Function: + + >>> def my_sum(col1, col2): + >>> '''reduce will return a NestedFrame with two columns''' + >>> return {"sum_col1": sum(col1), "sum_col2": sum(col2)} + >>> + >>> catalog.reduce(my_sum, 'sources.col1', 'sources.col2') + + """ catalog = super().reduce(func, *args, meta=meta, **kwargs) if self.margin is not None: catalog.margin = self.margin.reduce(func, *args, meta=meta, **kwargs) diff --git a/src/lsdb/catalog/dataset/healpix_dataset.py b/src/lsdb/catalog/dataset/healpix_dataset.py index 4946b5d9..c8225c22 100644 --- a/src/lsdb/catalog/dataset/healpix_dataset.py +++ b/src/lsdb/catalog/dataset/healpix_dataset.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import warnings from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Tuple, cast @@ -82,6 +81,20 @@ def __len__(self): """ return len(self.hc_structure) + def _create_modified_hc_structure(self, **kwargs) -> HCHealpixDataset: + """Copy the catalog structure and override the specified catalog info parameters. + + Returns: + A copy of the catalog's structure with updated info parameters. + """ + return self.hc_structure.__class__( + catalog_info=self.hc_structure.catalog_info.copy_and_update(**kwargs), + pixels=self.hc_structure.pixel_tree, + catalog_path=self.hc_structure.catalog_path, + schema=self.hc_structure.schema, + moc=self.hc_structure.moc, + ) + def get_healpix_pixels(self) -> List[HealpixPixel]: """Get all HEALPix pixels that are contained in the catalog @@ -146,8 +159,7 @@ def query(self, expr: str) -> Self: with the query expression """ ndf = self._ddf.query(expr) - hc_structure = copy.copy(self.hc_structure) - hc_structure.catalog_info.total_rows = 0 + hc_structure = self._create_modified_hc_structure(total_rows=0) return self.__class__(ndf, self._ddf_pixel_map, hc_structure) def _perform_search( @@ -527,8 +539,7 @@ def drop_na_part(df: npd.NestedFrame): return df ndf = self._ddf.map_partitions(drop_na_part, meta=self._ddf._meta) - hc_structure = copy.copy(self.hc_structure) - hc_structure.catalog_info.total_rows = 0 + hc_structure = self._create_modified_hc_structure(total_rows=0) return self.__class__(ndf, self._ddf_pixel_map, hc_structure) def nest_lists( @@ -574,9 +585,7 @@ def nest_lists( list_columns=list_columns, name=name, ) - - hc_structure = copy.copy(self.hc_structure) - hc_structure.catalog_info.total_rows = 0 + hc_structure = self._create_modified_hc_structure(total_rows=0) return self.__class__(new_ddf, self._ddf_pixel_map, hc_structure) def reduce(self, func, *args, meta=None, append_columns=False, **kwargs) -> Self: @@ -643,13 +652,10 @@ def reduce_part(df): ndf = nd.NestedFrame.from_dask_dataframe(self._ddf.map_partitions(reduce_part, meta=meta)) - hc_catalog = self.hc_structure + hc_updates: dict = {"total_rows": 0} if not append_columns: - new_catalog_info = self.hc_structure.catalog_info.copy_and_update(ra_column="", dec_column="") - hc_catalog = self.hc_structure.__class__( - new_catalog_info, - self.hc_structure.pixel_tree, - schema=get_arrow_schema(ndf), - moc=self.hc_structure.moc, - ) + hc_updates = {**hc_updates, "ra_column": "", "dec_column": ""} + + hc_catalog = self._create_modified_hc_structure(**hc_updates) + hc_catalog.schema = get_arrow_schema(ndf) return self.__class__(ndf, self._ddf_pixel_map, hc_catalog) diff --git a/tests/lsdb/catalog/test_catalog.py b/tests/lsdb/catalog/test_catalog.py index 08c80794..0b94cf8a 100644 --- a/tests/lsdb/catalog/test_catalog.py +++ b/tests/lsdb/catalog/test_catalog.py @@ -658,3 +658,27 @@ def test_joined_catalog_has_undetermined_len( ) with pytest.raises(ValueError, match="undetermined"): len(small_sky_order1_catalog.merge_asof(small_sky_xmatch_catalog)) + + +def test_modified_hc_structure_is_a_deep_copy(small_sky_order1_catalog): + assert small_sky_order1_catalog.hc_structure.pixel_tree is not None + assert small_sky_order1_catalog.hc_structure.catalog_path is not None + assert small_sky_order1_catalog.hc_structure.schema is not None + assert small_sky_order1_catalog.hc_structure.moc is not None + assert small_sky_order1_catalog.hc_structure.catalog_info.total_rows == 131 + + modified_hc_structure = small_sky_order1_catalog._create_modified_hc_structure(total_rows=0) + modified_hc_structure.pixel_tree = None + modified_hc_structure.catalog_path = None + modified_hc_structure.schema = None + modified_hc_structure.moc = None + + # The original catalog structure is not modified + assert small_sky_order1_catalog.hc_structure.pixel_tree is not None + assert small_sky_order1_catalog.hc_structure.catalog_path is not None + assert small_sky_order1_catalog.hc_structure.schema is not None + assert small_sky_order1_catalog.hc_structure.moc is not None + assert small_sky_order1_catalog.hc_structure.catalog_info.total_rows == 131 + + # The rows of the new structure are invalidated + assert modified_hc_structure.catalog_info.total_rows == 0 diff --git a/tests/lsdb/catalog/test_nested.py b/tests/lsdb/catalog/test_nested.py index 1222cf6c..48a2da73 100644 --- a/tests/lsdb/catalog/test_nested.py +++ b/tests/lsdb/catalog/test_nested.py @@ -72,6 +72,9 @@ def mean_mag(ra, dec, mag): assert isinstance(reduced_cat, Catalog) assert isinstance(reduced_cat._ddf, nd.NestedFrame) + assert reduced_cat.hc_structure.catalog_info.ra_column == "" + assert reduced_cat.hc_structure.catalog_info.dec_column == "" + reduced_cat_compute = reduced_cat.compute() assert isinstance(reduced_cat_compute, npd.NestedFrame)