From ae6ad6bc522904ce209dc291af7eb4fb283b082b Mon Sep 17 00:00:00 2001 From: Ferran Llamas Date: Tue, 27 Aug 2024 20:40:24 +0200 Subject: [PATCH] Hard delete chat logic (#2414) --- .../nucliadb_search/templates/search.vs.yaml | 6 +- .../src/nucliadb/search/api/v1/__init__.py | 2 - nucliadb/src/nucliadb/search/api/v1/chat.py | 271 ------------------ .../nucliadb/search/api/v1/resource/chat.py | 183 ------------ nucliadb/src/nucliadb/search/predict.py | 35 --- .../src/nucliadb/search/search/chat/query.py | 258 +---------------- .../tests/nucliadb/integration/test_chat.py | 38 --- .../unit/search/search/chat/test_query.py | 78 +---- nucliadb/tests/search/unit/test_predict.py | 6 +- nucliadb_utils/src/nucliadb_utils/const.py | 1 - .../src/nucliadb_utils/featureflagging.py | 4 - 11 files changed, 9 insertions(+), 873 deletions(-) delete mode 100644 nucliadb/src/nucliadb/search/api/v1/chat.py delete mode 100644 nucliadb/src/nucliadb/search/api/v1/resource/chat.py delete mode 100644 nucliadb/tests/nucliadb/integration/test_chat.py diff --git a/charts/nucliadb_search/templates/search.vs.yaml b/charts/nucliadb_search/templates/search.vs.yaml index e310fc25ae..b23460b9d9 100644 --- a/charts/nucliadb_search/templates/search.vs.yaml +++ b/charts/nucliadb_search/templates/search.vs.yaml @@ -18,10 +18,6 @@ spec: regex: '^/api/v\d+/kb/[^/]+/find$' method: regex: "GET|POST|OPTIONS" - - uri: - regex: '^/api/v\d+/kb/[^/]+/chat$' - method: - regex: "POST|OPTIONS" - uri: regex: '^/api/v\d+/kb/[^/]+/ask$' method: @@ -43,7 +39,7 @@ spec: method: regex: "GET|POST|OPTIONS" - uri: - regex: '^/api/v\d+/kb/[^/]+/(resource|slug)/[^/]+/(chat|find|search|ask)$' + regex: '^/api/v\d+/kb/[^/]+/(resource|slug)/[^/]+/(find|search|ask)$' method: regex: "GET|POST|OPTIONS" - uri: diff --git a/nucliadb/src/nucliadb/search/api/v1/__init__.py b/nucliadb/src/nucliadb/search/api/v1/__init__.py index 755b9593d8..ea347605b7 100644 --- a/nucliadb/src/nucliadb/search/api/v1/__init__.py +++ b/nucliadb/src/nucliadb/search/api/v1/__init__.py @@ -18,7 +18,6 @@ # along with this program. If not, see . # from . import ask # noqa -from . import chat # noqa from . import feedback # noqa from . import find # noqa from . import knowledgebox # noqa @@ -27,6 +26,5 @@ from . import suggest # noqa from . import summarize # noqa from .resource import ask as ask_resource # noqa -from .resource import chat as chat_resource # noqa from .resource import search as search_resource # noqa from .router import api # noqa diff --git a/nucliadb/src/nucliadb/search/api/v1/chat.py b/nucliadb/src/nucliadb/search/api/v1/chat.py deleted file mode 100644 index 4ad68dc704..0000000000 --- a/nucliadb/src/nucliadb/search/api/v1/chat.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -# -import base64 -import json -from typing import Any, Optional, Union - -import pydantic -from fastapi import Body, Header, Request, Response -from fastapi.openapi.models import Example -from fastapi_versioning import version -from starlette.responses import StreamingResponse - -from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound -from nucliadb.models.responses import HTTPClientError -from nucliadb.search import logger, predict -from nucliadb.search.api.v1.router import KB_PREFIX, api -from nucliadb.search.predict import AnswerStatusCode -from nucliadb.search.search.chat.query import ( - START_OF_CITATIONS, - chat, - get_relations_results, -) -from nucliadb.search.search.exceptions import ( - IncompleteFindResultsError, - InvalidQueryError, -) -from nucliadb_models.resource import NucliaDBRoles -from nucliadb_models.search import ( - ChatOptions, - ChatRequest, - KnowledgeboxFindResults, - NucliaDBClientType, - PromptContext, - PromptContextOrder, - Relations, - parse_max_tokens, -) -from nucliadb_telemetry.errors import capture_exception -from nucliadb_utils import const -from nucliadb_utils.authentication import requires -from nucliadb_utils.exceptions import LimitsExceededError -from nucliadb_utils.utilities import has_feature - -END_OF_STREAM = "_END_" - - -class SyncChatResponse(pydantic.BaseModel): - answer: str - relations: Optional[Relations] = None - results: KnowledgeboxFindResults - status: AnswerStatusCode - citations: dict[str, Any] = {} - prompt_context: Optional[PromptContext] = None - prompt_context_order: Optional[PromptContextOrder] = None - - -CHAT_EXAMPLES = { - "search_and_chat": Example( - summary="Ask who won the league final", - description="You can ask a question to your knowledge box", # noqa - value={ - "query": "Who won the league final?", - }, - ), - "search_and_chat_with_custom_prompt": Example( - summary="Ask for the gold price evolution in 2023 in a very conscise way", - description="You can ask a question and specify a custom prompt to tweak the tone of the response", # noqa - value={ - "query": "How has the price of gold evolved during 2023?", - "prompt": "Given this context: {context}. Answer this {question} in a concise way using the provided context", # noqa - }, - ), -} - - -@api.post( - f"/{KB_PREFIX}/{{kbid}}/chat", - status_code=200, - summary="Chat on a Knowledge Box", - description="Chat on a Knowledge Box", - tags=["Search"], - response_model=None, - deprecated=True, - include_in_schema=False, -) -@requires(NucliaDBRoles.READER) -@version(1) -async def chat_knowledgebox_endpoint( - request: Request, - kbid: str, - item: ChatRequest = Body(openapi_examples=CHAT_EXAMPLES), - x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), - x_nucliadb_user: str = Header(""), - x_forwarded_for: str = Header(""), - x_synchronous: bool = Header( - False, - description="When set to true, outputs response as JSON in a non-streaming way. " - "This is slower and requires waiting for entire answer to be ready.", - ), -) -> Union[StreamingResponse, HTTPClientError, Response]: - if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}): - # We keep this for a while so we can enable back chat on a per KB basis, in case we need to - return HTTPClientError( - status_code=404, - detail="This endpoint has been deprecated. Please use /ask instead.", - ) - - try: - return await create_chat_response( - kbid, item, x_nucliadb_user, x_ndb_client, x_forwarded_for, x_synchronous - ) - except KnowledgeBoxNotFound: - return HTTPClientError( - status_code=404, - detail=f"Knowledge Box '{kbid}' not found.", - ) - except LimitsExceededError as exc: - return HTTPClientError(status_code=exc.status_code, detail=exc.detail) - except predict.ProxiedPredictAPIError as err: - return HTTPClientError( - status_code=err.status, - detail=err.detail, - ) - except IncompleteFindResultsError: - return HTTPClientError( - status_code=529, - detail="Temporary error on information retrieval. Please try again.", - ) - except predict.RephraseMissingContextError: - return HTTPClientError( - status_code=412, - detail="Unable to rephrase the query with the provided context.", - ) - except predict.RephraseError as err: - return HTTPClientError( - status_code=529, - detail=f"Temporary error while rephrasing the query. Please try again later. Error: {err}", - ) - except InvalidQueryError as exc: - return HTTPClientError(status_code=412, detail=str(exc)) - - -async def create_chat_response( - kbid: str, - chat_request: ChatRequest, - user_id: str, - client_type: NucliaDBClientType, - origin: str, - x_synchronous: bool, - resource: Optional[str] = None, -) -> Response: # pragma: no cover - chat_request.max_tokens = parse_max_tokens(chat_request.max_tokens) - chat_result = await chat( - kbid, - chat_request, - user_id, - client_type, - origin, - resource=resource, - ) - - if x_synchronous: - streamed_answer = b"" - async for chunk in chat_result.answer_stream: - streamed_answer += chunk - - answer, citations = parse_streamed_answer(streamed_answer, chat_request.citations) - - relations_results = None - if ChatOptions.RELATIONS in chat_request.features: - # XXX should use query parser here - relations_results = await get_relations_results( - kbid=kbid, text_answer=answer, target_shard_replicas=chat_request.shards - ) - - sync_chat_resp = SyncChatResponse( - answer=answer, - relations=relations_results, - results=chat_result.find_results, - status=chat_result.status_code.value, - citations=citations, - ) - if chat_request.debug: - sync_chat_resp.prompt_context = chat_result.prompt_context - sync_chat_resp.prompt_context_order = chat_result.prompt_context_order - return Response( - content=sync_chat_resp.model_dump_json(exclude_unset=True), - headers={ - "NUCLIA-LEARNING-ID": chat_result.nuclia_learning_id or "unknown", - "Access-Control-Expose-Headers": "NUCLIA-LEARNING-ID", - "Content-Type": "application/json", - }, - ) - else: - - async def _streaming_response(): - bytes_results = base64.b64encode(chat_result.find_results.model_dump_json().encode()) - yield len(bytes_results).to_bytes(length=4, byteorder="big", signed=False) - yield bytes_results - - streamed_answer = b"" - async for chunk in chat_result.answer_stream: - streamed_answer += chunk - yield chunk - - answer, _ = parse_streamed_answer(streamed_answer, chat_request.citations) - - yield END_OF_STREAM.encode() - if ChatOptions.RELATIONS in chat_request.features: - # XXX should use query parser here - relations_results = await get_relations_results( - kbid=kbid, - text_answer=answer, - target_shard_replicas=chat_request.shards, - ) - yield base64.b64encode(relations_results.model_dump_json().encode()) - - return StreamingResponse( - _streaming_response(), - media_type="application/octet-stream", - headers={ - "NUCLIA-LEARNING-ID": chat_result.nuclia_learning_id or "unknown", - "Access-Control-Expose-Headers": "NUCLIA-LEARNING-ID", - }, - ) - - -def parse_streamed_answer( - streamed_bytes: bytes, requested_citations: bool -) -> tuple[str, dict[str, Any]]: - try: - text_answer, tail = streamed_bytes.split(START_OF_CITATIONS, 1) - except ValueError: - if requested_citations: - logger.warning( - "Citations were requested but not found in the answer. " - "Returning the answer without citations." - ) - return streamed_bytes.decode("utf-8"), {} - if not requested_citations: - logger.warning( - "Citations were not requested but found in the answer. " - "Returning the answer without citations." - ) - return text_answer.decode("utf-8"), {} - try: - citations_length = int.from_bytes(tail[:4], byteorder="big", signed=False) - citations_bytes = tail[4 : 4 + citations_length] - citations = json.loads(base64.b64decode(citations_bytes).decode()) - return text_answer.decode("utf-8"), citations - except Exception as exc: - capture_exception(exc) - logger.exception("Error parsing citations. Returning the answer without citations.") - return text_answer.decode("utf-8"), {} diff --git a/nucliadb/src/nucliadb/search/api/v1/resource/chat.py b/nucliadb/src/nucliadb/search/api/v1/resource/chat.py deleted file mode 100644 index 1d1289f281..0000000000 --- a/nucliadb/src/nucliadb/search/api/v1/resource/chat.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -# -from typing import Optional, Union - -from fastapi import Header, Request, Response -from fastapi_versioning import version -from starlette.responses import StreamingResponse - -from nucliadb.common import datamanagers -from nucliadb.models.responses import HTTPClientError -from nucliadb.search import predict -from nucliadb.search.api.v1.router import KB_PREFIX, RESOURCE_SLUG_PREFIX, api -from nucliadb.search.search.exceptions import ( - IncompleteFindResultsError, - InvalidQueryError, -) -from nucliadb_models.resource import NucliaDBRoles -from nucliadb_models.search import ChatRequest, NucliaDBClientType -from nucliadb_utils import const -from nucliadb_utils.authentication import requires -from nucliadb_utils.exceptions import LimitsExceededError -from nucliadb_utils.utilities import has_feature - -from ..chat import create_chat_response - - -@api.post( - f"/{KB_PREFIX}/{{kbid}}/resource/{{rid}}/chat", - status_code=200, - summary="Chat with a resource (by id)", - description="Chat with a resource", - tags=["Search"], - response_model=None, - deprecated=True, - include_in_schema=False, -) -@requires(NucliaDBRoles.READER) -@version(1) -async def resource_chat_endpoint_by_uuid( - request: Request, - kbid: str, - rid: str, - item: ChatRequest, - x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), - x_nucliadb_user: str = Header(""), - x_forwarded_for: str = Header(""), - x_synchronous: bool = Header( - False, - description="When set to true, outputs response as JSON in a non-streaming way. " - "This is slower and requires waiting for entire answer to be ready.", - ), -) -> Union[StreamingResponse, HTTPClientError, Response]: - return await resource_chat_endpoint( - request, - kbid, - item, - x_ndb_client, - x_nucliadb_user, - x_forwarded_for, - x_synchronous, - resource_id=rid, - ) - - -@api.post( - f"/{KB_PREFIX}/{{kbid}}/{RESOURCE_SLUG_PREFIX}/{{slug}}/chat", - status_code=200, - summary="Chat with a resource (by slug)", - description="Chat with a resource", - tags=["Search"], - response_model=None, - deprecated=True, -) -@requires(NucliaDBRoles.READER) -@version(1) -async def resource_chat_endpoint_by_slug( - request: Request, - kbid: str, - slug: str, - item: ChatRequest, - x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), - x_nucliadb_user: str = Header(""), - x_forwarded_for: str = Header(""), - x_synchronous: bool = Header( - False, - description="When set to true, outputs response as JSON in a non-streaming way. " - "This is slower and requires waiting for entire answer to be ready.", - ), -) -> Union[StreamingResponse, HTTPClientError, Response]: - return await resource_chat_endpoint( - request, - kbid, - item, - x_ndb_client, - x_nucliadb_user, - x_forwarded_for, - x_synchronous, - resource_slug=slug, - ) - - -async def resource_chat_endpoint( - request: Request, - kbid: str, - item: ChatRequest, - x_ndb_client: NucliaDBClientType, - x_nucliadb_user: str, - x_forwarded_for: str, - x_synchronous: bool, - resource_id: Optional[str] = None, - resource_slug: Optional[str] = None, -) -> Union[StreamingResponse, HTTPClientError, Response]: - if not has_feature(const.Features.DEPRECATED_CHAT_ENABLED, default=False, context={"kbid": kbid}): - # We keep this for a while so we can enable back chat on a per KB basis, in case we need to - return HTTPClientError( - status_code=404, - detail="This endpoint has been deprecated. Please use /ask instead.", - ) - - if resource_id is None: - if resource_slug is None: - raise ValueError("Either resource_id or resource_slug must be provided") - - resource_id = await get_resource_uuid_by_slug(kbid, resource_slug) - if resource_id is None: - return HTTPClientError(status_code=404, detail="Resource not found") - - try: - return await create_chat_response( - kbid, - item, - x_nucliadb_user, - x_ndb_client, - x_forwarded_for, - x_synchronous, - resource=resource_id, - ) - except LimitsExceededError as exc: - return HTTPClientError(status_code=exc.status_code, detail=exc.detail) - except predict.ProxiedPredictAPIError as err: - return HTTPClientError( - status_code=err.status, - detail=err.detail, - ) - except IncompleteFindResultsError: - return HTTPClientError( - status_code=529, - detail="Temporary error on information retrieval. Please try again.", - ) - except predict.RephraseMissingContextError: - return HTTPClientError( - status_code=412, - detail="Unable to rephrase the query with the provided context.", - ) - except predict.RephraseError as err: - return HTTPClientError( - status_code=529, - detail=f"Temporary error while rephrasing the query. Please try again later. Error: {err}", - ) - except InvalidQueryError as exc: - return HTTPClientError(status_code=412, detail=str(exc)) - - -async def get_resource_uuid_by_slug(kbid: str, slug: str) -> Optional[str]: - async with datamanagers.with_ro_transaction() as txn: - return await datamanagers.resources.get_resource_uuid_from_slug(txn, kbid=kbid, slug=slug) diff --git a/nucliadb/src/nucliadb/search/predict.py b/nucliadb/src/nucliadb/search/predict.py index 83a55e0bc6..d3dd679502 100644 --- a/nucliadb/src/nucliadb/search/predict.py +++ b/nucliadb/src/nucliadb/search/predict.py @@ -320,26 +320,6 @@ async def rephrase_query(self, kbid: str, item: RephraseModel) -> str: await self.check_response(resp, expected_status=200) return await _parse_rephrase_response(resp) - @predict_observer.wrap({"type": "chat"}) - async def chat_query(self, kbid: str, item: ChatModel) -> tuple[str, AsyncIterator[bytes]]: - try: - self.check_nua_key_is_configured_for_onprem() - except NUAKeyMissingError: - error = "Nuclia Service account is not defined so the chat operation could not be performed" - logger.warning(error) - raise SendToPredictError(error) - - resp = await self.make_request( - "POST", - url=self.get_predict_url(CHAT, kbid), - json=item.model_dump(), - headers=self.get_predict_headers(kbid), - timeout=None, - ) - await self.check_response(resp, expected_status=200) - ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER) - return ident, get_answer_generator(resp) - @predict_observer.wrap({"type": "chat_ndjson"}) async def chat_query_ndjson( self, kbid: str, item: ChatModel @@ -452,12 +432,6 @@ def __init__(self): self.cluster_url = "http://localhost:8000" self.public_url = "http://localhost:8000" self.calls = [] - self.generated_answer = [ - b"valid ", - b"answer ", - b" to", - AnswerStatusCode.SUCCESS.encode(), - ] self.ndjson_answer = [ b'{"chunk": {"type": "text", "text": "valid "}}\n', b'{"chunk": {"type": "text", "text": "answer "}}\n', @@ -496,15 +470,6 @@ async def rephrase_query(self, kbid: str, item: RephraseModel) -> str: self.calls.append(("rephrase_query", item)) return DUMMY_REPHRASE_QUERY - async def chat_query(self, kbid: str, item: ChatModel) -> tuple[str, AsyncIterator[bytes]]: - self.calls.append(("chat_query", item)) - - async def generate(): - for i in self.generated_answer: - yield i - - return (DUMMY_LEARNING_ID, generate()) - async def chat_query_ndjson( self, kbid: str, item: ChatModel ) -> tuple[str, AsyncIterator[GenerativeChunk]]: diff --git a/nucliadb/src/nucliadb/search/search/chat/query.py b/nucliadb/src/nucliadb/search/search/chat/query.py index e27a607dbb..525f3e93f8 100644 --- a/nucliadb/src/nucliadb/search/search/chat/query.py +++ b/nucliadb/src/nucliadb/search/search/chat/query.py @@ -17,15 +17,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -import asyncio -import time -from dataclasses import dataclass -from typing import AsyncGenerator, AsyncIterator, Optional +from typing import Optional from nucliadb.search import logger -from nucliadb.search.predict import AnswerStatusCode, RephraseMissingContextError +from nucliadb.search.predict import AnswerStatusCode from nucliadb.search.requesters.utils import Method, node_query -from nucliadb.search.search.chat.prompt import PromptContextBuilder from nucliadb.search.search.exceptions import IncompleteFindResultsError from nucliadb.search.search.find import find from nucliadb.search.search.merge import merge_relations_results @@ -34,52 +30,23 @@ from nucliadb.search.utilities import get_predict from nucliadb_models.search import ( ChatContextMessage, - ChatModel, ChatOptions, ChatRequest, FindRequest, KnowledgeboxFindResults, - MinScore, NucliaDBClientType, PromptContext, PromptContextOrder, Relations, RephraseModel, SearchOptions, - UserPrompt, - parse_custom_prompt, ) from nucliadb_protos import audit_pb2 from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse from nucliadb_telemetry.errors import capture_exception -from nucliadb_utils.helpers import async_gen_lookahead from nucliadb_utils.utilities import get_audit NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this." -AUDIT_TEXT_RESULT_SEP = " \n\n " -START_OF_CITATIONS = b"_CIT_" - - -class FoundStatusCode: - def __init__(self, default: AnswerStatusCode = AnswerStatusCode.SUCCESS): - self._value = AnswerStatusCode.SUCCESS - - def set(self, value: AnswerStatusCode) -> None: - self._value = value - - @property - def value(self) -> AnswerStatusCode: - return self._value - - -@dataclass -class ChatResult: - nuclia_learning_id: Optional[str] - answer_stream: AsyncIterator[bytes] - status_code: FoundStatusCode - find_results: KnowledgeboxFindResults - prompt_context: PromptContext - prompt_context_order: PromptContextOrder async def rephrase_query( @@ -101,35 +68,6 @@ async def rephrase_query( return await predict.rephrase_query(kbid, req) -async def format_generated_answer( - answer_generator: AsyncGenerator[bytes, None], output_status_code: FoundStatusCode -): - status_code: Optional[AnswerStatusCode] = None - is_last_chunk = False - async for answer_chunk, is_last_chunk in async_gen_lookahead(answer_generator): - if is_last_chunk: - try: - status_code = _parse_answer_status_code(answer_chunk) - except ValueError: - # TODO: remove this in the future, it's - # just for bw compatibility until predict - # is updated to the new protocol - status_code = AnswerStatusCode.SUCCESS - yield answer_chunk - else: - # TODO: this should be needed but, in case we receive the status - # code mixed with text, we strip it and return the text - if len(answer_chunk) != len(status_code.encode()): - answer_chunk = answer_chunk.rstrip(status_code.encode()) - yield answer_chunk - break - yield answer_chunk - if not is_last_chunk: - logger.warning("BUG: /chat endpoint without last chunk") - - output_status_code.set(status_code or AnswerStatusCode.SUCCESS) - - async def get_find_results( *, kbid: str, @@ -217,190 +155,6 @@ async def get_relations_results( return Relations(entities={}) -async def not_enough_context_generator(): - await asyncio.sleep(0) - yield NOT_ENOUGH_CONTEXT_ANSWER.encode() - yield AnswerStatusCode.NO_CONTEXT.encode() - - -async def chat( - kbid: str, - chat_request: ChatRequest, - user_id: str, - client_type: NucliaDBClientType, - origin: str, - resource: Optional[str] = None, -) -> ChatResult: # pragma: no cover - metrics = RAGMetrics() - nuclia_learning_id: Optional[str] = None - chat_history = chat_request.context or [] - user_context = chat_request.extra_context or [] - user_query = chat_request.query - rephrased_query = None - prompt_context: PromptContext = {} - prompt_context_order: PromptContextOrder = {} - - if len(chat_history) > 0 or len(user_context) > 0: - try: - with metrics.time("rephrase"): - rephrased_query = await rephrase_query( - kbid, - chat_history=chat_history, - query=user_query, - user_id=user_id, - user_context=user_context, - generative_model=chat_request.generative_model, - ) - except RephraseMissingContextError: - logger.info("Failed to rephrase chat query, using original") - - # Retrieval is not needed if we are chatting on a specific - # resource and the full_resource strategy is enabled - needs_retrieval = True - if resource is not None: - chat_request.resource_filters = [resource] - if any(strategy.name == "full_resource" for strategy in chat_request.rag_strategies): - needs_retrieval = False - - if needs_retrieval: - with metrics.time("retrieval"): - find_results, query_parser = await get_find_results( - kbid=kbid, - query=rephrased_query or user_query, - chat_request=chat_request, - ndb_client=client_type, - user=user_id, - origin=origin, - metrics=metrics, - ) - status_code = FoundStatusCode() - if len(find_results.resources) == 0: - # If no resources were found on the retrieval, we return - # a "Not enough context" answer and skip the llm query - answer_stream = format_generated_answer(not_enough_context_generator(), status_code) - return ChatResult( - nuclia_learning_id=nuclia_learning_id, - answer_stream=answer_stream, - status_code=status_code, - find_results=find_results, - prompt_context=prompt_context, - prompt_context_order=prompt_context_order, - ) - else: - status_code = FoundStatusCode() - find_results = KnowledgeboxFindResults(resources={}, min_score=None) - query_parser = QueryParser( - kbid=kbid, - features=[], - query="", - filters=chat_request.filters, - page_number=0, - page_size=0, - min_score=MinScore(), - ) - - with metrics.time("context_building"): - query_parser.max_tokens = chat_request.max_tokens # type: ignore - max_tokens_context = await query_parser.get_max_tokens_context() - prompt_context_builder = PromptContextBuilder( - kbid=kbid, - find_results=find_results, - resource=resource, - user_context=user_context, - strategies=chat_request.rag_strategies, - image_strategies=chat_request.rag_images_strategies, - max_context_characters=tokens_to_chars(max_tokens_context), - visual_llm=await query_parser.get_visual_llm_enabled(), - ) - ( - prompt_context, - prompt_context_order, - prompt_context_images, - ) = await prompt_context_builder.build() - - custom_prompt = parse_custom_prompt(chat_request) - chat_model = ChatModel( - user_id=user_id, - system=custom_prompt.system, - user_prompt=UserPrompt(prompt=custom_prompt.user) if custom_prompt.user else None, - query_context=prompt_context, - query_context_order=prompt_context_order, - chat_history=chat_history, - question=user_query, - truncate=True, - citations=chat_request.citations, - generative_model=chat_request.generative_model, - max_tokens=query_parser.get_max_tokens_answer(), - query_context_images=prompt_context_images, - prefer_markdown=chat_request.prefer_markdown, - ) - predict = get_predict() - generative_start = time.monotonic() - nuclia_learning_id, predict_generator = await predict.chat_query(kbid, chat_model) - - async def _wrapped_stream(): - # so we can audit after streamed out answer - text_answer = b"" - async for chunk in format_generated_answer(predict_generator, status_code): - if text_answer == b"": - # first chunk - metrics.record_first_chunk_yielded() - text_answer += chunk - yield chunk - try: - rephrase_time = metrics.elapsed("rephrase") - except KeyError: - rephrase_time = None - - maybe_audit_chat( - kbid=kbid, - user_id=user_id, - client_type=client_type, - origin=origin, - user_query=user_query, - rephrased_query=rephrased_query, - rephrase_time=rephrase_time, - generative_answer_time=time.monotonic() - generative_start, - generative_answer_first_chunk_time=metrics.first_chunk_yielded_at - metrics.global_start, - text_answer=text_answer, - status_code=status_code.value, - chat_history=chat_history, - query_context=prompt_context, - query_context_order=prompt_context_order, - learning_id=nuclia_learning_id, - ) - - answer_stream = _wrapped_stream() - return ChatResult( - nuclia_learning_id=nuclia_learning_id, - answer_stream=answer_stream, - status_code=status_code, - find_results=find_results, - prompt_context=prompt_context, - prompt_context_order=prompt_context_order, - ) - - -def _parse_answer_status_code(chunk: bytes) -> AnswerStatusCode: - """ - Parses the status code from the last chunk of the answer. - """ - try: - return AnswerStatusCode(chunk.decode()) - except ValueError: - # In some cases, even if the status code was yield separately - # at the server side, the status code is appended to the previous chunk... - # It may be a bug in the aiohttp.StreamResponse implementation, - # but we haven't spotted it yet. For now, we just try to parse the status code - # from the tail of the chunk. - logger.debug(f"Error decoding status code from /chat's last chunk. Chunk: {chunk!r}") - if chunk == b"": - raise - if chunk.endswith(b"0"): - return AnswerStatusCode.SUCCESS - return AnswerStatusCode(chunk[-2:].decode()) - - def maybe_audit_chat( *, kbid: str, @@ -456,13 +210,7 @@ def parse_audit_answer(raw_text_answer: bytes, status_code: Optional[AnswerStatu if status_code == AnswerStatusCode.NO_CONTEXT: # We don't want to audit "Not enough context to answer this." and instead set a None. return None - # Split citations part from answer - try: - raw_audit_answer, _ = raw_text_answer.split(START_OF_CITATIONS) - except ValueError: - raw_audit_answer = raw_text_answer - audit_answer = raw_audit_answer.decode() - return audit_answer + return raw_text_answer.decode() def tokens_to_chars(n_tokens: int) -> int: diff --git a/nucliadb/tests/nucliadb/integration/test_chat.py b/nucliadb/tests/nucliadb/integration/test_chat.py deleted file mode 100644 index 045be035b3..0000000000 --- a/nucliadb/tests/nucliadb/integration/test_chat.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (C) 2021 Bosutech XXI S.L. -# -# nucliadb is offered under the AGPL v3.0 and as commercial software. -# For commercial licensing, contact us at info@nuclia.com. -# -# AGPL: -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -# -from unittest import mock - -import pytest -from httpx import AsyncClient - - -@pytest.mark.asyncio() -@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) -async def test_chat( - nucliadb_reader: AsyncClient, - knowledgebox, -): - resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"}) - assert resp.status_code == 404 - assert resp.json()["detail"] == "This endpoint has been deprecated. Please use /ask instead." - - with mock.patch("nucliadb.search.api.v1.chat.has_feature", return_value=True): - resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/chat", json={"query": "query"}) - assert resp.status_code == 200 diff --git a/nucliadb/tests/search/unit/search/search/chat/test_query.py b/nucliadb/tests/search/unit/search/search/chat/test_query.py index 812e225694..d3184d39b4 100644 --- a/nucliadb/tests/search/unit/search/search/chat/test_query.py +++ b/nucliadb/tests/search/unit/search/search/chat/test_query.py @@ -20,16 +20,12 @@ import pytest -from nucliadb.search.predict import AnswerStatusCode, RephraseMissingContextError +from nucliadb.search.predict import AnswerStatusCode from nucliadb.search.search.chat.query import ( - _parse_answer_status_code, - chat, get_find_results, parse_audit_answer, ) from nucliadb_models.search import ( - Author, - ChatContextMessage, ChatOptions, ChatRequest, KnowledgeboxFindResults, @@ -46,74 +42,6 @@ def predict(): yield predict -async def test_chat_does_not_call_predict_if_no_find_results(predict): - find_results = KnowledgeboxFindResults( - total=0, min_score=MinScore(semantic=0.7), resources={}, facets=[] - ) - chat_request = ChatRequest(query="query") - - with mock.patch( - "nucliadb.search.search.chat.query.get_find_results", - return_value=(find_results, None), - ): - await chat( - "kbid", - chat_request, - "user_id", - NucliaDBClientType.API, - "origin", - ) - - predict.chat_query.assert_not_called() - - -async def test_chat_uses_original_query_if_failed_to_rephrase(predict): - find_results = KnowledgeboxFindResults( - total=0, min_score=MinScore(semantic=0.7), resources={}, facets=[] - ) - chat_request = ChatRequest( - query="query", context=[ChatContextMessage(author=Author.NUCLIA, text="hello!")] - ) - - predict.rephrase_query.side_effect = RephraseMissingContextError() - - with mock.patch( - "nucliadb.search.search.chat.query.get_find_results", - return_value=(find_results, None), - ) as find_mock: - await chat( - "kbid", - chat_request, - "user_id", - NucliaDBClientType.API, - "origin", - ) - - find_mock.assert_called_once() - assert find_mock.call_args.kwargs["query"] == "query" - - -@pytest.mark.parametrize( - "chunk,status_code,error", - [ - (b"", None, True), - (b"errorcodeisnotpresetn", None, True), - (b"0", AnswerStatusCode.SUCCESS, False), - (b"-1", AnswerStatusCode.ERROR, False), - (b"-2", AnswerStatusCode.NO_CONTEXT, False), - (b"foo.0", AnswerStatusCode.SUCCESS, False), - (b"bar.-1", AnswerStatusCode.ERROR, False), - (b"baz.-2", AnswerStatusCode.NO_CONTEXT, False), - ], -) -def test_parse_status_code(chunk, status_code, error): - if error: - with pytest.raises(ValueError): - _parse_answer_status_code(chunk) - else: - assert _parse_answer_status_code(chunk) == status_code - - @pytest.mark.parametrize( "chat_features,find_features", [ @@ -176,9 +104,7 @@ async def test_get_find_results_vector_search_is_optional(predict, chat_features @pytest.mark.parametrize( "raw_text_answer,status_code,audit_answer", [ - (b"foobar_CIT_blabla", AnswerStatusCode.NO_CONTEXT, None), - (b"foobar_CIT_blabla", AnswerStatusCode.SUCCESS, "foobar"), - (b"foobar_CIT_blabla", None, "foobar"), + (b"foobar", AnswerStatusCode.NO_CONTEXT, None), (b"foobar", AnswerStatusCode.SUCCESS, "foobar"), ], ) diff --git a/nucliadb/tests/search/unit/test_predict.py b/nucliadb/tests/search/unit/test_predict.py index aca5767f5a..5e83704a7e 100644 --- a/nucliadb/tests/search/unit/test_predict.py +++ b/nucliadb/tests/search/unit/test_predict.py @@ -62,7 +62,7 @@ async def test_dummy_predict_engine(): await pe.finalize() await pe.send_feedback("kbid", Mock(), "", "", "") assert await pe.rephrase_query("kbid", Mock()) - assert await pe.chat_query("kbid", Mock()) + assert await pe.chat_query_ndjson("kbid", Mock()) assert await pe.detect_entities("kbid", "some sentence") assert await pe.summarize("kbid", Mock(resources={})) @@ -144,7 +144,7 @@ def session_limits_exceeded(): "method,args", [ ("detect_entities", ["kbid", "sentence"]), - ("chat_query", ["kbid", ChatModel(question="foo", user_id="bar")]), + ("chat_query_ndjson", ["kbid", ChatModel(question="foo", user_id="bar")]), ( "send_feedback", [ @@ -173,7 +173,7 @@ async def test_predict_engine_handles_limits_exceeded_error(session_limits_excee @pytest.mark.parametrize( "method,args,exception,output", [ - ("chat_query", ["kbid", Mock()], True, None), + ("chat_query_ndjson", ["kbid", Mock()], True, None), ("rephrase_query", ["kbid", Mock()], True, None), ("send_feedback", ["kbid", MagicMock(), "", "", ""], False, None), ("detect_entities", ["kbid", "sentence"], False, []), diff --git a/nucliadb_utils/src/nucliadb_utils/const.py b/nucliadb_utils/src/nucliadb_utils/const.py index c9178cdd4e..bc944e0531 100644 --- a/nucliadb_utils/src/nucliadb_utils/const.py +++ b/nucliadb_utils/src/nucliadb_utils/const.py @@ -84,5 +84,4 @@ class Features: VECTORSETS_V0 = "vectorsets_v0_new_kbs_with_multiple_vectorsets" SKIP_EXTERNAL_INDEX = "nucliadb_skip_external_index" NATS_SYNC_ACK = "nucliadb_nats_sync_ack" - DEPRECATED_CHAT_ENABLED = "nucliadb_deprecated_chat_enabled" LOG_REQUEST_PAYLOADS = "nucliadb_log_request_payloads" diff --git a/nucliadb_utils/src/nucliadb_utils/featureflagging.py b/nucliadb_utils/src/nucliadb_utils/featureflagging.py index a366b68b90..11fb2391f5 100644 --- a/nucliadb_utils/src/nucliadb_utils/featureflagging.py +++ b/nucliadb_utils/src/nucliadb_utils/featureflagging.py @@ -89,10 +89,6 @@ class Settings(pydantic_settings.BaseSettings): "rollout": 0, "variants": {"environment": ["none"]}, }, - const.Features.DEPRECATED_CHAT_ENABLED: { - "rollout": 0, - "variants": {"environment": ["none"]}, - }, }