Skip to content

Commit

Permalink
(soft) remove chat endpoint (#2406)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Aug 22, 2024
1 parent d30b4bd commit 366b241
Show file tree
Hide file tree
Showing 14 changed files with 37 additions and 932 deletions.
46 changes: 9 additions & 37 deletions e2e/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import base64
import io
import json
import os
import random
import time
Expand Down Expand Up @@ -137,18 +135,18 @@ def test_resource_processed(kbid: str, resource_id: str):

def test_search(kbid: str, resource_id: str):
resp = requests.post(
os.path.join(BASE_URL, f"api/v1/kb/{kbid}/chat"),
os.path.join(BASE_URL, f"api/v1/kb/{kbid}/ask"),
headers={
"content-type": "application/json",
"X-NUCLIADB-ROLES": "READER",
"x-ndb-client": "web",
"x-synchronous": "true",
},
json={
"query": "Why is soccer called soccer?",
"context": [],
"show": ["basic", "values", "origin"],
"features": ["paragraphs", "relations"],
"inTitleOnly": False,
"features": ["keyword", "relations"],
"highlight": True,
"autofilter": False,
"page_number": 0,
Expand All @@ -158,38 +156,12 @@ def test_search(kbid: str, resource_id: str):
)

raise_for_status(resp)

raw = io.BytesIO(resp.content)
toread_bytes = raw.read(4)
toread = int.from_bytes(toread_bytes, byteorder="big", signed=False)
print(f"toread: {toread}")
raw_search_results = raw.read(toread)
search_results = json.loads(base64.b64decode(raw_search_results))
print(f"Search results: {search_results}")

data = raw.read()
try:
answer, relations_payload = data.split(b"_END_")
except ValueError:
answer = data
relations_payload = b""
if len(relations_payload) > 0:
decoded_relations_payload = base64.b64decode(relations_payload)
print(f"Relations payload: {decoded_relations_payload}")
try:
answer, tail = answer.split(b"_CIT_")
chat_answer = answer.decode("utf-8")
citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False)
citations_bytes = tail[4 : 4 + citations_length]
citations = json.loads(base64.b64decode(citations_bytes).decode())
except ValueError:
chat_answer = answer.decode("utf-8")
citations = {}
print(f"Answer: {chat_answer}")
print(f"Citations: {citations}")

# assert "Not enough data to answer this" not in chat_answer, search_results
assert len(search_results["resources"]) == 1
ask_response = resp.json()
print(f"Search results: {ask_response["retrieval_results"]}")
assert len(ask_response["retrieval_results"]["resources"]) == 1
print(f"Relations payload: {ask_response["relations"]}")
print(f"Answer: {ask_response["answer"]}")
print(f"Citations: {ask_response["citations"]}")


def test_predict_proxy(kbid: str):
Expand Down
12 changes: 11 additions & 1 deletion nucliadb/src/nucliadb/search/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
parse_max_tokens,
)
from nucliadb_telemetry.errors import capture_exception
from nucliadb_utils import const
from nucliadb_utils.authentication import requires
from nucliadb_utils.exceptions import LimitsExceededError
from nucliadb_utils.utilities import has_feature

END_OF_STREAM = "_END_"

Expand Down Expand Up @@ -96,6 +98,7 @@ class SyncChatResponse(pydantic.BaseModel):
tags=["Search"],
response_model=None,
deprecated=True,
include_in_schema=False,
)
@requires(NucliaDBRoles.READER)
@version(1)
Expand All @@ -112,6 +115,13 @@ async def chat_knowledgebox_endpoint(
"This is slower and requires waiting for entire answer to be ready.",
),
) -> Union[StreamingResponse, HTTPClientError, Response]:
if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}):
# We keep this for a while so we can enable back chat on a per KB basis, in case we need to
return HTTPClientError(
status_code=404,
detail="This endpoint has been deprecated. Please use /ask instead.",
)

try:
return await create_chat_response(
kbid, item, x_nucliadb_user, x_ndb_client, x_forwarded_for, x_synchronous
Expand Down Expand Up @@ -155,7 +165,7 @@ async def create_chat_response(
origin: str,
x_synchronous: bool,
resource: Optional[str] = None,
) -> Response:
) -> Response: # pragma: no cover
chat_request.max_tokens = parse_max_tokens(chat_request.max_tokens)
chat_result = await chat(
kbid,
Expand Down
10 changes: 10 additions & 0 deletions nucliadb/src/nucliadb/search/api/v1/resource/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
)
from nucliadb_models.resource import NucliaDBRoles
from nucliadb_models.search import ChatRequest, NucliaDBClientType
from nucliadb_utils import const
from nucliadb_utils.authentication import requires
from nucliadb_utils.exceptions import LimitsExceededError
from nucliadb_utils.utilities import has_feature

from ..chat import create_chat_response

Expand All @@ -47,6 +49,7 @@
tags=["Search"],
response_model=None,
deprecated=True,
include_in_schema=False,
)
@requires(NucliaDBRoles.READER)
@version(1)
Expand Down Expand Up @@ -124,6 +127,13 @@ async def resource_chat_endpoint(
resource_id: Optional[str] = None,
resource_slug: Optional[str] = None,
) -> Union[StreamingResponse, HTTPClientError, Response]:
if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}):
# We keep this for a while so we can enable back chat on a per KB basis, in case we need to
return HTTPClientError(
status_code=404,
detail="This endpoint has been deprecated. Please use /ask instead.",
)

if resource_id is None:
if resource_slug is None:
raise ValueError("Either resource_id or resource_slug must be provided")
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def chat(
client_type: NucliaDBClientType,
origin: str,
resource: Optional[str] = None,
) -> ChatResult:
) -> ChatResult: # pragma: no cover
metrics = RAGMetrics()
nuclia_learning_id: Optional[str] = None
chat_history = chat_request.context or []
Expand Down
Loading

0 comments on commit 366b241

Please sign in to comment.