diff --git a/e2e/test_e2e.py b/e2e/test_e2e.py
index 0612902ecb..da446defcd 100644
--- a/e2e/test_e2e.py
+++ b/e2e/test_e2e.py
@@ -1,6 +1,4 @@
-import base64
import io
-import json
import os
import random
import time
@@ -137,18 +135,18 @@ def test_resource_processed(kbid: str, resource_id: str):
def test_search(kbid: str, resource_id: str):
resp = requests.post(
- os.path.join(BASE_URL, f"api/v1/kb/{kbid}/chat"),
+ os.path.join(BASE_URL, f"api/v1/kb/{kbid}/ask"),
headers={
"content-type": "application/json",
"X-NUCLIADB-ROLES": "READER",
"x-ndb-client": "web",
+ "x-synchronous": "true",
},
json={
"query": "Why is soccer called soccer?",
"context": [],
"show": ["basic", "values", "origin"],
- "features": ["paragraphs", "relations"],
- "inTitleOnly": False,
+ "features": ["keyword", "relations"],
"highlight": True,
"autofilter": False,
"page_number": 0,
@@ -158,38 +156,12 @@ def test_search(kbid: str, resource_id: str):
)
raise_for_status(resp)
-
- raw = io.BytesIO(resp.content)
- toread_bytes = raw.read(4)
- toread = int.from_bytes(toread_bytes, byteorder="big", signed=False)
- print(f"toread: {toread}")
- raw_search_results = raw.read(toread)
- search_results = json.loads(base64.b64decode(raw_search_results))
- print(f"Search results: {search_results}")
-
- data = raw.read()
- try:
- answer, relations_payload = data.split(b"_END_")
- except ValueError:
- answer = data
- relations_payload = b""
- if len(relations_payload) > 0:
- decoded_relations_payload = base64.b64decode(relations_payload)
- print(f"Relations payload: {decoded_relations_payload}")
- try:
- answer, tail = answer.split(b"_CIT_")
- chat_answer = answer.decode("utf-8")
- citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False)
- citations_bytes = tail[4 : 4 + citations_length]
- citations = json.loads(base64.b64decode(citations_bytes).decode())
- except ValueError:
- chat_answer = answer.decode("utf-8")
- citations = {}
- print(f"Answer: {chat_answer}")
- print(f"Citations: {citations}")
-
- # assert "Not enough data to answer this" not in chat_answer, search_results
- assert len(search_results["resources"]) == 1
+ ask_response = resp.json()
+ print(f"Search results: {ask_response["retrieval_results"]}")
+ assert len(ask_response["retrieval_results"]["resources"]) == 1
+ print(f"Relations payload: {ask_response["relations"]}")
+ print(f"Answer: {ask_response["answer"]}")
+ print(f"Citations: {ask_response["citations"]}")
def test_predict_proxy(kbid: str):
diff --git a/nucliadb/src/nucliadb/search/api/v1/chat.py b/nucliadb/src/nucliadb/search/api/v1/chat.py
index 8e554f8cc3..4ad68dc704 100644
--- a/nucliadb/src/nucliadb/search/api/v1/chat.py
+++ b/nucliadb/src/nucliadb/search/api/v1/chat.py
@@ -53,8 +53,10 @@
parse_max_tokens,
)
from nucliadb_telemetry.errors import capture_exception
+from nucliadb_utils import const
from nucliadb_utils.authentication import requires
from nucliadb_utils.exceptions import LimitsExceededError
+from nucliadb_utils.utilities import has_feature
END_OF_STREAM = "_END_"
@@ -96,6 +98,7 @@ class SyncChatResponse(pydantic.BaseModel):
tags=["Search"],
response_model=None,
deprecated=True,
+ include_in_schema=False,
)
@requires(NucliaDBRoles.READER)
@version(1)
@@ -112,6 +115,13 @@ async def chat_knowledgebox_endpoint(
"This is slower and requires waiting for entire answer to be ready.",
),
) -> Union[StreamingResponse, HTTPClientError, Response]:
+ if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}):
+ # We keep this for a while so we can enable back chat on a per KB basis, in case we need to
+ return HTTPClientError(
+ status_code=404,
+ detail="This endpoint has been deprecated. Please use /ask instead.",
+ )
+
try:
return await create_chat_response(
kbid, item, x_nucliadb_user, x_ndb_client, x_forwarded_for, x_synchronous
@@ -155,7 +165,7 @@ async def create_chat_response(
origin: str,
x_synchronous: bool,
resource: Optional[str] = None,
-) -> Response:
+) -> Response: # pragma: no cover
chat_request.max_tokens = parse_max_tokens(chat_request.max_tokens)
chat_result = await chat(
kbid,
diff --git a/nucliadb/src/nucliadb/search/api/v1/resource/chat.py b/nucliadb/src/nucliadb/search/api/v1/resource/chat.py
index 3ce0d9c40a..1d1289f281 100644
--- a/nucliadb/src/nucliadb/search/api/v1/resource/chat.py
+++ b/nucliadb/src/nucliadb/search/api/v1/resource/chat.py
@@ -33,8 +33,10 @@
)
from nucliadb_models.resource import NucliaDBRoles
from nucliadb_models.search import ChatRequest, NucliaDBClientType
+from nucliadb_utils import const
from nucliadb_utils.authentication import requires
from nucliadb_utils.exceptions import LimitsExceededError
+from nucliadb_utils.utilities import has_feature
from ..chat import create_chat_response
@@ -47,6 +49,7 @@
tags=["Search"],
response_model=None,
deprecated=True,
+ include_in_schema=False,
)
@requires(NucliaDBRoles.READER)
@version(1)
@@ -124,6 +127,13 @@ async def resource_chat_endpoint(
resource_id: Optional[str] = None,
resource_slug: Optional[str] = None,
) -> Union[StreamingResponse, HTTPClientError, Response]:
+ if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}):
+ # We keep this for a while so we can enable back chat on a per KB basis, in case we need to
+ return HTTPClientError(
+ status_code=404,
+ detail="This endpoint has been deprecated. Please use /ask instead.",
+ )
+
if resource_id is None:
if resource_slug is None:
raise ValueError("Either resource_id or resource_slug must be provided")
diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py
index d1568a94db..e27a607dbb 100644
--- a/nucliadb/src/nucliadb/search/search/chat/query.py
+++ b/nucliadb/src/nucliadb/search/search/chat/query.py
@@ -230,7 +230,7 @@ async def chat(
client_type: NucliaDBClientType,
origin: str,
resource: Optional[str] = None,
-) -> ChatResult:
+) -> ChatResult: # pragma: no cover
metrics = RAGMetrics()
nuclia_learning_id: Optional[str] = None
chat_history = chat_request.context or []
diff --git a/nucliadb/tests/nucliadb/integration/test_chat.py b/nucliadb/tests/nucliadb/integration/test_chat.py
index 086e662e91..045be035b3 100644
--- a/nucliadb/tests/nucliadb/integration/test_chat.py
+++ b/nucliadb/tests/nucliadb/integration/test_chat.py
@@ -17,25 +17,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
#
-import base64
-import io
-import json
from unittest import mock
import pytest
from httpx import AsyncClient
-from nucliadb.search.api.v1.chat import SyncChatResponse
-from nucliadb.search.predict import AnswerStatusCode
-from nucliadb.search.utilities import get_predict
-
-
-@pytest.fixture(scope="function", autouse=True)
-def audit():
- audit_mock = mock.Mock(chat=mock.AsyncMock())
- with mock.patch("nucliadb.search.search.chat.query.get_audit", return_value=audit_mock):
- yield audit_mock
-
@pytest.mark.asyncio()
@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
@@ -44,501 +30,9 @@ async def test_chat(
knowledgebox,
):
resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"})
- assert resp.status_code == 200
-
- context = [{"author": "USER", "text": "query"}]
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat", json={"query": "query", "context": context}
- )
- assert resp.status_code == 200
-
-
-@pytest.fixture(scope="function")
-def find_incomplete_results():
- with mock.patch(
- "nucliadb.search.search.chat.query.find",
- return_value=(mock.MagicMock(), True, None),
- ):
- yield
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_handles_incomplete_find_results(
- nucliadb_reader: AsyncClient,
- knowledgebox,
- find_incomplete_results,
-):
- resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"})
- assert resp.status_code == 529
- assert resp.json() == {"detail": "Temporary error on information retrieval. Please try again."}
-
-
-@pytest.fixture
-async def resource(nucliadb_writer, knowledgebox):
- kbid = knowledgebox
- resp = await nucliadb_writer.post(
- f"/kb/{kbid}/resources",
- json={
- "title": "The title",
- "summary": "The summary",
- "texts": {"text_field": {"body": "The body of the text field"}},
- },
- )
- assert resp.status_code in (200, 201)
- rid = resp.json()["uuid"]
- yield rid
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_handles_status_codes_in_a_different_chunk(
- nucliadb_reader: AsyncClient, knowledgebox, resource
-):
- predict = get_predict()
- predict.generated_answer = [b"some ", b"text ", b"with ", b"status.", b"-2"] # type: ignore
-
- resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "title"})
- assert resp.status_code == 200
- _, answer, _, _ = parse_chat_response(resp.content)
-
- assert answer == b"some text with status."
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_handles_status_codes_in_the_same_chunk(
- nucliadb_reader: AsyncClient, knowledgebox, resource
-):
- predict = get_predict()
- predict.generated_answer = [b"some ", b"text ", b"with ", b"status.-2"] # type: ignore
-
- resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "title"})
- assert resp.status_code == 200
- _, answer, _, _ = parse_chat_response(resp.content)
-
- assert answer == b"some text with status."
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_handles_status_codes_with_last_chunk_empty(
- nucliadb_reader: AsyncClient, knowledgebox, resource
-):
- predict = get_predict()
- predict.generated_answer = [b"some ", b"text ", b"with ", b"status.", b"-2", b""] # type: ignore
-
- resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "title"})
- assert resp.status_code == 200
- _, answer, _, _ = parse_chat_response(resp.content)
-
- assert answer == b"some text with status."
-
-
-def parse_chat_response(content: bytes):
- raw = io.BytesIO(content)
- header = raw.read(4)
- payload_size = int.from_bytes(header, byteorder="big", signed=False)
- data = raw.read(payload_size)
- find_result = json.loads(base64.b64decode(data))
- data = raw.read()
- try:
- answer, relations_payload = data.split(b"_END_")
- except ValueError:
- answer = data
- relations_payload = b""
- relations_result = None
- if len(relations_payload) > 0:
- relations_result = json.loads(base64.b64decode(relations_payload))
- try:
- answer, tail = answer.split(b"_CIT_")
- citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False)
- citations_part = tail[4 : 4 + citations_length]
- citations = json.loads(base64.b64decode(citations_part).decode())
- except ValueError:
- answer = answer
- citations = {}
- return find_result, answer, relations_result, citations
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_always_returns_relations(nucliadb_reader: AsyncClient, knowledgebox):
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={"query": "summary", "features": ["relations"]},
- )
- assert resp.status_code == 200
- _, answer, relations_result, _ = parse_chat_response(resp.content)
- assert answer == b"Not enough data to answer this."
- assert "Ferran" in relations_result["entities"]
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_synchronous(nucliadb_reader: AsyncClient, knowledgebox, resource):
- predict = get_predict()
- predict.generated_answer = [b"some ", b"text ", b"with ", b"status.", b"0"] # type: ignore
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={"query": "title"},
- headers={"X-Synchronous": "True"},
- )
- assert resp.status_code == 200
- resp_data = SyncChatResponse.model_validate_json(resp.content)
-
- assert resp_data.answer == "some text with status."
- assert len(resp_data.results.resources) == 1
- assert resp_data.status == AnswerStatusCode.SUCCESS
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-@pytest.mark.parametrize("sync_chat", (True, False))
-async def test_chat_with_citations(nucliadb_reader: AsyncClient, knowledgebox, resource, sync_chat):
- citations = {"foo": [], "bar": []} # type: ignore
- citations_payload = base64.b64encode(json.dumps(citations).encode())
- citations_size = len(citations_payload).to_bytes(4, byteorder="big", signed=False)
-
- predict = get_predict()
- predict.generated_answer = [ # type: ignore
- b"some ",
- b"text ",
- b"with ",
- b"status.",
- b"_CIT_",
- citations_size,
- citations_payload,
- b"0",
- ]
-
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={"query": "title", "citations": True},
- headers={"X-Synchronous": str(sync_chat)},
- )
- assert resp.status_code == 200
-
- if sync_chat:
- resp_data = SyncChatResponse.model_validate_json(resp.content)
- resp_citations = resp_data.citations
- else:
- resp_citations = parse_chat_response(resp.content)[-1]
- assert resp_citations == citations
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-@pytest.mark.parametrize("sync_chat", (True, False))
-async def test_chat_without_citations(nucliadb_reader: AsyncClient, knowledgebox, resource, sync_chat):
- predict = get_predict()
- predict.generated_answer = [ # type: ignore
- b"some ",
- b"text ",
- b"with ",
- b"status.",
- b"0",
- ]
-
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={"query": "title", "citations": False},
- headers={"X-Synchronous": str(sync_chat)},
- )
- assert resp.status_code == 200
-
- if sync_chat:
- resp_data = SyncChatResponse.model_validate_json(resp.content)
- resp_citations = resp_data.citations
- else:
- resp_citations = parse_chat_response(resp.content)[-1]
- assert resp_citations == {}
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-@pytest.mark.parametrize("debug", (True, False))
-async def test_sync_chat_returns_prompt_context(
- nucliadb_reader: AsyncClient, knowledgebox, resource, debug
-):
- # Make sure prompt context is returned if debug is True
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={"query": "title", "debug": debug},
- headers={"X-Synchronous": "True"},
- )
- assert resp.status_code == 200
- resp_data = SyncChatResponse.model_validate_json(resp.content)
- if debug:
- assert resp_data.prompt_context
- assert resp_data.prompt_context_order
- else:
- assert resp_data.prompt_context is None
- assert resp_data.prompt_context_order is None
-
-
-@pytest.fixture
-async def resources(nucliadb_writer, knowledgebox):
- kbid = knowledgebox
- rids = []
- for i in range(2):
- resp = await nucliadb_writer.post(
- f"/kb/{kbid}/resources",
- json={
- "title": f"The title {i}",
- "summary": f"The summary {i}",
- "texts": {"text_field": {"body": "The body of the text field"}},
- },
- )
- assert resp.status_code in (200, 201)
- rid = resp.json()["uuid"]
- rids.append(rid)
- yield rids
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_rag_options_full_resource(nucliadb_reader: AsyncClient, knowledgebox, resources):
- resource1, resource2 = resources
-
- predict = get_predict()
- predict.calls.clear() # type: ignore
-
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={"query": "title", "rag_strategies": [{"name": "full_resource"}]},
- )
- assert resp.status_code == 200
- _ = parse_chat_response(resp.content)
-
- # Make sure the prompt context is properly crafted
- assert predict.calls[-2][0] == "chat_query" # type: ignore
- prompt_context = predict.calls[-2][1].query_context # type: ignore
-
- # All fields of the matching resource should be in the prompt context
- assert len(prompt_context) == 6
- assert prompt_context[f"{resource1}/a/title"] == "The title 0"
- assert prompt_context[f"{resource1}/a/summary"] == "The summary 0"
- assert prompt_context[f"{resource1}/t/text_field"] == "The body of the text field"
- assert prompt_context[f"{resource2}/a/title"] == "The title 1"
- assert prompt_context[f"{resource2}/a/summary"] == "The summary 1"
- assert prompt_context[f"{resource2}/t/text_field"] == "The body of the text field"
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_rag_options_extend_with_fields(
- nucliadb_reader: AsyncClient, knowledgebox, resources
-):
- resource1, resource2 = resources
-
- predict = get_predict()
- predict.calls.clear() # type: ignore
-
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "field_extension", "fields": ["a/summary"]}],
- },
- )
- assert resp.status_code == 200
- _ = parse_chat_response(resp.content)
-
- # Make sure the prompt context is properly crafted
- assert predict.calls[-2][0] == "chat_query" # type: ignore
- prompt_context = predict.calls[-2][1].query_context # type: ignore
-
- # Matching paragraphs should be in the prompt
- # context, plus the extended field for each resource
- assert len(prompt_context) == 4
- # The matching paragraphs
- assert prompt_context[f"{resource1}/a/title/0-11"] == "The title 0"
- assert prompt_context[f"{resource2}/a/title/0-11"] == "The title 1"
- # The extended fields
- assert prompt_context[f"{resource1}/a/summary"] == "The summary 0"
- assert prompt_context[f"{resource2}/a/summary"] == "The summary 1"
-
-
-@pytest.mark.asyncio()
-async def test_chat_rag_options_validation(nucliadb_reader):
- # Invalid strategy
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "foobar", "fields": ["a/summary"]}],
- },
- )
- assert resp.status_code == 422
-
- # Invalid strategy as a string
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={
- "query": "title",
- "rag_strategies": ["full_resource"],
- },
- )
- assert resp.status_code == 422
-
- # Invalid strategy without name
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={
- "query": "title",
- "rag_strategies": [{"fields": ["a/summary"]}],
- },
- )
- assert resp.status_code == 422
-
- # full_resource cannot be combined with other strategies
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={
- "query": "title",
- "rag_strategies": [
- {"name": "full_resource"},
- {"name": "field_extension", "fields": ["a/summary"]},
- ],
- },
- )
- assert resp.status_code == 422
- detail = resp.json()["detail"]
- assert "If 'full_resource' strategy is chosen, it must be the only strategy" in detail[0]["msg"]
-
- # field_extension requires fields
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={"query": "title", "rag_strategies": [{"name": "field_extension"}]},
- )
- assert resp.status_code == 422
- detail = resp.json()["detail"]
- detail[0]["loc"][-1] == "fields"
- assert detail[0]["msg"] == "Field required"
-
- # fields must be in the right format: field_type/field_name
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "field_extension", "fields": ["foo/t/text"]}],
- },
- )
- assert resp.status_code == 422
- detail = resp.json()["detail"]
- detail[0]["loc"][-1] == "fields"
- assert (
- detail[0]["msg"]
- == "Value error, Field 'foo/t/text' is not in the format {field_type}/{field_name}"
- )
-
- # But fields can have leading and trailing slashes and they will be ignored
- resp = await nucliadb_reader.post(
- f"/kb/foo/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "field_extension", "fields": ["/a/text/"]}],
- },
- )
- assert resp.status_code != 422
-
- # fields must have a valid field type
- resp = await nucliadb_reader.post(
- f"/kb/kbid/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "field_extension", "fields": ["X/fieldname"]}],
- },
- )
- assert resp.status_code == 422
- detail = resp.json()["detail"]
- detail[0]["loc"][-1] == "fields"
- assert (
- "Field 'X/fieldname' does not have a valid field type. Valid field types are" in detail[0]["msg"]
- )
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_capped_context(nucliadb_reader: AsyncClient, knowledgebox, resources):
- # By default, max size is big enough to fit all the prompt context
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "full_resource"}],
- "debug": True,
- },
- headers={"X-Synchronous": "True"},
- )
- assert resp.status_code == 200
- resp_data = SyncChatResponse.model_validate_json(resp.content)
- assert resp_data.prompt_context is not None
- assert len(resp_data.prompt_context) == 6
- total_size = sum(len(v) for v in resp_data.prompt_context.values())
- # Try now setting a smaller max size. It should be respected
- max_size = 28
- assert total_size > max_size * 3
-
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={
- "query": "title",
- "rag_strategies": [{"name": "full_resource"}],
- "debug": True,
- "max_tokens": {"context": max_size},
- },
- headers={"X-Synchronous": "True"},
- )
- assert resp.status_code == 200, resp.text
- resp_data = SyncChatResponse.model_validate_json(resp.content)
- assert resp_data.prompt_context is not None
- total_size = sum(len(v) for v in resp_data.prompt_context.values())
- assert total_size <= max_size * 3
-
-
-@pytest.mark.asyncio()
-async def test_chat_on_a_kb_not_found(nucliadb_reader):
- resp = await nucliadb_reader.post("/kb/unknown_kb_id/chat", json={"query": "title"})
assert resp.status_code == 404
- assert resp.json() == {"detail": "Knowledge Box 'unknown_kb_id' not found."}
-
-
-@pytest.mark.asyncio()
-@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
-async def test_chat_max_tokens(nucliadb_reader, knowledgebox, resources):
- # As an integer
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={
- "query": "title",
- "max_tokens": 100,
- },
- )
- assert resp.status_code == 200
-
- # Same but with the max tokens in a dict
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={
- "query": "title",
- "max_tokens": {"context": 100, "answer": 50},
- },
- )
- assert resp.status_code == 200
+ assert resp.json()["detail"] == "This endpoint has been deprecated. Please use /ask instead."
- # If the context requested is bigger than the max tokens, it should fail
- predict = get_predict()
- resp = await nucliadb_reader.post(
- f"/kb/{knowledgebox}/chat",
- json={
- "query": "title",
- "max_tokens": {"context": predict.max_context + 1},
- },
- )
- assert resp.status_code == 412
- assert (
- resp.json()["detail"]
- == "Invalid query. Error in max_tokens.context: Max context tokens is higher than the model's limit of 1000"
- )
+ with mock.patch("nucliadb.search.api.v1.chat.has_feature", return_value=True):
+ resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"})
+ assert resp.status_code == 200
diff --git a/nucliadb/tests/search/unit/api/v1/resource/test_chat.py b/nucliadb/tests/search/unit/api/v1/resource/test_chat.py
deleted file mode 100644
index 160caa5e6a..0000000000
--- a/nucliadb/tests/search/unit/api/v1/resource/test_chat.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright (C) 2021 Bosutech XXI S.L.
-#
-# nucliadb is offered under the AGPL v3.0 and as commercial software.
-# For commercial licensing, contact us at info@nuclia.com.
-#
-# AGPL:
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-#
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from starlette.requests import Request
-
-from nucliadb.models.responses import HTTPClientError
-from nucliadb.search import predict
-from nucliadb.search.api.v1.resource.chat import resource_chat_endpoint
-from nucliadb_utils.exceptions import LimitsExceededError
-
-pytestmark = pytest.mark.asyncio
-
-
-class DummyTestRequest(Request):
- @property
- def auth(self):
- return Mock(scopes=["READER"])
-
- @property
- def user(self):
- return Mock(display_name="username")
-
-
-@pytest.fixture(scope="function")
-def create_chat_response_mock():
- with mock.patch(
- "nucliadb.search.api.v1.resource.chat.create_chat_response",
- ) as mocked:
- yield mocked
-
-
-@pytest.mark.parametrize(
- "predict_error,http_error_response",
- [
- (
- LimitsExceededError(402, "over the quota"),
- HTTPClientError(status_code=402, detail="over the quota"),
- ),
- (
- predict.RephraseError("foobar"),
- HTTPClientError(
- status_code=529,
- detail="Temporary error while rephrasing the query. Please try again later. Error: foobar",
- ),
- ),
- (
- predict.RephraseMissingContextError(),
- HTTPClientError(
- status_code=412,
- detail="Unable to rephrase the query with the provided context.",
- ),
- ),
- ],
-)
-async def test_resource_chat_endpoint_handles_errors(
- create_chat_response_mock, predict_error, http_error_response
-):
- create_chat_response_mock.side_effect = predict_error
- request = DummyTestRequest(
- scope={
- "type": "http",
- "http_version": "1.1",
- "method": "GET",
- "headers": [],
- }
- )
- response = await resource_chat_endpoint(
- request=request,
- kbid="kbid",
- item=Mock(),
- x_ndb_client=None,
- x_nucliadb_user="",
- x_forwarded_for="",
- x_synchronous=True,
- resource_id="rid",
- )
- assert response.status_code == http_error_response.status_code
- assert response.body == http_error_response.body
diff --git a/nucliadb/tests/search/unit/api/v1/test_chat.py b/nucliadb/tests/search/unit/api/v1/test_chat.py
deleted file mode 100644
index 0ac3fa80a1..0000000000
--- a/nucliadb/tests/search/unit/api/v1/test_chat.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Copyright (C) 2021 Bosutech XXI S.L.
-#
-# nucliadb is offered under the AGPL v3.0 and as commercial software.
-# For commercial licensing, contact us at info@nuclia.com.
-#
-# AGPL:
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-#
-from unittest import mock
-from unittest.mock import Mock
-
-import pytest
-from starlette.requests import Request
-
-from nucliadb.models.responses import HTTPClientError
-from nucliadb.search import predict
-from nucliadb.search.api.v1.chat import chat_knowledgebox_endpoint
-from nucliadb_utils.exceptions import LimitsExceededError
-
-pytestmark = pytest.mark.asyncio
-
-
-class DummyTestRequest(Request):
- @property
- def auth(self):
- return Mock(scopes=["READER"])
-
- @property
- def user(self):
- return Mock(display_name="username")
-
-
-@pytest.fixture(scope="function")
-def create_chat_response_mock():
- with mock.patch(
- "nucliadb.search.api.v1.chat.create_chat_response",
- ) as mocked:
- yield mocked
-
-
-@pytest.mark.parametrize(
- "predict_error,http_error_response",
- [
- (
- LimitsExceededError(402, "over the quota"),
- HTTPClientError(status_code=402, detail="over the quota"),
- ),
- (
- predict.RephraseError("foobar"),
- HTTPClientError(
- status_code=529,
- detail="Temporary error while rephrasing the query. Please try again later. Error: foobar",
- ),
- ),
- (
- predict.RephraseMissingContextError(),
- HTTPClientError(
- status_code=412,
- detail="Unable to rephrase the query with the provided context.",
- ),
- ),
- ],
-)
-async def test_chat_endpoint_handles_errors(
- create_chat_response_mock, predict_error, http_error_response
-):
- create_chat_response_mock.side_effect = predict_error
- request = DummyTestRequest(
- scope={
- "type": "http",
- "http_version": "1.1",
- "method": "GET",
- "headers": [],
- }
- )
- response = await chat_knowledgebox_endpoint(
- request=request,
- kbid="kbid",
- item=Mock(),
- x_ndb_client=None,
- x_nucliadb_user="",
- x_forwarded_for="",
- )
- assert response.status_code == http_error_response.status_code
- assert response.body == http_error_response.body
diff --git a/nucliadb_sdk/README.md b/nucliadb_sdk/README.md
index f950ec5ae8..ad37c4906b 100644
--- a/nucliadb_sdk/README.md
+++ b/nucliadb_sdk/README.md
@@ -134,7 +134,7 @@ After the data is pushed, the NucliaDB SDK could also be used to find answers on
>>> import nucliadb_sdk
>>>
>>> ndb = nucliadb_sdk.NucliaDB(region="on-prem", url="http://localhost:8080")
->>> resp = ndb.chat(kbid="my-kb-id", query="What does Hakuna Matata mean?")
+>>> resp = ndb.ask(kbid="my-kb-id", query="What does Hakuna Matata mean?")
>>> print(resp.answer)
'Hakuna matata is actually a phrase in the East African language of Swahili that literally means “no trouble” or “no problems”.'
```
diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py b/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py
index f86fc2cdcb..c0a57f31ee 100644
--- a/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py
+++ b/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py
@@ -232,50 +232,6 @@ class Docstring(BaseModel):
],
)
-CHAT = Docstring(
- doc="""Chat with your Knowledge Box""",
- examples=[
- Example(
- description="Get an answer for a question that is part of the data in the Knowledge Box",
- code=""">>> from nucliadb_sdk import *
->>> sdk = NucliaDBSDK(api_key="api-key")
->>> sdk.chat(kbid="mykbid", query="Will France be in recession in 2023?").answer
-'Yes, according to the provided context, France is expected to be in recession in 2023.'
-""",
- ),
- Example(
- description="You can use the `content` parameter to pass a `ChatRequest` object",
- code=""">>> content = ChatRequest(query="Who won the 2018 football World Cup?")
->>> sdk.chat(kbid="mykbid", content=content).answer
-'France won the 2018 football World Cup.'
-""",
- ),
- ],
-)
-
-RESOURCE_CHAT = Docstring(
- doc="""Chat with your document""",
- examples=[
- Example(
- description="Have a chat with your document. Generated answers are scoped to the context of the document.",
- code=""">>> sdk.chat_on_resource(kbid="mykbid", query="What is the coldest season in Sevilla?").answer
-'January is the coldest month.'
-""",
- ),
- Example(
- description="You can use the `content` parameter to pass previous context to the query",
- code=""">>> from nucliadb_models.search import ChatRequest, ChatContextMessage
->>> content = ChatRequest()
->>> content.query = "What is the average temperature?"
->>> content.context.append(ChatContextMessage(author="USER", text="What is the coldest season in Sevilla?"))
->>> content.context.append(ChatContextMessage(author="NUCLIA", text="January is the coldest month."))
->>> sdk.chat(kbid="mykbid", content=content).answer
-'According to the context, the average temperature in January in Sevilla is 15.9 °C and 5.2 °C.'
-""",
- ),
- ],
-)
-
SUMMARIZE = Docstring(
doc="""Summarize your documents""",
examples=[
diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py
index 3259ffbebf..dd2f3a4ce0 100644
--- a/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py
+++ b/nucliadb_sdk/src/nucliadb_sdk/v2/sdk.py
@@ -17,7 +17,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
import asyncio
-import base64
import enum
import inspect
import io
@@ -65,7 +64,6 @@
AnswerAskResponseItem,
AskRequest,
AskResponseItem,
- ChatRequest,
CitationsAskResponseItem,
ErrorAskResponseItem,
FeedbackRequest,
@@ -74,7 +72,6 @@
KnowledgeboxFindResults,
KnowledgeboxSearchResults,
MetadataAskResponseItem,
- Relations,
RelationsAskResponseItem,
RetrievalAskResponseItem,
SearchRequest,
@@ -101,14 +98,6 @@ class Region(enum.Enum):
AWS_US_EAST_2_1 = "aws-us-east-2-1"
-class ChatResponse(BaseModel):
- result: KnowledgeboxFindResults
- answer: str
- relations: Optional[Relations] = None
- learning_id: Optional[str] = None
- citations: dict[str, Any] = {}
-
-
RawRequestContent = Union[str, bytes, Iterable[bytes], AsyncIterable[bytes]]
@@ -119,38 +108,6 @@ def json_response_parser(response: httpx.Response) -> Any:
return orjson.loads(response.content.decode())
-def chat_response_parser(response: httpx.Response) -> ChatResponse:
- raw = io.BytesIO(response.content)
- header = raw.read(4)
- payload_size = int.from_bytes(header, byteorder="big", signed=False)
- data = raw.read(payload_size)
- find_result = KnowledgeboxFindResults.model_validate_json(base64.b64decode(data))
- data = raw.read()
- try:
- answer, relations_payload = data.split(b"_END_")
- except ValueError:
- answer = data
- relations_payload = b""
- learning_id = response.headers.get("NUCLIA-LEARNING-ID")
- relations_result = None
- if len(relations_payload) > 0:
- relations_result = Relations.model_validate_json(base64.b64decode(relations_payload))
- try:
- answer, tail = answer.split(b"_CIT_")
- citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False)
- citations_bytes = tail[4 : 4 + citations_length]
- citations = orjson.loads(base64.b64decode(citations_bytes).decode())
- except ValueError:
- citations = {}
- return ChatResponse(
- result=find_result,
- answer=answer.decode("utf-8"),
- relations=relations_result,
- learning_id=learning_id,
- citations=citations,
- )
-
-
def ask_response_parser(response: httpx.Response) -> SyncAskResponse:
content_type = response.headers.get("Content-Type")
if content_type not in ("application/json", "application/x-ndjson"):
@@ -646,14 +603,6 @@ def _check_response(self, response: httpx.Response):
request_type=SearchRequest,
response_type=KnowledgeboxSearchResults,
)
- chat = _request_builder(
- name="chat",
- path_template="/v1/kb/{kbid}/chat",
- method="POST",
- path_params=("kbid",),
- request_type=ChatRequest,
- response_type=chat_response_parser,
- )
ask = _request_builder(
name="ask",
@@ -664,24 +613,6 @@ def _check_response(self, response: httpx.Response):
response_type=ask_response_parser,
)
- chat_on_resource = _request_builder(
- name="chat_on_resource",
- path_template="/v1/kb/{kbid}/resource/{rid}/chat",
- method="POST",
- path_params=("kbid", "rid"),
- request_type=ChatRequest,
- response_type=chat_response_parser,
- )
-
- chat_on_resource_by_slug = _request_builder(
- name="chat_on_resource_by_slug",
- path_template="/v1/kb/{kbid}/slug/{slug}/chat",
- method="POST",
- path_params=("kbid", "slug"),
- request_type=ChatRequest,
- response_type=chat_response_parser,
- )
-
ask_on_resource = _request_builder(
name="ask_on_resource",
path_template="/v1/kb/{kbid}/resource/{rid}/ask",
diff --git a/nucliadb_sdk/tests/test_chat.py b/nucliadb_sdk/tests/test_chat.py
deleted file mode 100644
index 783dfd3061..0000000000
--- a/nucliadb_sdk/tests/test_chat.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright (C) 2021 Bosutech XXI S.L.
-#
-# nucliadb is offered under the AGPL v3.0 and as commercial software.
-# For commercial licensing, contact us at info@nuclia.com.
-#
-# AGPL:
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-import nucliadb_sdk
-from nucliadb_models.search import MaxTokens
-
-
-def test_chat_on_kb(docs_dataset, sdk: nucliadb_sdk.NucliaDB):
- result = sdk.chat(
- kbid=docs_dataset,
- query="Nuclia loves Semantic Search",
- generative_model="everest",
- prompt="Given this context: {context}. Answer this {question} in a concise way using the provided context",
- extra_context=[
- "Nuclia is a powerful AI search platform",
- "AI Search involves semantic search",
- ],
- # Control the number of AI tokens used for every request
- max_tokens=MaxTokens(context=100, answer=50),
- )
- assert result.learning_id == "00"
- assert result.answer == "valid answer to"
- assert len(result.result.resources) == 7
- assert result.relations
- assert len(result.relations.entities["Nuclia"].related_to) == 18
-
-
-def test_chat_on_kb_with_citations(docs_dataset, sdk: nucliadb_sdk.NucliaDB):
- result = sdk.chat(
- kbid=docs_dataset,
- query="Nuclia loves Semantic Search",
- citations=True,
- )
- assert result.citations == {}
-
-
-def test_chat_on_kb_no_context_found(docs_dataset, sdk: nucliadb_sdk.NucliaDB):
- result = sdk.chat(kbid=docs_dataset, query="penguin")
- assert result.answer == "Not enough data to answer this."
-
-
-def test_chat_on_resource(docs_dataset, sdk: nucliadb_sdk.NucliaDB):
- rid = sdk.list_resources(kbid=docs_dataset).resources[0].id
- # With retrieval
- _ = sdk.chat_on_resource(kbid=docs_dataset, rid=rid, query="Nuclia loves Semantic Search")
-
- # Check chatting with the whole resource (no retrieval)
- _ = sdk.chat_on_resource(
- kbid=docs_dataset,
- rid=rid,
- query="Nuclia loves Semantic Search",
- rag_strategies=[{"name": "full_resource"}],
- )
diff --git a/nucliadb_sdk/tests/test_sdk.py b/nucliadb_sdk/tests/test_sdk.py
index 33aa2aa640..0f29c3c3c1 100644
--- a/nucliadb_sdk/tests/test_sdk.py
+++ b/nucliadb_sdk/tests/test_sdk.py
@@ -85,12 +85,9 @@ def test_resource_endpoints(sdk: nucliadb_sdk.NucliaDB, kb):
def test_search_endpoints(sdk: nucliadb_sdk.NucliaDB, kb):
sdk.find(kbid=kb.uuid, query="foo")
sdk.search(kbid=kb.uuid, query="foo")
- sdk.chat(kbid=kb.uuid, query="foo")
sdk.ask(kbid=kb.uuid, query="foo")
resource = sdk.create_resource(kbid=kb.uuid, title="Resource", slug="resource")
- sdk.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
- sdk.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo")
sdk.ask_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
sdk.ask_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo")
sdk.feedback(kbid=kb.uuid, ident="bar", good=True, feedback="baz", task="CHAT")
diff --git a/nucliadb_sdk/tests/test_sdk_async.py b/nucliadb_sdk/tests/test_sdk_async.py
index f454284b21..c8ec221cc0 100644
--- a/nucliadb_sdk/tests/test_sdk_async.py
+++ b/nucliadb_sdk/tests/test_sdk_async.py
@@ -73,12 +73,9 @@ async def test_resource_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb):
async def test_search_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb):
await sdk_async.find(kbid=kb.uuid, query="foo")
await sdk_async.search(kbid=kb.uuid, query="foo")
- await sdk_async.chat(kbid=kb.uuid, query="foo")
await sdk_async.ask(kbid=kb.uuid, query="foo")
resource = await sdk_async.create_resource(kbid=kb.uuid, title="Resource", slug="resource")
- await sdk_async.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
- await sdk_async.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo")
await sdk_async.ask_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
await sdk_async.ask_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo")
await sdk_async.feedback(
diff --git a/nucliadb_utils/src/nucliadb_utils/const.py b/nucliadb_utils/src/nucliadb_utils/const.py
index 6c689b95ce..e9a4a5d5ee 100644
--- a/nucliadb_utils/src/nucliadb_utils/const.py
+++ b/nucliadb_utils/src/nucliadb_utils/const.py
@@ -84,3 +84,4 @@ class Features:
VECTORSETS_V0 = "vectorsets_v0_new_kbs_with_multiple_vectorsets"
SKIP_EXTERNAL_INDEX = "nucliadb_skip_external_index"
NATS_SYNC_ACK = "nucliadb_nats_sync_ack"
+ DEPRECATED_CHAT_ENABLED = "nucliadb_deprecated_chat_enabled"