From 292bf3efbcf3d2a5aef63a551dbc5f050fb88358 Mon Sep 17 00:00:00 2001 From: Ferran Llamas Date: Thu, 19 Sep 2024 16:15:25 +0200 Subject: [PATCH] Add rephrase query prompt (#2473) --- nucliadb/src/nucliadb/search/api/v1/search.py | 1 + nucliadb/src/nucliadb/search/predict.py | 23 +++++++- .../src/nucliadb/search/search/chat/ask.py | 8 ++- .../src/nucliadb/search/search/chat/query.py | 2 + nucliadb/src/nucliadb/search/search/find.py | 2 + nucliadb/src/nucliadb/search/search/query.py | 7 ++- nucliadb_models/src/nucliadb_models/search.py | 59 +++++++++++++++++-- 7 files changed, 89 insertions(+), 13 deletions(-) diff --git a/nucliadb/src/nucliadb/search/api/v1/search.py b/nucliadb/src/nucliadb/search/api/v1/search.py index aaa3ed40c5..e97d1e2507 100644 --- a/nucliadb/src/nucliadb/search/api/v1/search.py +++ b/nucliadb/src/nucliadb/search/api/v1/search.py @@ -477,6 +477,7 @@ async def search( autofilter=item.autofilter, security=item.security, rephrase=item.rephrase, + rephrase_prompt=item.rephrase_prompt, ) pb_query, incomplete_results, autofilters = await query_parser.parse() diff --git a/nucliadb/src/nucliadb/search/predict.py b/nucliadb/src/nucliadb/search/predict.py index 48c32c9b70..ddc10a64fe 100644 --- a/nucliadb/src/nucliadb/search/predict.py +++ b/nucliadb/src/nucliadb/search/predict.py @@ -362,8 +362,24 @@ async def query( sentence: str, semantic_model: Optional[str] = None, generative_model: Optional[str] = None, - rephrase: Optional[bool] = False, + rephrase: bool = False, + rephrase_prompt: Optional[str] = None, ) -> QueryInfo: + """ + Query endpoint: returns information to be used by NucliaDB at retrieval time, for instance: + - The embeddings + - The entities + - The stop words + - The semantic threshold + - etc. + + :param kbid: KnowledgeBox ID + :param sentence: The query sentence + :param semantic_model: The semantic model to use to generate the embeddings + :param generative_model: The generative model that will be used to generate the answer + :param rephrase: If the query should be rephrased before calculating the embeddings for a better retrieval + :param rephrase_prompt: Custom prompt to use for rephrasing + """ try: self.check_nua_key_is_configured_for_onprem() except NUAKeyMissingError: @@ -375,6 +391,8 @@ async def query( "text": sentence, "rephrase": str(rephrase), } + if rephrase_prompt is not None: + params["rephrase_prompt"] = rephrase_prompt if semantic_model is not None: params["semantic_models"] = [semantic_model] if generative_model is not None: @@ -491,7 +509,8 @@ async def query( sentence: str, semantic_model: Optional[str] = None, generative_model: Optional[str] = None, - rephrase: Optional[bool] = False, + rephrase: bool = False, + rephrase_prompt: Optional[str] = None, ) -> QueryInfo: self.calls.append(("query", sentence)) diff --git a/nucliadb/src/nucliadb/search/search/chat/ask.py b/nucliadb/src/nucliadb/search/search/chat/ask.py index b0d80c7c10..c3c80486d3 100644 --- a/nucliadb/src/nucliadb/search/search/chat/ask.py +++ b/nucliadb/src/nucliadb/search/search/chat/ask.py @@ -87,6 +87,7 @@ SyncAskResponse, UserPrompt, parse_custom_prompt, + parse_rephrase_prompt, ) from nucliadb_telemetry import errors from nucliadb_utils.exceptions import LimitsExceededError @@ -462,9 +463,8 @@ async def ask( prompt_context_images, ) = await prompt_context_builder.build() - custom_prompt = parse_custom_prompt(ask_request) - # Make the chat request to the predict API + custom_prompt = parse_custom_prompt(ask_request) chat_model = ChatModel( user_id=user_id, system=custom_prompt.system, @@ -756,6 +756,7 @@ def calculate_prequeries_for_json_schema(ask_request: AskRequest) -> Optional[Pr features.append(SearchOptions.SEMANTIC) if ChatOptions.KEYWORD in ask_request.features: features.append(SearchOptions.KEYWORD) + properties = json_schema.get("parameters", {}).get("properties", {}) if len(properties) == 0: # pragma: no cover return None @@ -778,7 +779,8 @@ def calculate_prequeries_for_json_schema(ask_request: AskRequest) -> Optional[Pr with_duplicates=False, with_synonyms=False, resource_filters=[], # to be filled with the resource filter - rephrase=False, + rephrase=ask_request.rephrase, + rephrase_prompt=parse_rephrase_prompt(ask_request), security=ask_request.security, autofilter=False, ) diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index 418749f099..e68141f9b2 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -46,6 +46,7 @@ Relations, RephraseModel, SearchOptions, + parse_rephrase_prompt, ) from nucliadb_protos import audit_pb2 from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse @@ -178,6 +179,7 @@ async def run_main_query( find_request.security = item.security find_request.debug = item.debug find_request.rephrase = item.rephrase + find_request.rephrase_prompt = parse_rephrase_prompt(item) # We don't support pagination, we always get the top_k results. find_request.page_size = item.top_k find_request.page_number = 0 diff --git a/nucliadb/src/nucliadb/search/search/find.py b/nucliadb/src/nucliadb/search/search/find.py index 608fc537bb..df1bae0c52 100644 --- a/nucliadb/src/nucliadb/search/search/find.py +++ b/nucliadb/src/nucliadb/search/search/find.py @@ -123,6 +123,7 @@ async def _index_node_retrieval( security=item.security, generative_model=generative_model, rephrase=item.rephrase, + rephrase_prompt=item.rephrase_prompt, ) with metrics.time("query_parse"): pb_query, incomplete_results, autofilters = await query_parser.parse() @@ -234,6 +235,7 @@ async def _external_index_retrieval( security=item.security, generative_model=generative_model, rephrase=item.rephrase, + rephrase_prompt=item.rephrase_prompt, ) search_request, incomplete_results, _ = await query_parser.parse() diff --git a/nucliadb/src/nucliadb/search/search/query.py b/nucliadb/src/nucliadb/search/search/query.py index 4f74c3dba9..8e41dc6818 100644 --- a/nucliadb/src/nucliadb/search/search/query.py +++ b/nucliadb/src/nucliadb/search/search/query.py @@ -118,6 +118,7 @@ def __init__( security: Optional[RequestSecurity] = None, generative_model: Optional[str] = None, rephrase: bool = False, + rephrase_prompt: Optional[str] = None, max_tokens: Optional[MaxTokens] = None, ): self.kbid = kbid @@ -146,6 +147,7 @@ def __init__( self.security = security self.generative_model = generative_model self.rephrase = rephrase + self.rephrase_prompt = rephrase_prompt self.query_endpoint_used = False if len(self.label_filters) > 0: self.label_filters = translate_label_filters(self.label_filters) @@ -168,7 +170,7 @@ def _get_query_information(self) -> Awaitable[QueryInfo]: async def _query_information(self) -> QueryInfo: vectorset = await self.select_vectorset() return await query_information( - self.kbid, self.query, vectorset, self.generative_model, self.rephrase + self.kbid, self.query, vectorset, self.generative_model, self.rephrase, self.rephrase_prompt ) def _get_matryoshka_dimension(self) -> Awaitable[Optional[int]]: @@ -600,9 +602,10 @@ async def query_information( semantic_model: Optional[str], generative_model: Optional[str] = None, rephrase: bool = False, + rephrase_prompt: Optional[str] = None, ) -> QueryInfo: predict = get_predict() - return await predict.query(kbid, query, semantic_model, generative_model, rephrase) + return await predict.query(kbid, query, semantic_model, generative_model, rephrase, rephrase_prompt) @query_parse_dependency_observer.wrap({"type": "detect_entities"}) diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index bafb8bca0b..8c9dd05448 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -687,8 +687,29 @@ class BaseSearchRequest(BaseModel): rephrase: bool = Field( default=False, - title="Rephrase the query to improve search", - description="Consume LLM tokens to rephrase the query so the semantic search is better", + description=( + "Rephrase the query for a more efficient retrieval. This will consume LLM tokens and make the request slower." + ), + ) + + rephrase_prompt: Optional[str] = Field( + default=None, + title="Rephrase", + description=( + "Rephrase prompt given to the generative model responsible for rephrasing the query for a more effective retrieval step. " + "This is only used if the `rephrase` flag is set to true in the request.\n" + "If not specified, Nuclia's default prompt is used. It must include the {question} placeholder. " + "The placeholder will be replaced with the original question" + ), + min_length=1, + examples=[ + """Rephrase this question so its better for retrieval, and keep the rephrased question in the same language as the original. +QUESTION: {question} +Please return ONLY the question without any explanation. Just the rephrased question.""", + """Rephrase this question so its better for retrieval, identify any part numbers and append them to the end of the question separated by a commas. + QUESTION: {question} + Please return ONLY the question without any explanation.""", + ], ) @field_validator("features", mode="after") @@ -1061,7 +1082,7 @@ class CustomPrompt(BaseModel): system: Optional[str] = Field( default=None, title="System prompt", - description="System prompt given to the generative model. This can help customize the behavior of the model. If not specified, the default model provider's prompt is used.", # noqa + description="System prompt given to the generative model responsible of generating the answer. This can help customize the behavior of the model when generating the answer. If not specified, the default model provider's prompt is used.", # noqa min_length=1, examples=[ "You are a medical assistant, use medical terminology", @@ -1073,7 +1094,7 @@ class CustomPrompt(BaseModel): user: Optional[str] = Field( default=None, title="User prompt", - description="User prompt given to the generative model. Use the words {context} and {question} in brackets where you want those fields to be placed, in case you want them in your prompt. Context will be the data returned by the retrieval step and question will be the user's query.", # noqa + description="User prompt given to the generative model responsible of generating the answer. Use the words {context} and {question} in brackets where you want those fields to be placed, in case you want them in your prompt. Context will be the data returned by the retrieval step and question will be the user's query.", # noqa min_length=1, examples=[ "Taking into account our previous conversation, and this context: {context} answer this {question}", @@ -1082,6 +1103,25 @@ class CustomPrompt(BaseModel): "Given this context: {context}. Answer this {question} using the provided context. Please, answer always in French", ], ) + rephrase: Optional[str] = Field( + default=None, + title="Rephrase", + description=( + "Rephrase prompt given to the generative model responsible for rephrasing the query for a more effective retrieval step. " + "This is only used if the `rephrase` flag is set to true in the request.\n" + "If not specified, Nuclia's default prompt is used. It must include the {question} placeholder. " + "The placeholder will be replaced with the original question" + ), + min_length=1, + examples=[ + """Rephrase this question so its better for retrieval, and keep the rephrased question in the same language as the original. +QUESTION: {question} +Please return ONLY the question without any explanation. Just the rephrased question.""", + """Rephrase this question so its better for retrieval, identify any part numbers and append them to the end of the question separated by a commas. + QUESTION: {question} + Please return ONLY the question without any explanation.""", + ], + ) class ChatRequest(BaseModel): @@ -1227,8 +1267,9 @@ class ChatRequest(BaseModel): rephrase: bool = Field( default=False, - title="Rephrase the query to improve search", - description="Consume LLM tokens to rephrase the query so the semantic search is better", + description=( + "Rephrase the query for a more efficient retrieval. This will consume LLM tokens and make the request slower." + ), ) prefer_markdown: bool = Field( @@ -1723,4 +1764,10 @@ def parse_custom_prompt(item: AskRequest) -> CustomPrompt: else: prompt.user = item.prompt.user prompt.system = item.prompt.system + prompt.rephrase = item.prompt.rephrase return prompt + + +def parse_rephrase_prompt(item: AskRequest) -> Optional[str]: + prompt = parse_custom_prompt(item) + return prompt.rephrase