Skip to content

Commit

Permalink
Allow users to pass generic context to LLM (#1678)
Browse files Browse the repository at this point in the history
* Add chat extra context option to chat endpoint

* Add chat extra context option to chat endpoint

* Add chat extra context option to chat endpoint

* Add sdk integration tests too, as doc

* Add sdk integration tests too, as doc

* fix docstrings
  • Loading branch information
lferran authored Dec 19, 2023
1 parent 3958c09 commit a2a9376
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 27 deletions.
22 changes: 18 additions & 4 deletions nucliadb/nucliadb/search/search/chat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,34 @@ async def get_expanded_conversation_messages(


async def get_chat_prompt_context(
kbid: str, results: KnowledgeboxFindResults
kbid: str,
results: KnowledgeboxFindResults,
user_context: Optional[list[str]] = None,
) -> dict[str, str]:
"""
- Returns an ordered dict of context_id -> context_text.
- context_id is typically the paragraph id, but has a special value for the
user context. (USER_CONTEXT_0, USER_CONTEXT_1, ...)
- Paragraphs are inserted in order of relevance, by increasing `order` field
of the find result paragraphs.
- User context is inserted first, in order of appearance.
- Using an dict prevents from duplicates pulled in through conversation expansion.
"""
output = {}
# Chat extra context passed by the user is the most important, therefore
for i, context in enumerate(user_context or []):
output[f"USER_CONTEXT_{i}"] = context

# Sort retrieved paragraphs by decreasing order (most relevant first)
ordered_paras = []
for result in results.resources.values():
for field_path, field in result.fields.items():
for paragraph in field.paragraphs.values():
ordered_paras.append((field_path, paragraph))

ordered_paras.sort(key=lambda x: x[1].order, reverse=False)

driver = get_driver()
storage = await get_storage()
# ordered dict that prevents duplicates pulled in through conversation expansion
output = {}
async with driver.transaction() as txn:
kb = KnowledgeBoxORM(txn, storage, kbid)
for field_path, paragraph in ordered_paras:
Expand Down
28 changes: 19 additions & 9 deletions nucliadb/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,19 @@ class ChatResult:
find_results: KnowledgeboxFindResults


async def rephrase_query_from_chat_history(
async def rephrase_query(
kbid: str,
chat_history: List[ChatContextMessage],
query: str,
user_id: str,
user_context: List[str],
) -> str:
predict = get_predict()
req = RephraseModel(
question=query,
chat_history=chat_history,
user_id=user_id,
user_context=user_context,
)
return await predict.rephrase_query(kbid, req)

Expand Down Expand Up @@ -201,20 +203,25 @@ async def chat(
client_type: NucliaDBClientType,
origin: str,
) -> ChatResult:
start_time = time()
nuclia_learning_id: Optional[str] = None
chat_history = chat_request.context or []
start_time = time()

user_context = chat_request.extra_context or []
user_query = chat_request.query
rephrased_query = None
if chat_request.context and len(chat_request.context) > 0:
rephrased_query = await rephrase_query_from_chat_history(
kbid, chat_request.context, user_query, user_id

if len(chat_history) > 0 or len(user_context) > 0:
rephrased_query = await rephrase_query(
kbid,
chat_history=chat_history,
query=user_query,
user_id=user_id,
user_context=user_context,
)

find_results: KnowledgeboxFindResults = await get_find_results(
kbid=kbid,
query=rephrased_query or user_query or "",
query=rephrased_query or user_query,
chat_request=chat_request,
ndb_client=client_type,
user=user_id,
Expand All @@ -227,20 +234,23 @@ async def chat(
not_enough_context_generator(), status_code
)
else:
query_context = await get_chat_prompt_context(kbid, find_results)
query_context = await get_chat_prompt_context(
kbid, find_results, user_context=user_context
)
query_context_order = {
paragraph_id: order
for order, paragraph_id in enumerate(query_context.keys())
}
user_prompt = None
if chat_request.prompt is not None:
user_prompt = UserPrompt(prompt=chat_request.prompt)

chat_model = ChatModel(
user_id=user_id,
query_context=query_context,
query_context_order=query_context_order,
chat_history=chat_history,
question=chat_request.query,
question=user_query,
truncate=True,
user_prompt=user_prompt,
citations=chat_request.citations,
Expand Down
26 changes: 18 additions & 8 deletions nucliadb/nucliadb/search/tests/unit/search/test_chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def test_get_expanded_conversation_messages_missing(kb, messages):


def _create_find_result(
_id: str, result_text: str, score_type: SCORE_TYPE = SCORE_TYPE.BM25
_id: str, result_text: str, score_type: SCORE_TYPE = SCORE_TYPE.BM25, order=1
):
return FindResource(
id=_id.split("/")[0],
Expand All @@ -153,7 +153,7 @@ def _create_find_result(
id=_id,
score=1.0,
score_type=score_type,
order=1,
order=order,
text=result_text,
)
}
Expand All @@ -173,16 +173,26 @@ async def test_get_chat_prompt_context(kb):
facets={},
resources={
"bmid": _create_find_result(
"bmid/c/conv/ident", result_text, SCORE_TYPE.BM25
"bmid/c/conv/ident", result_text, SCORE_TYPE.BM25, order=1
),
"vecid": _create_find_result(
"vecid/c/conv/ident", result_text, SCORE_TYPE.VECTOR
"vecid/c/conv/ident", result_text, SCORE_TYPE.VECTOR, order=2
),
"both_id": _create_find_result(
"both_id/c/conv/ident", result_text, SCORE_TYPE.BOTH, order=0
),
},
min_score=-1,
),
user_context=["Some extra context"],
)
assert prompt_result == {
"bmid/c/conv/ident": result_text,
"vecid/c/conv/ident": result_text,
}
# Check that the results are sorted by increasing order and that the extra
# context is added at the beginning, indicating that it has the most priority
paragraph_ids = [pid for pid in prompt_result.keys()]
assert paragraph_ids == [
"USER_CONTEXT_0",
"both_id/c/conv/ident",
"bmid/c/conv/ident",
"vecid/c/conv/ident",
]
assert prompt_result["USER_CONTEXT_0"] == "Some extra context"
4 changes: 3 additions & 1 deletion nucliadb/nucliadb/search/tests/unit/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ async def test_rephrase():
"POST", 200, json="rephrased", context_manager=False
)

item = RephraseModel(question="question", chat_history=[], user_id="foo")
item = RephraseModel(
question="question", chat_history=[], user_id="foo", user_context=["foo"]
)
rephrased_query = await pe.rephrase_query("kbid", item)
# The rephrase query should not be wrapped in quotes, otherwise it will trigger an exact match query to the index
assert rephrased_query.strip('"') == rephrased_query
Expand Down
8 changes: 7 additions & 1 deletion nucliadb_models/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ class SearchParamDefaults:
title="Chat history",
description="Use to rephrase the new LLM query by taking into account the chat conversation history", # noqa
)

chat_features = ParamDefault(
default=[ChatOptions.VECTORS, ChatOptions.PARAGRAPHS, ChatOptions.RELATIONS],
title="Chat features",
Expand Down Expand Up @@ -696,6 +695,7 @@ class RephraseModel(BaseModel):
question: str
chat_history: List[ChatContextMessage] = []
user_id: str
user_context: List[str] = []


class AskDocumentModel(BaseModel):
Expand Down Expand Up @@ -735,6 +735,12 @@ class ChatRequest(BaseModel):
context: Optional[
List[ChatContextMessage]
] = SearchParamDefaults.chat_context.to_pydantic_field()
extra_context: Optional[List[str]] = Field(
default=None,
title="Extra query context",
description="""Additional context that is added to the retrieval context sent to the LLM.
It allows extending the chat feature with content that may not be in the Knowledge Box.""",
)
autofilter: bool = SearchParamDefaults.autofilter.to_pydantic_field()
highlight: bool = SearchParamDefaults.highlight.to_pydantic_field()
resource_filters: List[
Expand Down
6 changes: 5 additions & 1 deletion nucliadb_sdk/nucliadb_sdk/tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ def test_chat_on_kb(docs_dataset, sdk: nucliadb_sdk.NucliaDB):
kbid=docs_dataset,
query="Nuclia loves Semantic Search",
prompt="Given this context: {context}. Answer this {question} in a concise way using the provided context",
extra_context=[
"Nuclia is a powerful AI search platform",
"AI Search involves semantic search",
],
)
assert result.learning_id == "00"
assert result.answer == "valid answer to"
assert len(result.result.resources) == 9
assert len(result.result.resources) == 7
assert result.relations
assert len(result.relations.entities["Nuclia"].related_to) == 18

Expand Down
6 changes: 3 additions & 3 deletions nucliadb_sdk/nucliadb_sdk/v2/docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ class Docstring(BaseModel):
),
Example(
description="You can use the `content` parameter to pass previous context to the query",
code=""">>> from nucliadb_models.search import ChatRequest, Message
code=""">>> from nucliadb_models.search import ChatRequest, ChatContextMessage
>>> content = ChatRequest()
>>> content.query = "What is the average temperature?"
>>> content.context.append(Messate(author="USER", text="What is the coldest season in Sevilla?"))
>>> content.context.append(Messate(author="NUCLIA", text="January is the coldest month."))
>>> content.context.append(ChatContextMessage(author="USER", text="What is the coldest season in Sevilla?"))
>>> content.context.append(ChatContextMessage(author="NUCLIA", text="January is the coldest month."))
>>> sdk.chat(kbid="mykbid", content=content).answer
'According to the context, the average temperature in January in Sevilla is 15.9 °C and 5.2 °C.'
""",
Expand Down

1 comment on commit a2a9376

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: a2a9376 Previous: 5a633b0 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12728.254590086523 iter/sec (stddev: 2.0693560511969848e-7) 12745.686329086004 iter/sec (stddev: 1.7317806991721728e-7) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.