Skip to content

Commit

Permalink
Handle answer json schema with too many properties (#2498)
Browse files Browse the repository at this point in the history
* Handle answer_json_schema with too many properties

* Increase number of prequeries
  • Loading branch information
jotare authored Sep 26, 2024
1 parent 69e3ca9 commit 73e484c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 11 deletions.
20 changes: 12 additions & 8 deletions nucliadb/src/nucliadb/search/api/v1/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nucliadb.search.api.v1.router import KB_PREFIX, api
from nucliadb.search.search import cache
from nucliadb.search.search.chat.ask import AskResult, ask, handled_ask_exceptions
from nucliadb.search.search.chat.exceptions import AnswerJsonSchemaTooLong
from nucliadb.search.search.utils import maybe_log_request_payload
from nucliadb_models.resource import NucliaDBRoles
from nucliadb_models.search import (
Expand Down Expand Up @@ -86,14 +87,17 @@ async def create_ask_response(
maybe_log_request_payload(kbid, "/ask", ask_request)
ask_request.max_tokens = parse_max_tokens(ask_request.max_tokens)
with cache.request_caches():
ask_result: AskResult = await ask(
kbid=kbid,
ask_request=ask_request,
user_id=user_id,
client_type=client_type,
origin=origin,
resource=resource,
)
try:
ask_result: AskResult = await ask(
kbid=kbid,
ask_request=ask_request,
user_id=user_id,
client_type=client_type,
origin=origin,
resource=resource,
)
except AnswerJsonSchemaTooLong as err:
return HTTPClientError(status_code=400, detail=str(err))

headers = {
"NUCLIA-LEARNING-ID": ask_result.nuclia_learning_id or "unknown",
Expand Down
12 changes: 10 additions & 2 deletions nucliadb/src/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Optional, cast

from pydantic_core import ValidationError

from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
from nucliadb.models.responses import HTTPClientError
from nucliadb.search import logger, predict
Expand All @@ -35,7 +37,7 @@
StatusGenerativeResponse,
TextGenerativeResponse,
)
from nucliadb.search.search.chat.exceptions import NoRetrievalResultsError
from nucliadb.search.search.chat.exceptions import AnswerJsonSchemaTooLong, NoRetrievalResultsError
from nucliadb.search.search.chat.prompt import PromptContextBuilder
from nucliadb.search.search.chat.query import (
NOT_ENOUGH_CONTEXT_ANSWER,
Expand Down Expand Up @@ -789,6 +791,12 @@ def calculate_prequeries_for_json_schema(ask_request: AskRequest) -> Optional[Pr
weight=1.0,
)
prequeries.append(prequery)
strategy = PreQueriesStrategy(queries=prequeries)
try:
strategy = PreQueriesStrategy(queries=prequeries)
except ValidationError:
raise AnswerJsonSchemaTooLong(
"Answer JSON schema with too many properties generated too many prequeries"
)

ask_request.rag_strategies = [strategy]
return strategy
4 changes: 4 additions & 0 deletions nucliadb/src/nucliadb/search/search/chat/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ def __init__(
self.main_query = main
self.prequeries = prequeries
self.prefilters = prefilters


class AnswerJsonSchemaTooLong(Exception):
pass
38 changes: 38 additions & 0 deletions nucliadb/tests/nucliadb/integration/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from nucliadb.search.utilities import get_predict
from nucliadb_models.search import (
AskRequest,
AskResponseItem,
ChatRequest,
FieldExtensionStrategy,
Expand Down Expand Up @@ -859,3 +860,40 @@ def valid_combination(combination: list[RagStrategies]) -> bool:
},
)
assert resp.status_code == 200, resp.text


async def test_ask_fails_with_answer_json_schema_too_big(
nucliadb_reader: AsyncClient,
knowledgebox: str,
resources: list[str],
):
kbid = knowledgebox
rid = resources[0]

resp = await nucliadb_reader.post(
f"/kb/{kbid}/resource/{rid}/ask",
json=AskRequest(
query="",
answer_json_schema={
"name": "structred_response",
"description": "Structured response with custom fields",
"parameters": {
"type": "object",
"properties": {
f"property-{i}": {
"type": "string",
"description": f"Yet another property... ({i})",
}
for i in range(50)
},
"required": ["property-0"],
},
},
).model_dump(),
)

assert resp.status_code == 400
assert (
resp.json()["detail"]
== "Answer JSON schema with too many properties generated too many prequeries"
)
2 changes: 1 addition & 1 deletion nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ class PreQueriesStrategy(RagStrategy):
title="Queries",
description="List of queries to run before the main query. The results are added to the context with the specified weights for each query. There is a limit of 10 prequeries per request.",
min_length=1,
max_length=10,
max_length=15,
)
main_query_weight: float = Field(
default=1.0,
Expand Down

0 comments on commit 73e484c

Please sign in to comment.