From 55b59ae8c0568f22f347418c069fb4135536fdb6 Mon Sep 17 00:00:00 2001 From: Ferran Llamas Date: Mon, 18 Dec 2023 10:53:53 +0100 Subject: [PATCH] Return whether paragaph search results are fuzzy or not (#1673) --- nucliadb/nucliadb/search/search/find_merge.py | 20 +++---- nucliadb/nucliadb/search/search/merge.py | 2 + .../tests/integration/search/test_search.py | 58 ++++++++++++++++++ .../nucliadb/tests/integration/test_find.py | 60 +++++++++++++++++++ nucliadb_models/nucliadb_models/search.py | 3 + 5 files changed, 131 insertions(+), 12 deletions(-) diff --git a/nucliadb/nucliadb/search/search/find_merge.py b/nucliadb/nucliadb/search/search/find_merge.py index 0ce1df929c..a56ecfe164 100644 --- a/nucliadb/nucliadb/search/search/find_merge.py +++ b/nucliadb/nucliadb/search/search/find_merge.py @@ -64,9 +64,7 @@ async def set_text_value( ematches: Optional[List[str]] = None, extracted_text_cache: Optional[paragraphs.ExtractedTextCache] = None, ): - # TODO: Improve - await max_operations.acquire() - try: + async with max_operations: assert result_paragraph.paragraph assert result_paragraph.paragraph.position result_paragraph.paragraph.text = await paragraphs.get_paragraph_text( @@ -81,8 +79,6 @@ async def set_text_value( matches=[], # TODO extracted_text_cache=extracted_text_cache, ) - finally: - max_operations.release() @merge_observer.wrap({"type": "set_resource_metadada_value"}) @@ -95,9 +91,7 @@ async def set_resource_metadata_value( find_resources: Dict[str, FindResource], max_operations: asyncio.Semaphore, ): - await max_operations.acquire() - - try: + async with max_operations: serialized_resource = await serialize( kbid, resource, @@ -112,9 +106,6 @@ async def set_resource_metadata_value( logger.warning(f"Resource {resource} not found in {kbid}") find_resources.pop(resource, None) - finally: - max_operations.release() - class Orderer: def __init__(self): @@ -248,6 +239,7 @@ def merge_paragraphs_vectors( # We assume that paragraphs_shards and vectors_shards are already ordered for paragraphs_shard in paragraphs_shards: for paragraph in paragraphs_shard: + fuzzy_result = len(paragraph.matches) > 0 merged_paragrahs.append( TempFindParagraph( paragraph_index=paragraph, @@ -258,6 +250,7 @@ def merge_paragraphs_vectors( split=paragraph.split, end=paragraph.end, id=paragraph.paragraph, + fuzzy_result=fuzzy_result, ) ) @@ -322,8 +315,10 @@ def merge_paragraphs_vectors( ], ), id=merged_paragraph.id, + # Vector searches don't have fuzziness + fuzzy_result=False, ) - if merged_paragraph.paragraph_index is not None: + elif merged_paragraph.paragraph_index is not None: merged_paragraph.paragraph = FindParagraph( score=merged_paragraph.paragraph_index.score.bm25, score_type=SCORE_TYPE.BM25, @@ -344,6 +339,7 @@ def merge_paragraphs_vectors( ], ), id=merged_paragraph.id, + fuzzy_result=merged_paragraph.fuzzy_result, ) return merged_paragrahs, next_page diff --git a/nucliadb/nucliadb/search/search/merge.py b/nucliadb/nucliadb/search/search/merge.py index db4085f4a9..56f98a546f 100644 --- a/nucliadb/nucliadb/search/search/merge.py +++ b/nucliadb/nucliadb/search/search/merge.py @@ -401,6 +401,7 @@ async def merge_paragraph_results( extracted_text_cache=etcache, ) labels = await get_labels_paragraph(result, kbid) + fuzzy_result = len(result.matches) > 0 new_paragraph = Paragraph( score=result.score.bm25, rid=result.uuid, @@ -414,6 +415,7 @@ async def merge_paragraph_results( end=result.metadata.position.end, page_number=result.metadata.position.page_number, ), + fuzzy_result=fuzzy_result, ) if len(result.metadata.position.start_seconds) or len( result.metadata.position.end_seconds diff --git a/nucliadb/nucliadb/tests/integration/search/test_search.py b/nucliadb/nucliadb/tests/integration/search/test_search.py index 060a1a9c57..d54ba5304e 100644 --- a/nucliadb/nucliadb/tests/integration/search/test_search.py +++ b/nucliadb/nucliadb/tests/integration/search/test_search.py @@ -1441,3 +1441,61 @@ async def test_facets_validation( else: assert resp.status_code == 422 assert error_message == resp.json()["detail"][0]["msg"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) +async def test_search_marks_fuzzy_results( + nucliadb_reader: AsyncClient, + nucliadb_writer: AsyncClient, + knowledgebox, +): + resp = await nucliadb_writer.post( + f"/kb/{knowledgebox}/resources", + json={ + "slug": "myresource", + "title": "My Title", + }, + ) + assert resp.status_code == 201 + + # Should get only one non-fuzzy result + resp = await nucliadb_reader.post( + f"/kb/{knowledgebox}/search", + json={ + "query": "Title", + }, + ) + assert resp.status_code == 200 + body = resp.json() + check_fuzzy_paragraphs(body, fuzzy_result=False, n_expected=1) + + # Should get only one fuzzy result + resp = await nucliadb_reader.post( + f"/kb/{knowledgebox}/search", + json={ + "query": "totle", + }, + ) + assert resp.status_code == 200 + body = resp.json() + check_fuzzy_paragraphs(body, fuzzy_result=True, n_expected=1) + + # Should not get any result if exact match term queried + resp = await nucliadb_reader.post( + f"/kb/{knowledgebox}/search", + json={ + "query": '"totle"', + }, + ) + assert resp.status_code == 200 + body = resp.json() + check_fuzzy_paragraphs(body, fuzzy_result=True, n_expected=0) + + +def check_fuzzy_paragraphs(search_response, *, fuzzy_result: bool, n_expected: int): + found = 0 + for paragraph in search_response["paragraphs"]["results"]: + assert paragraph["fuzzy_result"] is fuzzy_result + found += 1 + assert found == n_expected diff --git a/nucliadb/nucliadb/tests/integration/test_find.py b/nucliadb/nucliadb/tests/integration/test_find.py index 57c1fa563f..14f1ae67a9 100644 --- a/nucliadb/nucliadb/tests/integration/test_find.py +++ b/nucliadb/nucliadb/tests/integration/test_find.py @@ -258,3 +258,63 @@ async def test_story_7286( body = resp.json() assert len(body["resources"]) == 0 assert caplog.record_tuples[0][2] == f"Resource {rid} not found in {knowledgebox}" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) +async def test_find_marks_fuzzy_results( + nucliadb_reader: AsyncClient, + nucliadb_writer: AsyncClient, + knowledgebox, +): + resp = await nucliadb_writer.post( + f"/kb/{knowledgebox}/resources", + json={ + "slug": "myresource", + "title": "My Title", + }, + ) + assert resp.status_code == 201 + + # Should get only one non-fuzzy result + resp = await nucliadb_reader.post( + f"/kb/{knowledgebox}/find", + json={ + "query": "Title", + }, + ) + assert resp.status_code == 200 + body = resp.json() + check_fuzzy_paragraphs(body, fuzzy_result=False, n_expected=1) + + # Should get only one fuzzy result + resp = await nucliadb_reader.post( + f"/kb/{knowledgebox}/find", + json={ + "query": "totle", + }, + ) + assert resp.status_code == 200 + body = resp.json() + check_fuzzy_paragraphs(body, fuzzy_result=True, n_expected=1) + + # Should not get any result if exact match term queried + resp = await nucliadb_reader.post( + f"/kb/{knowledgebox}/find", + json={ + "query": '"totle"', + }, + ) + assert resp.status_code == 200 + body = resp.json() + check_fuzzy_paragraphs(body, fuzzy_result=True, n_expected=0) + + +def check_fuzzy_paragraphs(find_response, *, fuzzy_result: bool, n_expected: int): + found = 0 + for resource in find_response["resources"].values(): + for field in resource["fields"].values(): + for paragraph in field["paragraphs"].values(): + assert paragraph["fuzzy_result"] is fuzzy_result + found += 1 + assert found == n_expected diff --git a/nucliadb_models/nucliadb_models/search.py b/nucliadb_models/nucliadb_models/search.py index 301cdbf61c..b510768327 100644 --- a/nucliadb_models/nucliadb_models/search.py +++ b/nucliadb_models/nucliadb_models/search.py @@ -150,6 +150,7 @@ class Paragraph(BaseModel): start_seconds: Optional[List[int]] = None end_seconds: Optional[List[int]] = None position: Optional[TextPosition] = None + fuzzy_result: bool = False class Paragraphs(BaseModel): @@ -825,6 +826,7 @@ class FindParagraph(BaseModel): id: str labels: Optional[List[str]] = [] position: Optional[TextPosition] = None + fuzzy_result: bool = False @dataclass @@ -839,6 +841,7 @@ class TempFindParagraph: paragraph: Optional[FindParagraph] = None vector_index: Optional[DocumentScored] = None paragraph_index: Optional[PBParagraphResult] = None + fuzzy_result: bool = False class FindField(BaseModel):