Skip to content

Commit

Permalink
Add an option to include dimension records into general query result …
Browse files Browse the repository at this point in the history
…(DM-47980)

The `GeneralQueryResults.iter_tuples` method returned DataIds without
dimension records. In some cases (e.g. for obscore export) it would be
very useful to include records in the same result to avoid querying
them separately. New method `with_dimension_records` is added to the class
to trigger adding fields from all dimension records into returned page.
This will produce many duplicates for some dimensions (e.g. `instrument`)
but it keeps page structure simple.

This adds one attribute to the `GeneralResultSpec` class, will need some
care with Butler server compatibility.
  • Loading branch information
andy-slac committed Dec 16, 2024
1 parent 812ef82 commit 825f7fe
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,35 @@ class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01

def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None:
self.spec = spec

result_columns = spec.get_result_columns()
# In case `spec.include_dimension_records` is True then in addition to
# columns returned by the query we have to add columns from dimension
# records that are not returned by the query. These columns belong to
# either cached or skypix dimensions.
query_result_columns = set(spec.get_result_columns())
output_columns = spec.get_all_result_columns()
universe = spec.dimensions.universe
self.converters: list[_GeneralColumnConverter] = []
for column in result_columns:
for column in output_columns:
column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field)
if column.field == TimespanDatabaseRepresentation.NAME:
self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db))
converter: _GeneralColumnConverter
if column not in query_result_columns and column.field is not None:
# This must be a field from a cached dimension record or
# skypix record.
assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here"
element = universe[column.logical_table]
if isinstance(element, SkyPixDimension):
converter = _SkypixRecordGeneralColumnConverter(element, column.field)
else:
converter = _CachedRecordGeneralColumnConverter(
element, column.field, ctx.dimension_record_cache
)
elif column.field == TimespanDatabaseRepresentation.NAME:
converter = _TimespanGeneralColumnConverter(column_name, ctx.db)
elif column.field == "ingest_date":
self.converters.append(_TimestampGeneralColumnConverter(column_name))
converter = _TimestampGeneralColumnConverter(column_name)
else:
self.converters.append(_DefaultGeneralColumnConverter(column_name))
converter = _DefaultGeneralColumnConverter(column_name)
self.converters.append(converter)

def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage:
rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows]
Expand Down Expand Up @@ -422,3 +440,47 @@ def __init__(self, name: str, db: Database):
def convert(self, row: sqlalchemy.Row) -> Any:
timespan = self.timespan_class.extract(row._mapping, self.name)
return timespan


class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for cached
dimension records.
Parameters
----------
element : `DimensionElement`
Dimension element, must be of cached type.
field : `str`
Name of the field to extract from the dimension record.
cache : `DimensionRecordCache`
Cache for dimension records.
"""

def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None:
self._record_converter = _CachedDimensionRecordRowConverter(element, cache)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)


class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter):
"""Helper for converting result row into a field value for skypix
dimension records.
Parameters
----------
element : `SkyPixDimension`
Dimension element.
field : `str`
Name of the field to extract from the dimension record.
"""

def __init__(self, element: SkyPixDimension, field: str) -> None:
self._record_converter = _SkypixDimensionRecordRowConverter(element)
self._field = field

def convert(self, row: sqlalchemy.Row) -> Any:
record = self._record_converter.convert(row)
return getattr(record, self._field)
53 changes: 43 additions & 10 deletions python/lsst/daf/butler/queries/_general_query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

from .._dataset_ref import DatasetRef
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DimensionGroup
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord
from ._base import QueryResultsBase
from .driver import QueryDriver
from .result_specs import GeneralResultSpec
from .tree import QueryTree
from .tree import QueryTree, ResultColumn


class GeneralResultTuple(NamedTuple):
Expand Down Expand Up @@ -99,9 +99,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
fields (separated from dataset type name by dot).
"""
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
columns = tuple(str(column) for column in page.spec.get_all_result_columns())
for row in page.rows:
yield dict(zip(columns, row))
yield dict(zip(columns, row, strict=True))

def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
"""Iterate over result rows and return data coordinate, and dataset
Expand All @@ -125,14 +125,10 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl
run_key = f"{dataset_type.name}.run"
dataset_keys.append((dataset_type, dimensions, id_key, run_key))
for row in self:
values = tuple(
row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied)
)
data_coordinate = DataCoordinate.from_full_values(all_dimensions, values)
data_coordinate = self._make_data_id(row, all_dimensions)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_id = DataCoordinate.from_full_values(dimensions, values)
data_id = self._make_data_id(row, dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)

Expand All @@ -141,6 +137,19 @@ def dimensions(self) -> DimensionGroup:
# Docstring inherited
return self._spec.dimensions

@property
def has_dimension_records(self) -> bool:
"""Whether all data IDs in this iterable contain dimension records."""
return self._spec.include_dimension_records

def with_dimension_records(self) -> GeneralQueryResults:
"""Return a results object for which `has_dimension_records` is
`True`.
"""
if self.has_dimension_records:
return self
return self._copy(tree=self._tree, include_dimension_records=True)

def count(self, *, exact: bool = True, discard: bool = False) -> int:
# Docstring inherited.
return self._driver.count(self._tree, self._spec, exact=exact, discard=discard)
Expand All @@ -152,3 +161,27 @@ def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults:
def _get_datasets(self) -> frozenset[str]:
# Docstring inherited.
return frozenset(self._spec.dataset_fields)

def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_coordinate = DataCoordinate.from_full_values(dimensions, values)
if self.has_dimension_records:
records = {
name: self._make_dimension_record(row, dimensions.universe[name])
for name in dimensions.elements
}
data_coordinate = data_coordinate.expanded(records)
return data_coordinate

def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord:
column_map = list(
zip(
element.schema.dimensions.names,
element.dimensions.names,
)
)
for field in element.schema.remainder.names:
column_map.append((field, str(ResultColumn(element.name, field))))
d = {k: row[v] for k, v in column_map}
record_cls = element.RecordClass
return record_cls(**d)
2 changes: 1 addition & 1 deletion python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class GeneralResultPage:
spec: GeneralResultSpec

# Raw tabular data, with columns in the same order as
# spec.get_result_columns().
# spec.get_all_result_columns().
rows: list[tuple[Any, ...]]


Expand Down
32 changes: 32 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase):
dataset_fields: Mapping[str, set[DatasetFieldName]]
"""Dataset fields included in this query."""

include_dimension_records: bool = False
"""Whether to include fields for all dimension records, in addition to
explicitly specified in `dimension_fields`.
"""

find_first: bool
"""Whether this query requires find-first resolution for a dataset.
Expand Down Expand Up @@ -241,6 +246,33 @@ def get_result_columns(self) -> ColumnSet:
result.dimension_fields[element_name].update(fields_for_element)
for dataset_type, fields_for_dataset in self.dataset_fields.items():
result.dataset_fields[dataset_type].update(fields_for_dataset)
if self.include_dimension_records:
# This only adds record fields for non-cached and non-skypix
# elements, this is what we want when generating query. We could
# potentially add those too but it may make queries slower, so
# instead we query cached dimension records separately and add them
# to the result page in the page converter.
_add_dimension_records_to_column_set(self.dimensions, result)
return result

def get_all_result_columns(self) -> ColumnSet:
"""Return all columns that have to appear in the result. This includes
columns for all dimension records for all dimensions if
``include_dimension_records`` is `True`.
Returns
-------
columns : `ColumnSet`
Full column set.
"""
dimensions = self.dimensions
result = self.get_result_columns()
if self.include_dimension_records:
for element_name in dimensions.elements:
element = dimensions.universe[element_name]
# Non-cached dimensions are already there, but it does not harm
# to add them again.
result.dimension_fields[element_name].update(element.schema.remainder.names)
return result

@pydantic.model_validator(mode="after")
Expand Down
15 changes: 13 additions & 2 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,23 @@ def _convert_query_result_page(

def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
"""Convert GeneralResultModel to a general result page."""
columns = spec.get_result_columns()
columns = spec.get_all_result_columns()
# Verify that column list that we received from server matches local
# expectations (mismatch could result from different versions). Older
# server may not know about `model.columns` in that case it will be empty.
if model.columns:
expected_column_names = [str(column) for column in columns]
if expected_column_names != model.columns:
raise ValueError(
"Inconsistent columns in general result -- "
f"server columns: {model.columns}, expected: {expected_column_names}"
)

serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers))
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult

def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel:
"""Convert GeneralResultPage to a serializable model."""
columns = page.spec.get_result_columns()
columns = page.spec.get_all_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in page.rows
]
return GeneralResultModel(rows=rows)
return GeneralResultModel(rows=rows, columns=[str(column) for column in columns])
3 changes: 3 additions & 0 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,9 @@ class GeneralResultModel(pydantic.BaseModel):

type: Literal["general"] = "general"
rows: list[tuple[Any, ...]]
# List of column names, default is used for compatibility with older
# servers that do not set this field.
columns: list[str] = pydantic.Field(default_factory=list)


class QueryErrorResultModel(pydantic.BaseModel):
Expand Down
42 changes: 42 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def test_general_query(self) -> None:
self.assertEqual(len(row_tuple.refs), 1)
self.assertEqual(row_tuple.refs[0].datasetType, flat)
self.assertTrue(row_tuple.refs[0].dataId.hasFull())
self.assertFalse(row_tuple.refs[0].dataId.hasRecords())
self.assertTrue(row_tuple.data_id.hasFull())
self.assertFalse(row_tuple.data_id.hasRecords())
self.assertEqual(row_tuple.data_id.dimensions, dimensions)
self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g")

Expand Down Expand Up @@ -511,6 +513,46 @@ def test_general_query(self) -> None:
{Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None},
)

dimensions = butler.dimensions["detector"].minimal_group

# Include dimension records into query.
with butler.query() as query:
query = query.join_dimensions(dimensions)
result = query.general(dimensions).order_by("detector")
rows = list(result.with_dimension_records())
self.assertEqual(
rows[0],
{
"instrument": "Cam1",
"detector": 1,
"instrument.visit_max": 1024,
"instrument.visit_system": 1,
"instrument.exposure_max": 512,
"instrument.detector_max": 4,
"instrument.class_name": "lsst.pipe.base.Instrument",
"detector.full_name": "Aa",
"detector.name_in_raft": "a",
"detector.raft": "A",
"detector.purpose": "SCIENCE",
},
)

dimensions = butler.dimensions.conform(["detector", "physical_filter"])

# DataIds should come with records.
with butler.query() as query:
query = query.join_dataset_search("flat", "imported_g")
result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by(
"detector"
)
result = result.with_dimension_records()
row_tuples = list(result.iter_tuples(flat))
self.assertEqual(len(row_tuples), 3)
for row_tuple in row_tuples:
self.assertTrue(row_tuple.data_id.hasRecords())
self.assertEqual(len(row_tuple.refs), 1)
self.assertTrue(row_tuple.refs[0].dataId.hasRecords())

def test_query_ingest_date(self) -> None:
"""Test general query returning ingest_date field."""
before_ingest = astropy.time.Time.now()
Expand Down

0 comments on commit 825f7fe

Please sign in to comment.