Skip to content

Commit

Permalink
Use strategy pattern for rank fusion algorithms + Implement reciproca…
Browse files Browse the repository at this point in the history
…l rank fusion (RRF) (#2599)

* Strategy pattern for rank fusion and add RRF

RRF still not available to the API

* Add rank_fusion parameter to /find and /ask (void)

* Move rank fusion parsing to get_rank_fusion

* Implement RRF

* Add discriminator to help pydantic validation

* Parametrize RRF + Implement weighted (boosted) RRF

* Implement RRF deduplication

* Fix default_factory for rank fusion

* Use a model for RRF boosting

* Hide rank_fusion parameter from public API
  • Loading branch information
jotare authored Nov 8, 2024
1 parent 392bd4e commit d47930c
Show file tree
Hide file tree
Showing 14 changed files with 913 additions and 276 deletions.
5 changes: 5 additions & 0 deletions nucliadb/src/nucliadb/search/api/v1/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
FindRequest,
KnowledgeboxFindResults,
NucliaDBClientType,
RankFusionName,
Reranker,
ResourceProperties,
SearchOptions,
Expand Down Expand Up @@ -125,6 +126,9 @@ async def find_knowledgebox(
autofilter: bool = fastapi_query(SearchParamDefaults.autofilter),
security_groups: list[str] = fastapi_query(SearchParamDefaults.security_groups),
show_hidden: bool = fastapi_query(SearchParamDefaults.show_hidden),
rank_fusion: RankFusionName = fastapi_query(
SearchParamDefaults.rank_fusion, include_in_schema=False
),
reranker: Reranker = fastapi_query(SearchParamDefaults.reranker),
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
x_nucliadb_user: str = Header(""),
Expand Down Expand Up @@ -159,6 +163,7 @@ async def find_knowledgebox(
autofilter=autofilter,
security=security,
show_hidden=show_hidden,
rank_fusion=rank_fusion,
reranker=reranker,
)
except ValidationError as exc:
Expand Down
1 change: 1 addition & 0 deletions nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ async def run_main_query(
find_request.debug = item.debug
find_request.rephrase = item.rephrase
find_request.rephrase_prompt = parse_rephrase_prompt(item)
find_request.rank_fusion = item.rank_fusion
find_request.reranker = item.reranker
# We don't support pagination, we always get the top_k results.
find_request.top_k = item.top_k
Expand Down
4 changes: 4 additions & 0 deletions nucliadb/src/nucliadb/search/search/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from nucliadb.search.search.metrics import RAGMetrics
from nucliadb.search.search.query import QueryParser
from nucliadb.search.search.rank_fusion import get_rank_fusion
from nucliadb.search.search.rerankers import RerankingOptions, get_reranker
from nucliadb.search.search.utils import (
filter_hidden_resources,
Expand Down Expand Up @@ -120,6 +121,7 @@ async def _index_node_retrieval(
extracted=item.extracted,
field_type_filter=item.field_type_filter,
highlight=item.highlight,
rank_fusion_algorithm=query_parser.rank_fusion,
reranker=query_parser.reranker,
)

Expand Down Expand Up @@ -247,6 +249,7 @@ async def query_parser_from_find_request(
hidden = await filter_hidden_resources(kbid, item.show_hidden)

reranker = get_reranker(item.reranker)
rank_fusion = get_rank_fusion(item.rank_fusion)
query_parser = QueryParser(
kbid=kbid,
features=item.features,
Expand Down Expand Up @@ -274,6 +277,7 @@ async def query_parser_from_find_request(
rephrase=item.rephrase,
rephrase_prompt=item.rephrase_prompt,
hidden=hidden,
rank_fusion=rank_fusion,
reranker=reranker,
)
return query_parser
167 changes: 90 additions & 77 deletions nucliadb/src/nucliadb/search/search/find_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
text_block_to_find_paragraph,
)
from nucliadb.search.search.merge import merge_relations_results
from nucliadb.search.search.rank_fusion import RankFusionAlgorithm
from nucliadb.search.search.rerankers import (
RerankableItem,
Reranker,
Expand Down Expand Up @@ -77,6 +78,7 @@ async def build_find_response(
page_number: int,
min_score_bm25: float,
min_score_semantic: float,
rank_fusion_algorithm: RankFusionAlgorithm,
reranker: Reranker,
show: list[ResourceProperties] = [],
extracted: list[ExtractedDataTypeName] = [],
Expand All @@ -85,12 +87,17 @@ async def build_find_response(
) -> KnowledgeboxFindResults:
# merge
search_response = merge_shard_responses(search_responses)
merged_text_blocks: list[TextBlockMatch] = rank_fusion_merge(
search_response.paragraph.results,

keyword_results = keyword_results_to_text_block_matches(search_response.paragraph.results)
semantic_results = semantic_results_to_text_block_matches(
filter(
lambda x: x.score >= min_score_semantic,
search_response.vector.documents,
),
)
)

merged_text_blocks: list[TextBlockMatch] = rank_fusion_merge(
keyword_results, semantic_results, rank_fusion_algorithm
)

# cut
Expand Down Expand Up @@ -212,86 +219,92 @@ def merge_shards_relation_responses(
return merged


@merge_observer.wrap({"type": "rank_fusion_merge"})
def rank_fusion_merge(
paragraphs: Iterable[ParagraphResult],
vectors: Iterable[DocumentScored],
) -> list[TextBlockMatch]:
"""Merge results from different indexes using a rank fusion algorithm.
def keyword_result_to_text_block_match(item: ParagraphResult) -> TextBlockMatch:
fuzzy_result = len(item.matches) > 0
return TextBlockMatch(
paragraph_id=ParagraphId.from_string(item.paragraph),
score=item.score.bm25,
score_type=SCORE_TYPE.BM25,
order=0, # NOTE: this will be filled later
text="", # NOTE: this will be filled later too
position=TextPosition(
page_number=item.metadata.position.page_number,
index=item.metadata.position.index,
start=item.start,
end=item.end,
start_seconds=[x for x in item.metadata.position.start_seconds],
end_seconds=[x for x in item.metadata.position.end_seconds],
),
paragraph_labels=[x for x in item.labels], # XXX could be list(paragraph.labels)?
fuzzy_search=fuzzy_result,
is_a_table=item.metadata.representation.is_a_table,
representation_file=item.metadata.representation.file,
page_with_visual=item.metadata.page_with_visual,
)

Given two list of sorted results from keyword and semantic search, this rank
fusion algorithm mixes them in the following way:
- 1st result from keyword search
- 2nd result from semantic search
- 2 keyword results and 1 semantic

"""
merged_paragraphs: list[TextBlockMatch] = []

# sort results by it's score before merging them
paragraphs = [p for p in sorted(paragraphs, key=lambda r: r.score.bm25, reverse=True)]
vectors = [v for v in sorted(vectors, key=lambda r: r.score, reverse=True)]

for paragraph in paragraphs:
fuzzy_result = len(paragraph.matches) > 0
merged_paragraphs.append(
TextBlockMatch(
paragraph_id=ParagraphId.from_string(paragraph.paragraph),
score=paragraph.score.bm25,
score_type=SCORE_TYPE.BM25,
order=0, # NOTE: this will be filled later
text="", # NOTE: this will be filled later too
position=TextPosition(
page_number=paragraph.metadata.position.page_number,
index=paragraph.metadata.position.index,
start=paragraph.start,
end=paragraph.end,
start_seconds=[x for x in paragraph.metadata.position.start_seconds],
end_seconds=[x for x in paragraph.metadata.position.end_seconds],
),
paragraph_labels=[x for x in paragraph.labels], # XXX could be list(paragraph.labels)?
fuzzy_search=fuzzy_result,
is_a_table=paragraph.metadata.representation.is_a_table,
representation_file=paragraph.metadata.representation.file,
page_with_visual=paragraph.metadata.page_with_visual,
)
)
def keyword_results_to_text_block_matches(items: Iterable[ParagraphResult]) -> list[TextBlockMatch]:
return [keyword_result_to_text_block_match(item) for item in items]

nextpos = 1
for vector in vectors:

class InvalidDocId(Exception):
"""Raised while parsing an invalid id coming from semantic search"""

def __init__(self, invalid_vector_id: str):
self.invalid_vector_id = invalid_vector_id
super().__init__(f"Invalid vector ID: {invalid_vector_id}")


def semantic_result_to_text_block_match(item: DocumentScored) -> TextBlockMatch:
try:
vector_id = VectorId.from_string(item.doc_id.id)
except (IndexError, ValueError):
raise InvalidDocId(item.doc_id.id)

return TextBlockMatch(
paragraph_id=ParagraphId.from_vector_id(vector_id),
score=item.score,
score_type=SCORE_TYPE.VECTOR,
order=0, # NOTE: this will be filled later
text="", # NOTE: this will be filled later too
position=TextPosition(
page_number=item.metadata.position.page_number,
index=item.metadata.position.index,
start=vector_id.vector_start,
end=vector_id.vector_end,
start_seconds=[x for x in item.metadata.position.start_seconds],
end_seconds=[x for x in item.metadata.position.end_seconds],
),
# TODO: get labels from index
field_labels=[],
paragraph_labels=[],
fuzzy_search=False, # semantic search doesn't have fuzziness
is_a_table=item.metadata.representation.is_a_table,
representation_file=item.metadata.representation.file,
page_with_visual=item.metadata.page_with_visual,
)


def semantic_results_to_text_block_matches(items: Iterable[DocumentScored]) -> list[TextBlockMatch]:
text_blocks: list[TextBlockMatch] = []
for item in items:
try:
vector_id = VectorId.from_string(vector.doc_id.id)
except (IndexError, ValueError):
logger.warning(f"Skipping invalid doc_id: {vector.doc_id.id}")
text_block = semantic_result_to_text_block_match(item)
except InvalidDocId as exc:
logger.warning(f"Skipping invalid doc_id: {exc.invalid_vector_id}")
continue
merged_paragraphs.insert(
nextpos,
TextBlockMatch(
paragraph_id=ParagraphId.from_vector_id(vector_id),
score=vector.score,
score_type=SCORE_TYPE.VECTOR,
order=0, # NOTE: this will be filled later
text="", # NOTE: this will be filled later too
position=TextPosition(
page_number=vector.metadata.position.page_number,
index=vector.metadata.position.index,
start=vector_id.vector_start,
end=vector_id.vector_end,
start_seconds=[x for x in vector.metadata.position.start_seconds],
end_seconds=[x for x in vector.metadata.position.end_seconds],
),
# TODO: get labels from index
field_labels=[],
paragraph_labels=[],
fuzzy_search=False, # semantic search doesn't have fuzziness
is_a_table=vector.metadata.representation.is_a_table,
representation_file=vector.metadata.representation.file,
page_with_visual=vector.metadata.page_with_visual,
),
)
nextpos += 3
text_blocks.append(text_block)
return text_blocks


return merged_paragraphs
# we use a wrapper function to apply observability here
@merge_observer.wrap({"type": "rank_fusion_merge"})
def rank_fusion_merge(
keyword: Iterable[TextBlockMatch],
semantic: Iterable[TextBlockMatch],
rank_fusion_algorithm: RankFusionAlgorithm,
) -> list[TextBlockMatch]:
return rank_fusion_algorithm.fuse(keyword, semantic)


def cut_page(items: list[Any], page_size: int, page_number: int) -> tuple[list[Any], bool]:
Expand Down
11 changes: 7 additions & 4 deletions nucliadb/src/nucliadb/search/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@
node_features,
query_parse_dependency_observer,
)
from nucliadb.search.search.rank_fusion import RankFusionAlgorithm, get_default_rank_fusion
from nucliadb.search.search.rerankers import (
MultiMatchBoosterReranker,
PredictReranker,
Reranker,
get_default_reranker,
)
from nucliadb.search.utilities import get_predict
from nucliadb_models.internal.predict import QueryInfo
Expand Down Expand Up @@ -127,7 +128,8 @@ def __init__(
rephrase_prompt: Optional[str] = None,
max_tokens: Optional[MaxTokens] = None,
hidden: Optional[bool] = None,
reranker: Reranker = MultiMatchBoosterReranker(),
rank_fusion: RankFusionAlgorithm = get_default_rank_fusion(),
reranker: Reranker = get_default_reranker(),
):
self.kbid = kbid
self.features = features
Expand Down Expand Up @@ -168,13 +170,14 @@ def __init__(
self.label_filters = translate_label_filters(self.label_filters)
self.flat_label_filters = flatten_filter_literals(self.label_filters)
self.max_tokens = max_tokens
self.rank_fusion = rank_fusion
self.reranker: Reranker
if page_number > 0 and isinstance(reranker, PredictReranker):
logger.warning(
"Trying to use predict reranker with pagination. Using multi-match booster instead",
"Trying to use predict reranker with pagination. Using default instead",
extra={"kbid": kbid},
)
self.reranker = MultiMatchBoosterReranker()
self.reranker = get_default_reranker()
else:
self.reranker = reranker

Expand Down
Loading

0 comments on commit d47930c

Please sign in to comment.