diff --git a/nucliadb/src/nucliadb/search/search/query.py b/nucliadb/src/nucliadb/search/search/query.py index 99a9b7b655..bafbd26c7c 100644 --- a/nucliadb/src/nucliadb/search/search/query.py +++ b/nucliadb/src/nucliadb/search/search/query.py @@ -19,6 +19,7 @@ # import asyncio import json +import string from datetime import datetime from typing import Any, Awaitable, Optional, Union @@ -539,7 +540,17 @@ async def parse_relation_search(self, request: nodereader_pb2.SearchRequest) -> return autofilters async def parse_synonyms(self, request: nodereader_pb2.SearchRequest) -> None: - if not self.with_synonyms: + """ + Replace the terms in the query with an expression that will make it match with the configured synonyms. + We're using the Tantivy's query language here: https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html + + Example: + - Synonyms: Foo -> Bar, Baz + - Query: "What is Foo?" + - Advanced Query: "What is (Foo OR Bar OR Baz)?" + """ + if not self.with_synonyms or not self.query: + # Nothing to do return if self.has_vector_search or self.has_relations_search: @@ -548,27 +559,32 @@ async def parse_synonyms(self, request: nodereader_pb2.SearchRequest) -> None: "Search with custom synonyms is only supported on paragraph and document search", ) - if not self.query: - # Nothing to do - return - synonyms = await self._get_synomyns() if synonyms is None: # No synonyms found return - synonyms_found: list[str] = [] - advanced_query = [] - for term in self.query.split(" "): - advanced_query.append(term) - term_synonyms = synonyms.terms.get(term) - if term_synonyms is None or len(term_synonyms.synonyms) == 0: - # No synonyms found for this term - continue - synonyms_found.extend(term_synonyms.synonyms) - - if len(synonyms_found): - request.advanced_query = " OR ".join(advanced_query + synonyms_found) + # Calculate term variants: 'term' -> '(term OR synonym1 OR synonym2)' + variants: dict[str, str] = {} + for term, term_synonyms in synonyms.terms.items(): + if len(term_synonyms.synonyms) > 0: + variants[term] = f"({" OR ".join([term] + list(term_synonyms.synonyms))})" + + # Split the query into terms + query_terms = self.query.split() + + # Remove punctuation from the query terms + clean_query_terms = [term.strip(string.punctuation) for term in query_terms] + + # Replace the original terms with the variants if the cleaned term is in the variants + term_with_synonyms_found = False + for index, clean_term in enumerate(clean_query_terms): + if clean_term in variants: + term_with_synonyms_found = True + query_terms[index] = query_terms[index].replace(clean_term, variants[clean_term]) + + if term_with_synonyms_found: + request.advanced_query = " ".join(query_terms) request.ClearField("body") async def get_visual_llm_enabled(self) -> bool: diff --git a/nucliadb/tests/nucliadb/integration/test_synonyms.py b/nucliadb/tests/nucliadb/integration/test_synonyms.py index 34cad5f6b9..e25e80b07f 100644 --- a/nucliadb/tests/nucliadb/integration/test_synonyms.py +++ b/nucliadb/tests/nucliadb/integration/test_synonyms.py @@ -132,9 +132,9 @@ async def test_search_with_synonyms( tomatoe_rid = resp.json()["uuid"] resp = await nucliadb_reader.post( - f"/kb/{kbid}/search", + f"/kb/{kbid}/find", json=dict( - features=["paragraph", "document"], + features=["keyword"], query="planet", with_synonyms=True, highlight=True, @@ -145,21 +145,19 @@ async def test_search_with_synonyms( # Paragraph and fulltext search should match on summary (term) # and title (synonym) for the two resources - assert len(body["paragraphs"]["results"]) == 4 - assert len(body["fulltext"]["results"]) == 4 + assert len(get_pararagraphs(body)) == 4 assert body["resources"][planet_rid] assert body["resources"][sphere_rid] assert tomatoe_rid not in body["resources"] # Check that searching without synonyms matches only query term resp = await nucliadb_reader.post( - f"/kb/{kbid}/search", - json=dict(features=["paragraph", "document"], query="planet"), + f"/kb/{kbid}/find", + json=dict(features=["keyword"], query="planet"), ) assert resp.status_code == 200 body = resp.json() - assert len(body["paragraphs"]["results"]) == 1 - assert len(body["fulltext"]["results"]) == 1 + assert len(get_pararagraphs(body)) == 1 assert body["resources"][planet_rid] assert sphere_rid not in body["resources"] assert tomatoe_rid not in body["resources"] @@ -167,22 +165,29 @@ async def test_search_with_synonyms( # Check that searching with a term that has synonyms and # one that doesn't matches all of them resp = await nucliadb_reader.post( - f"/kb/{kbid}/search", + f"/kb/{kbid}/find", json=dict( - features=["paragraph", "document"], + features=["keyword"], query="planet tomatoe", with_synonyms=True, ), ) assert resp.status_code == 200 body = resp.json() - assert len(body["paragraphs"]["results"]) == 5 - assert len(body["fulltext"]["results"]) == 5 + assert len(get_pararagraphs(body)) == 5 assert body["resources"][planet_rid] assert body["resources"][sphere_rid] assert body["resources"][tomatoe_rid] +def get_pararagraphs(body): + paragraphs = [] + for resource in body.get("resources", {}).values(): + for field in resource.get("fields", {}).values(): + paragraphs.extend(field.get("paragraphs", [])) + return paragraphs + + @pytest.mark.asyncio async def test_search_errors_if_vectors_or_relations_requested( nucliadb_reader, diff --git a/nucliadb/tests/search/unit/search/test_query.py b/nucliadb/tests/search/unit/search/test_query.py index 25ebd63316..9b6bad67c0 100644 --- a/nucliadb/tests/search/unit/search/test_query.py +++ b/nucliadb/tests/search/unit/search/test_query.py @@ -128,11 +128,12 @@ async def test_not_applies_if_synonyms_not_found_for_query( request.ClearField.assert_not_called() - query_parser.query = "planet" + # Append planetary at the end of the query to test that partial terms are not replaced + query_parser.query = "which is this planet? planetary" await query_parser.parse_synonyms(request) request.ClearField.assert_called_once_with("body") - assert request.advanced_query == "planet OR earth OR globe" + assert request.advanced_query == "which is this (planet OR earth OR globe)? planetary" def test_check_supported_filters():