Skip to content

Commit

Permalink
Add citation threshold on ask (#2466)
Browse files Browse the repository at this point in the history
* Add citation threshold on ask

* Better wording
  • Loading branch information
carlesonielfa authored Sep 19, 2024
1 parent 5a871e4 commit 874bc68
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions nucliadb/src/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ async def ask(
question=user_query,
truncate=True,
citations=ask_request.citations,
citation_threshold=ask_request.citation_threshold,
generative_model=ask_request.generative_model,
max_tokens=query_parser.get_max_tokens_answer(),
query_context_images=prompt_context_images,
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/tests/nucliadb/integration/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def test_ask_with_citations(nucliadb_reader: AsyncClient, knowledgebox, re

resp = await nucliadb_reader.post(
f"/kb/{knowledgebox}/ask",
json={"query": "title", "citations": True},
json={"query": "title", "citations": True, "citation_threshold": 0.5},
headers={"X-Synchronous": "true"},
)
assert resp.status_code == 200
Expand Down
12 changes: 12 additions & 0 deletions nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,12 @@ class ChatModel(BaseModel):
default=None, description="Optional custom prompt input by the user"
)
citations: bool = Field(default=False, description="Whether to include the citations in the answer")
citation_threshold: Optional[float] = Field(
default=None,
description="If citations is True, this sets the similarity threshold (0 to 1) for paragraphs to be included as citations. Lower values result in more citations. If not provided, Nuclia's default threshold is used.", # noqa
ge=0.0,
le=1.0,
)
generative_model: Optional[str] = Field(
default=None,
title="Generative model",
Expand Down Expand Up @@ -1148,6 +1154,12 @@ class ChatRequest(BaseModel):
default=False,
description="Whether to include the citations for the answer in the response",
)
citation_threshold: Optional[float] = Field(
default=None,
description="If citations is True, this sets the similarity threshold (0 to 1) for paragraphs to be included as citations. Lower values result in more citations. If not provided, Nuclia's default threshold is used.",
ge=0.0,
le=1.0,
)
security: Optional[RequestSecurity] = SearchParamDefaults.security.to_pydantic_field()
rag_strategies: list[RagStrategies] = Field(
default=[],
Expand Down

0 comments on commit 874bc68

Please sign in to comment.