diff --git a/doc/changes/DM-47947.bugfix.md b/doc/changes/DM-47947.bugfix.md new file mode 100644 index 0000000000..f6b436dd60 --- /dev/null +++ b/doc/changes/DM-47947.bugfix.md @@ -0,0 +1 @@ +Fixed a bug in which projections spatial-join queries (particularly those where the dimensions of the actual regions being compared are not in the query result rows) could return additional records where there actually was no overlap. diff --git a/python/lsst/daf/butler/ddl.py b/python/lsst/daf/butler/ddl.py index bf769299c6..007805add9 100644 --- a/python/lsst/daf/butler/ddl.py +++ b/python/lsst/daf/butler/ddl.py @@ -49,7 +49,6 @@ "GUID", ) -import functools import logging import uuid from base64 import b64decode, b64encode @@ -60,7 +59,7 @@ import astropy.time import sqlalchemy -from lsst.sphgeom import Region, UnionRegion +from lsst.sphgeom import Region from lsst.utils.iteration import ensure_iterable from sqlalchemy.dialects.postgresql import UUID @@ -182,14 +181,7 @@ def process_bind_param(self, value: Region | None, dialect: sqlalchemy.engine.Di def process_result_value(self, value: str | None, dialect: sqlalchemy.engine.Dialect) -> Region | None: if value is None: return None - return functools.reduce( - UnionRegion, - [ - # For some reason super() doesn't work here! - Region.decode(Base64Bytes.process_result_value(self, union_member, dialect)) - for union_member in value.split(":") - ], - ) + return Region.decodeBase64(value) @property def python_type(self) -> type[Region]: diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index 1088cda7ec..1121327360 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -45,7 +45,7 @@ from .._collection_type import CollectionType from .._dataset_type import DatasetType from .._exceptions import InvalidQueryError -from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse +from ..dimensions import DataCoordinate, DataIdValue, DimensionElement, DimensionGroup, DimensionUniverse from ..dimensions.record_cache import DimensionRecordCache from ..queries import tree as qt from ..queries.driver import ( @@ -388,6 +388,7 @@ def count( select_builder = builder.finish_nested() # Replace the columns of the query with just COUNT(*). select_builder.columns = qt.ColumnSet(self._universe.empty) + select_builder.joins.special.clear() count_func: sqlalchemy.ColumnElement[int] = sqlalchemy.func.count() select_builder.joins.special["_ROWCOUNT"] = count_func # Render and run the query. @@ -655,6 +656,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis: # it here. postprocessing.spatial_join_filtering.extend(m_state.postprocessing.spatial_join_filtering) postprocessing.spatial_where_filtering.extend(m_state.postprocessing.spatial_where_filtering) + postprocessing.spatial_expression_filtering.extend( + m_state.postprocessing.spatial_expression_filtering + ) # Add data coordinate uploads. joins.data_coordinate_uploads.update(tree.data_coordinate_uploads) # Add dataset_searches and filter out collections that don't have the @@ -715,7 +719,7 @@ def _resolve_union_datasets( searches : `list` [ `ResolvedDatasetSearch` ] Resolved dataset searches for all union dataset types with these dimensions. Each item in the list groups dataset types with the - same colletion search path. + same collection search path. """ # Gather the filtered collection search path for each union dataset # type. @@ -849,6 +853,48 @@ def apply_missing_dimension_joins( joins_analysis.predicate.visit(SqlColumnVisitor(select_builder.joins, self)) ) + def project_spatial_join_filtering( + self, + columns: qt.ColumnSet, + postprocessing: Postprocessing, + select_builders: Iterable[SqlSelectBuilder], + ) -> None: + """Transform spatial join postprocessing into expressions that can be + OR'd together via an aggregate function in a GROUP BY. + + This only affects spatial join constraints involving region columns + whose dimensions are being projected away. + + Parameters + ---------- + columns : `.queries.tree.ColumnSet` + Columns that will be included in the final query. + postprocessing : `Postprocessing` + Object that describes post-query processing; modified in place. + select_builders : `~collections.abc.Iterable` [ `SqlSelectBuilder` ] + SQL Builder objects to be modified in place. + """ + kept: list[tuple[DimensionElement, DimensionElement]] = [] + for a, b in postprocessing.spatial_join_filtering: + if a.name not in columns.dimensions.elements or b.name not in columns.dimensions.elements: + expr_name = f"_{a}_OVERLAPS_{b}" + postprocessing.spatial_expression_filtering.append(expr_name) + for select_builder in select_builders: + expr = sqlalchemy.cast( + sqlalchemy.cast( + select_builder.joins.fields[a.name]["region"], type_=sqlalchemy.String + ) + + sqlalchemy.literal("&", type_=sqlalchemy.String) + + sqlalchemy.cast( + select_builder.joins.fields[b.name]["region"], type_=sqlalchemy.String + ), + type_=sqlalchemy.LargeBinary, + ) + select_builder.joins.special[expr_name] = expr + else: + kept.append((a, b)) + postprocessing.spatial_join_filtering = kept + def apply_query_projection( self, select_builder: SqlSelectBuilder, @@ -938,8 +984,8 @@ def apply_query_projection( # the data IDs for those regions are not wholly included in the # results (i.e. we need to postprocess on # visit_detector_region.region, but the output rows don't have - # detector, just visit - so we compute the union of the - # visit_detector region over all matched detectors). + # detector, just visit - so we pack the overlap expression into a + # blob via an aggregate function and interpret it later). if postprocessing.check_validity_match_count: if needs_validity_match_count: select_builder.joins.special[postprocessing.VALIDITY_MATCH_COUNT] = ( @@ -960,11 +1006,27 @@ def apply_query_projection( # might be collapsing the dimensions of the postprocessing # regions. When that happens, we want to apply an aggregate # function to them that computes the union of the regions that - # are grouped together. + # are grouped together. Note that this should only happen for + # constraints that involve a "given", external-to-the-database + # region (postprocessing.spatial_where_filtering); join + # constraints that need aggregates should have already been + # transformed in advance. select_builder.joins.fields[element.name]["region"] = ddl.Base64Region.union_aggregate( select_builder.joins.fields[element.name]["region"] ) have_aggregates = True + # Postprocessing spatial join constraints where at least one region's + # dimensions are being projected away will have already been turned + # into the kind of expression that sphgeom.Region.decodeOverlapsBase64 + # processes. We can just apply an aggregate function to these. Note + # that we don't do this to other constraints in order to minimize + # duplicate fetches of the same region blob. + for expr_name in postprocessing.spatial_expression_filtering: + select_builder.joins.special[expr_name] = sqlalchemy.cast( + sqlalchemy.func.aggregate_strings(select_builder.joins.special[expr_name], "|"), + type_=sqlalchemy.LargeBinary, + ) + have_aggregates = True # All dimension record fields are derived fields. for element_name, fields_for_element in projection_columns.dimension_fields.items(): diff --git a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py index 589c5e7af6..db577ec6a9 100644 --- a/python/lsst/daf/butler/direct_query_driver/_postprocessing.py +++ b/python/lsst/daf/butler/direct_query_driver/_postprocessing.py @@ -33,7 +33,7 @@ from typing import TYPE_CHECKING, ClassVar import sqlalchemy -from lsst.sphgeom import DISJOINT, Region +from lsst.sphgeom import Region from .._exceptions import CalibrationLookupError from ..queries import tree as qt @@ -59,6 +59,7 @@ class Postprocessing: def __init__(self) -> None: self.spatial_join_filtering = [] self.spatial_where_filtering = [] + self.spatial_expression_filtering = [] self.check_validity_match_count: bool = False self._limit: int | None = None @@ -79,6 +80,12 @@ def __init__(self) -> None: non-overlap pair will be filtered out. """ + spatial_expression_filtering: list[str] + """The names of calculated columns that can be parsed by + `lsst.sphgeom.Region.decodeOverlapsBase64` into a `bool` or `None` that + indicates whether regions definitely overlap. + """ + check_validity_match_count: bool """If `True`, result rows will include a special column that counts the number of matching datasets in each collection for each data ID, and @@ -104,7 +111,9 @@ def limit(self, value: int | None) -> None: self._limit = value def __bool__(self) -> bool: - return bool(self.spatial_join_filtering or self.spatial_where_filtering) + return bool( + self.spatial_join_filtering or self.spatial_where_filtering or self.spatial_expression_filtering + ) def gather_columns_required(self, columns: qt.ColumnSet) -> None: """Add all columns required to perform postprocessing to the given @@ -197,8 +206,11 @@ def apply(self, rows: Iterable[sqlalchemy.Row]) -> Iterable[sqlalchemy.Row]: for row in rows: m = row._mapping - if any(m[a].relate(m[b]) & DISJOINT for a, b in joins) or any( - m[field].relate(region) & DISJOINT for field, region in where + # Skip rows where at least one couple of regions do not overlap. + if ( + any(Region.decodeOverlapsBase64(m[c]) is False for c in self.spatial_expression_filtering) + or any(m[a].overlaps(m[b]) is False for a, b in joins) + or any(m[field].overlaps(region) is False for field, region in where) ): continue if self.check_validity_match_count and m[self.VALIDITY_MATCH_COUNT] > 1: diff --git a/python/lsst/daf/butler/direct_query_driver/_query_builder.py b/python/lsst/daf/butler/direct_query_driver/_query_builder.py index 3a2e8346f9..795c71916b 100644 --- a/python/lsst/daf/butler/direct_query_driver/_query_builder.py +++ b/python/lsst/daf/butler/direct_query_driver/_query_builder.py @@ -35,6 +35,7 @@ ) import dataclasses +import itertools from abc import ABC, abstractmethod from collections.abc import Iterable, Set from typing import TYPE_CHECKING, Literal, TypeVar, overload @@ -382,6 +383,9 @@ def apply_joins(self, driver: DirectQueryDriver) -> None: def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: # Docstring inherited. + driver.project_spatial_join_filtering( + self.projection_columns, self.postprocessing, [self._select_builder] + ) driver.apply_query_projection( self._select_builder, self.postprocessing, @@ -635,6 +639,11 @@ def apply_joins(self, driver: DirectQueryDriver) -> None: def apply_projection(self, driver: DirectQueryDriver, order_by: Iterable[qt.OrderExpression]) -> None: # Docstring inherited. + driver.project_spatial_join_filtering( + self.projection_columns, + self.postprocessing, + itertools.chain.from_iterable(union_term.select_builders for union_term in self.union_terms), + ) for union_term in self.union_terms: for builder in union_term.select_builders: driver.apply_query_projection( diff --git a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py index 92b0451a3c..1a537d5d28 100644 --- a/python/lsst/daf/butler/direct_query_driver/_sql_builders.py +++ b/python/lsst/daf/butler/direct_query_driver/_sql_builders.py @@ -199,7 +199,7 @@ def join(self, other: SqlJoinsBuilder) -> SqlSelectBuilder: self.joins.join(other) return self - def into_from_builder( + def into_joins_builder( self, cte: bool = False, force: bool = False, *, postprocessing: Postprocessing | None ) -> SqlJoinsBuilder: """Convert this builder into a `SqlJoinsBuilder`, nesting it in a @@ -265,7 +265,7 @@ def nested( object. """ return SqlSelectBuilder( - self.into_from_builder(cte=cte, force=force, postprocessing=postprocessing), columns=self.columns + self.into_joins_builder(cte=cte, force=force, postprocessing=postprocessing), columns=self.columns ) def union_subquery( @@ -422,6 +422,8 @@ def extract_columns( self.fields[element.name]["region"] = column_collection[ self.db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ] + for name in postprocessing.spatial_expression_filtering: + self.special[name] = column_collection[name] if postprocessing.check_validity_match_count: self.special[postprocessing.VALIDITY_MATCH_COUNT] = column_collection[ postprocessing.VALIDITY_MATCH_COUNT @@ -670,6 +672,8 @@ def make_table_spec( db.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) ) ) + for name in postprocessing.spatial_expression_filtering: + results.fields.add(ddl.FieldSpec(name, dtype=sqlalchemy.types.LargeBinary, nullable=True)) if not results.fields: results.fields.add( ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index 714f315aad..ca75ef5a5e 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -1453,9 +1453,9 @@ def make_joins_builder( # Need a UNION subquery. return tags_builder.union_subquery([calibs_builder]) else: - return tags_builder.into_from_builder(postprocessing=None) + return tags_builder.into_joins_builder(postprocessing=None) elif calibs_builder is not None: - return calibs_builder.into_from_builder(postprocessing=None) + return calibs_builder.into_joins_builder(postprocessing=None) else: raise AssertionError("Branch should be unreachable.") diff --git a/python/lsst/daf/butler/registry/dimensions/static.py b/python/lsst/daf/butler/registry/dimensions/static.py index 637b637dd5..e18c1cb94b 100644 --- a/python/lsst/daf/butler/registry/dimensions/static.py +++ b/python/lsst/daf/butler/registry/dimensions/static.py @@ -464,7 +464,7 @@ def make_joins_builder(self, element: DimensionElement, fields: Set[str]) -> Sql self.make_joins_builder(element.implied_union_target, fields), columns=qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True, - ).into_from_builder(postprocessing=None) + ).into_joins_builder(postprocessing=None) if not element.has_own_table: raise NotImplementedError(f"Cannot join dimension element {element} with no table.") table = self._tables[element.name] @@ -1082,7 +1082,7 @@ def visit_spatial_constraint( self.builder.join( joins_builder.to_select_builder( qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True - ).into_from_builder(postprocessing=None) + ).into_joins_builder(postprocessing=None) ) # Short circuit here since the SQL WHERE clause has already # been embedded in the subquery. @@ -1147,7 +1147,7 @@ def visit_spatial_join( qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(), distinct=True, ) - .into_from_builder(postprocessing=None) + .into_joins_builder(postprocessing=None) ) # In both cases we add postprocessing to check that the regions # really do overlap, since overlapping the same common skypix diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 5905df81ba..139ce47d4b 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -759,6 +759,18 @@ def test_spatial_overlaps(self) -> None: [1, 2, 3], has_postprocessing=True, ) + # Same as above, but with a materialization. + self.check_detector_records( + query.where( + _x.visit_detector_region.region.overlaps(_x.patch.region), + tract=0, + patch=4, + ) + .materialize(dimensions=["detector"]) + .dimension_records("detector"), + [1, 2, 3], + has_postprocessing=True, + ) # Query for that patch's region and express the previous query as # a region-constraint instead of a spatial join. (patch_record,) = query.where(tract=0, patch=4).dimension_records("patch") @@ -777,6 +789,21 @@ def test_spatial_overlaps(self) -> None: ), ids=[1, 2, 3], ) + # Query for detectors where a patch/visit+detector overlap is + # satisfied, in the case where there are no rows with an overlap, + # but the union of the patch regions overlaps the union of the + # visit+detector regions. + self.check_detector_records( + query.where( + _x.visit_detector_region.region.overlaps(_x.patch.region), + _x.any( + _x.all(_x.tract == 1, _x.visit == 1), + _x.all(_x.tract == 0, _x.patch == 0, _x.visit == 2), + ), + ).dimension_records("detector"), + [], + has_postprocessing=True, + ) # Combine postprocessing with order_by and limit. self.check_detector_records( query.where(