diff --git a/e2e/test_e2e.py b/e2e/test_e2e.py index 0612902ecb..da446defcd 100644 --- a/e2e/test_e2e.py +++ b/e2e/test_e2e.py @@ -1,6 +1,4 @@ -import base64 import io -import json import os import random import time @@ -137,18 +135,18 @@ def test_resource_processed(kbid: str, resource_id: str): def test_search(kbid: str, resource_id: str): resp = requests.post( - os.path.join(BASE_URL, f"api/v1/kb/{kbid}/chat"), + os.path.join(BASE_URL, f"api/v1/kb/{kbid}/ask"), headers={ "content-type": "application/json", "X-NUCLIADB-ROLES": "READER", "x-ndb-client": "web", + "x-synchronous": "true", }, json={ "query": "Why is soccer called soccer?", "context": [], "show": ["basic", "values", "origin"], - "features": ["paragraphs", "relations"], - "inTitleOnly": False, + "features": ["keyword", "relations"], "highlight": True, "autofilter": False, "page_number": 0, @@ -158,38 +156,12 @@ def test_search(kbid: str, resource_id: str): ) raise_for_status(resp) - - raw = io.BytesIO(resp.content) - toread_bytes = raw.read(4) - toread = int.from_bytes(toread_bytes, byteorder="big", signed=False) - print(f"toread: {toread}") - raw_search_results = raw.read(toread) - search_results = json.loads(base64.b64decode(raw_search_results)) - print(f"Search results: {search_results}") - - data = raw.read() - try: - answer, relations_payload = data.split(b"_END_") - except ValueError: - answer = data - relations_payload = b"" - if len(relations_payload) > 0: - decoded_relations_payload = base64.b64decode(relations_payload) - print(f"Relations payload: {decoded_relations_payload}") - try: - answer, tail = answer.split(b"_CIT_") - chat_answer = answer.decode("utf-8") - citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False) - citations_bytes = tail[4 : 4 + citations_length] - citations = json.loads(base64.b64decode(citations_bytes).decode()) - except ValueError: - chat_answer = answer.decode("utf-8") - citations = {} - print(f"Answer: {chat_answer}") - print(f"Citations: {citations}") - - # assert "Not enough data to answer this" not in chat_answer, search_results - assert len(search_results["resources"]) == 1 + ask_response = resp.json() + print(f"Search results: {ask_response["retrieval_results"]}") + assert len(ask_response["retrieval_results"]["resources"]) == 1 + print(f"Relations payload: {ask_response["relations"]}") + print(f"Answer: {ask_response["answer"]}") + print(f"Citations: {ask_response["citations"]}") def test_predict_proxy(kbid: str): diff --git a/nucliadb/src/nucliadb/search/api/v1/chat.py b/nucliadb/src/nucliadb/search/api/v1/chat.py index 8e554f8cc3..4ad68dc704 100644 --- a/nucliadb/src/nucliadb/search/api/v1/chat.py +++ b/nucliadb/src/nucliadb/search/api/v1/chat.py @@ -53,8 +53,10 @@ parse_max_tokens, ) from nucliadb_telemetry.errors import capture_exception +from nucliadb_utils import const from nucliadb_utils.authentication import requires from nucliadb_utils.exceptions import LimitsExceededError +from nucliadb_utils.utilities import has_feature END_OF_STREAM = "_END_" @@ -96,6 +98,7 @@ class SyncChatResponse(pydantic.BaseModel): tags=["Search"], response_model=None, deprecated=True, + include_in_schema=False, ) @requires(NucliaDBRoles.READER) @version(1) @@ -112,6 +115,13 @@ async def chat_knowledgebox_endpoint( "This is slower and requires waiting for entire answer to be ready.", ), ) -> Union[StreamingResponse, HTTPClientError, Response]: + if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}): + # We keep this for a while so we can enable back chat on a per KB basis, in case we need to + return HTTPClientError( + status_code=404, + detail="This endpoint has been deprecated. Please use /ask instead.", + ) + try: return await create_chat_response( kbid, item, x_nucliadb_user, x_ndb_client, x_forwarded_for, x_synchronous @@ -155,7 +165,7 @@ async def create_chat_response( origin: str, x_synchronous: bool, resource: Optional[str] = None, -) -> Response: +) -> Response: # pragma: no cover chat_request.max_tokens = parse_max_tokens(chat_request.max_tokens) chat_result = await chat( kbid, diff --git a/nucliadb/src/nucliadb/search/api/v1/resource/chat.py b/nucliadb/src/nucliadb/search/api/v1/resource/chat.py index 3ce0d9c40a..1d1289f281 100644 --- a/nucliadb/src/nucliadb/search/api/v1/resource/chat.py +++ b/nucliadb/src/nucliadb/search/api/v1/resource/chat.py @@ -33,8 +33,10 @@ ) from nucliadb_models.resource import NucliaDBRoles from nucliadb_models.search import ChatRequest, NucliaDBClientType +from nucliadb_utils import const from nucliadb_utils.authentication import requires from nucliadb_utils.exceptions import LimitsExceededError +from nucliadb_utils.utilities import has_feature from ..chat import create_chat_response @@ -47,6 +49,7 @@ tags=["Search"], response_model=None, deprecated=True, + include_in_schema=False, ) @requires(NucliaDBRoles.READER) @version(1) @@ -124,6 +127,13 @@ async def resource_chat_endpoint( resource_id: Optional[str] = None, resource_slug: Optional[str] = None, ) -> Union[StreamingResponse, HTTPClientError, Response]: + if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}): + # We keep this for a while so we can enable back chat on a per KB basis, in case we need to + return HTTPClientError( + status_code=404, + detail="This endpoint has been deprecated. Please use /ask instead.", + ) + if resource_id is None: if resource_slug is None: raise ValueError("Either resource_id or resource_slug must be provided") diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index d1568a94db..e27a607dbb 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -230,7 +230,7 @@ async def chat( client_type: NucliaDBClientType, origin: str, resource: Optional[str] = None, -) -> ChatResult: +) -> ChatResult: # pragma: no cover metrics = RAGMetrics() nuclia_learning_id: Optional[str] = None chat_history = chat_request.context or [] diff --git a/nucliadb/tests/nucliadb/integration/test_chat.py b/nucliadb/tests/nucliadb/integration/test_chat.py index 086e662e91..045be035b3 100644 --- a/nucliadb/tests/nucliadb/integration/test_chat.py +++ b/nucliadb/tests/nucliadb/integration/test_chat.py @@ -17,25 +17,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -import base64 -import io -import json from unittest import mock import pytest from httpx import AsyncClient -from nucliadb.search.api.v1.chat import SyncChatResponse -from nucliadb.search.predict import AnswerStatusCode -from nucliadb.search.utilities import get_predict - - -@pytest.fixture(scope="function", autouse=True) -def audit(): - audit_mock = mock.Mock(chat=mock.AsyncMock()) - with mock.patch("nucliadb.search.search.chat.query.get_audit", return_value=audit_mock): - yield audit_mock - @pytest.mark.asyncio() @pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) @@ -44,501 +30,9 @@ async def test_chat( knowledgebox, ): resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"}) - assert resp.status_code == 200 - - context = [{"author": "USER", "text": "query"}] - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", json={"query": "query", "context": context} - ) - assert resp.status_code == 200 - - -@pytest.fixture(scope="function") -def find_incomplete_results(): - with mock.patch( - "nucliadb.search.search.chat.query.find", - return_value=(mock.MagicMock(), True, None), - ): - yield - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_handles_incomplete_find_results( - nucliadb_reader: AsyncClient, - knowledgebox, - find_incomplete_results, -): - resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"}) - assert resp.status_code == 529 - assert resp.json() == {"detail": "Temporary error on information retrieval. Please try again."} - - -@pytest.fixture -async def resource(nucliadb_writer, knowledgebox): - kbid = knowledgebox - resp = await nucliadb_writer.post( - f"/kb/{kbid}/resources", - json={ - "title": "The title", - "summary": "The summary", - "texts": {"text_field": {"body": "The body of the text field"}}, - }, - ) - assert resp.status_code in (200, 201) - rid = resp.json()["uuid"] - yield rid - - -@pytest.mark.asyncio -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_handles_status_codes_in_a_different_chunk( - nucliadb_reader: AsyncClient, knowledgebox, resource -): - predict = get_predict() - predict.generated_answer = [b"some ", b"text ", b"with ", b"status.", b"-2"] # type: ignore - - resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "title"}) - assert resp.status_code == 200 - _, answer, _, _ = parse_chat_response(resp.content) - - assert answer == b"some text with status." - - -@pytest.mark.asyncio -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_handles_status_codes_in_the_same_chunk( - nucliadb_reader: AsyncClient, knowledgebox, resource -): - predict = get_predict() - predict.generated_answer = [b"some ", b"text ", b"with ", b"status.-2"] # type: ignore - - resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "title"}) - assert resp.status_code == 200 - _, answer, _, _ = parse_chat_response(resp.content) - - assert answer == b"some text with status." - - -@pytest.mark.asyncio -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_handles_status_codes_with_last_chunk_empty( - nucliadb_reader: AsyncClient, knowledgebox, resource -): - predict = get_predict() - predict.generated_answer = [b"some ", b"text ", b"with ", b"status.", b"-2", b""] # type: ignore - - resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "title"}) - assert resp.status_code == 200 - _, answer, _, _ = parse_chat_response(resp.content) - - assert answer == b"some text with status." - - -def parse_chat_response(content: bytes): - raw = io.BytesIO(content) - header = raw.read(4) - payload_size = int.from_bytes(header, byteorder="big", signed=False) - data = raw.read(payload_size) - find_result = json.loads(base64.b64decode(data)) - data = raw.read() - try: - answer, relations_payload = data.split(b"_END_") - except ValueError: - answer = data - relations_payload = b"" - relations_result = None - if len(relations_payload) > 0: - relations_result = json.loads(base64.b64decode(relations_payload)) - try: - answer, tail = answer.split(b"_CIT_") - citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False) - citations_part = tail[4 : 4 + citations_length] - citations = json.loads(base64.b64decode(citations_part).decode()) - except ValueError: - answer = answer - citations = {} - return find_result, answer, relations_result, citations - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_always_returns_relations(nucliadb_reader: AsyncClient, knowledgebox): - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={"query": "summary", "features": ["relations"]}, - ) - assert resp.status_code == 200 - _, answer, relations_result, _ = parse_chat_response(resp.content) - assert answer == b"Not enough data to answer this." - assert "Ferran" in relations_result["entities"] - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_synchronous(nucliadb_reader: AsyncClient, knowledgebox, resource): - predict = get_predict() - predict.generated_answer = [b"some ", b"text ", b"with ", b"status.", b"0"] # type: ignore - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={"query": "title"}, - headers={"X-Synchronous": "True"}, - ) - assert resp.status_code == 200 - resp_data = SyncChatResponse.model_validate_json(resp.content) - - assert resp_data.answer == "some text with status." - assert len(resp_data.results.resources) == 1 - assert resp_data.status == AnswerStatusCode.SUCCESS - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -@pytest.mark.parametrize("sync_chat", (True, False)) -async def test_chat_with_citations(nucliadb_reader: AsyncClient, knowledgebox, resource, sync_chat): - citations = {"foo": [], "bar": []} # type: ignore - citations_payload = base64.b64encode(json.dumps(citations).encode()) - citations_size = len(citations_payload).to_bytes(4, byteorder="big", signed=False) - - predict = get_predict() - predict.generated_answer = [ # type: ignore - b"some ", - b"text ", - b"with ", - b"status.", - b"_CIT_", - citations_size, - citations_payload, - b"0", - ] - - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={"query": "title", "citations": True}, - headers={"X-Synchronous": str(sync_chat)}, - ) - assert resp.status_code == 200 - - if sync_chat: - resp_data = SyncChatResponse.model_validate_json(resp.content) - resp_citations = resp_data.citations - else: - resp_citations = parse_chat_response(resp.content)[-1] - assert resp_citations == citations - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -@pytest.mark.parametrize("sync_chat", (True, False)) -async def test_chat_without_citations(nucliadb_reader: AsyncClient, knowledgebox, resource, sync_chat): - predict = get_predict() - predict.generated_answer = [ # type: ignore - b"some ", - b"text ", - b"with ", - b"status.", - b"0", - ] - - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={"query": "title", "citations": False}, - headers={"X-Synchronous": str(sync_chat)}, - ) - assert resp.status_code == 200 - - if sync_chat: - resp_data = SyncChatResponse.model_validate_json(resp.content) - resp_citations = resp_data.citations - else: - resp_citations = parse_chat_response(resp.content)[-1] - assert resp_citations == {} - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -@pytest.mark.parametrize("debug", (True, False)) -async def test_sync_chat_returns_prompt_context( - nucliadb_reader: AsyncClient, knowledgebox, resource, debug -): - # Make sure prompt context is returned if debug is True - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={"query": "title", "debug": debug}, - headers={"X-Synchronous": "True"}, - ) - assert resp.status_code == 200 - resp_data = SyncChatResponse.model_validate_json(resp.content) - if debug: - assert resp_data.prompt_context - assert resp_data.prompt_context_order - else: - assert resp_data.prompt_context is None - assert resp_data.prompt_context_order is None - - -@pytest.fixture -async def resources(nucliadb_writer, knowledgebox): - kbid = knowledgebox - rids = [] - for i in range(2): - resp = await nucliadb_writer.post( - f"/kb/{kbid}/resources", - json={ - "title": f"The title {i}", - "summary": f"The summary {i}", - "texts": {"text_field": {"body": "The body of the text field"}}, - }, - ) - assert resp.status_code in (200, 201) - rid = resp.json()["uuid"] - rids.append(rid) - yield rids - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_rag_options_full_resource(nucliadb_reader: AsyncClient, knowledgebox, resources): - resource1, resource2 = resources - - predict = get_predict() - predict.calls.clear() # type: ignore - - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={"query": "title", "rag_strategies": [{"name": "full_resource"}]}, - ) - assert resp.status_code == 200 - _ = parse_chat_response(resp.content) - - # Make sure the prompt context is properly crafted - assert predict.calls[-2][0] == "chat_query" # type: ignore - prompt_context = predict.calls[-2][1].query_context # type: ignore - - # All fields of the matching resource should be in the prompt context - assert len(prompt_context) == 6 - assert prompt_context[f"{resource1}/a/title"] == "The title 0" - assert prompt_context[f"{resource1}/a/summary"] == "The summary 0" - assert prompt_context[f"{resource1}/t/text_field"] == "The body of the text field" - assert prompt_context[f"{resource2}/a/title"] == "The title 1" - assert prompt_context[f"{resource2}/a/summary"] == "The summary 1" - assert prompt_context[f"{resource2}/t/text_field"] == "The body of the text field" - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_rag_options_extend_with_fields( - nucliadb_reader: AsyncClient, knowledgebox, resources -): - resource1, resource2 = resources - - predict = get_predict() - predict.calls.clear() # type: ignore - - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "field_extension", "fields": ["a/summary"]}], - }, - ) - assert resp.status_code == 200 - _ = parse_chat_response(resp.content) - - # Make sure the prompt context is properly crafted - assert predict.calls[-2][0] == "chat_query" # type: ignore - prompt_context = predict.calls[-2][1].query_context # type: ignore - - # Matching paragraphs should be in the prompt - # context, plus the extended field for each resource - assert len(prompt_context) == 4 - # The matching paragraphs - assert prompt_context[f"{resource1}/a/title/0-11"] == "The title 0" - assert prompt_context[f"{resource2}/a/title/0-11"] == "The title 1" - # The extended fields - assert prompt_context[f"{resource1}/a/summary"] == "The summary 0" - assert prompt_context[f"{resource2}/a/summary"] == "The summary 1" - - -@pytest.mark.asyncio() -async def test_chat_rag_options_validation(nucliadb_reader): - # Invalid strategy - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "foobar", "fields": ["a/summary"]}], - }, - ) - assert resp.status_code == 422 - - # Invalid strategy as a string - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={ - "query": "title", - "rag_strategies": ["full_resource"], - }, - ) - assert resp.status_code == 422 - - # Invalid strategy without name - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={ - "query": "title", - "rag_strategies": [{"fields": ["a/summary"]}], - }, - ) - assert resp.status_code == 422 - - # full_resource cannot be combined with other strategies - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={ - "query": "title", - "rag_strategies": [ - {"name": "full_resource"}, - {"name": "field_extension", "fields": ["a/summary"]}, - ], - }, - ) - assert resp.status_code == 422 - detail = resp.json()["detail"] - assert "If 'full_resource' strategy is chosen, it must be the only strategy" in detail[0]["msg"] - - # field_extension requires fields - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={"query": "title", "rag_strategies": [{"name": "field_extension"}]}, - ) - assert resp.status_code == 422 - detail = resp.json()["detail"] - detail[0]["loc"][-1] == "fields" - assert detail[0]["msg"] == "Field required" - - # fields must be in the right format: field_type/field_name - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "field_extension", "fields": ["foo/t/text"]}], - }, - ) - assert resp.status_code == 422 - detail = resp.json()["detail"] - detail[0]["loc"][-1] == "fields" - assert ( - detail[0]["msg"] - == "Value error, Field 'foo/t/text' is not in the format {field_type}/{field_name}" - ) - - # But fields can have leading and trailing slashes and they will be ignored - resp = await nucliadb_reader.post( - f"/kb/foo/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "field_extension", "fields": ["/a/text/"]}], - }, - ) - assert resp.status_code != 422 - - # fields must have a valid field type - resp = await nucliadb_reader.post( - f"/kb/kbid/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "field_extension", "fields": ["X/fieldname"]}], - }, - ) - assert resp.status_code == 422 - detail = resp.json()["detail"] - detail[0]["loc"][-1] == "fields" - assert ( - "Field 'X/fieldname' does not have a valid field type. Valid field types are" in detail[0]["msg"] - ) - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_capped_context(nucliadb_reader: AsyncClient, knowledgebox, resources): - # By default, max size is big enough to fit all the prompt context - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "full_resource"}], - "debug": True, - }, - headers={"X-Synchronous": "True"}, - ) - assert resp.status_code == 200 - resp_data = SyncChatResponse.model_validate_json(resp.content) - assert resp_data.prompt_context is not None - assert len(resp_data.prompt_context) == 6 - total_size = sum(len(v) for v in resp_data.prompt_context.values()) - # Try now setting a smaller max size. It should be respected - max_size = 28 - assert total_size > max_size * 3 - - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={ - "query": "title", - "rag_strategies": [{"name": "full_resource"}], - "debug": True, - "max_tokens": {"context": max_size}, - }, - headers={"X-Synchronous": "True"}, - ) - assert resp.status_code == 200, resp.text - resp_data = SyncChatResponse.model_validate_json(resp.content) - assert resp_data.prompt_context is not None - total_size = sum(len(v) for v in resp_data.prompt_context.values()) - assert total_size <= max_size * 3 - - -@pytest.mark.asyncio() -async def test_chat_on_a_kb_not_found(nucliadb_reader): - resp = await nucliadb_reader.post("/kb/unknown_kb_id/chat", json={"query": "title"}) assert resp.status_code == 404 - assert resp.json() == {"detail": "Knowledge Box 'unknown_kb_id' not found."} - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat_max_tokens(nucliadb_reader, knowledgebox, resources): - # As an integer - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={ - "query": "title", - "max_tokens": 100, - }, - ) - assert resp.status_code == 200 - - # Same but with the max tokens in a dict - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={ - "query": "title", - "max_tokens": {"context": 100, "answer": 50}, - }, - ) - assert resp.status_code == 200 + assert resp.json()["detail"] == "This endpoint has been deprecated. Please use /ask instead." - # If the context requested is bigger than the max tokens, it should fail - predict = get_predict() - resp = await nucliadb_reader.post( - f"/kb/{knowledgebox}/chat", - json={ - "query": "title", - "max_tokens": {"context": predict.max_context + 1}, - }, - ) - assert resp.status_code == 412 - assert ( - resp.json()["detail"] - == "Invalid query. Error in max_tokens.context: Max context tokens is higher than the model's limit of 1000" - ) + with mock.patch("nucliadb.search.api.v1.chat.has_feature", return_value=True): + resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"}) + assert resp.status_code == 200 diff --git a/nucliadb/tests/search/unit/api/v1/resource/test_chat.py b/nucliadb/tests/search/unit/api/v1/resource/test_chat.py deleted file mode 100644 index 160caa5e6a..0000000000 --- a/nucliadb/tests/search/unit/api/v1/resource/test_chat.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -# -from unittest import mock -from unittest.mock import Mock - -import pytest -from starlette.requests import Request - -from nucliadb.models.responses import HTTPClientError -from nucliadb.search import predict -from nucliadb.search.api.v1.resource.chat import resource_chat_endpoint -from nucliadb_utils.exceptions import LimitsExceededError - -pytestmark = pytest.mark.asyncio - - -class DummyTestRequest(Request): - @property - def auth(self): - return Mock(scopes=["READER"]) - - @property - def user(self): - return Mock(display_name="username") - - -@pytest.fixture(scope="function") -def create_chat_response_mock(): - with mock.patch( - "nucliadb.search.api.v1.resource.chat.create_chat_response", - ) as mocked: - yield mocked - - -@pytest.mark.parametrize( - "predict_error,http_error_response", - [ - ( - LimitsExceededError(402, "over the quota"), - HTTPClientError(status_code=402, detail="over the quota"), - ), - ( - predict.RephraseError("foobar"), - HTTPClientError( - status_code=529, - detail="Temporary error while rephrasing the query. Please try again later. Error: foobar", - ), - ), - ( - predict.RephraseMissingContextError(), - HTTPClientError( - status_code=412, - detail="Unable to rephrase the query with the provided context.", - ), - ), - ], -) -async def test_resource_chat_endpoint_handles_errors( - create_chat_response_mock, predict_error, http_error_response -): - create_chat_response_mock.side_effect = predict_error - request = DummyTestRequest( - scope={ - "type": "http", - "http_version": "1.1", - "method": "GET", - "headers": [], - } - ) - response = await resource_chat_endpoint( - request=request, - kbid="kbid", - item=Mock(), - x_ndb_client=None, - x_nucliadb_user="", - x_forwarded_for="", - x_synchronous=True, - resource_id="rid", - ) - assert response.status_code == http_error_response.status_code - assert response.body == http_error_response.body diff --git a/nucliadb/tests/search/unit/api/v1/test_chat.py b/nucliadb/tests/search/unit/api/v1/test_chat.py deleted file mode 100644 index 0ac3fa80a1..0000000000 --- a/nucliadb/tests/search/unit/api/v1/test_chat.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -# -from unittest import mock -from unittest.mock import Mock - -import pytest -from starlette.requests import Request - -from nucliadb.models.responses import HTTPClientError -from nucliadb.search import predict -from nucliadb.search.api.v1.chat import chat_knowledgebox_endpoint -from nucliadb_utils.exceptions import LimitsExceededError - -pytestmark = pytest.mark.asyncio - - -class DummyTestRequest(Request): - @property - def auth(self): - return Mock(scopes=["READER"]) - - @property - def user(self): - return Mock(display_name="username") - - -@pytest.fixture(scope="function") -def create_chat_response_mock(): - with mock.patch( - "nucliadb.search.api.v1.chat.create_chat_response", - ) as mocked: - yield mocked - - -@pytest.mark.parametrize( - "predict_error,http_error_response", - [ - ( - LimitsExceededError(402, "over the quota"), - HTTPClientError(status_code=402, detail="over the quota"), - ), - ( - predict.RephraseError("foobar"), - HTTPClientError( - status_code=529, - detail="Temporary error while rephrasing the query. Please try again later. Error: foobar", - ), - ), - ( - predict.RephraseMissingContextError(), - HTTPClientError( - status_code=412, - detail="Unable to rephrase the query with the provided context.", - ), - ), - ], -) -async def test_chat_endpoint_handles_errors( - create_chat_response_mock, predict_error, http_error_response -): - create_chat_response_mock.side_effect = predict_error - request = DummyTestRequest( - scope={ - "type": "http", - "http_version": "1.1", - "method": "GET", - "headers": [], - } - ) - response = await chat_knowledgebox_endpoint( - request=request, - kbid="kbid", - item=Mock(), - x_ndb_client=None, - x_nucliadb_user="", - x_forwarded_for="", - ) - assert response.status_code == http_error_response.status_code - assert response.body == http_error_response.body diff --git a/nucliadb_sdk/README.md b/nucliadb_sdk/README.md index f950ec5ae8..ad37c4906b 100644 --- a/nucliadb_sdk/README.md +++ b/nucliadb_sdk/README.md @@ -134,7 +134,7 @@ After the data is pushed, the NucliaDB SDK could also be used to find answers on >>> import nucliadb_sdk >>> >>> ndb = nucliadb_sdk.NucliaDB(region="on-prem", url="http://localhost:8080") ->>> resp = ndb.chat(kbid="my-kb-id", query="What does Hakuna Matata mean?") +>>> resp = ndb.ask(kbid="my-kb-id", query="What does Hakuna Matata mean?") >>> print(resp.answer) 'Hakuna matata is actually a phrase in the East African language of Swahili that literally means “no trouble” or “no problems”.' ``` diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py b/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py index f86fc2cdcb..c0a57f31ee 100644 --- a/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py +++ b/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py @@ -232,50 +232,6 @@ class Docstring(BaseModel): ], ) -CHAT = Docstring( - doc="""Chat with your Knowledge Box""", - examples=[ - Example( - description="Get an answer for a question that is part of the data in the Knowledge Box", - code=""">>> from nucliadb_sdk import * ->>> sdk = NucliaDBSDK(api_key="api-key") ->>> sdk.chat(kbid="mykbid", query="Will France be in recession in 2023?").answer -'Yes, according to the provided context, France is expected to be in recession in 2023.' -""", - ), - Example( - description="You can use the `content` parameter to pass a `ChatRequest` object", - code=""">>> content = ChatRequest(query="Who won the 2018 football World Cup?") ->>> sdk.chat(kbid="mykbid", content=content).answer -'France won the 2018 football World Cup.' -""", - ), - ], -) - -RESOURCE_CHAT = Docstring( - doc="""Chat with your document""", - examples=[ - Example( - description="Have a chat with your document. Generated answers are scoped to the context of the document.", - code=""">>> sdk.chat_on_resource(kbid="mykbid", query="What is the coldest season in Sevilla?").answer -'January is the coldest month.' -""", - ), - Example( - description="You can use the `content` parameter to pass previous context to the query", - code=""">>> from nucliadb_models.search import ChatRequest, ChatContextMessage ->>> content = ChatRequest() ->>> content.query = "What is the average temperature?" ->>> content.context.append(ChatContextMessage(author="USER", text="What is the coldest season in Sevilla?")) ->>> content.context.append(ChatContextMessage(author="NUCLIA", text="January is the coldest month.")) ->>> sdk.chat(kbid="mykbid", content=content).answer -'According to the context, the average temperature in January in Sevilla is 15.9 °C and 5.2 °C.' -""", - ), - ], -) - SUMMARIZE = Docstring( doc="""Summarize your documents""", examples=[ diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py index 3259ffbebf..dd2f3a4ce0 100644 --- a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py +++ b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py @@ -17,7 +17,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import asyncio -import base64 import enum import inspect import io @@ -65,7 +64,6 @@ AnswerAskResponseItem, AskRequest, AskResponseItem, - ChatRequest, CitationsAskResponseItem, ErrorAskResponseItem, FeedbackRequest, @@ -74,7 +72,6 @@ KnowledgeboxFindResults, KnowledgeboxSearchResults, MetadataAskResponseItem, - Relations, RelationsAskResponseItem, RetrievalAskResponseItem, SearchRequest, @@ -101,14 +98,6 @@ class Region(enum.Enum): AWS_US_EAST_2_1 = "aws-us-east-2-1" -class ChatResponse(BaseModel): - result: KnowledgeboxFindResults - answer: str - relations: Optional[Relations] = None - learning_id: Optional[str] = None - citations: dict[str, Any] = {} - - RawRequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]] @@ -119,38 +108,6 @@ def json_response_parser(response: httpx.Response) -> Any: return orjson.loads(response.content.decode()) -def chat_response_parser(response: httpx.Response) -> ChatResponse: - raw = io.BytesIO(response.content) - header = raw.read(4) - payload_size = int.from_bytes(header, byteorder="big", signed=False) - data = raw.read(payload_size) - find_result = KnowledgeboxFindResults.model_validate_json(base64.b64decode(data)) - data = raw.read() - try: - answer, relations_payload = data.split(b"_END_") - except ValueError: - answer = data - relations_payload = b"" - learning_id = response.headers.get("NUCLIA-LEARNING-ID") - relations_result = None - if len(relations_payload) > 0: - relations_result = Relations.model_validate_json(base64.b64decode(relations_payload)) - try: - answer, tail = answer.split(b"_CIT_") - citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False) - citations_bytes = tail[4 : 4 + citations_length] - citations = orjson.loads(base64.b64decode(citations_bytes).decode()) - except ValueError: - citations = {} - return ChatResponse( - result=find_result, - answer=answer.decode("utf-8"), - relations=relations_result, - learning_id=learning_id, - citations=citations, - ) - - def ask_response_parser(response: httpx.Response) -> SyncAskResponse: content_type = response.headers.get("Content-Type") if content_type not in ("application/json", "application/x-ndjson"): @@ -646,14 +603,6 @@ def _check_response(self, response: httpx.Response): request_type=SearchRequest, response_type=KnowledgeboxSearchResults, ) - chat = _request_builder( - name="chat", - path_template="/v1/kb/{kbid}/chat", - method="POST", - path_params=("kbid",), - request_type=ChatRequest, - response_type=chat_response_parser, - ) ask = _request_builder( name="ask", @@ -664,24 +613,6 @@ def _check_response(self, response: httpx.Response): response_type=ask_response_parser, ) - chat_on_resource = _request_builder( - name="chat_on_resource", - path_template="/v1/kb/{kbid}/resource/{rid}/chat", - method="POST", - path_params=("kbid", "rid"), - request_type=ChatRequest, - response_type=chat_response_parser, - ) - - chat_on_resource_by_slug = _request_builder( - name="chat_on_resource_by_slug", - path_template="/v1/kb/{kbid}/slug/{slug}/chat", - method="POST", - path_params=("kbid", "slug"), - request_type=ChatRequest, - response_type=chat_response_parser, - ) - ask_on_resource = _request_builder( name="ask_on_resource", path_template="/v1/kb/{kbid}/resource/{rid}/ask", diff --git a/nucliadb_sdk/tests/test_chat.py b/nucliadb_sdk/tests/test_chat.py deleted file mode 100644 index 783dfd3061..0000000000 --- a/nucliadb_sdk/tests/test_chat.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import nucliadb_sdk -from nucliadb_models.search import MaxTokens - - -def test_chat_on_kb(docs_dataset, sdk: nucliadb_sdk.NucliaDB): - result = sdk.chat( - kbid=docs_dataset, - query="Nuclia loves Semantic Search", - generative_model="everest", - prompt="Given this context: {context}. Answer this {question} in a concise way using the provided context", - extra_context=[ - "Nuclia is a powerful AI search platform", - "AI Search involves semantic search", - ], - # Control the number of AI tokens used for every request - max_tokens=MaxTokens(context=100, answer=50), - ) - assert result.learning_id == "00" - assert result.answer == "valid answer to" - assert len(result.result.resources) == 7 - assert result.relations - assert len(result.relations.entities["Nuclia"].related_to) == 18 - - -def test_chat_on_kb_with_citations(docs_dataset, sdk: nucliadb_sdk.NucliaDB): - result = sdk.chat( - kbid=docs_dataset, - query="Nuclia loves Semantic Search", - citations=True, - ) - assert result.citations == {} - - -def test_chat_on_kb_no_context_found(docs_dataset, sdk: nucliadb_sdk.NucliaDB): - result = sdk.chat(kbid=docs_dataset, query="penguin") - assert result.answer == "Not enough data to answer this." - - -def test_chat_on_resource(docs_dataset, sdk: nucliadb_sdk.NucliaDB): - rid = sdk.list_resources(kbid=docs_dataset).resources[0].id - # With retrieval - _ = sdk.chat_on_resource(kbid=docs_dataset, rid=rid, query="Nuclia loves Semantic Search") - - # Check chatting with the whole resource (no retrieval) - _ = sdk.chat_on_resource( - kbid=docs_dataset, - rid=rid, - query="Nuclia loves Semantic Search", - rag_strategies=[{"name": "full_resource"}], - ) diff --git a/nucliadb_sdk/tests/test_sdk.py b/nucliadb_sdk/tests/test_sdk.py index 33aa2aa640..0f29c3c3c1 100644 --- a/nucliadb_sdk/tests/test_sdk.py +++ b/nucliadb_sdk/tests/test_sdk.py @@ -85,12 +85,9 @@ def test_resource_endpoints(sdk: nucliadb_sdk.NucliaDB, kb): def test_search_endpoints(sdk: nucliadb_sdk.NucliaDB, kb): sdk.find(kbid=kb.uuid, query="foo") sdk.search(kbid=kb.uuid, query="foo") - sdk.chat(kbid=kb.uuid, query="foo") sdk.ask(kbid=kb.uuid, query="foo") resource = sdk.create_resource(kbid=kb.uuid, title="Resource", slug="resource") - sdk.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo") - sdk.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo") sdk.ask_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo") sdk.ask_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo") sdk.feedback(kbid=kb.uuid, ident="bar", good=True, feedback="baz", task="CHAT") diff --git a/nucliadb_sdk/tests/test_sdk_async.py b/nucliadb_sdk/tests/test_sdk_async.py index f454284b21..c8ec221cc0 100644 --- a/nucliadb_sdk/tests/test_sdk_async.py +++ b/nucliadb_sdk/tests/test_sdk_async.py @@ -73,12 +73,9 @@ async def test_resource_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb): async def test_search_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb): await sdk_async.find(kbid=kb.uuid, query="foo") await sdk_async.search(kbid=kb.uuid, query="foo") - await sdk_async.chat(kbid=kb.uuid, query="foo") await sdk_async.ask(kbid=kb.uuid, query="foo") resource = await sdk_async.create_resource(kbid=kb.uuid, title="Resource", slug="resource") - await sdk_async.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo") - await sdk_async.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo") await sdk_async.ask_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo") await sdk_async.ask_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo") await sdk_async.feedback( diff --git a/nucliadb_utils/src/nucliadb_utils/const.py b/nucliadb_utils/src/nucliadb_utils/const.py index 6c689b95ce..e9a4a5d5ee 100644 --- a/nucliadb_utils/src/nucliadb_utils/const.py +++ b/nucliadb_utils/src/nucliadb_utils/const.py @@ -84,3 +84,4 @@ class Features: VECTORSETS_V0 = "vectorsets_v0_new_kbs_with_multiple_vectorsets" SKIP_EXTERNAL_INDEX = "nucliadb_skip_external_index" NATS_SYNC_ACK = "nucliadb_nats_sync_ack" + DEPRECATED_CHAT_ENABLED = "nucliadb_deprecated_chat_enabled"