Skip to content

Commit

Permalink
Use HATS filter_by_* methods for spatial filtering (#497)
Browse files Browse the repository at this point in the history
* Use validation for search filters

* Validate search object on construction

* Update convexity test
  • Loading branch information
camposandro authored Nov 14, 2024
1 parent e86782b commit 76a8ab1
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 71 deletions.
5 changes: 2 additions & 3 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import nested_pandas as npd
import pandas as pd
from hats.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog
from hats.pixel_math.polygon_filter import SphericalCoordinates
from pandas._libs import lib
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas.api.extensions import no_default
Expand Down Expand Up @@ -266,14 +265,14 @@ def box_search(
"""
return self.search(BoxSearch(ra, dec, fine))

def polygon_search(self, vertices: List[SphericalCoordinates], fine: bool = True) -> Catalog:
def polygon_search(self, vertices: list[tuple[float, float]], fine: bool = True) -> Catalog:
"""Perform a polygonal search to filter the catalog.
Filters to points within the polygonal region specified in ra and dec, in degrees.
Filters partitions in the catalog to those that have some overlap with the region.
Args:
vertices (List[Tuple[float, float]): The list of vertices of the polygon to
vertices (list[tuple[float, float]]): The list of vertices of the polygon to
filter pixels with, as a list of (ra,dec) coordinates, in degrees.
fine (bool): True if points are to be filtered, False if not. Defaults to True.
Expand Down
14 changes: 2 additions & 12 deletions src/lsdb/core/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,9 @@ class AbstractSearch(ABC):
def __init__(self, fine: bool = True):
self.fine = fine

def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar:
"""Filters the hats catalog object to the partitions included in the search"""
if len(hc_structure.get_healpix_pixels()) == 0:
return hc_structure
max_order = hc_structure.get_max_coverage_order()
search_moc = self.generate_search_moc(max_order)
return hc_structure.filter_by_moc(search_moc)

def generate_search_moc(self, max_order: int) -> MOC:
def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> MOC:
"""Determine the target partitions for further filtering."""
raise NotImplementedError(
"Search Class must implement `generate_search_moc` method or overwrite `filter_hc_catalog`"
)
raise NotImplementedError("Search Class must implement `filter_hc_catalog` method")

@abstractmethod
def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
Expand Down
28 changes: 14 additions & 14 deletions src/lsdb/core/search/box_search.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from typing import Tuple

import nested_pandas as npd
import numpy as np
from hats.catalog import TableProperties
from hats.pixel_math.box_filter import generate_box_moc, wrap_ra_angles
from hats.pixel_math.validators import validate_box_search
from hats.pixel_math.box_filter import wrap_ra_angles
from hats.pixel_math.validators import validate_box
from mocpy import MOC

from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.types import HCCatalogTypeVar


class BoxSearch(AbstractSearch):
Expand All @@ -22,17 +21,18 @@ class BoxSearch(AbstractSearch):

def __init__(
self,
ra: Tuple[float, float] | None = None,
dec: Tuple[float, float] | None = None,
ra: tuple[float, float] | None = None,
dec: tuple[float, float] | None = None,
fine: bool = True,
):
super().__init__(fine)
ra = tuple(wrap_ra_angles(ra)) if ra else None
validate_box_search(ra, dec)
validate_box(ra, dec)
self.ra, self.dec = ra, dec

def generate_search_moc(self, max_order: int) -> MOC:
return generate_box_moc(self.ra, self.dec, max_order)
def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> MOC:
"""Filters catalog pixels according to the box"""
return hc_structure.filter_by_box(self.ra, self.dec)

def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
"""Determine the search results within a data frame"""
Expand All @@ -41,16 +41,16 @@ def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> np

def box_filter(
data_frame: npd.NestedFrame,
ra: Tuple[float, float] | None,
dec: Tuple[float, float] | None,
ra: tuple[float, float] | None,
dec: tuple[float, float] | None,
metadata: TableProperties,
) -> npd.NestedFrame:
"""Filters a dataframe to only include points within the specified box region.
Args:
data_frame (npd.NestedFrame): DataFrame containing points in the sky
ra (Tuple[float, float]): Right ascension range, in degrees
dec (Tuple[float, float]): Declination range, in degrees
ra (tuple[float, float]): Right ascension range, in degrees
dec (tuple[float, float]): Declination range, in degrees
metadata (hc.catalog.Catalog): hats `Catalog` with catalog_info that matches `data_frame`
Returns:
Expand All @@ -70,7 +70,7 @@ def box_filter(
return data_frame


def _create_ra_mask(ra: Tuple[float, float], values: np.ndarray) -> np.ndarray:
def _create_ra_mask(ra: tuple[float, float], values: np.ndarray) -> np.ndarray:
"""Creates the mask to filter right ascension values. If this range crosses
the discontinuity line (0 degrees), we have a branched logical operation."""
if ra[0] <= ra[1]:
Expand Down
7 changes: 4 additions & 3 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import nested_pandas as npd
from astropy.coordinates import SkyCoord
from hats.catalog import TableProperties
from hats.pixel_math.cone_filter import generate_cone_moc
from hats.pixel_math.validators import validate_declination_values, validate_radius
from mocpy import MOC

from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.types import HCCatalogTypeVar


class ConeSearch(AbstractSearch):
Expand All @@ -23,8 +23,9 @@ def __init__(self, ra: float, dec: float, radius_arcsec: float, fine: bool = Tru
self.dec = dec
self.radius_arcsec = radius_arcsec

def generate_search_moc(self, max_order: int) -> MOC:
return generate_cone_moc(self.ra, self.dec, self.radius_arcsec, max_order)
def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> MOC:
"""Filters catalog pixels according to the cone"""
return hc_structure.filter_by_cone(self.ra, self.dec, self.radius_arcsec)

def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
"""Determine the search results within a data frame"""
Expand Down
40 changes: 17 additions & 23 deletions src/lsdb/core/search/polygon_search.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import List, Tuple

import hats.pixel_math.healpix_shim as hp
import nested_pandas as npd
import numpy as np
from hats.catalog import TableProperties
from hats.pixel_math.polygon_filter import CartesianCoordinates, SphericalCoordinates, generate_polygon_moc
from hats.pixel_math.validators import validate_declination_values, validate_polygon
from hats.pixel_math.validators import validate_polygon
from lsst.sphgeom import ConvexPolygon, UnitVector3d
from mocpy import MOC

from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.types import HCCatalogTypeVar


class PolygonSearch(AbstractSearch):
Expand All @@ -19,15 +16,15 @@ class PolygonSearch(AbstractSearch):
Filters partitions in the catalog to those that have some overlap with the region.
"""

def __init__(self, vertices: List[SphericalCoordinates], fine: bool = True):
def __init__(self, vertices: list[tuple[float, float]], fine: bool = True):
super().__init__(fine)
_, dec = np.array(vertices).T
validate_declination_values(dec)
self.vertices = np.array(vertices)
self.polygon, self.vertices_xyz = get_cartesian_polygon(vertices)
validate_polygon(vertices)
self.vertices = vertices
self.polygon = get_cartesian_polygon(vertices)

def generate_search_moc(self, max_order: int) -> MOC:
return generate_polygon_moc(self.vertices_xyz, max_order)
def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> HCCatalogTypeVar:
"""Filters catalog pixels according to the polygon"""
return hc_structure.filter_by_polygon(self.vertices)

def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
"""Determine the search results within a data frame"""
Expand All @@ -54,21 +51,18 @@ def polygon_filter(
return data_frame


def get_cartesian_polygon(
vertices: List[SphericalCoordinates],
) -> Tuple[ConvexPolygon, List[CartesianCoordinates]]:
"""Creates the convex polygon to filter pixels with. It transforms the vertices, provided
in sky coordinates of ra and dec, to their respective cartesian representation on the unit sphere.
def get_cartesian_polygon(vertices: list[tuple[float, float]]) -> ConvexPolygon:
"""Creates the convex polygon to filter pixels with. It transforms the
vertices, provided in sky coordinates of ra and dec, to their respective
cartesian representation on the unit sphere.
Args:
vertices (List[Tuple[float, float]): The list of vertices of the polygon to
filter pixels with, as a list of (ra,dec) coordinates, in degrees.
vertices (list[tuple[float, float]): The list of vertices of the polygon
to filter pixels with, as a list of (ra,dec) coordinates, in degrees.
Returns:
A tuple, where the first element is the convex polygon object and the second is the
list of cartesian coordinates of its vertices.
The convex polygon object.
"""
vertices_xyz = hp.ang2vec(*np.array(vertices).T, lonlat=True)
validate_polygon(vertices_xyz)
edge_vectors = [UnitVector3d(x, y, z) for x, y, z in vertices_xyz]
return ConvexPolygon(edge_vectors), vertices_xyz
return ConvexPolygon(edge_vectors)
21 changes: 10 additions & 11 deletions tests/lsdb/catalog/test_polygon_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import sys

import nested_dask as nd
import nested_pandas as npd
import numpy as np
Expand All @@ -12,7 +10,7 @@

def test_polygon_search_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
polygon, _ = get_cartesian_polygon(vertices)
polygon = get_cartesian_polygon(vertices)
polygon_search_catalog = small_sky_order1_catalog.polygon_search(vertices)
assert isinstance(polygon_search_catalog._ddf, nd.NestedFrame)
polygon_search_df = polygon_search_catalog.compute()
Expand All @@ -31,7 +29,7 @@ def test_polygon_search_filters_correct_points_margin(
small_sky_order1_source_with_margin, assert_divisions_are_correct
):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
polygon, _ = get_cartesian_polygon(vertices)
polygon = get_cartesian_polygon(vertices)
polygon_search_catalog = small_sky_order1_source_with_margin.polygon_search(vertices)
polygon_search_df = polygon_search_catalog.compute()
ra_values_radians = np.radians(
Expand All @@ -57,8 +55,7 @@ def test_polygon_search_filters_correct_points_margin(

def test_polygon_search_filters_partitions(small_sky_order1_catalog):
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
_, vertices_xyz = get_cartesian_polygon(vertices)
hc_polygon_search = small_sky_order1_catalog.hc_structure.filter_by_polygon(vertices_xyz)
hc_polygon_search = small_sky_order1_catalog.hc_structure.filter_by_polygon(vertices)
polygon_search_catalog = small_sky_order1_catalog.polygon_search(vertices, fine=False)
assert len(hc_polygon_search.get_healpix_pixels()) == len(polygon_search_catalog.get_healpix_pixels())
assert len(hc_polygon_search.get_healpix_pixels()) == polygon_search_catalog._ddf.npartitions
Expand All @@ -82,10 +79,9 @@ def test_polygon_search_invalid_dec(small_sky_order1_catalog):
small_sky_order1_catalog.polygon_search(vertices)


@pytest.mark.skipif(sys.platform == "darwin", reason="Test skipped on macOS")
def test_polygon_search_invalid_shape(small_sky_order1_catalog):
"""The polygon is not convex, so the shape is invalid"""
with pytest.raises(RuntimeError):
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_CONCAVE_SHAPE):
vertices = [(45, 30), (60, 60), (90, 45), (60, 50)]
small_sky_order1_catalog.polygon_search(vertices)

Expand All @@ -105,6 +101,9 @@ def test_polygon_search_invalid_polygon(small_sky_order1_catalog):
with pytest.raises(ValueError, match=ValidatorsErrors.DEGENERATE_POLYGON):
vertices = [(50.1, 0), (100.1, 0), (150.1, 0), (200.1, 0)]
small_sky_order1_catalog.polygon_search(vertices)
with pytest.raises(ValueError, match=ValidatorsErrors.INVALID_CONCAVE_SHAPE):
vertices = [(45, 30), (60, 60), (90, 45), (60, 50)]
small_sky_order1_catalog.polygon_search(vertices)


def test_polygon_search_wrapped_right_ascension():
Expand Down Expand Up @@ -134,10 +133,10 @@ def test_polygon_search_wrapped_right_ascension():
[(-20.1, 1), (-380.2, -1), (380.3, -1)],
[(-20.1, 1), (-380.2, -1), (-339.7, -1)],
]
_, vertices_xyz = get_cartesian_polygon(vertices)
polygon = get_cartesian_polygon(vertices)
for v in all_vertices_combinations:
_, wrapped_v_xyz = get_cartesian_polygon(v)
npt.assert_allclose(vertices_xyz, wrapped_v_xyz, rtol=1e-7)
polygon_2 = get_cartesian_polygon(v)
npt.assert_allclose(polygon.getVertices(), polygon_2.getVertices(), rtol=1e-7)


def test_empty_polygon_search_with_margin(small_sky_order1_source_with_margin):
Expand Down
7 changes: 2 additions & 5 deletions tests/lsdb/loaders/hats/test_read_hats.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,8 @@ def test_read_hats_margin_catalog_subset(

def test_read_hats_margin_catalog_subset_is_empty(small_sky_order1_source_margin_dir):
search_filter = ConeSearch(ra=100, dec=80, radius_arcsec=1)
margin_catalog = lsdb.read_hats(small_sky_order1_source_margin_dir, search_filter=search_filter)
assert len(margin_catalog.get_healpix_pixels()) == 0
assert len(margin_catalog._ddf_pixel_map) == 0
assert len(margin_catalog.compute()) == 0
assert len(margin_catalog.hc_structure.pixel_tree) == 0
with pytest.raises(ValueError, match="empty catalog"):
lsdb.read_hats(small_sky_order1_source_margin_dir, search_filter=search_filter)


def test_read_hats_schema_not_found(small_sky_no_metadata_dir):
Expand Down

0 comments on commit 76a8ab1

Please sign in to comment.