Skip to content

Commit

Permalink
modified ListBackend.regionsearch and abstrations.Backend to raise ex…
Browse files Browse the repository at this point in the history
…ception when indexer method is missing

added unit tests for exception
  • Loading branch information
cchris28 committed Jul 11, 2023
1 parent 318bddd commit 0f93c7d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/kbmod/regionsearch/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def region_search(self, filter: Filter) -> numpy.ndarray:
numpy.ndarray
Observation identifiers of pointings that match the given filter.
"""
pass
if not hasattr(self, "observations_to_indices"):
raise NotImplementedError("region_search requires an implementation of observations_to_indices")
return numpy.array([])


class ObservationIndexer(ABC):
Expand Down
21 changes: 11 additions & 10 deletions src/kbmod/regionsearch/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
lincc-frameworks provides implementations of ``abstractions.Backend`` that may be composed with an implementation of ``abstractions.ObserverationIndexer`` to provide a complete region search.
Provides implementations of ``abstractions.Backend`` that may be composed with an implementation of ``abstractions.ObserverationIndexer`` to provide a complete region search.
"""

from dataclasses import dataclass
Expand Down Expand Up @@ -91,13 +91,14 @@ def region_search(self, filter: Filter) -> np.ndarray:
A list of matching indices.
"""
matching_observation_identifier = np.array([], dtype=self.observation_identifier.dtype)
if hasattr(self, "observations_to_indices"):
pointing = coord.SkyCoord(filter.search_ra, filter.search_dec)
matching_index = self.observations_to_indices(pointing, None, filter.search_fov, None) # type: ignore
pointing = coord.SkyCoord(self.observation_ra, self.observation_dec)
self.observation_index = self.observations_to_indices( # type: ignore
pointing, self.observation_time, self.observation_fov, self.observation_location
)
index_list = np.nonzero(self.observation_index == matching_index)[0]
matching_observation_identifier = self.observation_identifier[index_list]
if not hasattr(self, "observations_to_indices"):
raise NotImplementedError("region_search requires an implementation of observations_to_indices")
pointing = coord.SkyCoord(filter.search_ra, filter.search_dec)
matching_index = self.observations_to_indices(pointing, None, filter.search_fov, None) # type: ignore
pointing = coord.SkyCoord(self.observation_ra, self.observation_dec)
self.observation_index = self.observations_to_indices( # type: ignore
pointing, self.observation_time, self.observation_fov, self.observation_location
)
index_list = np.nonzero(self.observation_index == matching_index)[0]
matching_observation_identifier = self.observation_identifier[index_list]
return matching_observation_identifier
2 changes: 1 addition & 1 deletion src/kbmod/regionsearch/indexers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
lincc-frameworks provides implementations of ``abstractions.ObserverationIndexer`` that may be composed with an implementation of ``abstractions.Backend`` to provide a complete region search.
Provides implementations of ``abstractions.ObserverationIndexer`` that may be composed with an implementation of ``abstractions.Backend`` to provide a complete region search.
"""
import math

Expand Down
3 changes: 2 additions & 1 deletion src/kbmod/regionsearch/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def read_table(self, filename: str, format: str, colnames: typing.List[str]):
The format of the file to read or write.
colnames : typing.List[str]
The list of column names to read from the file. The file may contain other columns but it must have these columns.
Returns
-------
bool
Expand All @@ -204,7 +205,7 @@ def read_table(self, filename: str, format: str, colnames: typing.List[str]):

def write_table(self, filename: str, format: str, colnames: typing.List[str]):
"""
Write the table to the file including al the columns in colnames.
Write the table to the file including all the columns in colnames.
Parameters
----------
Expand Down
38 changes: 37 additions & 1 deletion tests/regionsearch/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,27 @@
)
from astropy.time import Time # type: ignore

from kbmod.regionsearch import backend, indexers, utilities
from kbmod.regionsearch import abstractions, backend, indexers, utilities
from kbmod.regionsearch.region_search import Filter


def test_backend_abstract():
"""Tests that the backend raises an exception when region_search is called without an indexer."""

class TestBackend(abstractions.Backend):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def region_search(self, filter: Filter) -> np.ndarray:
return super().region_search(filter)

b = TestBackend()
assert b is not None
with pytest.raises(NotImplementedError):
b.region_search(Filter())
assert False, "expected NotImplementedError when calling region_search without an indexer"


def test_observationlist_init():
"""Tests the ObservationList backend's constructor."""
data = utilities.RegionSearchClusterData(clustercnt=5, samplespercluster=10, removecache=True)
Expand Down Expand Up @@ -49,6 +66,25 @@ def test_observationlist_consistency():
backend.ObservationList(ra, dec, time, location, fov, observation_identifier)


def test_observationlist_missing_observation_to_indices():
"""Tests that the backend raises an exception when region_search is called without an indexer."""
data = utilities.RegionSearchClusterData(clustercnt=5, samplespercluster=10, removecache=True)

ra = data.observation_pointing.ra
dec = data.observation_pointing.dec
time = data.observation_time
location = data.observation_geolocation
fov = np.ones([data.rowcnt]) * Angle(1, "deg")
observation_identifier = data.cluster_id
b = backend.ObservationList(ra, dec, time, location, fov, observation_identifier)
with pytest.raises(NotImplementedError):
b.region_search(Filter())
assert (
False
), "Expect NotImplementedError when region_search is called without an observation_to_indices method (missing ObservationIdexer)."
assert True


def test_observationlist_partition():
"""Tests the ObservationList backend with partition indexer."""

Expand Down

0 comments on commit 0f93c7d

Please sign in to comment.