Skip to content

Commit

Permalink
Fix shallow copy of catalog structure (#447)
Browse files Browse the repository at this point in the history
* Rebase branch

* Fix shallow copy of hc structure
  • Loading branch information
camposandro authored Oct 23, 2024
1 parent 5ad61e1 commit 84dc3cb
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 21 deletions.
53 changes: 48 additions & 5 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,21 +614,17 @@ def nest_lists(
recommend setting the following dask config setting to prevent this:
`dask.config.set({"dataframe.convert-string":False})`
"""
new_ddf = super().nest_lists(
catalog = super().nest_lists(
base_columns=base_columns,
list_columns=list_columns,
name=name,
)

catalog = Catalog(new_ddf._ddf, self._ddf_pixel_map, self.hc_structure)

if self.margin is not None:
catalog.margin = self.margin.nest_lists(
base_columns=base_columns,
list_columns=list_columns,
name=name,
)

return catalog

def dropna(
Expand Down Expand Up @@ -708,6 +704,53 @@ def dropna(
return catalog

def reduce(self, func, *args, meta=None, **kwargs) -> Catalog:
"""
Takes a function and applies it to each top-level row of the Catalog.
docstring copied from nested-pandas
The user may specify which columns the function is applied to, with
columns from the 'base' layer being passsed to the function as
scalars and columns from the nested layers being passed as numpy arrays.
Parameters
----------
func : callable
Function to apply to each nested dataframe. The first arguments to `func` should be which
columns to apply the function to. See the Notes for recommendations
on writing func outputs.
args : positional arguments
Positional arguments to pass to the function, the first *args should be the names of the
columns to apply the function to.
meta : dataframe or series-like, optional
The dask meta of the output. If append_columns is True, the meta should specify just the
additional columns output by func.
append_columns : bool
If the output columns should be appended to the orignal dataframe.
kwargs : keyword arguments, optional
Keyword arguments to pass to the function.
Returns
-------
`HealpixDataset`
`HealpixDataset` with the results of the function applied to the columns of the frame.
Notes
-----
By default, `reduce` will produce a `NestedFrame` with enumerated
column names for each returned value of the function. For more useful
naming, it's recommended to have `func` return a dictionary where each
key is an output column of the dataframe returned by `reduce`.
Example User Function:
>>> def my_sum(col1, col2):
>>> '''reduce will return a NestedFrame with two columns'''
>>> return {"sum_col1": sum(col1), "sum_col2": sum(col2)}
>>>
>>> catalog.reduce(my_sum, 'sources.col1', 'sources.col2')
"""
catalog = super().reduce(func, *args, meta=meta, **kwargs)
if self.margin is not None:
catalog.margin = self.margin.reduce(func, *args, meta=meta, **kwargs)
Expand Down
38 changes: 22 additions & 16 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Tuple, cast
Expand Down Expand Up @@ -82,6 +81,20 @@ def __len__(self):
"""
return len(self.hc_structure)

def _create_modified_hc_structure(self, **kwargs) -> HCHealpixDataset:
"""Copy the catalog structure and override the specified catalog info parameters.
Returns:
A copy of the catalog's structure with updated info parameters.
"""
return self.hc_structure.__class__(
catalog_info=self.hc_structure.catalog_info.copy_and_update(**kwargs),
pixels=self.hc_structure.pixel_tree,
catalog_path=self.hc_structure.catalog_path,
schema=self.hc_structure.schema,
moc=self.hc_structure.moc,
)

def get_healpix_pixels(self) -> List[HealpixPixel]:
"""Get all HEALPix pixels that are contained in the catalog
Expand Down Expand Up @@ -146,8 +159,7 @@ def query(self, expr: str) -> Self:
with the query expression
"""
ndf = self._ddf.query(expr)
hc_structure = copy.copy(self.hc_structure)
hc_structure.catalog_info.total_rows = 0
hc_structure = self._create_modified_hc_structure(total_rows=0)
return self.__class__(ndf, self._ddf_pixel_map, hc_structure)

def _perform_search(
Expand Down Expand Up @@ -527,8 +539,7 @@ def drop_na_part(df: npd.NestedFrame):
return df

ndf = self._ddf.map_partitions(drop_na_part, meta=self._ddf._meta)
hc_structure = copy.copy(self.hc_structure)
hc_structure.catalog_info.total_rows = 0
hc_structure = self._create_modified_hc_structure(total_rows=0)
return self.__class__(ndf, self._ddf_pixel_map, hc_structure)

def nest_lists(
Expand Down Expand Up @@ -574,9 +585,7 @@ def nest_lists(
list_columns=list_columns,
name=name,
)

hc_structure = copy.copy(self.hc_structure)
hc_structure.catalog_info.total_rows = 0
hc_structure = self._create_modified_hc_structure(total_rows=0)
return self.__class__(new_ddf, self._ddf_pixel_map, hc_structure)

def reduce(self, func, *args, meta=None, append_columns=False, **kwargs) -> Self:
Expand Down Expand Up @@ -643,13 +652,10 @@ def reduce_part(df):

ndf = nd.NestedFrame.from_dask_dataframe(self._ddf.map_partitions(reduce_part, meta=meta))

hc_catalog = self.hc_structure
hc_updates: dict = {"total_rows": 0}
if not append_columns:
new_catalog_info = self.hc_structure.catalog_info.copy_and_update(ra_column="", dec_column="")
hc_catalog = self.hc_structure.__class__(
new_catalog_info,
self.hc_structure.pixel_tree,
schema=get_arrow_schema(ndf),
moc=self.hc_structure.moc,
)
hc_updates = {**hc_updates, "ra_column": "", "dec_column": ""}

hc_catalog = self._create_modified_hc_structure(**hc_updates)
hc_catalog.schema = get_arrow_schema(ndf)
return self.__class__(ndf, self._ddf_pixel_map, hc_catalog)
24 changes: 24 additions & 0 deletions tests/lsdb/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,27 @@ def test_joined_catalog_has_undetermined_len(
)
with pytest.raises(ValueError, match="undetermined"):
len(small_sky_order1_catalog.merge_asof(small_sky_xmatch_catalog))


def test_modified_hc_structure_is_a_deep_copy(small_sky_order1_catalog):
assert small_sky_order1_catalog.hc_structure.pixel_tree is not None
assert small_sky_order1_catalog.hc_structure.catalog_path is not None
assert small_sky_order1_catalog.hc_structure.schema is not None
assert small_sky_order1_catalog.hc_structure.moc is not None
assert small_sky_order1_catalog.hc_structure.catalog_info.total_rows == 131

modified_hc_structure = small_sky_order1_catalog._create_modified_hc_structure(total_rows=0)
modified_hc_structure.pixel_tree = None
modified_hc_structure.catalog_path = None
modified_hc_structure.schema = None
modified_hc_structure.moc = None

# The original catalog structure is not modified
assert small_sky_order1_catalog.hc_structure.pixel_tree is not None
assert small_sky_order1_catalog.hc_structure.catalog_path is not None
assert small_sky_order1_catalog.hc_structure.schema is not None
assert small_sky_order1_catalog.hc_structure.moc is not None
assert small_sky_order1_catalog.hc_structure.catalog_info.total_rows == 131

# The rows of the new structure are invalidated
assert modified_hc_structure.catalog_info.total_rows == 0
3 changes: 3 additions & 0 deletions tests/lsdb/catalog/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def mean_mag(ra, dec, mag):
assert isinstance(reduced_cat, Catalog)
assert isinstance(reduced_cat._ddf, nd.NestedFrame)

assert reduced_cat.hc_structure.catalog_info.ra_column == ""
assert reduced_cat.hc_structure.catalog_info.dec_column == ""

reduced_cat_compute = reduced_cat.compute()
assert isinstance(reduced_cat_compute, npd.NestedFrame)

Expand Down

0 comments on commit 84dc3cb

Please sign in to comment.