Skip to content

Commit

Permalink
Add summarize to sdk (#1731)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Jan 16, 2024
1 parent 4804ba2 commit 14ca50a
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 5 deletions.
13 changes: 9 additions & 4 deletions nucliadb/nucliadb/search/search/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
)
from nucliadb_utils.utilities import get_storage

ExtractedTexts = list[tuple[str, str, Optional[ExtractedText]]]

MAX_GET_EXTRACTED_TEXT_OPS = 20


Expand All @@ -59,10 +61,8 @@ async def summarize(kbid: str, request: SummarizeRequest) -> SummarizedResponse:
return await predict.summarize(kbid, predict_request)


async def get_extracted_texts(
kbid: str, resource_uuids: list[str]
) -> list[tuple[str, str, Optional[ExtractedText]]]:
results = []
async def get_extracted_texts(kbid: str, resource_uuids: list[str]) -> ExtractedTexts:
results: ExtractedTexts = []

driver = get_driver()
storage = await get_storage()
Expand All @@ -83,6 +83,11 @@ async def get_extracted_texts(
for _, field in fields.items():
task = asyncio.create_task(get_extracted_text(rid, field, max_tasks))
tasks.append(task)

if len(tasks) == 0:
# No extracted text to get
return results

done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)

# Parse the task results
Expand Down
2 changes: 1 addition & 1 deletion nucliadb_models/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ class SummarizedResponse(BaseModel):
default={}, title="Resources", description="Individual resource summaries"
)
summary: str = Field(
default="", title="Summary", description="Globla summary of all resources"
default="", title="Summary", description="Global summary of all resources"
)


Expand Down
1 change: 1 addition & 0 deletions nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@ def test_search_endpoints(sdk: nucliadb_sdk.NucliaDB, kb):
resource = sdk.create_resource(kbid=kb.uuid, title="Resource", slug="resource")
sdk.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
sdk.feedback(kbid=kb.uuid, ident="bar", good=True, feedback="baz", task="CHAT")
sdk.summarize(kbid=kb.uuid, resources=["foobar"])
1 change: 1 addition & 0 deletions nucliadb_sdk/nucliadb_sdk/tests/test_sdk_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ async def test_search_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb):
await sdk_async.feedback(
kbid=kb.uuid, ident="bar", good=True, feedback="baz", task=FeedbackTasks.CHAT
)
await sdk_async.summarize(kbid=kb.uuid, resources=["foobar"])
32 changes: 32 additions & 0 deletions nucliadb_sdk/nucliadb_sdk/tests/test_summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2021 Bosutech XXI S.L.
#
# nucliadb is offered under the AGPL v3.0 and as commercial software.
# For commercial licensing, contact us at [email protected].
#
# AGPL:
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import nucliadb_sdk
from nucliadb_models.search import KnowledgeboxFindResults, SummarizeRequest


def test_summarize(docs_dataset, sdk: nucliadb_sdk.NucliaDB):
results: KnowledgeboxFindResults = sdk.find(kbid=docs_dataset, query="love")
resource_uuids = [uuid for uuid in results.resources.keys()]

response = sdk.summarize(kbid=docs_dataset, resources=[resource_uuids[0]])
assert response.summary == "global summary"

content = SummarizeRequest(resources=[resource_uuids[0]])
response = sdk.summarize(kbid=docs_dataset, content=content)
assert response.summary == "global summary"
13 changes: 13 additions & 0 deletions nucliadb_sdk/nucliadb_sdk/v2/docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ class Docstring(BaseModel):
],
)

SUMMARIZE = Docstring(
doc="""Summarize your documents""",
examples=[
Example(
description="Get a summary of a document or a list of documents",
code=""">>> summary = sdk.summarize(kbid="mykbid", resources=["uuid1"]).summary
>>> print(summary)
'The document talks about Seville and its temperature. It also mentions the coldest month of the year, which is January.' # noqa
""",
),
],
)


DELETE_LABELSET = Docstring(
doc="Delete a specific set of labels",
Expand Down
12 changes: 12 additions & 0 deletions nucliadb_sdk/nucliadb_sdk/v2/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
KnowledgeboxSearchResults,
Relations,
SearchRequest,
SummarizedResponse,
SummarizeRequest,
)
from nucliadb_models.vectors import VectorSet, VectorSets
from nucliadb_models.writer import (
Expand Down Expand Up @@ -652,6 +654,16 @@ def _check_response(self, response: httpx.Response):
response_type=chat_response_parser,
docstring=docstrings.RESOURCE_CHAT,
)
summarize = _request_builder(
name="summarize",
path_template="/v1/kb/{kbid}/summarize",
method="POST",
path_params=("kbid",),
request_type=SummarizeRequest,
response_type=SummarizedResponse,
docstring=docstrings.SUMMARIZE,
)

feedback = _request_builder(
name="feedback",
path_template="/v1/kb/{kbid}/feedback",
Expand Down

1 comment on commit 14ca50a

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 14ca50a Previous: 5a633b0 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12716.022026413211 iter/sec (stddev: 6.182663650076198e-7) 12745.686329086004 iter/sec (stddev: 1.7317806991721728e-7) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.