From 874bc68af39baab90acb4d4a77e1b45ea92bf166 Mon Sep 17 00:00:00 2001 From: Carles Onielfa <31882346+carlesonielfa@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:04:06 +0200 Subject: [PATCH] Add citation threshold on ask (#2466) * Add citation threshold on ask * Better wording --- nucliadb/src/nucliadb/search/search/chat/ask.py | 1 + nucliadb/tests/nucliadb/integration/test_ask.py | 2 +- nucliadb_models/src/nucliadb_models/search.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/nucliadb/src/nucliadb/search/search/chat/ask.py b/nucliadb/src/nucliadb/search/search/chat/ask.py index de28e97b21..b0d80c7c10 100644 --- a/nucliadb/src/nucliadb/search/search/chat/ask.py +++ b/nucliadb/src/nucliadb/search/search/chat/ask.py @@ -475,6 +475,7 @@ async def ask( question=user_query, truncate=True, citations=ask_request.citations, + citation_threshold=ask_request.citation_threshold, generative_model=ask_request.generative_model, max_tokens=query_parser.get_max_tokens_answer(), query_context_images=prompt_context_images, diff --git a/nucliadb/tests/nucliadb/integration/test_ask.py b/nucliadb/tests/nucliadb/integration/test_ask.py index 54658bbbf0..399f67f594 100644 --- a/nucliadb/tests/nucliadb/integration/test_ask.py +++ b/nucliadb/tests/nucliadb/integration/test_ask.py @@ -121,7 +121,7 @@ async def test_ask_with_citations(nucliadb_reader: AsyncClient, knowledgebox, re resp = await nucliadb_reader.post( f"/kb/{knowledgebox}/ask", - json={"query": "title", "citations": True}, + json={"query": "title", "citations": True, "citation_threshold": 0.5}, headers={"X-Synchronous": "true"}, ) assert resp.status_code == 200 diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index a4e26fa9c0..bafb8bca0b 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -784,6 +784,12 @@ class ChatModel(BaseModel): default=None, description="Optional custom prompt input by the user" ) citations: bool = Field(default=False, description="Whether to include the citations in the answer") + citation_threshold: Optional[float] = Field( + default=None, + description="If citations is True, this sets the similarity threshold (0 to 1) for paragraphs to be included as citations. Lower values result in more citations. If not provided, Nuclia's default threshold is used.", # noqa + ge=0.0, + le=1.0, + ) generative_model: Optional[str] = Field( default=None, title="Generative model", @@ -1148,6 +1154,12 @@ class ChatRequest(BaseModel): default=False, description="Whether to include the citations for the answer in the response", ) + citation_threshold: Optional[float] = Field( + default=None, + description="If citations is True, this sets the similarity threshold (0 to 1) for paragraphs to be included as citations. Lower values result in more citations. If not provided, Nuclia's default threshold is used.", + ge=0.0, + le=1.0, + ) security: Optional[RequestSecurity] = SearchParamDefaults.security.to_pydantic_field() rag_strategies: list[RagStrategies] = Field( default=[],