Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split history context from retrieved contexts #2393

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from nucliadb.search.search.query import QueryParser
from nucliadb.search.utilities import get_predict
from nucliadb_models.search import (
Author,
ChatContextMessage,
ChatModel,
ChatOptions,
Expand Down Expand Up @@ -418,18 +417,17 @@ def maybe_audit_chat(
return

audit_answer = parse_audit_answer(text_answer, status_code)

# Append chat history and query context
audit_context = [
# Append chat history
chat_history_context = [
audit_pb2.ChatContext(author=message.author, text=message.text) for message in chat_history
]
query_context_paragaph_ids = list(query_context.values())
audit_context.append(
audit_pb2.ChatContext(
author=Author.NUCLIA,
text=AUDIT_TEXT_RESULT_SEP.join(query_context_paragaph_ids),
)
)

# Append paragraphs retrieved on this chat
chat_retrieved_context = [
audit_pb2.RetrievedContext(text_block_id=paragraph_id, text=text)
for paragraph_id, text in query_context.items()
]

audit.chat(
kbid,
user_id,
Expand All @@ -440,7 +438,8 @@ def maybe_audit_chat(
generative_answer_first_chunk_time=generative_answer_first_chunk_time,
rephrase_time=rephrase_time,
rephrased_question=rephrased_query,
context=audit_context,
chat_context=chat_history_context,
retrieved_context=chat_retrieved_context,
answer=audit_answer,
learning_id=learning_id,
)
Expand Down
14 changes: 14 additions & 0 deletions nucliadb/tests/nucliadb/integration/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,17 @@ async def test_ask_with_json_schema_output(
answer_json = results[1].item.object
assert answer_json["answer"] == "valid answer to"
assert answer_json["confidence"] == 0.5


@pytest.mark.asyncio()
@pytest.mark.parametrize("knowledgebox", ("EXPERIMENTAL", "STABLE"), indirect=True)
async def test_ask_assert_audit_retrieval_contexts(
nucliadb_reader: AsyncClient, knowledgebox, resources, audit
):
resp = await nucliadb_reader.post(f"/kb/{knowledgebox}/ask", json={"query": "title", "debug": True})
assert resp.status_code == 200

retrieved_context = audit.chat.call_args_list[0].kwargs["retrieved_context"]
assert {(f"{rid}/a/title/0-11", f"The title {i}") for i, rid in enumerate(resources)} == {
(a.text_block_id, a.text) for a in retrieved_context
}
12 changes: 11 additions & 1 deletion nucliadb_protos/audit.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,21 @@ message ChatContext {
string text = 2;
}

message RetrievedContext {
string text_block_id = 1;
string text = 2;

}

message ChatAudit {
string question = 1;
optional string answer = 2;
optional string rephrased_question = 3;
repeated ChatContext context = 4;
// Conversation from chats
repeated ChatContext context = 4 [deprecated = true];
// context retrieved on the current ask
repeated ChatContext chat_context = 6;
repeated RetrievedContext retrieved_context = 8;
string learning_id = 5;
}

Expand Down
22 changes: 13 additions & 9 deletions nucliadb_protos/python/src/nucliadb_protos/audit_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 32 additions & 2 deletions nucliadb_protos/python/src/nucliadb_protos/audit_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ class ChatContext(google.protobuf.message.Message):

global___ChatContext = ChatContext

@typing.final
class RetrievedContext(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

TEXT_BLOCK_ID_FIELD_NUMBER: builtins.int
TEXT_FIELD_NUMBER: builtins.int
text_block_id: builtins.str
text: builtins.str
def __init__(
self,
*,
text_block_id: builtins.str = ...,
text: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["text", b"text", "text_block_id", b"text_block_id"]) -> None: ...

global___RetrievedContext = RetrievedContext

@typing.final
class ChatAudit(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand All @@ -135,24 +153,36 @@ class ChatAudit(google.protobuf.message.Message):
ANSWER_FIELD_NUMBER: builtins.int
REPHRASED_QUESTION_FIELD_NUMBER: builtins.int
CONTEXT_FIELD_NUMBER: builtins.int
CHAT_CONTEXT_FIELD_NUMBER: builtins.int
RETRIEVED_CONTEXT_FIELD_NUMBER: builtins.int
LEARNING_ID_FIELD_NUMBER: builtins.int
question: builtins.str
answer: builtins.str
rephrased_question: builtins.str
learning_id: builtins.str
@property
def context(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ChatContext]: ...
def context(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ChatContext]:
"""Conversation from chats"""

@property
def chat_context(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ChatContext]:
"""context retrieved on the current ask"""

@property
def retrieved_context(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RetrievedContext]: ...
def __init__(
self,
*,
question: builtins.str = ...,
answer: builtins.str | None = ...,
rephrased_question: builtins.str | None = ...,
context: collections.abc.Iterable[global___ChatContext] | None = ...,
chat_context: collections.abc.Iterable[global___ChatContext] | None = ...,
retrieved_context: collections.abc.Iterable[global___RetrievedContext] | None = ...,
learning_id: builtins.str = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["_answer", b"_answer", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "rephrased_question", b"rephrased_question"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["_answer", b"_answer", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "context", b"context", "learning_id", b"learning_id", "question", b"question", "rephrased_question", b"rephrased_question"]) -> None: ...
def ClearField(self, field_name: typing.Literal["_answer", b"_answer", "_rephrased_question", b"_rephrased_question", "answer", b"answer", "chat_context", b"chat_context", "context", b"context", "learning_id", b"learning_id", "question", b"question", "rephrased_question", b"rephrased_question", "retrieved_context", b"retrieved_context"]) -> None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_answer", b"_answer"]) -> typing.Literal["answer"] | None: ...
@typing.overload
Expand Down
1 change: 1 addition & 0 deletions nucliadb_protos/python/src/nucliadb_protos/writer_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ from nucliadb_protos.audit_pb2 import (
ClientType as ClientType,
DASHBOARD as DASHBOARD,
DESKTOP as DESKTOP,
RetrievedContext as RetrievedContext,
WEB as WEB,
WIDGET as WIDGET,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from nucliadb_protos.audit_pb2 import (
ClientType as ClientType,
DASHBOARD as DASHBOARD,
DESKTOP as DESKTOP,
RetrievedContext as RetrievedContext,
WEB as WEB,
WIDGET as WIDGET,
)
Expand Down
15 changes: 15 additions & 0 deletions nucliadb_protos/rust/src/audit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,30 @@ pub struct ChatContext {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RetrievedContext {
#[prost(string, tag = "1")]
pub text_block_id: ::prost::alloc::string::String,
#[prost(string, tag = "2")]
pub text: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct ChatAudit {
#[prost(string, tag = "1")]
pub question: ::prost::alloc::string::String,
#[prost(string, optional, tag = "2")]
pub answer: ::core::option::Option<::prost::alloc::string::String>,
#[prost(string, optional, tag = "3")]
pub rephrased_question: ::core::option::Option<::prost::alloc::string::String>,
/// Conversation from chats
#[deprecated]
#[prost(message, repeated, tag = "4")]
pub context: ::prost::alloc::vec::Vec<ChatContext>,
/// context retrieved on the current ask
#[prost(message, repeated, tag = "6")]
pub chat_context: ::prost::alloc::vec::Vec<ChatContext>,
#[prost(message, repeated, tag = "8")]
pub retrieved_context: ::prost::alloc::vec::Vec<RetrievedContext>,
#[prost(string, tag = "5")]
pub learning_id: ::prost::alloc::string::String,
}
Expand Down
9 changes: 3 additions & 6 deletions nucliadb_utils/src/nucliadb_utils/audit/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@

from google.protobuf.timestamp_pb2 import Timestamp

from nucliadb_protos.audit_pb2 import (
AuditField,
AuditRequest,
ChatContext,
)
from nucliadb_protos.audit_pb2 import AuditField, AuditRequest, ChatContext, RetrievedContext
from nucliadb_protos.nodereader_pb2 import SearchRequest
from nucliadb_protos.resources_pb2 import FieldID

Expand Down Expand Up @@ -85,7 +81,8 @@ def chat(
origin: str,
question: str,
rephrased_question: Optional[str],
context: List[ChatContext],
chat_context: List[ChatContext],
retrieved_context: List[RetrievedContext],
answer: Optional[str],
learning_id: str,
rephrase_time: Optional[float] = None,
Expand Down
9 changes: 3 additions & 6 deletions nucliadb_utils/src/nucliadb_utils/audit/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@

from google.protobuf.timestamp_pb2 import Timestamp

from nucliadb_protos.audit_pb2 import (
AuditField,
AuditRequest,
ChatContext,
)
from nucliadb_protos.audit_pb2 import AuditField, AuditRequest, ChatContext, RetrievedContext
from nucliadb_protos.nodereader_pb2 import SearchRequest
from nucliadb_protos.resources_pb2 import FieldID
from nucliadb_protos.writer_pb2 import BrokerMessage
Expand Down Expand Up @@ -86,7 +82,8 @@ def chat(
origin: str,
question: str,
rephrased_question: Optional[str],
context: List[ChatContext],
chat_context: List[ChatContext],
retrieved_context: List[RetrievedContext],
answer: Optional[str],
learning_id: str,
rephrase_time: Optional[float] = None,
Expand Down
13 changes: 5 additions & 8 deletions nucliadb_utils/src/nucliadb_utils/audit/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp

from nucliadb_protos.audit_pb2 import (
AuditField,
AuditRequest,
ChatContext,
ClientType,
)
from nucliadb_protos.audit_pb2 import AuditField, AuditRequest, ChatContext, ClientType, RetrievedContext
from nucliadb_protos.nodereader_pb2 import SearchRequest
from nucliadb_protos.resources_pb2 import FieldID
from nucliadb_utils import logger
Expand Down Expand Up @@ -380,7 +375,8 @@ def chat(
origin: str,
question: str,
rephrased_question: Optional[str],
context: List[ChatContext],
chat_context: List[ChatContext],
retrieved_context: List[RetrievedContext],
answer: Optional[str],
learning_id: str,
rephrase_time: Optional[float] = None,
Expand All @@ -405,7 +401,8 @@ def chat(
auditrequest.generative_answer_first_chunk_time = generative_answer_first_chunk_time
auditrequest.type = AuditRequest.CHAT
auditrequest.chat.question = question
auditrequest.chat.context.extend(context)
auditrequest.chat.chat_context.extend(chat_context)
auditrequest.chat.retrieved_context.extend(retrieved_context)
auditrequest.chat.learning_id = learning_id
if rephrased_question is not None:
auditrequest.chat.rephrased_question = rephrased_question
Expand Down
Loading
Loading