Skip to content

Commit

Permalink
Add rephrase query prompt (#2473)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Sep 19, 2024
1 parent 874bc68 commit 292bf3e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 13 deletions.
1 change: 1 addition & 0 deletions nucliadb/src/nucliadb/search/api/v1/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ async def search(
autofilter=item.autofilter,
security=item.security,
rephrase=item.rephrase,
rephrase_prompt=item.rephrase_prompt,
)
pb_query, incomplete_results, autofilters = await query_parser.parse()

Expand Down
23 changes: 21 additions & 2 deletions nucliadb/src/nucliadb/search/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,24 @@ async def query(
sentence: str,
semantic_model: Optional[str] = None,
generative_model: Optional[str] = None,
rephrase: Optional[bool] = False,
rephrase: bool = False,
rephrase_prompt: Optional[str] = None,
) -> QueryInfo:
"""
Query endpoint: returns information to be used by NucliaDB at retrieval time, for instance:
- The embeddings
- The entities
- The stop words
- The semantic threshold
- etc.
:param kbid: KnowledgeBox ID
:param sentence: The query sentence
:param semantic_model: The semantic model to use to generate the embeddings
:param generative_model: The generative model that will be used to generate the answer
:param rephrase: If the query should be rephrased before calculating the embeddings for a better retrieval
:param rephrase_prompt: Custom prompt to use for rephrasing
"""
try:
self.check_nua_key_is_configured_for_onprem()
except NUAKeyMissingError:
Expand All @@ -375,6 +391,8 @@ async def query(
"text": sentence,
"rephrase": str(rephrase),
}
if rephrase_prompt is not None:
params["rephrase_prompt"] = rephrase_prompt
if semantic_model is not None:
params["semantic_models"] = [semantic_model]
if generative_model is not None:
Expand Down Expand Up @@ -491,7 +509,8 @@ async def query(
sentence: str,
semantic_model: Optional[str] = None,
generative_model: Optional[str] = None,
rephrase: Optional[bool] = False,
rephrase: bool = False,
rephrase_prompt: Optional[str] = None,
) -> QueryInfo:
self.calls.append(("query", sentence))

Expand Down
8 changes: 5 additions & 3 deletions nucliadb/src/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
SyncAskResponse,
UserPrompt,
parse_custom_prompt,
parse_rephrase_prompt,
)
from nucliadb_telemetry import errors
from nucliadb_utils.exceptions import LimitsExceededError
Expand Down Expand Up @@ -462,9 +463,8 @@ async def ask(
prompt_context_images,
) = await prompt_context_builder.build()

custom_prompt = parse_custom_prompt(ask_request)

# Make the chat request to the predict API
custom_prompt = parse_custom_prompt(ask_request)
chat_model = ChatModel(
user_id=user_id,
system=custom_prompt.system,
Expand Down Expand Up @@ -756,6 +756,7 @@ def calculate_prequeries_for_json_schema(ask_request: AskRequest) -> Optional[Pr
features.append(SearchOptions.SEMANTIC)
if ChatOptions.KEYWORD in ask_request.features:
features.append(SearchOptions.KEYWORD)

properties = json_schema.get("parameters", {}).get("properties", {})
if len(properties) == 0: # pragma: no cover
return None
Expand All @@ -778,7 +779,8 @@ def calculate_prequeries_for_json_schema(ask_request: AskRequest) -> Optional[Pr
with_duplicates=False,
with_synonyms=False,
resource_filters=[], # to be filled with the resource filter
rephrase=False,
rephrase=ask_request.rephrase,
rephrase_prompt=parse_rephrase_prompt(ask_request),
security=ask_request.security,
autofilter=False,
)
Expand Down
2 changes: 2 additions & 0 deletions nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Relations,
RephraseModel,
SearchOptions,
parse_rephrase_prompt,
)
from nucliadb_protos import audit_pb2
from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse
Expand Down Expand Up @@ -178,6 +179,7 @@ async def run_main_query(
find_request.security = item.security
find_request.debug = item.debug
find_request.rephrase = item.rephrase
find_request.rephrase_prompt = parse_rephrase_prompt(item)
# We don't support pagination, we always get the top_k results.
find_request.page_size = item.top_k
find_request.page_number = 0
Expand Down
2 changes: 2 additions & 0 deletions nucliadb/src/nucliadb/search/search/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def _index_node_retrieval(
security=item.security,
generative_model=generative_model,
rephrase=item.rephrase,
rephrase_prompt=item.rephrase_prompt,
)
with metrics.time("query_parse"):
pb_query, incomplete_results, autofilters = await query_parser.parse()
Expand Down Expand Up @@ -234,6 +235,7 @@ async def _external_index_retrieval(
security=item.security,
generative_model=generative_model,
rephrase=item.rephrase,
rephrase_prompt=item.rephrase_prompt,
)
search_request, incomplete_results, _ = await query_parser.parse()

Expand Down
7 changes: 5 additions & 2 deletions nucliadb/src/nucliadb/search/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
security: Optional[RequestSecurity] = None,
generative_model: Optional[str] = None,
rephrase: bool = False,
rephrase_prompt: Optional[str] = None,
max_tokens: Optional[MaxTokens] = None,
):
self.kbid = kbid
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
self.security = security
self.generative_model = generative_model
self.rephrase = rephrase
self.rephrase_prompt = rephrase_prompt
self.query_endpoint_used = False
if len(self.label_filters) > 0:
self.label_filters = translate_label_filters(self.label_filters)
Expand All @@ -168,7 +170,7 @@ def _get_query_information(self) -> Awaitable[QueryInfo]:
async def _query_information(self) -> QueryInfo:
vectorset = await self.select_vectorset()
return await query_information(
self.kbid, self.query, vectorset, self.generative_model, self.rephrase
self.kbid, self.query, vectorset, self.generative_model, self.rephrase, self.rephrase_prompt
)

def _get_matryoshka_dimension(self) -> Awaitable[Optional[int]]:
Expand Down Expand Up @@ -600,9 +602,10 @@ async def query_information(
semantic_model: Optional[str],
generative_model: Optional[str] = None,
rephrase: bool = False,
rephrase_prompt: Optional[str] = None,
) -> QueryInfo:
predict = get_predict()
return await predict.query(kbid, query, semantic_model, generative_model, rephrase)
return await predict.query(kbid, query, semantic_model, generative_model, rephrase, rephrase_prompt)


@query_parse_dependency_observer.wrap({"type": "detect_entities"})
Expand Down
59 changes: 53 additions & 6 deletions nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,29 @@ class BaseSearchRequest(BaseModel):

rephrase: bool = Field(
default=False,
title="Rephrase the query to improve search",
description="Consume LLM tokens to rephrase the query so the semantic search is better",
description=(
"Rephrase the query for a more efficient retrieval. This will consume LLM tokens and make the request slower."
),
)

rephrase_prompt: Optional[str] = Field(
default=None,
title="Rephrase",
description=(
"Rephrase prompt given to the generative model responsible for rephrasing the query for a more effective retrieval step. "
"This is only used if the `rephrase` flag is set to true in the request.\n"
"If not specified, Nuclia's default prompt is used. It must include the {question} placeholder. "
"The placeholder will be replaced with the original question"
),
min_length=1,
examples=[
"""Rephrase this question so its better for retrieval, and keep the rephrased question in the same language as the original.
QUESTION: {question}
Please return ONLY the question without any explanation. Just the rephrased question.""",
"""Rephrase this question so its better for retrieval, identify any part numbers and append them to the end of the question separated by a commas.
QUESTION: {question}
Please return ONLY the question without any explanation.""",
],
)

@field_validator("features", mode="after")
Expand Down Expand Up @@ -1061,7 +1082,7 @@ class CustomPrompt(BaseModel):
system: Optional[str] = Field(
default=None,
title="System prompt",
description="System prompt given to the generative model. This can help customize the behavior of the model. If not specified, the default model provider's prompt is used.", # noqa
description="System prompt given to the generative model responsible of generating the answer. This can help customize the behavior of the model when generating the answer. If not specified, the default model provider's prompt is used.", # noqa
min_length=1,
examples=[
"You are a medical assistant, use medical terminology",
Expand All @@ -1073,7 +1094,7 @@ class CustomPrompt(BaseModel):
user: Optional[str] = Field(
default=None,
title="User prompt",
description="User prompt given to the generative model. Use the words {context} and {question} in brackets where you want those fields to be placed, in case you want them in your prompt. Context will be the data returned by the retrieval step and question will be the user's query.", # noqa
description="User prompt given to the generative model responsible of generating the answer. Use the words {context} and {question} in brackets where you want those fields to be placed, in case you want them in your prompt. Context will be the data returned by the retrieval step and question will be the user's query.", # noqa
min_length=1,
examples=[
"Taking into account our previous conversation, and this context: {context} answer this {question}",
Expand All @@ -1082,6 +1103,25 @@ class CustomPrompt(BaseModel):
"Given this context: {context}. Answer this {question} using the provided context. Please, answer always in French",
],
)
rephrase: Optional[str] = Field(
default=None,
title="Rephrase",
description=(
"Rephrase prompt given to the generative model responsible for rephrasing the query for a more effective retrieval step. "
"This is only used if the `rephrase` flag is set to true in the request.\n"
"If not specified, Nuclia's default prompt is used. It must include the {question} placeholder. "
"The placeholder will be replaced with the original question"
),
min_length=1,
examples=[
"""Rephrase this question so its better for retrieval, and keep the rephrased question in the same language as the original.
QUESTION: {question}
Please return ONLY the question without any explanation. Just the rephrased question.""",
"""Rephrase this question so its better for retrieval, identify any part numbers and append them to the end of the question separated by a commas.
QUESTION: {question}
Please return ONLY the question without any explanation.""",
],
)


class ChatRequest(BaseModel):
Expand Down Expand Up @@ -1227,8 +1267,9 @@ class ChatRequest(BaseModel):

rephrase: bool = Field(
default=False,
title="Rephrase the query to improve search",
description="Consume LLM tokens to rephrase the query so the semantic search is better",
description=(
"Rephrase the query for a more efficient retrieval. This will consume LLM tokens and make the request slower."
),
)

prefer_markdown: bool = Field(
Expand Down Expand Up @@ -1723,4 +1764,10 @@ def parse_custom_prompt(item: AskRequest) -> CustomPrompt:
else:
prompt.user = item.prompt.user
prompt.system = item.prompt.system
prompt.rephrase = item.prompt.rephrase
return prompt


def parse_rephrase_prompt(item: AskRequest) -> Optional[str]:
prompt = parse_custom_prompt(item)
return prompt.rephrase

0 comments on commit 292bf3e

Please sign in to comment.