Skip to content

Commit

Permalink
Remove deprecated labelset logic (#2562)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Oct 22, 2024
1 parent 3193aff commit 5cac6a3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 115 deletions.
13 changes: 1 addition & 12 deletions nucliadb/src/migrations/0011_materialize_labelset_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import logging

from nucliadb.common import datamanagers
from nucliadb.migrator.context import ExecutionContext

logger = logging.getLogger(__name__)
Expand All @@ -35,14 +34,4 @@
async def migrate(context: ExecutionContext) -> None: ...


async def migrate_kb(context: ExecutionContext, kbid: str) -> None:
async with context.kv_driver.transaction() as txn:
labelset_list = await datamanagers.labels._get_labelset_ids(txn, kbid=kbid)
if labelset_list is not None:
logger.info("No need for labelset list migration", extra={"kbid": kbid})
return

labelset_list = await datamanagers.labels._deprecated_scan_labelset_ids(txn, kbid=kbid)
await datamanagers.labels._set_labelset_ids(txn, kbid=kbid, labelsets=labelset_list)
logger.info("Labelset list migrated", extra={"kbid": kbid})
await txn.commit()
async def migrate_kb(context: ExecutionContext, kbid: str) -> None: ...
56 changes: 14 additions & 42 deletions nucliadb/src/nucliadb/common/datamanagers/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ async def get_labels(txn: Transaction, *, kbid: str) -> kb_pb2.Labels:
Get all labels for a knowledge box (from multiple labelsets)
"""
labels = kb_pb2.Labels()
labelset_ids = await _get_labelset_ids_bw_compat(txn, kbid=kbid)
labelset_ids = await _get_labelset_ids(txn, kbid=kbid)
if labelset_ids is None:
return labels
for labelset_id in labelset_ids:
labelset = await txn.get(KB_LABELSET.format(kbid=kbid, id=labelset_id))
if not labelset:
Expand All @@ -48,26 +50,6 @@ async def get_labels(txn: Transaction, *, kbid: str) -> kb_pb2.Labels:
return labels


async def _get_labelset_ids_bw_compat(txn: Transaction, *, kbid: str) -> list[str]:
labelsets = await _get_labelset_ids(txn, kbid=kbid)
if labelsets is not None:
return labelsets
# TODO: Remove this after migration #11
return await _deprecated_scan_labelset_ids(txn, kbid=kbid)


async def _deprecated_scan_labelset_ids(txn: Transaction, *, kbid: str) -> list[str]:
logger.warning(
"Scanning labelset ids. This is not optimal and should have been migrated.", extra={"kbid": kbid}
)
labelsets = []
labels_key = KB_LABELS.format(kbid=kbid)
async for key in txn.keys(labels_key, count=-1, include_start=False):
lsid = key.split("/")[-1]
labelsets.append(lsid)
return labelsets


async def _get_labelset_ids(txn: Transaction, *, kbid: str) -> Optional[list[str]]:
key = KB_LABELSET_IDS.format(kbid=kbid)
data = await txn.get(key, for_update=True)
Expand All @@ -77,33 +59,23 @@ async def _get_labelset_ids(txn: Transaction, *, kbid: str) -> Optional[list[str


async def _add_to_labelset_ids(txn: Transaction, *, kbid: str, labelsets: list[str]) -> None:
updated = set(labelsets)
previous = await _get_labelset_ids(txn, kbid=kbid)
needs_set = False
if previous is None:
# TODO: Remove this after migration #11
needs_set = True
previous = await _deprecated_scan_labelset_ids(txn, kbid=kbid)
for labelset in labelsets:
if labelset not in previous:
needs_set = True
previous.append(labelset)
if needs_set:
await _set_labelset_ids(txn, kbid=kbid, labelsets=previous)
if previous is not None:
updated.update(previous)
if previous is None or previous != updated:
await _set_labelset_ids(txn, kbid=kbid, labelsets=list(updated))


async def _delete_from_labelset_ids(txn: Transaction, *, kbid: str, labelsets: list[str]) -> None:
needs_set = False
previous = await _get_labelset_ids(txn, kbid=kbid)
if previous is None:
# TODO: Remove this after migration #11
needs_set = True
previous = await _deprecated_scan_labelset_ids(txn, kbid=kbid)
for labelset in labelsets:
if labelset in previous:
needs_set = True
previous.remove(labelset)
if needs_set:
await _set_labelset_ids(txn, kbid=kbid, labelsets=previous)
# Nothing to delete
return
previous_set = set(previous)
updated = previous_set - set(labelsets)
if previous_set != updated:
await _set_labelset_ids(txn, kbid=kbid, labelsets=list(updated))


async def _set_labelset_ids(txn: Transaction, *, kbid: str, labelsets: list[str]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,23 @@ async def test_labelset_ids(maindb_driver):
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) is None

# Check that deleting from an empty list does not break
async with maindb_driver.transaction() as txn:
await datamanagers.labels._delete_from_labelset_ids(txn, kbid=kbid, labelsets=["foo"])
await txn.commit()

# Check that adding to the list creates the list
async with maindb_driver.transaction() as txn:
await datamanagers.labels._add_to_labelset_ids(txn, kbid=kbid, labelsets=["bar", "ba"])
await txn.commit()
async with maindb_driver.transaction() as txn:
await datamanagers.labels._add_to_labelset_ids(txn, kbid=kbid, labelsets=["bar", "baz"])
await txn.commit()

async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) == [
"bar",
assert sorted(await datamanagers.labels._get_labelset_ids(txn, kbid=kbid)) == [
"ba",
"bar",
"baz",
]

Expand All @@ -50,7 +56,7 @@ async def test_labelset_ids(maindb_driver):
await datamanagers.labels._delete_from_labelset_ids(txn, kbid=kbid, labelsets=["ba"])
await txn.commit()
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) == [
assert sorted(await datamanagers.labels._get_labelset_ids(txn, kbid=kbid)) == [
"bar",
"baz",
]
Expand All @@ -75,61 +81,3 @@ async def test_labelset_ids(maindb_driver):

async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) == ["bar"]


async def test_labelset_ids_bw_compat(maindb_driver):
kbid = "foo"
# Check that initially all are empty
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._deprecated_scan_labelset_ids(txn, kbid=kbid) == []
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) is None
assert await datamanagers.labels._get_labelset_ids_bw_compat(txn, kbid=kbid) == []

# Check that adding to the list creates the list
async with maindb_driver.transaction() as txn:
await datamanagers.labels._add_to_labelset_ids(txn, kbid=kbid, labelsets=["bar", "ba"])
await txn.commit()
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) == [
"bar",
"ba",
]
assert await datamanagers.labels._deprecated_scan_labelset_ids(txn, kbid=kbid) == []

# Check that adding appends to the list
async with maindb_driver.transaction() as txn:
await datamanagers.labels._add_to_labelset_ids(txn, kbid=kbid, labelsets=["baz", "ba"])
await txn.commit()
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) == [
"bar",
"ba",
"baz",
]

# Check that removing from the list removes the item
async with maindb_driver.transaction() as txn:
await datamanagers.labels._delete_from_labelset_ids(txn, kbid=kbid, labelsets=["ba"])
await txn.commit()
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid=kbid) == [
"bar",
"baz",
]

# Check that removing also creates the list
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid="other") is None
await datamanagers.labels._delete_from_labelset_ids(txn, kbid="other", labelsets=["bar", "baz"])
await txn.commit()
async with maindb_driver.transaction() as txn:
assert await datamanagers.labels._get_labelset_ids(txn, kbid="other") == []

# Check legacy method
async with maindb_driver.transaction() as txn:
await txn.set("kb/labels/foo/1", b"somedata")
await txn.commit()
async with maindb_driver.transaction() as txn:
await datamanagers.labels._get_labelset_ids(txn, kbid="foo") is None
await datamanagers.labels._deprecated_scan_labelset_ids(txn, kbid="foo") == ["1"]
await datamanagers.labels._get_labelset_ids_bw_compat(txn, kbid="foo") == ["1"]

0 comments on commit 5cac6a3

Please sign in to comment.