Skip to content

Commit

Permalink
Disable fine filtering in margin generation for from_dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Oct 23, 2024
1 parent 2225b1e commit 489d742
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 60 deletions.
66 changes: 21 additions & 45 deletions src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import nested_pandas as npd
import numpy as np
import pandas as pd
from hats import pixel_math
from hats.catalog import CatalogType, TableProperties
from hats.pixel_math import HealpixPixel
from hats.pixel_math import HealpixPixel, get_margin
from hats.pixel_math.healpix_pixel_function import get_pixel_argsort

from lsdb import Catalog
Expand Down Expand Up @@ -45,31 +44,34 @@ def __init__(
self.dataframe: npd.NestedFrame = catalog.compute().copy()
self.hc_structure = catalog.hc_structure
self.margin_threshold = margin_threshold
self.margin_order = self._set_margin_order(margin_order)
self.margin_order = margin_order
self._resolve_margin_order()
self.use_pyarrow_types = use_pyarrow_types
self.catalog_info = self._create_catalog_info(**kwargs)

def _set_margin_order(self, margin_order: int | None) -> int:
def _resolve_margin_order(self) -> int:
"""Calculate the order of the margin cache to be generated. If not provided
the margin will be greater than that of the original catalog by 1.
Args:
margin_order (int): The order to generate the margin cache with
Returns:
The validated order of the margin catalog.
the margin will be calculated based on the smallest pixel possible for the threshold.
Raises:
ValueError: if the provided margin order is lower than that of the catalog.
ValueError: if the margin order and thresholds are incompatible with the catalog.
"""
margin_pixel_k = self.hc_structure.partition_info.get_highest_order() + 1
if margin_order is None or margin_order == -1:
margin_order = margin_pixel_k
elif margin_order < margin_pixel_k:

highest_order = int(self.hc_structure.partition_info.get_highest_order())

if self.margin_order < 0:
self.margin_order = hp.margin2order(margin_thr_arcmin=self.margin_threshold / 60.0)

if self.margin_order < highest_order + 1:
raise ValueError(
"margin_order must be of a higher order than the highest order catalog partition pixel."
)
return margin_order

margin_pixel_nside = hp.order2nside(self.margin_order)
margin_pixel_avgsize = hp.nside2resol(margin_pixel_nside, arcmin=True)
margin_pixel_mindist = hp.avgsize2mindist(margin_pixel_avgsize)
if margin_pixel_mindist * 60.0 < self.margin_threshold:
raise ValueError("margin pixels must be larger than margin_threshold")

Check warning on line 74 in src/lsdb/loaders/dataframe/margin_catalog_generator.py

View check run for this annotation

Codecov / codecov/patch

src/lsdb/loaders/dataframe/margin_catalog_generator.py#L74

Added line #L74 was not covered by tests

def create_catalog(self) -> MarginCatalog | None:
"""Create a margin catalog for another pre-computed catalog
Expand Down Expand Up @@ -147,7 +149,7 @@ def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame:
order = pixel.order
pix = pixel.pixel
d_order = self.margin_order - order
margins = pixel_math.get_margin(order, pix, d_order)
margins = get_margin(order, pix, d_order)
for m_p in margins:
n_orders.append(order)
part_pix.append(pix)
Expand Down Expand Up @@ -182,35 +184,9 @@ def _create_margins(self, margin_pairs_df: pd.DataFrame) -> Dict[HealpixPixel, p
["partition_order", "partition_pixel"]
):
margin_pixel = HealpixPixel(partition_group[0], partition_group[1])
df = self._get_data_in_margin(partition_df, margin_pixel)
if len(df):
df = _format_margin_partition_dataframe(df)
margin_pixel_df_map[margin_pixel] = df
margin_pixel_df_map[margin_pixel] = _format_margin_partition_dataframe(partition_df)
return margin_pixel_df_map

def _get_data_in_margin(
self, partition_df: npd.NestedFrame, margin_pixel: HealpixPixel
) -> npd.NestedFrame:
"""Calculate the margin boundaries for the HEALPix and include the points
on the margin according to the specified threshold
Args:
partition_df (pd.DataFrame): The margin pixel data
margin_pixel (HealpixPixel): The margin HEALPix
Returns:
A Pandas Dataframe with the points of the partition that are within
the specified threshold in the margin.
"""
margin_mask = pixel_math.check_margin_bounds(
partition_df[self.hc_structure.catalog_info.ra_column].to_numpy(),
partition_df[self.hc_structure.catalog_info.dec_column].to_numpy(),
margin_pixel.order,
margin_pixel.pixel,
self.margin_threshold,
)
return partition_df.iloc[margin_mask]

def _create_catalog_info(self, catalog_name: str | None = None, **kwargs) -> TableProperties:
"""Create the margin catalog info object
Expand Down
34 changes: 19 additions & 15 deletions tests/lsdb/loaders/dataframe/test_from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,32 +230,36 @@ def test_from_dataframe_small_sky_source_with_margins(small_sky_source_df, small
lowest_order=0,
highest_order=2,
threshold=3000,
margin_order=8,
margin_threshold=180,
**kwargs,
)

assert catalog.margin is not None
assert isinstance(catalog.margin, MarginCatalog)
assert isinstance(catalog.margin._ddf, nd.NestedFrame)
assert catalog.margin.get_healpix_pixels() == small_sky_source_margin_catalog.get_healpix_pixels()
margin = catalog.margin
assert isinstance(margin, MarginCatalog)
assert isinstance(margin._ddf, nd.NestedFrame)
assert margin.get_healpix_pixels() == small_sky_source_margin_catalog.get_healpix_pixels()

# The points of this margin catalog are present in one partition only
# so we are able to perform the comparison between the computed results
pd.testing.assert_frame_equal(
catalog.margin.compute().sort_index(),
small_sky_source_margin_catalog.compute().sort_index(),
check_like=True,
)
assert isinstance(catalog.margin.compute(), npd.NestedFrame)
# The points of this margin catalog will be a superset of the hats-imported one,
# as fine filtering is not enabled here.
for hp_pixel in margin.hc_structure.get_healpix_pixels():
partition_from_df = margin.get_partition(hp_pixel.order, hp_pixel.pixel)
expected_df = small_sky_source_margin_catalog.get_partition(hp_pixel.order, hp_pixel.pixel)
assert len(expected_df) <= len(partition_from_df)

margin_source_ids = set(partition_from_df["source_id"])
expected_source_ids = set(expected_df["source_id"])
assert len(expected_source_ids - margin_source_ids) == 0

assert isinstance(margin.compute(), npd.NestedFrame)

assert catalog.hc_structure.catalog_info.__pydantic_extra__["obs_regime"] == "Optical"
assert catalog.margin.hc_structure.catalog_info.__pydantic_extra__["obs_regime"] == "Optical"
assert margin.hc_structure.catalog_info.__pydantic_extra__["obs_regime"] == "Optical"

assert catalog.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb")
assert catalog.margin.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb")
assert margin.hc_structure.catalog_info.__pydantic_extra__["hats_builder"].startswith("lsdb")
# The margin and main catalog's schemas are the same
assert catalog.margin.hc_structure.schema is catalog.hc_structure.schema
assert margin.hc_structure.schema is catalog.hc_structure.schema


def test_from_dataframe_invalid_margin_order(small_sky_source_df):
Expand Down

0 comments on commit 489d742

Please sign in to comment.