Skip to content

Commit

Permalink
Fix audit chat when citations are present (#1670)
Browse files Browse the repository at this point in the history
* fix audit chat when citations are present

* make sure audit logic is covered in tests

* Refactor header into a const

* increase coverage of chat
  • Loading branch information
lferran authored Dec 15, 2023
1 parent 40663e4 commit 45dfe19
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 12 deletions.
7 changes: 5 additions & 2 deletions nucliadb/nucliadb/search/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from nucliadb.search import logger, predict
from nucliadb.search.api.v1.router import KB_PREFIX, api
from nucliadb.search.predict import AnswerStatusCode
from nucliadb.search.search.chat.query import chat, get_relations_results
from nucliadb.search.search.chat.query import (
START_OF_CITATIONS,
chat,
get_relations_results,
)
from nucliadb.search.search.exceptions import (
IncompleteFindResultsError,
InvalidQueryError,
Expand All @@ -49,7 +53,6 @@
from nucliadb_utils.exceptions import LimitsExceededError

END_OF_STREAM = "_END_"
START_OF_CITATIONS = b"_CIT_"


class SyncChatResponse(pydantic.BaseModel):
Expand Down
6 changes: 4 additions & 2 deletions nucliadb/nucliadb/search/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class RephraseMissingContextError(Exception):
REPHRASE = "/rephrase"
FEEDBACK = "/feedback"

NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"


predict_observer = metrics.Observer(
"predict_engine",
Expand Down Expand Up @@ -321,7 +323,7 @@ async def chat_query(
timeout=None,
)
await self.check_response(resp, expected_status=200)
ident = resp.headers.get("NUCLIA-LEARNING-ID")
ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER)
return ident, get_answer_generator(resp)

@predict_observer.wrap({"type": "ask_document"})
Expand Down Expand Up @@ -434,7 +436,7 @@ async def make_request(self, method: str, **request_args):
self.calls.append((method, request_args))
response = Mock(status=200)
response.json = AsyncMock(return_value={"foo": "bar"})
response.headers = {"NUCLIA-LEARNING-ID": DUMMY_LEARNING_ID}
response.headers = {NUCLIA_LEARNING_ID_HEADER: DUMMY_LEARNING_ID}
return response

async def send_feedback(
Expand Down
20 changes: 17 additions & 3 deletions nucliadb/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

NOT_ENOUGH_CONTEXT_ANSWER = "Not enough data to answer this."
AUDIT_TEXT_RESULT_SEP = " \n\n "
START_OF_CITATIONS = b"_CIT_"


class FoundStatusCode:
Expand Down Expand Up @@ -320,9 +321,7 @@ async def maybe_audit_chat(
if audit is None:
return

audit_answer: Optional[str] = text_answer.decode()
if status_code == AnswerStatusCode.NO_CONTEXT:
audit_answer = None
audit_answer = parse_audit_answer(text_answer, status_code)

# Append chat history and query context
audit_context = [
Expand All @@ -346,3 +345,18 @@ async def maybe_audit_chat(
context=audit_context,
answer=audit_answer,
)


def parse_audit_answer(
raw_text_answer: bytes, status_code: Optional[AnswerStatusCode]
) -> Optional[str]:
if status_code == AnswerStatusCode.NO_CONTEXT:
# We don't want to audit "Not enough context to answer this." and instead set a None.
return None
# Split citations part from answer
try:
raw_audit_answer, _ = raw_text_answer.split(START_OF_CITATIONS)
except ValueError:
raw_audit_answer = raw_text_answer
audit_answer = raw_audit_answer.decode()
return audit_answer
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_parse_answer_status_code,
chat,
get_find_results,
parse_audit_answer,
)
from nucliadb_models.search import (
ChatOptions,
Expand Down Expand Up @@ -143,3 +144,16 @@ async def test_get_find_results_vector_search_is_optional(
)
find_request = find_mock.call_args[0][1]
assert set(find_request.features) == set(find_features)


@pytest.mark.parametrize(
"raw_text_answer,status_code,audit_answer",
[
(b"foobar_CIT_blabla", AnswerStatusCode.NO_CONTEXT, None),
(b"foobar_CIT_blabla", AnswerStatusCode.SUCCESS, "foobar"),
(b"foobar_CIT_blabla", None, "foobar"),
(b"foobar", AnswerStatusCode.SUCCESS, "foobar"),
],
)
def test_parse_audit_answer(raw_text_answer, status_code, audit_answer):
assert parse_audit_answer(raw_text_answer, status_code) == audit_answer
18 changes: 18 additions & 0 deletions nucliadb/nucliadb/search/tests/unit/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import asyncio
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, Mock

Expand All @@ -34,6 +35,7 @@
RephraseMissingContextError,
SendToPredictError,
_parse_rephrase_response,
get_answer_generator,
)
from nucliadb.tests.utils.aiohttp_session import get_mocked_session
from nucliadb_models.search import (
Expand Down Expand Up @@ -469,3 +471,19 @@ async def test_get_predict_headers(onprem, txn):
assert predict_headers["X-STF-VISUAL-LABELING"] == kb_config.visual_labeling
else:
assert predict_headers == {"X-STF-KBID": "kbid"}


async def test_get_answer_generator():
async def _iter_chunks():
await asyncio.sleep(0.1)
# Chunk, end_of_chunk
yield b"foo", False
yield b"bar", True
yield b"baz", True

resp = Mock()
resp.content.iter_chunks = Mock(return_value=_iter_chunks())
get_answer_generator(resp)

answer_chunks = [chunk async for chunk in get_answer_generator(resp)]
assert answer_chunks == [b"foobar", b"baz"]
14 changes: 9 additions & 5 deletions nucliadb/nucliadb/tests/integration/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
from nucliadb.search.utilities import get_predict


@pytest.fixture(scope="function", autouse=True)
def audit():
audit_mock = mock.Mock(chat=mock.AsyncMock())
with mock.patch(
"nucliadb.search.search.chat.query.get_audit", return_value=audit_mock
):
yield audit_mock


@pytest.mark.asyncio()
@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
async def test_chat(
Expand Down Expand Up @@ -86,11 +95,6 @@ async def resource(nucliadb_writer, knowledgebox):
)
assert resp.status_code in (200, 201)
rid = resp.json()["uuid"]

import asyncio

await asyncio.sleep(1)

yield rid


Expand Down

1 comment on commit 45dfe19

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 45dfe19 Previous: 5a633b0 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12801.053779023852 iter/sec (stddev: 2.87681053873229e-7) 12745.686329086004 iter/sec (stddev: 1.7317806991721728e-7) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.