Skip to content

Commit

Permalink
Fix synonyms advanced query (#2606)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Nov 8, 2024
1 parent bc9cf29 commit ee7258b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 31 deletions.
50 changes: 33 additions & 17 deletions nucliadb/src/nucliadb/search/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#
import asyncio
import json
import string
from datetime import datetime
from typing import Any, Awaitable, Optional, Union

Expand Down Expand Up @@ -539,7 +540,17 @@ async def parse_relation_search(self, request: nodereader_pb2.SearchRequest) ->
return autofilters

async def parse_synonyms(self, request: nodereader_pb2.SearchRequest) -> None:
if not self.with_synonyms:
"""
Replace the terms in the query with an expression that will make it match with the configured synonyms.
We're using the Tantivy's query language here: https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html
Example:
- Synonyms: Foo -> Bar, Baz
- Query: "What is Foo?"
- Advanced Query: "What is (Foo OR Bar OR Baz)?"
"""
if not self.with_synonyms or not self.query:
# Nothing to do
return

if self.has_vector_search or self.has_relations_search:
Expand All @@ -548,27 +559,32 @@ async def parse_synonyms(self, request: nodereader_pb2.SearchRequest) -> None:
"Search with custom synonyms is only supported on paragraph and document search",
)

if not self.query:
# Nothing to do
return

synonyms = await self._get_synomyns()
if synonyms is None:
# No synonyms found
return

synonyms_found: list[str] = []
advanced_query = []
for term in self.query.split(" "):
advanced_query.append(term)
term_synonyms = synonyms.terms.get(term)
if term_synonyms is None or len(term_synonyms.synonyms) == 0:
# No synonyms found for this term
continue
synonyms_found.extend(term_synonyms.synonyms)

if len(synonyms_found):
request.advanced_query = " OR ".join(advanced_query + synonyms_found)
# Calculate term variants: 'term' -> '(term OR synonym1 OR synonym2)'
variants: dict[str, str] = {}
for term, term_synonyms in synonyms.terms.items():
if len(term_synonyms.synonyms) > 0:
variants[term] = f"({" OR ".join([term] + list(term_synonyms.synonyms))})"

# Split the query into terms
query_terms = self.query.split()

# Remove punctuation from the query terms
clean_query_terms = [term.strip(string.punctuation) for term in query_terms]

# Replace the original terms with the variants if the cleaned term is in the variants
term_with_synonyms_found = False
for index, clean_term in enumerate(clean_query_terms):
if clean_term in variants:
term_with_synonyms_found = True
query_terms[index] = query_terms[index].replace(clean_term, variants[clean_term])

if term_with_synonyms_found:
request.advanced_query = " ".join(query_terms)
request.ClearField("body")

async def get_visual_llm_enabled(self) -> bool:
Expand Down
29 changes: 17 additions & 12 deletions nucliadb/tests/nucliadb/integration/test_synonyms.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ async def test_search_with_synonyms(
tomatoe_rid = resp.json()["uuid"]

resp = await nucliadb_reader.post(
f"/kb/{kbid}/search",
f"/kb/{kbid}/find",
json=dict(
features=["paragraph", "document"],
features=["keyword"],
query="planet",
with_synonyms=True,
highlight=True,
Expand All @@ -145,44 +145,49 @@ async def test_search_with_synonyms(

# Paragraph and fulltext search should match on summary (term)
# and title (synonym) for the two resources
assert len(body["paragraphs"]["results"]) == 4
assert len(body["fulltext"]["results"]) == 4
assert len(get_pararagraphs(body)) == 4
assert body["resources"][planet_rid]
assert body["resources"][sphere_rid]
assert tomatoe_rid not in body["resources"]

# Check that searching without synonyms matches only query term
resp = await nucliadb_reader.post(
f"/kb/{kbid}/search",
json=dict(features=["paragraph", "document"], query="planet"),
f"/kb/{kbid}/find",
json=dict(features=["keyword"], query="planet"),
)
assert resp.status_code == 200
body = resp.json()
assert len(body["paragraphs"]["results"]) == 1
assert len(body["fulltext"]["results"]) == 1
assert len(get_pararagraphs(body)) == 1
assert body["resources"][planet_rid]
assert sphere_rid not in body["resources"]
assert tomatoe_rid not in body["resources"]

# Check that searching with a term that has synonyms and
# one that doesn't matches all of them
resp = await nucliadb_reader.post(
f"/kb/{kbid}/search",
f"/kb/{kbid}/find",
json=dict(
features=["paragraph", "document"],
features=["keyword"],
query="planet tomatoe",
with_synonyms=True,
),
)
assert resp.status_code == 200
body = resp.json()
assert len(body["paragraphs"]["results"]) == 5
assert len(body["fulltext"]["results"]) == 5
assert len(get_pararagraphs(body)) == 5
assert body["resources"][planet_rid]
assert body["resources"][sphere_rid]
assert body["resources"][tomatoe_rid]


def get_pararagraphs(body):
paragraphs = []
for resource in body.get("resources", {}).values():
for field in resource.get("fields", {}).values():
paragraphs.extend(field.get("paragraphs", []))
return paragraphs


@pytest.mark.asyncio
async def test_search_errors_if_vectors_or_relations_requested(
nucliadb_reader,
Expand Down
5 changes: 3 additions & 2 deletions nucliadb/tests/search/unit/search/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,12 @@ async def test_not_applies_if_synonyms_not_found_for_query(

request.ClearField.assert_not_called()

query_parser.query = "planet"
# Append planetary at the end of the query to test that partial terms are not replaced
query_parser.query = "which is this planet? planetary"
await query_parser.parse_synonyms(request)

request.ClearField.assert_called_once_with("body")
assert request.advanced_query == "planet OR earth OR globe"
assert request.advanced_query == "which is this (planet OR earth OR globe)? planetary"


def test_check_supported_filters():
Expand Down

0 comments on commit ee7258b

Please sign in to comment.