diff --git a/nucliadb/src/nucliadb/search/search/find_merge.py b/nucliadb/src/nucliadb/search/search/find_merge.py index d3cc3f8215..44daa28def 100644 --- a/nucliadb/src/nucliadb/search/search/find_merge.py +++ b/nucliadb/src/nucliadb/search/search/find_merge.py @@ -19,16 +19,16 @@ # import asyncio from dataclasses import dataclass -from typing import Optional, cast +from typing import cast from nucliadb.common.external_index_providers.base import TextBlockMatch from nucliadb.common.ids import ParagraphId, VectorId -from nucliadb.common.maindb.driver import Transaction -from nucliadb.common.maindb.utils import get_driver -from nucliadb.ingest.serialize import managed_serialize from nucliadb.search import SERVICE_NAME, logger from nucliadb.search.search.merge import merge_relations_results from nucliadb.search.search.results_hydrator.base import ( + ResourceHydrationOptions, + TextBlockHydrationOptions, + hydrate_resource_metadata, text_block_to_find_paragraph, ) from nucliadb_models.common import FieldTypeName @@ -50,8 +50,6 @@ SearchResponse, ) from nucliadb_telemetry import metrics -from nucliadb_utils import const -from nucliadb_utils.utilities import has_feature from . import paragraphs from .metrics import merge_observer @@ -80,64 +78,25 @@ async def set_text_value( paragraph_id: ParagraphId, find_paragraph: FindParagraph, max_operations: asyncio.Semaphore, - highlight: bool = False, - ematches: Optional[list[str]] = None, + hydration_options: TextBlockHydrationOptions, ): async with max_operations: find_paragraph.text = await paragraphs.get_paragraph_text( kbid=kbid, paragraph_id=paragraph_id, - highlight=highlight, - ematches=ematches, + highlight=hydration_options.highlight, + ematches=hydration_options.ematches, matches=[], # TODO ) -@merge_observer.wrap({"type": "set_resource_metadada_value"}) -async def set_resource_metadata_value( - txn: Transaction, - kbid: str, - resource: str, - show: list[ResourceProperties], - field_type_filter: list[FieldTypeName], - extracted: list[ExtractedDataTypeName], - find_resources: dict[str, FindResource], - max_operations: asyncio.Semaphore, -): - if ResourceProperties.EXTRACTED in show and has_feature( - const.Features.IGNORE_EXTRACTED_IN_SEARCH, context={"kbid": kbid}, default=False - ): - # Returning extracted metadata in search results is deprecated and this flag - # will be set to True for all KBs in the future. - show.remove(ResourceProperties.EXTRACTED) - extracted = [] - - async with max_operations: - serialized_resource = await managed_serialize( - txn, - kbid, - resource, - show, - field_type_filter=field_type_filter, - extracted=extracted, - service_name=SERVICE_NAME, - ) - if serialized_resource is not None: - find_resources[resource].updated_from(serialized_resource) - else: - logger.warning(f"Resource {resource} not found in {kbid}") - find_resources.pop(resource, None) - - @merge_observer.wrap({"type": "fetch_find_metadata"}) async def fetch_find_metadata( result_paragraphs: list[TextBlockMatch], kbid: str, - show: list[ResourceProperties], - field_type_filter: list[FieldTypeName], - extracted: list[ExtractedDataTypeName], - highlight: bool = False, - ematches: Optional[list[str]] = None, + *, + resource_hydration_options: ResourceHydrationOptions, + text_block_hydration_options: TextBlockHydrationOptions, ) -> tuple[dict[str, FindResource], list[str]]: find_resources: dict[str, FindResource] = {} best_matches: list[str] = [] @@ -184,8 +143,7 @@ async def fetch_find_metadata( kbid=kbid, paragraph_id=text_block.paragraph_id, find_paragraph=find_field.paragraphs[paragraph_id], - highlight=highlight, - ematches=ematches, + hydration_options=text_block_hydration_options, max_operations=max_operations, ) ) @@ -198,29 +156,26 @@ async def fetch_find_metadata( find_resources[paragraph.rid].fields[paragraph.fid].paragraphs[paragraph.pid].order = order best_matches.append(paragraph.pid) - async with get_driver().transaction(read_only=True) as txn: - for resource in resources: - operations.append( - asyncio.create_task( - set_resource_metadata_value( - txn, - kbid=kbid, - resource=resource, - show=show, - field_type_filter=field_type_filter, - extracted=extracted, - find_resources=find_resources, - max_operations=max_operations, - ) + for resource in resources: + operations.append( + asyncio.create_task( + hydrate_resource_metadata( + kbid, + resource_id=resource, + options=resource_hydration_options, + find_resources=find_resources, + concurrency_control=max_operations, + service_name=SERVICE_NAME, ) ) + ) - FIND_FETCH_OPS_DISTRIBUTION.observe(len(operations)) - if len(operations) > 0: - done, _ = await asyncio.wait(operations) - for task in done: - if task.exception() is not None: # pragma: no cover - logger.error("Error fetching find metadata", exc_info=task.exception()) + FIND_FETCH_OPS_DISTRIBUTION.observe(len(operations)) + if len(operations) > 0: + done, _ = await asyncio.wait(operations) + for task in done: + if task.exception() is not None: # pragma: no cover + logger.error("Error fetching find metadata", exc_info=task.exception()) return find_resources, best_matches @@ -385,11 +340,15 @@ async def find_merge_results( resources, best_matches = await fetch_find_metadata( result_paragraphs, kbid, - show, - field_type_filter, - extracted, - highlight, - ematches, + resource_hydration_options=ResourceHydrationOptions( + show=show, + extracted=extracted, + field_type_filter=field_type_filter, + ), + text_block_hydration_options=TextBlockHydrationOptions( + highlight=highlight, + ematches=ematches, + ), ) api_results.resources = resources api_results.best_matches = best_matches diff --git a/nucliadb/src/nucliadb/search/search/results_hydrator/base.py b/nucliadb/src/nucliadb/search/search/results_hydrator/base.py index 477f069ccb..41fe465832 100644 --- a/nucliadb/src/nucliadb/search/search/results_hydrator/base.py +++ b/nucliadb/src/nucliadb/search/search/results_hydrator/base.py @@ -19,13 +19,13 @@ # import asyncio import logging +from contextlib import AsyncExitStack from typing import Optional from pydantic import BaseModel from nucliadb.common.external_index_providers.base import QueryResults as ExternalIndexQueryResults from nucliadb.common.external_index_providers.base import TextBlockMatch -from nucliadb.common.maindb.driver import Transaction from nucliadb.common.maindb.utils import get_driver from nucliadb.ingest.serialize import managed_serialize from nucliadb.search.search import paragraphs @@ -39,6 +39,8 @@ ResourceProperties, ) from nucliadb_telemetry.metrics import Observer +from nucliadb_utils import const +from nucliadb_utils.utilities import has_feature logger = logging.getLogger(__name__) @@ -60,7 +62,11 @@ class TextBlockHydrationOptions(BaseModel): Options for hydrating text blocks (aka paragraphs). """ - pass + # whether to highlight the text block with `...` tags or not + highlight: bool = False + + # list of exact matches to highlight + ematches: Optional[list[str]] = None @hydrator_observer.wrap({"type": "hydrate_external"}) @@ -104,7 +110,7 @@ async def hydrate_external( async def _hydrate_text_block(**kwargs): async with semaphore: - await hydrate_text_block(**kwargs) + await hydrate_text_block_and_update_find_paragraph(**kwargs) hydrate_ops.append( asyncio.create_task( @@ -117,31 +123,26 @@ async def _hydrate_text_block(**kwargs): ) ) - async def _hydrate_resource_metadata(**kwargs): - async with semaphore: - await hydrate_resource_metadata(**kwargs) - if len(resource_ids) > 0: - async with get_driver().transaction(read_only=True) as ro_txn: - for resource_id in resource_ids: - hydrate_ops.append( - asyncio.create_task( - _hydrate_resource_metadata( - txn=ro_txn, - kbid=kbid, - resource_id=resource_id, - options=resource_options, - find_resources=retrieval_results.resources, - ) + for resource_id in resource_ids: + hydrate_ops.append( + asyncio.create_task( + hydrate_resource_metadata( + kbid=kbid, + resource_id=resource_id, + options=resource_options, + find_resources=retrieval_results.resources, + concurrency_control=semaphore, ) ) + ) if len(hydrate_ops) > 0: await asyncio.gather(*hydrate_ops) @hydrator_observer.wrap({"type": "text_block"}) -async def hydrate_text_block( +async def hydrate_text_block_and_update_find_paragraph( kbid: str, text_block: TextBlockMatch, options: TextBlockHydrationOptions, @@ -167,36 +168,77 @@ async def hydrate_text_block( ) +@hydrator_observer.wrap({"type": "text_block"}) +async def hydrate_text_block( + kbid: str, + text_block: TextBlockMatch, + options: TextBlockHydrationOptions, + *, + concurrency_control: Optional[asyncio.Semaphore] = None, +) -> TextBlockMatch: + """Given a `text_block`, fetch its corresponding text, modify and return the + `text_block` object. + + """ + async with AsyncExitStack() as stack: + if concurrency_control is not None: + await stack.enter_async_context(concurrency_control) + + text_block.text = await paragraphs.get_paragraph_text( + kbid=kbid, + paragraph_id=text_block.paragraph_id, + highlight=options.highlight, + matches=[], # TODO: this was never implemented + ematches=options.ematches, + ) + return text_block + + @hydrator_observer.wrap({"type": "resource_metadata"}) async def hydrate_resource_metadata( - txn: Transaction, kbid: str, resource_id: str, options: ResourceHydrationOptions, find_resources: dict[str, FindResource], + *, + concurrency_control: Optional[asyncio.Semaphore] = None, + service_name: Optional[str] = None, ) -> None: """ Fetch the various metadata fields of the resource and update the FindResource object. """ - serialized_resource = await managed_serialize( - txn=txn, - kbid=kbid, - rid=resource_id, - show=options.show, - field_type_filter=options.field_type_filter, - extracted=options.extracted, - ) - if serialized_resource is None: - logger.warning( - "Resource not found in database", - extra={ - "kbid": kbid, - "rid": resource_id, - }, - ) - find_resources.pop(resource_id, None) - return - find_resources[resource_id].updated_from(serialized_resource) + show = options.show + extracted = options.extracted + + if ResourceProperties.EXTRACTED in show and has_feature( + const.Features.IGNORE_EXTRACTED_IN_SEARCH, context={"kbid": kbid}, default=False + ): + # Returning extracted metadata in search results is deprecated and this flag + # will be set to True for all KBs in the future. + show.remove(ResourceProperties.EXTRACTED) + extracted = [] + + async with AsyncExitStack() as stack: + if concurrency_control is not None: + await stack.enter_async_context(concurrency_control) + + async with get_driver().transaction(read_only=True) as ro_txn: + serialized_resource = await managed_serialize( + txn=ro_txn, + kbid=kbid, + rid=resource_id, + show=show, + field_type_filter=options.field_type_filter, + extracted=extracted, + service_name=service_name, + ) + if serialized_resource is not None: + find_resources[resource_id].updated_from(serialized_resource) + else: + logger.warning( + "Resource not found in database", extra={"kbid": kbid, "rid": resource_id} + ) + find_resources.pop(resource_id, None) def text_block_to_find_paragraph(text_block: TextBlockMatch) -> FindParagraph: diff --git a/nucliadb/tests/nucliadb/integration/test_find.py b/nucliadb/tests/nucliadb/integration/test_find.py index 57da555dfc..cf60baedf1 100644 --- a/nucliadb/tests/nucliadb/integration/test_find.py +++ b/nucliadb/tests/nucliadb/integration/test_find.py @@ -238,7 +238,7 @@ async def test_story_7286( ) assert resp.status_code == 200 - with patch("nucliadb.search.search.find_merge.managed_serialize", return_value=None): + with patch("nucliadb.search.search.results_hydrator.base.managed_serialize", return_value=None): # should get no result (because serialize returns None, as the resource is not found in the DB) resp = await nucliadb_reader.post( f"/kb/{knowledgebox}/find", diff --git a/nucliadb/tests/search/unit/test_find_post_index.py b/nucliadb/tests/search/unit/test_find_post_index.py index 9e70f74343..697a505508 100644 --- a/nucliadb/tests/search/unit/test_find_post_index.py +++ b/nucliadb/tests/search/unit/test_find_post_index.py @@ -115,8 +115,7 @@ async def test_find_post_index_search(expected_find_response: dict): with ( patch("nucliadb.search.search.find_merge.set_text_value"), - patch("nucliadb.search.search.find_merge.set_resource_metadata_value"), - patch("nucliadb.search.search.find_merge.get_driver"), + patch("nucliadb.search.search.find_merge.hydrate_resource_metadata"), ): find_response = await find_merge_results( search_responses,