From 45dfe197ba1337211461c0742b50ba8b32c4de0f Mon Sep 17 00:00:00 2001 From: Ferran Llamas Date: Fri, 15 Dec 2023 09:36:16 +0100 Subject: [PATCH] Fix audit chat when citations are present (#1670) * fix audit chat when citations are present * make sure audit logic is covered in tests * Refactor header into a const * increase coverage of chat --- nucliadb/nucliadb/search/api/v1/chat.py | 7 +++++-- nucliadb/nucliadb/search/predict.py | 6 ++++-- nucliadb/nucliadb/search/search/chat/query.py | 20 ++++++++++++++++--- .../unit/search/search/chat/test_query.py | 14 +++++++++++++ .../search/tests/unit/test_predict.py | 18 +++++++++++++++++ .../nucliadb/tests/integration/test_chat.py | 14 ++++++++----- 6 files changed, 67 insertions(+), 12 deletions(-) diff --git a/nucliadb/nucliadb/search/api/v1/chat.py b/nucliadb/nucliadb/search/api/v1/chat.py index da39d91ecf..c36d9f833f 100644 --- a/nucliadb/nucliadb/search/api/v1/chat.py +++ b/nucliadb/nucliadb/search/api/v1/chat.py @@ -31,7 +31,11 @@ from nucliadb.search import logger, predict from nucliadb.search.api.v1.router import KB_PREFIX, api from nucliadb.search.predict import AnswerStatusCode -from nucliadb.search.search.chat.query import chat, get_relations_results +from nucliadb.search.search.chat.query import ( + START_OF_CITATIONS, + chat, + get_relations_results, +) from nucliadb.search.search.exceptions import ( IncompleteFindResultsError, InvalidQueryError, @@ -49,7 +53,6 @@ from nucliadb_utils.exceptions import LimitsExceededError END_OF_STREAM = "_END_" -START_OF_CITATIONS = b"_CIT_" class SyncChatResponse(pydantic.BaseModel): diff --git a/nucliadb/nucliadb/search/predict.py b/nucliadb/nucliadb/search/predict.py index 0047f4c66b..cf0b9caf69 100644 --- a/nucliadb/nucliadb/search/predict.py +++ b/nucliadb/nucliadb/search/predict.py @@ -95,6 +95,8 @@ class RephraseMissingContextError(Exception): REPHRASE = "/rephrase" FEEDBACK = "/feedback" +NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID" + predict_observer = metrics.Observer( "predict_engine", @@ -321,7 +323,7 @@ async def chat_query( timeout=None, ) await self.check_response(resp, expected_status=200) - ident = resp.headers.get("NUCLIA-LEARNING-ID") + ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER) return ident, get_answer_generator(resp) @predict_observer.wrap({"type": "ask_document"}) @@ -434,7 +436,7 @@ async def make_request(self, method: str, **request_args): self.calls.append((method, request_args)) response = Mock(status=200) response.json = AsyncMock(return_value={"foo": "bar"}) - response.headers = {"NUCLIA-LEARNING-ID": DUMMY_LEARNING_ID} + response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID} return response async def send_feedback( diff --git a/nucliadb/nucliadb/search/search/chat/query.py b/nucliadb/nucliadb/search/search/chat/query.py index 20cd247daf..d2e286dbf6 100644 --- a/nucliadb/nucliadb/search/search/chat/query.py +++ b/nucliadb/nucliadb/search/search/chat/query.py @@ -53,6 +53,7 @@ NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this." AUDIT_TEXT_RESULT_SEP = " \n\n " +START_OF_CITATIONS = b"_CIT_" class FoundStatusCode: @@ -320,9 +321,7 @@ async def maybe_audit_chat( if audit is None: return - audit_answer: Optional[str] = text_answer.decode() - if status_code == AnswerStatusCode.NO_CONTEXT: - audit_answer = None + audit_answer = parse_audit_answer(text_answer, status_code) # Append chat history and query context audit_context = [ @@ -346,3 +345,18 @@ async def maybe_audit_chat( context=audit_context, answer=audit_answer, ) + + +def parse_audit_answer( + raw_text_answer: bytes, status_code: Optional[AnswerStatusCode] +) -> Optional[str]: + if status_code == AnswerStatusCode.NO_CONTEXT: + # We don't want to audit "Not enough context to answer this." and instead set a None. + return None + # Split citations part from answer + try: + raw_audit_answer, _ = raw_text_answer.split(START_OF_CITATIONS) + except ValueError: + raw_audit_answer = raw_text_answer + audit_answer = raw_audit_answer.decode() + return audit_answer diff --git a/nucliadb/nucliadb/search/tests/unit/search/search/chat/test_query.py b/nucliadb/nucliadb/search/tests/unit/search/search/chat/test_query.py index cdf820f672..4c84093ebc 100644 --- a/nucliadb/nucliadb/search/tests/unit/search/search/chat/test_query.py +++ b/nucliadb/nucliadb/search/tests/unit/search/search/chat/test_query.py @@ -25,6 +25,7 @@ _parse_answer_status_code, chat, get_find_results, + parse_audit_answer, ) from nucliadb_models.search import ( ChatOptions, @@ -143,3 +144,16 @@ async def test_get_find_results_vector_search_is_optional( ) find_request = find_mock.call_args[0][1] assert set(find_request.features) == set(find_features) + + +@pytest.mark.parametrize( + "raw_text_answer,status_code,audit_answer", + [ + (b"foobar_CIT_blabla", AnswerStatusCode.NO_CONTEXT, None), + (b"foobar_CIT_blabla", AnswerStatusCode.SUCCESS, "foobar"), + (b"foobar_CIT_blabla", None, "foobar"), + (b"foobar", AnswerStatusCode.SUCCESS, "foobar"), + ], +) +def test_parse_audit_answer(raw_text_answer, status_code, audit_answer): + assert parse_audit_answer(raw_text_answer, status_code) == audit_answer diff --git a/nucliadb/nucliadb/search/tests/unit/test_predict.py b/nucliadb/nucliadb/search/tests/unit/test_predict.py index 9c455cb257..1cdf71c401 100644 --- a/nucliadb/nucliadb/search/tests/unit/test_predict.py +++ b/nucliadb/nucliadb/search/tests/unit/test_predict.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +import asyncio from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock @@ -34,6 +35,7 @@ RephraseMissingContextError, SendToPredictError, _parse_rephrase_response, + get_answer_generator, ) from nucliadb.tests.utils.aiohttp_session import get_mocked_session from nucliadb_models.search import ( @@ -469,3 +471,19 @@ async def test_get_predict_headers(onprem, txn): assert predict_headers["X-STF-VISUAL-LABELING"] == kb_config.visual_labeling else: assert predict_headers == {"X-STF-KBID": "kbid"} + + +async def test_get_answer_generator(): + async def _iter_chunks(): + await asyncio.sleep(0.1) + # Chunk, end_of_chunk + yield b"foo", False + yield b"bar", True + yield b"baz", True + + resp = Mock() + resp.content.iter_chunks = Mock(return_value=_iter_chunks()) + get_answer_generator(resp) + + answer_chunks = [chunk async for chunk in get_answer_generator(resp)] + assert answer_chunks == [b"foobar", b"baz"] diff --git a/nucliadb/nucliadb/tests/integration/test_chat.py b/nucliadb/nucliadb/tests/integration/test_chat.py index 7ba19b5e74..4a20b7157e 100644 --- a/nucliadb/nucliadb/tests/integration/test_chat.py +++ b/nucliadb/nucliadb/tests/integration/test_chat.py @@ -30,6 +30,15 @@ from nucliadb.search.utilities import get_predict +@pytest.fixture(scope="function", autouse=True) +def audit(): + audit_mock = mock.Mock(chat=mock.AsyncMock()) + with mock.patch( + "nucliadb.search.search.chat.query.get_audit", return_value=audit_mock + ): + yield audit_mock + + @pytest.mark.asyncio() @pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True) async def test_chat( @@ -86,11 +95,6 @@ async def resource(nucliadb_writer, knowledgebox): ) assert resp.status_code in (200, 201) rid = resp.json()["uuid"] - - import asyncio - - await asyncio.sleep(1) - yield rid