From 825f7fe614de92a5de8b8b3876c6965812e9faca Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Fri, 13 Dec 2024 16:58:58 -0800 Subject: [PATCH] Add an option to include dimension records into general query result (DM-47980) The `GeneralQueryResults.iter_tuples` method returned DataIds without dimension records. In some cases (e.g. for obscore export) it would be very useful to include records in the same result to avoid querying them separately. New method `with_dimension_records` is added to the class to trigger adding fields from all dimension records into returned page. This will produce many duplicates for some dimensions (e.g. `instrument`) but it keeps page structure simple. This adds one attribute to the `GeneralResultSpec` class, will need some care with Butler server compatibility. --- .../_result_page_converter.py | 76 +++++++++++++++++-- .../butler/queries/_general_query_results.py | 53 ++++++++++--- python/lsst/daf/butler/queries/driver.py | 2 +- .../lsst/daf/butler/queries/result_specs.py | 32 ++++++++ .../daf/butler/remote_butler/_query_driver.py | 15 +++- .../server/handlers/_query_serialization.py | 7 +- .../daf/butler/remote_butler/server_models.py | 3 + .../lsst/daf/butler/tests/butler_queries.py | 42 ++++++++++ 8 files changed, 207 insertions(+), 23 deletions(-) diff --git a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py index 33c2988155..1d97d28df2 100644 --- a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py +++ b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py @@ -326,17 +326,35 @@ class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01 def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None: self.spec = spec - - result_columns = spec.get_result_columns() + # In case `spec.include_dimension_records` is True then in addition to + # columns returned by the query we have to add columns from dimension + # records that are not returned by the query. These columns belong to + # either cached or skypix dimensions. + query_result_columns = set(spec.get_result_columns()) + output_columns = spec.get_all_result_columns() + universe = spec.dimensions.universe self.converters: list[_GeneralColumnConverter] = [] - for column in result_columns: + for column in output_columns: column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field) - if column.field == TimespanDatabaseRepresentation.NAME: - self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db)) + converter: _GeneralColumnConverter + if column not in query_result_columns and column.field is not None: + # This must be a field from a cached dimension record or + # skypix record. + assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here" + element = universe[column.logical_table] + if isinstance(element, SkyPixDimension): + converter = _SkypixRecordGeneralColumnConverter(element, column.field) + else: + converter = _CachedRecordGeneralColumnConverter( + element, column.field, ctx.dimension_record_cache + ) + elif column.field == TimespanDatabaseRepresentation.NAME: + converter = _TimespanGeneralColumnConverter(column_name, ctx.db) elif column.field == "ingest_date": - self.converters.append(_TimestampGeneralColumnConverter(column_name)) + converter = _TimestampGeneralColumnConverter(column_name) else: - self.converters.append(_DefaultGeneralColumnConverter(column_name)) + converter = _DefaultGeneralColumnConverter(column_name) + self.converters.append(converter) def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage: rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows] @@ -422,3 +440,47 @@ def __init__(self, name: str, db: Database): def convert(self, row: sqlalchemy.Row) -> Any: timespan = self.timespan_class.extract(row._mapping, self.name) return timespan + + +class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter): + """Helper for converting result row into a field value for cached + dimension records. + + Parameters + ---------- + element : `DimensionElement` + Dimension element, must be of cached type. + field : `str` + Name of the field to extract from the dimension record. + cache : `DimensionRecordCache` + Cache for dimension records. + """ + + def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None: + self._record_converter = _CachedDimensionRecordRowConverter(element, cache) + self._field = field + + def convert(self, row: sqlalchemy.Row) -> Any: + record = self._record_converter.convert(row) + return getattr(record, self._field) + + +class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter): + """Helper for converting result row into a field value for skypix + dimension records. + + Parameters + ---------- + element : `SkyPixDimension` + Dimension element. + field : `str` + Name of the field to extract from the dimension record. + """ + + def __init__(self, element: SkyPixDimension, field: str) -> None: + self._record_converter = _SkypixDimensionRecordRowConverter(element) + self._field = field + + def convert(self, row: sqlalchemy.Row) -> Any: + record = self._record_converter.convert(row) + return getattr(record, self._field) diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py index 3a98264bff..fa8c06bbea 100644 --- a/python/lsst/daf/butler/queries/_general_query_results.py +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -35,11 +35,11 @@ from .._dataset_ref import DatasetRef from .._dataset_type import DatasetType -from ..dimensions import DataCoordinate, DimensionGroup +from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord from ._base import QueryResultsBase from .driver import QueryDriver from .result_specs import GeneralResultSpec -from .tree import QueryTree +from .tree import QueryTree, ResultColumn class GeneralResultTuple(NamedTuple): @@ -99,9 +99,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]: fields (separated from dataset type name by dot). """ for page in self._driver.execute(self._spec, self._tree): - columns = tuple(str(column) for column in page.spec.get_result_columns()) + columns = tuple(str(column) for column in page.spec.get_all_result_columns()) for row in page.rows: - yield dict(zip(columns, row)) + yield dict(zip(columns, row, strict=True)) def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]: """Iterate over result rows and return data coordinate, and dataset @@ -125,14 +125,10 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl run_key = f"{dataset_type.name}.run" dataset_keys.append((dataset_type, dimensions, id_key, run_key)) for row in self: - values = tuple( - row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied) - ) - data_coordinate = DataCoordinate.from_full_values(all_dimensions, values) + data_coordinate = self._make_data_id(row, all_dimensions) refs = [] for dataset_type, dimensions, id_key, run_key in dataset_keys: - values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied)) - data_id = DataCoordinate.from_full_values(dimensions, values) + data_id = self._make_data_id(row, dimensions) refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key])) yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row) @@ -141,6 +137,19 @@ def dimensions(self) -> DimensionGroup: # Docstring inherited return self._spec.dimensions + @property + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + return self._spec.include_dimension_records + + def with_dimension_records(self) -> GeneralQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + def count(self, *, exact: bool = True, discard: bool = False) -> int: # Docstring inherited. return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) @@ -152,3 +161,27 @@ def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults: def _get_datasets(self) -> frozenset[str]: # Docstring inherited. return frozenset(self._spec.dataset_fields) + + def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate: + values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied)) + data_coordinate = DataCoordinate.from_full_values(dimensions, values) + if self.has_dimension_records: + records = { + name: self._make_dimension_record(row, dimensions.universe[name]) + for name in dimensions.elements + } + data_coordinate = data_coordinate.expanded(records) + return data_coordinate + + def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord: + column_map = list( + zip( + element.schema.dimensions.names, + element.dimensions.names, + ) + ) + for field in element.schema.remainder.names: + column_map.append((field, str(ResultColumn(element.name, field)))) + d = {k: row[v] for k, v in column_map} + record_cls = element.RecordClass + return record_cls(**d) diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py index f7a6af352a..d1e23ef813 100644 --- a/python/lsst/daf/butler/queries/driver.py +++ b/python/lsst/daf/butler/queries/driver.py @@ -117,7 +117,7 @@ class GeneralResultPage: spec: GeneralResultSpec # Raw tabular data, with columns in the same order as - # spec.get_result_columns(). + # spec.get_all_result_columns(). rows: list[tuple[Any, ...]] diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py index 5bdc459f7c..6e3b1360e2 100644 --- a/python/lsst/daf/butler/queries/result_specs.py +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase): dataset_fields: Mapping[str, set[DatasetFieldName]] """Dataset fields included in this query.""" + include_dimension_records: bool = False + """Whether to include fields for all dimension records, in addition to + explicitly specified in `dimension_fields`. + """ + find_first: bool """Whether this query requires find-first resolution for a dataset. @@ -241,6 +246,33 @@ def get_result_columns(self) -> ColumnSet: result.dimension_fields[element_name].update(fields_for_element) for dataset_type, fields_for_dataset in self.dataset_fields.items(): result.dataset_fields[dataset_type].update(fields_for_dataset) + if self.include_dimension_records: + # This only adds record fields for non-cached and non-skypix + # elements, this is what we want when generating query. We could + # potentially add those too but it may make queries slower, so + # instead we query cached dimension records separately and add them + # to the result page in the page converter. + _add_dimension_records_to_column_set(self.dimensions, result) + return result + + def get_all_result_columns(self) -> ColumnSet: + """Return all columns that have to appear in the result. This includes + columns for all dimension records for all dimensions if + ``include_dimension_records`` is `True`. + + Returns + ------- + columns : `ColumnSet` + Full column set. + """ + dimensions = self.dimensions + result = self.get_result_columns() + if self.include_dimension_records: + for element_name in dimensions.elements: + element = dimensions.universe[element_name] + # Non-cached dimensions are already there, but it does not harm + # to add them again. + result.dimension_fields[element_name].update(element.schema.remainder.names) return result @pydantic.model_validator(mode="after") diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 6f3035bc9e..6c883a4a67 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -257,12 +257,23 @@ def _convert_query_result_page( def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: """Convert GeneralResultModel to a general result page.""" - columns = spec.get_result_columns() + columns = spec.get_all_result_columns() + # Verify that column list that we received from server matches local + # expectations (mismatch could result from different versions). Older + # server may not know about `model.columns` in that case it will be empty. + if model.columns: + expected_column_names = [str(column) for column in columns] + if expected_column_names != model.columns: + raise ValueError( + "Inconsistent columns in general result -- " + f"server columns: {model.columns}, expected: {expected_column_names}" + ) + serializers = [ columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] rows = [ - tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers)) + tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True)) for row in model.rows ] return GeneralResultPage(spec=spec, rows=rows) diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index e934948374..7357699f19 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -79,11 +79,12 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: """Convert GeneralResultPage to a serializable model.""" - columns = page.spec.get_result_columns() + columns = page.spec.get_all_result_columns() serializers = [ columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] rows = [ - tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows + tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True)) + for row in page.rows ] - return GeneralResultModel(rows=rows) + return GeneralResultModel(rows=rows, columns=[str(column) for column in columns]) diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index 37309b6431..6a92d4a4ad 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -313,6 +313,9 @@ class GeneralResultModel(pydantic.BaseModel): type: Literal["general"] = "general" rows: list[tuple[Any, ...]] + # List of column names, default is used for compatibility with older + # servers that do not set this field. + columns: list[str] = pydantic.Field(default_factory=list) class QueryErrorResultModel(pydantic.BaseModel): diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 139ce47d4b..4c1367fcf7 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -436,7 +436,9 @@ def test_general_query(self) -> None: self.assertEqual(len(row_tuple.refs), 1) self.assertEqual(row_tuple.refs[0].datasetType, flat) self.assertTrue(row_tuple.refs[0].dataId.hasFull()) + self.assertFalse(row_tuple.refs[0].dataId.hasRecords()) self.assertTrue(row_tuple.data_id.hasFull()) + self.assertFalse(row_tuple.data_id.hasRecords()) self.assertEqual(row_tuple.data_id.dimensions, dimensions) self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g") @@ -511,6 +513,46 @@ def test_general_query(self) -> None: {Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None}, ) + dimensions = butler.dimensions["detector"].minimal_group + + # Include dimension records into query. + with butler.query() as query: + query = query.join_dimensions(dimensions) + result = query.general(dimensions).order_by("detector") + rows = list(result.with_dimension_records()) + self.assertEqual( + rows[0], + { + "instrument": "Cam1", + "detector": 1, + "instrument.visit_max": 1024, + "instrument.visit_system": 1, + "instrument.exposure_max": 512, + "instrument.detector_max": 4, + "instrument.class_name": "lsst.pipe.base.Instrument", + "detector.full_name": "Aa", + "detector.name_in_raft": "a", + "detector.raft": "A", + "detector.purpose": "SCIENCE", + }, + ) + + dimensions = butler.dimensions.conform(["detector", "physical_filter"]) + + # DataIds should come with records. + with butler.query() as query: + query = query.join_dataset_search("flat", "imported_g") + result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by( + "detector" + ) + result = result.with_dimension_records() + row_tuples = list(result.iter_tuples(flat)) + self.assertEqual(len(row_tuples), 3) + for row_tuple in row_tuples: + self.assertTrue(row_tuple.data_id.hasRecords()) + self.assertEqual(len(row_tuple.refs), 1) + self.assertTrue(row_tuple.refs[0].dataId.hasRecords()) + def test_query_ingest_date(self) -> None: """Test general query returning ingest_date field.""" before_ingest = astropy.time.Time.now()