diff --git a/src/kbmod/regionsearch/abstractions.py b/src/kbmod/regionsearch/abstractions.py index 2747681..a6c31dc 100644 --- a/src/kbmod/regionsearch/abstractions.py +++ b/src/kbmod/regionsearch/abstractions.py @@ -59,7 +59,9 @@ def region_search(self, filter: Filter) -> numpy.ndarray: numpy.ndarray Observation identifiers of pointings that match the given filter. """ - pass + if not hasattr(self, "observations_to_indices"): + raise NotImplementedError("region_search requires an implementation of observations_to_indices") + return numpy.array([]) class ObservationIndexer(ABC): diff --git a/src/kbmod/regionsearch/backend.py b/src/kbmod/regionsearch/backend.py index 362f77d..5ecf3e0 100644 --- a/src/kbmod/regionsearch/backend.py +++ b/src/kbmod/regionsearch/backend.py @@ -1,5 +1,5 @@ """ -lincc-frameworks provides implementations of ``abstractions.Backend`` that may be composed with an implementation of ``abstractions.ObserverationIndexer`` to provide a complete region search. +Provides implementations of ``abstractions.Backend`` that may be composed with an implementation of ``abstractions.ObserverationIndexer`` to provide a complete region search. """ from dataclasses import dataclass @@ -91,13 +91,14 @@ def region_search(self, filter: Filter) -> np.ndarray: A list of matching indices. """ matching_observation_identifier = np.array([], dtype=self.observation_identifier.dtype) - if hasattr(self, "observations_to_indices"): - pointing = coord.SkyCoord(filter.search_ra, filter.search_dec) - matching_index = self.observations_to_indices(pointing, None, filter.search_fov, None) # type: ignore - pointing = coord.SkyCoord(self.observation_ra, self.observation_dec) - self.observation_index = self.observations_to_indices( # type: ignore - pointing, self.observation_time, self.observation_fov, self.observation_location - ) - index_list = np.nonzero(self.observation_index == matching_index)[0] - matching_observation_identifier = self.observation_identifier[index_list] + if not hasattr(self, "observations_to_indices"): + raise NotImplementedError("region_search requires an implementation of observations_to_indices") + pointing = coord.SkyCoord(filter.search_ra, filter.search_dec) + matching_index = self.observations_to_indices(pointing, None, filter.search_fov, None) # type: ignore + pointing = coord.SkyCoord(self.observation_ra, self.observation_dec) + self.observation_index = self.observations_to_indices( # type: ignore + pointing, self.observation_time, self.observation_fov, self.observation_location + ) + index_list = np.nonzero(self.observation_index == matching_index)[0] + matching_observation_identifier = self.observation_identifier[index_list] return matching_observation_identifier diff --git a/src/kbmod/regionsearch/indexers.py b/src/kbmod/regionsearch/indexers.py index b21882c..a627ea9 100644 --- a/src/kbmod/regionsearch/indexers.py +++ b/src/kbmod/regionsearch/indexers.py @@ -1,5 +1,5 @@ """ -lincc-frameworks provides implementations of ``abstractions.ObserverationIndexer`` that may be composed with an implementation of ``abstractions.Backend`` to provide a complete region search. +Provides implementations of ``abstractions.ObserverationIndexer`` that may be composed with an implementation of ``abstractions.Backend`` to provide a complete region search. """ import math diff --git a/src/kbmod/regionsearch/utilities.py b/src/kbmod/regionsearch/utilities.py index 196abc5..e852010 100644 --- a/src/kbmod/regionsearch/utilities.py +++ b/src/kbmod/regionsearch/utilities.py @@ -180,6 +180,7 @@ def read_table(self, filename: str, format: str, colnames: typing.List[str]): The format of the file to read or write. colnames : typing.List[str] The list of column names to read from the file. The file may contain other columns but it must have these columns. + Returns ------- bool @@ -204,7 +205,7 @@ def read_table(self, filename: str, format: str, colnames: typing.List[str]): def write_table(self, filename: str, format: str, colnames: typing.List[str]): """ - Write the table to the file including al the columns in colnames. + Write the table to the file including all the columns in colnames. Parameters ---------- diff --git a/tests/regionsearch/test_backend.py b/tests/regionsearch/test_backend.py index fef234c..0ba803a 100644 --- a/tests/regionsearch/test_backend.py +++ b/tests/regionsearch/test_backend.py @@ -12,10 +12,27 @@ ) from astropy.time import Time # type: ignore -from kbmod.regionsearch import backend, indexers, utilities +from kbmod.regionsearch import abstractions, backend, indexers, utilities from kbmod.regionsearch.region_search import Filter +def test_backend_abstract(): + """Tests that the backend raises an exception when region_search is called without an indexer.""" + + class TestBackend(abstractions.Backend): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def region_search(self, filter: Filter) -> np.ndarray: + return super().region_search(filter) + + b = TestBackend() + assert b is not None + with pytest.raises(NotImplementedError): + b.region_search(Filter()) + assert False, "expected NotImplementedError when calling region_search without an indexer" + + def test_observationlist_init(): """Tests the ObservationList backend's constructor.""" data = utilities.RegionSearchClusterData(clustercnt=5, samplespercluster=10, removecache=True) @@ -49,6 +66,25 @@ def test_observationlist_consistency(): backend.ObservationList(ra, dec, time, location, fov, observation_identifier) +def test_observationlist_missing_observation_to_indices(): + """Tests that the backend raises an exception when region_search is called without an indexer.""" + data = utilities.RegionSearchClusterData(clustercnt=5, samplespercluster=10, removecache=True) + + ra = data.observation_pointing.ra + dec = data.observation_pointing.dec + time = data.observation_time + location = data.observation_geolocation + fov = np.ones([data.rowcnt]) * Angle(1, "deg") + observation_identifier = data.cluster_id + b = backend.ObservationList(ra, dec, time, location, fov, observation_identifier) + with pytest.raises(NotImplementedError): + b.region_search(Filter()) + assert ( + False + ), "Expect NotImplementedError when region_search is called without an observation_to_indices method (missing ObservationIdexer)." + assert True + + def test_observationlist_partition(): """Tests the ObservationList backend with partition indexer."""