Skip to content

Commit

Permalink
Fetch resource title on notifications (#1784)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Jan 30, 2024
1 parent 56e1386 commit d094603
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/reader/api/v1/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def notifications_endpoint(
return HTTPClientError(status_code=404, detail="Knowledge Box not found")

response = StreamingResponse(
content=kb_notifications_stream(kbid),
content=kb_notifications_stream(context, kbid),
status_code=200,
media_type="binary/octet-stream",
)
Expand Down
57 changes: 52 additions & 5 deletions nucliadb/nucliadb/reader/reader/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
import contextlib
import uuid
from collections.abc import AsyncGenerator
from typing import Optional

import async_timeout

from nucliadb.common.context import ApplicationContext
from nucliadb.common.datamanagers.resources import ResourcesDataManager
from nucliadb.common.maindb.driver import Driver
from nucliadb.reader import logger
from nucliadb_models.notifications import (
Notification,
Expand Down Expand Up @@ -55,18 +59,26 @@
}


async def kb_notifications_stream(kbid: str) -> AsyncGenerator[bytes, None]:
async def kb_notifications_stream(
context: ApplicationContext, kbid: str
) -> AsyncGenerator[bytes, None]:
"""
Returns an async generator that yields pubsub notifications for the given kbid.
The generator will return after NOTIFICATIONS_TIMEOUT_S seconds.
"""
try:
resource_cache: dict[str, str] = {}
async with async_timeout.timeout(NOTIFICATIONS_TIMEOUT_S):
async for pb_notification in kb_notifications(kbid):
line = encode_streamed_notification(pb_notification) + b"\n"
notification = await serialize_notification(
context, pb_notification, resource_cache
)
line = encode_streamed_notification(notification) + b"\n"
yield line
except asyncio.TimeoutError:
return
finally: # pragma: no cover
resource_cache.clear()


async def kb_notifications(kbid: str) -> AsyncGenerator[writer_pb2.Notification, None]:
Expand Down Expand Up @@ -131,14 +143,21 @@ async def managed_subscription(pubsub: PubSubDriver, key: str, handler: Callback
)


def serialize_notification(pb: writer_pb2.Notification) -> Notification:
async def serialize_notification(
context: ApplicationContext, pb: writer_pb2.Notification, cache: dict[str, str]
) -> Notification:
kbid = pb.kbid
resource_uuid = pb.uuid
seqid = pb.seqid

resource_title = await get_resource_title_cached(
context.kv_driver, kbid, resource_uuid, cache
)
if pb.action == writer_pb2.Notification.Action.INDEXED:
return ResourceIndexedNotification(
data=ResourceIndexed(
resource_uuid=resource_uuid,
resource_title=resource_title,
seqid=seqid,
)
)
Expand All @@ -151,6 +170,7 @@ def serialize_notification(pb: writer_pb2.Notification) -> Notification:
return ResourceWrittenNotification(
data=ResourceWritten(
resource_uuid=resource_uuid,
resource_title=resource_title,
seqid=seqid,
operation=writer_operation,
error=has_ingestion_error,
Expand All @@ -160,6 +180,7 @@ def serialize_notification(pb: writer_pb2.Notification) -> Notification:
return ResourceProcessedNotification(
data=ResourceProcessed(
resource_uuid=resource_uuid,
resource_title=resource_title,
seqid=seqid,
ingestion_succeeded=not has_ingestion_error,
processing_errors=has_processing_error,
Expand All @@ -169,7 +190,33 @@ def serialize_notification(pb: writer_pb2.Notification) -> Notification:
raise ValueError(f"Unknown notification source: {pb.source}")


def encode_streamed_notification(pb: writer_pb2.Notification) -> bytes:
notification = serialize_notification(pb)
async def get_resource_title_cached(
kv_driver: Driver,
kbid: str,
resource_uuid: str,
cache: dict[str, str],
):
if resource_uuid in cache:
# Cache hit
return cache[resource_uuid]
# Cache miss
resource_title = await get_resource_title(kv_driver, kbid, resource_uuid)
if resource_title is None:
return ""
cache[resource_uuid] = resource_title
return resource_title


async def get_resource_title(
kv_driver: Driver, kbid: str, resource_uuid: str
) -> Optional[str]:
async with kv_driver.transaction(read_only=True) as txn:
basic = await ResourcesDataManager.get_resource_basic(txn, kbid, resource_uuid)
if basic is None:
return None
return basic.title


def encode_streamed_notification(notification: Notification) -> bytes:
encoded_nofication = notification.json().encode("utf-8")
return encoded_nofication
12 changes: 12 additions & 0 deletions nucliadb/nucliadb/reader/tests/integration/api/v1/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
from nucliadb_protos import writer_pb2


@pytest.fixture(scope="function", autouse=True)
def get_resource_title():
with mock.patch(
"nucliadb.reader.reader.notifications.get_resource_title",
return_value="Resource",
) as m:
yield m


@pytest.fixture(scope="function")
def kb_notifications():
async def _kb_notifications(
Expand Down Expand Up @@ -106,12 +115,14 @@ async def test_activity(
notif = ResourceIndexedNotification.parse_raw(line)
assert notif.type == "resource_indexed"
assert notif.data.resource_uuid == "resource"
assert notif.data.resource_title == "Resource"
assert notif.data.seqid == 1

elif notification_type == "resource_written":
notif = ResourceWrittenNotification.parse_raw(line)
assert notif.type == "resource_written"
assert notif.data.resource_uuid == "resource"
assert notif.data.resource_title == "Resource"
assert notif.data.seqid == 1
assert notif.data.operation == "created"
assert notif.data.error is False
Expand All @@ -120,6 +131,7 @@ async def test_activity(
notif = ResourceProcessedNotification.parse_raw(line)
assert notif.type == "resource_processed"
assert notif.data.resource_uuid == "resource"
assert notif.data.resource_title == "Resource"
assert notif.data.seqid == 1
assert notif.data.ingestion_succeeded is True
assert notif.data.processing_errors is True
Expand Down
71 changes: 66 additions & 5 deletions nucliadb/nucliadb/reader/tests/unit/reader/test_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest

from nucliadb.reader.reader.notifications import (
get_resource_title,
kb_notifications_stream,
serialize_notification,
)
Expand All @@ -44,6 +45,7 @@ def timeout():


async def test_kb_notifications_stream_timeout_gracefully():
context = mock.Mock()
event = asyncio.Event()
cancelled_event = asyncio.Event()

Expand All @@ -58,14 +60,26 @@ async def mocked_kb_notifications(kbid):

with mock.patch(f"{MODULE}.kb_notifications", new=mocked_kb_notifications):
# Check that the generator returns gracefully after NOTIFICATIONS_TIMEOUT_S seconds
async for _ in kb_notifications_stream("testkb"):
async for _ in kb_notifications_stream(context, "testkb"):
assert False, "Should not be reached"

assert not event.is_set()
assert cancelled_event.is_set()


async def test_kb_notifications_stream_timeout_gracefully_while_streaming():
@pytest.fixture(scope="function")
def get_resource_title_mock():
with mock.patch(
"nucliadb.reader.reader.notifications.get_resource_title",
return_value="Resource",
) as m:
yield m


async def test_kb_notifications_stream_timeout_gracefully_while_streaming(
get_resource_title_mock,
):
context = mock.Mock()
cancelled_event = asyncio.Event()

async def mocked_kb_notifications(kbid):
Expand All @@ -79,7 +93,7 @@ async def mocked_kb_notifications(kbid):

with mock.patch(f"{MODULE}.kb_notifications", new=mocked_kb_notifications):
# Yield a notification first
stream = kb_notifications_stream("testkb")
stream = kb_notifications_stream(context, "testkb")
assert await stream.__anext__()

# Since there are no more notifications, the generator will eventually finish due to the timeout
Expand All @@ -101,6 +115,7 @@ async def mocked_kb_notifications(kbid):
processing_errors=True,
),
ResourceProcessed(
resource_title="Resource",
resource_uuid="rid",
seqid=1,
ingestion_succeeded=True,
Expand All @@ -116,6 +131,7 @@ async def mocked_kb_notifications(kbid):
action=writer_pb2.Notification.Action.ABORT,
),
ResourceWritten(
resource_title="Resource",
resource_uuid="rid",
seqid=1,
operation=ResourceOperationType.DELETED,
Expand All @@ -129,12 +145,57 @@ async def mocked_kb_notifications(kbid):
action=writer_pb2.Notification.Action.INDEXED,
),
ResourceIndexed(
resource_title="Resource",
resource_uuid="rid",
seqid=1,
),
),
],
)
def test_serialize_notification(pb, serialized_data):
serialized = serialize_notification(pb)
async def test_serialize_notification(pb, serialized_data, get_resource_title_mock):
context = mock.Mock()
cache = {}
serialized = await serialize_notification(context, pb, cache)
assert serialized.data == serialized_data


async def test_serialize_notification_caches_resource_titles(get_resource_title_mock):
cache = {}
notif = writer_pb2.Notification(
uuid="rid",
seqid=1,
action=writer_pb2.Notification.Action.INDEXED,
)
await serialize_notification(mock.Mock(), notif, cache)
assert cache == {"rid": "Resource"}
get_resource_title_mock.assert_called_once()

# Check that the cache is used
await serialize_notification(mock.Mock(), notif, cache)
get_resource_title_mock.assert_called_once()


@pytest.fixture(scope="function")
def get_resource_basic():
with mock.patch(
"nucliadb.reader.reader.notifications.ResourcesDataManager.get_resource_basic"
) as m:
yield m


@pytest.fixture(scope="function")
def kv_driver():
txn = mock.Mock()
driver = mock.MagicMock()
driver.transaction.return_value.__aenter__.return_value = txn
return driver


async def test_get_resource_title(kv_driver, get_resource_basic):
basic = mock.Mock(title="Resource")
get_resource_basic.return_value = basic

assert await get_resource_title(kv_driver, "kbid", "rid") == "Resource"

get_resource_basic.return_value = None
assert await get_resource_title(kv_driver, "kbid", "rid") is None
9 changes: 9 additions & 0 deletions nucliadb_models/nucliadb_models/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class ResourceIndexed(BaseModel):
resource_uuid: str = Field(
..., title="Resource UUID", description="UUID of the resource."
)
resource_title: str = Field(
..., title="Resource Title", description="Title of the resource."
)
seqid: int = Field(
...,
title="Sequence ID",
Expand All @@ -62,6 +65,9 @@ class ResourceWritten(BaseModel):
resource_uuid: str = Field(
..., title="Resource UUID", description="UUID of the resource."
)
resource_title: str = Field(
..., title="Resource Title", description="Title of the resource."
)
seqid: int = Field(
...,
title="Sequence ID",
Expand All @@ -81,6 +87,9 @@ class ResourceProcessed(BaseModel):
resource_uuid: str = Field(
..., title="Resource UUID", description="UUID of the resource."
)
resource_title: str = Field(
..., title="Resource Title", description="Title of the resource."
)
seqid: int = Field(
...,
title="Sequence ID",
Expand Down

3 comments on commit d094603

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: d094603 Previous: d4afd82 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 13275.659500466483 iter/sec (stddev: 8.931064577194443e-7) 13028.533525895236 iter/sec (stddev: 4.192637045977425e-7) 0.98

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: d094603 Previous: d4afd82 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 13122.14233431664 iter/sec (stddev: 0.0000011003824404240356) 13028.533525895236 iter/sec (stddev: 4.192637045977425e-7) 0.99

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: d094603 Previous: d4afd82 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12522.139086133056 iter/sec (stddev: 0.0000013683507917258907) 13028.533525895236 iter/sec (stddev: 4.192637045977425e-7) 1.04

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.