Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor rollover state persistence #2395

Merged
merged 5 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 85 additions & 62 deletions nucliadb/src/nucliadb/common/cluster/rollover.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
#
import argparse
import asyncio
import enum
import logging
from datetime import datetime
from typing import Optional

from nucliadb.common import datamanagers, locking
from nucliadb.common.cluster import manager as cluster_manager
from nucliadb.common.context import ApplicationContext
from nucliadb.common.datamanagers.rollover import RolloverState, RolloverStateNotFoundError
from nucliadb_protos import nodewriter_pb2, writer_pb2
from nucliadb_telemetry import errors

Expand All @@ -37,25 +37,6 @@
logger = logging.getLogger(__name__)


class RolloverStatus(enum.Enum):
RESOURCES_SCHEDULED = "resources_scheduled"
RESOURCES_INDEXED = "resources_indexed"
RESOURCES_VALIDATED = "resources_validated"


def _get_rollover_status(rollover_shards: writer_pb2.Shards, status: RolloverStatus) -> bool:
return rollover_shards.extra.get(status.value) == "true"


def _set_rollover_status(rollover_shards: writer_pb2.Shards, status: RolloverStatus):
rollover_shards.extra[status.value] = "true"


def _clear_rollover_status(rollover_shards: writer_pb2.Shards):
for status in RolloverStatus:
rollover_shards.extra.pop(status.value, None)


class UnexpectedRolloverError(Exception):
pass

Expand All @@ -71,15 +52,26 @@ async def create_rollover_shards(
sm = app_context.shard_manager

async with datamanagers.with_ro_transaction() as txn:
existing_rollover_shards = await datamanagers.rollover.get_kb_rollover_shards(txn, kbid=kbid)
if existing_rollover_shards is not None:
logger.info("Rollover shards already exist, skipping", extra={"kbid": kbid})
return existing_rollover_shards
try:
state = await datamanagers.rollover.get_rollover_state(txn, kbid=kbid)
except RolloverStateNotFoundError:
# State is not set yet, create it
state = RolloverState(
rollover_shards_created=False,
resources_scheduled=False,
resources_indexed=False,
cutover=False,
resources_validated=False,
)

kb_shards = await datamanagers.cluster.get_kb_shards(txn, kbid=kbid)
if kb_shards is None:
raise UnexpectedRolloverError(f"No shards found for KB {kbid}")

if state.rollover_shards_created:
logger.info("Rollover shards already created, skipping", extra={"kbid": kbid})
return kb_shards

# create new shards
created_shards = []
try:
Expand Down Expand Up @@ -130,6 +122,8 @@ async def create_rollover_shards(

async with datamanagers.with_transaction() as txn:
await datamanagers.rollover.update_kb_rollover_shards(txn, kbid=kbid, kb_shards=kb_shards)
state.rollover_shards_created = True
await datamanagers.rollover.set_rollover_state(txn, kbid=kbid, state=state)
await txn.commit()
return kb_shards

Expand All @@ -145,19 +139,17 @@ async def schedule_resource_indexing(app_context: ApplicationContext, kbid: str)
"""
Schedule indexing all data in a kb in rollover shards
"""
logger.info("Indexing rollover shards", extra={"kbid": kbid})

async with datamanagers.with_transaction() as txn:
rollover_shards = await datamanagers.rollover.get_kb_rollover_shards(txn, kbid=kbid)
if rollover_shards is None:
logger.info("Scheduling resources to be indexed to rollover shards", extra={"kbid": kbid})
async with datamanagers.with_ro_transaction() as txn:
state = await datamanagers.rollover.get_rollover_state(txn, kbid=kbid)
if not state.rollover_shards_created:
raise UnexpectedRolloverError(f"No rollover shards found for KB {kbid}")

if _get_rollover_status(rollover_shards, RolloverStatus.RESOURCES_SCHEDULED):
logger.info(
"Resources already scheduled for indexing, skipping",
extra={"kbid": kbid},
)
return
if state.resources_scheduled:
logger.info(
"Resources already scheduled for indexing, skipping",
extra={"kbid": kbid},
)
return

batch = []
async for resource_id in datamanagers.resources.iterate_resource_ids(kbid=kbid):
Expand All @@ -174,8 +166,8 @@ async def schedule_resource_indexing(app_context: ApplicationContext, kbid: str)
await txn.commit()

async with datamanagers.with_transaction() as txn:
_set_rollover_status(rollover_shards, RolloverStatus.RESOURCES_SCHEDULED)
await datamanagers.rollover.update_kb_rollover_shards(txn, kbid=kbid, kb_shards=rollover_shards)
state.resources_scheduled = True
await datamanagers.rollover.set_rollover_state(txn, kbid=kbid, state=state)
await txn.commit()


Expand All @@ -188,17 +180,18 @@ async def index_rollover_shards(app_context: ApplicationContext, kbid: str) -> N
Indexes all data in a kb in rollover shards
"""

async with datamanagers.with_transaction() as txn:
async with datamanagers.with_ro_transaction() as txn:
state = await datamanagers.rollover.get_rollover_state(txn, kbid=kbid)
if not all([state.rollover_shards_created, state.resources_scheduled]):
raise UnexpectedRolloverError(f"Preconditions not met for KB {kbid}")
rollover_shards = await datamanagers.rollover.get_kb_rollover_shards(txn, kbid=kbid)
if rollover_shards is None:
raise UnexpectedRolloverError(f"No rollover shards found for KB {kbid}")

if _get_rollover_status(rollover_shards, RolloverStatus.RESOURCES_INDEXED):
if rollover_shards is None:
raise UnexpectedRolloverError(f"No rollover shards found for KB {kbid}")
if state.resources_indexed:
logger.info("Resources already indexed, skipping", extra={"kbid": kbid})
return

logger.info("Indexing rollover shards", extra={"kbid": kbid})

wait_index_batch: list[writer_pb2.ShardObject] = []
# now index on all new shards only
while True:
Expand Down Expand Up @@ -259,8 +252,9 @@ async def index_rollover_shards(app_context: ApplicationContext, kbid: str) -> N
await wait_for_node(app_context, node_id)
wait_index_batch = []

_set_rollover_status(rollover_shards, RolloverStatus.RESOURCES_INDEXED)
async with datamanagers.with_transaction() as txn:
state.resources_indexed = True
await datamanagers.rollover.set_rollover_state(txn, kbid=kbid, state=state)
await datamanagers.rollover.update_kb_rollover_shards(txn, kbid=kbid, kb_shards=rollover_shards)
await txn.commit()

Expand All @@ -273,20 +267,35 @@ async def cutover_shards(app_context: ApplicationContext, kbid: str) -> None:
async with datamanagers.with_transaction() as txn:
sm = app_context.shard_manager

state = await datamanagers.rollover.get_rollover_state(txn, kbid=kbid)
if not all(
[
state.rollover_shards_created,
state.resources_scheduled,
state.resources_indexed,
]
):
raise UnexpectedRolloverError(f"Preconditions not met for KB {kbid}")
if state.cutover:
logger.info("Shards already cut over, skipping", extra={"kbid": kbid})
return

previously_active_shards = await datamanagers.cluster.get_kb_shards(
txn, kbid=kbid, for_update=True
)
rollover_shards = await datamanagers.rollover.get_kb_rollover_shards(txn, kbid=kbid)
if previously_active_shards is None or rollover_shards is None:
raise UnexpectedRolloverError("Shards for kb not found")

_clear_rollover_status(rollover_shards)
await datamanagers.cluster.update_kb_shards(txn, kbid=kbid, shards=rollover_shards)
await datamanagers.rollover.delete_kb_rollover_shards(txn, kbid=kbid)

for shard in previously_active_shards.shards:
await sm.rollback_shard(shard)

state.cutover = True
await datamanagers.rollover.set_rollover_state(txn, kbid=kbid, state=state)

await txn.commit()


Expand All @@ -301,11 +310,22 @@ async def validate_indexed_data(app_context: ApplicationContext, kbid: str) -> l
"""

async with datamanagers.with_ro_transaction() as txn:
state = await datamanagers.rollover.get_rollover_state(txn, kbid=kbid)
if not all(
[
state.rollover_shards_created,
state.resources_scheduled,
state.resources_indexed,
state.cutover,
]
):
raise UnexpectedRolloverError(f"Preconditions not met for KB {kbid}")

rolled_over_shards = await datamanagers.cluster.get_kb_shards(txn, kbid=kbid)
if rolled_over_shards is None:
raise UnexpectedRolloverError(f"No rollover shards found for KB {kbid}")

if _get_rollover_status(rolled_over_shards, RolloverStatus.RESOURCES_VALIDATED):
if state.resources_validated:
logger.info("Resources already validated, skipping", extra={"kbid": kbid})
return []

Expand Down Expand Up @@ -396,8 +416,9 @@ async def validate_indexed_data(app_context: ApplicationContext, kbid: str) -> l
raise UnexpectedRolloverError("Shard not found. This should not happen")
await delete_resource_from_shard(app_context, kbid, resource_id, shard)

_set_rollover_status(rolled_over_shards, RolloverStatus.RESOURCES_VALIDATED)
async with datamanagers.with_transaction() as txn:
state.resources_validated = True
await datamanagers.rollover.set_rollover_state(txn, kbid=kbid, state=state)
await datamanagers.cluster.update_kb_shards(txn, kbid=kbid, shards=rolled_over_shards)

return repaired_resources
Expand All @@ -420,16 +441,25 @@ async def clean_indexed_data(app_context: ApplicationContext, kbid: str) -> None

async def clean_rollover_status(app_context: ApplicationContext, kbid: str) -> None:
async with datamanagers.with_transaction() as txn:
kb_shards = await datamanagers.cluster.get_kb_shards(txn, kbid=kbid, for_update=True)
if kb_shards is None:
try:
await datamanagers.rollover.get_rollover_state(txn, kbid=kbid)
except RolloverStateNotFoundError:
logger.warning(
"No shards found for KB, skipping clean rollover status",
extra={"kbid": kbid},
"No rollover state found, skipping clean rollover status", extra={"kbid": kbid}
)
return
await datamanagers.rollover.clear_rollover_state(txn, kbid=kbid)
await txn.commit()

_clear_rollover_status(kb_shards)
await datamanagers.cluster.update_kb_shards(txn, kbid=kbid, shards=kb_shards)

async def wait_for_cluster_ready() -> None:
node_ready_checks = 0
while len(cluster_manager.INDEX_NODES) == 0:
if node_ready_checks > 10:
raise Exception("No index nodes available")
logger.info("Waiting for index nodes to be available")
await asyncio.sleep(1)
node_ready_checks += 1


async def rollover_kb_shards(
Expand All @@ -455,16 +485,9 @@ async def rollover_kb_shards(
- Validate that all resources are in the new shards
- Clean up indexed data
"""
node_ready_checks = 0
while len(cluster_manager.INDEX_NODES) == 0:
if node_ready_checks > 10:
raise Exception("No index nodes available")
logger.info("Waiting for index nodes to be available")
await asyncio.sleep(1)
node_ready_checks += 1
await wait_for_cluster_ready()

logger.info("Rolling over shards", extra={"kbid": kbid})

async with locking.distributed_lock(locking.KB_SHARDS_LOCK.format(kbid=kbid)):
await create_rollover_shards(app_context, kbid, drain_nodes=drain_nodes)
await schedule_resource_indexing(app_context, kbid)
Expand Down
36 changes: 36 additions & 0 deletions nucliadb/src/nucliadb/common/datamanagers/rollover.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import AsyncGenerator, Optional

import orjson
from pydantic import BaseModel

from nucliadb.common.maindb.driver import Transaction
from nucliadb_protos import writer_pb2
Expand All @@ -29,11 +30,28 @@

logger = logging.getLogger(__name__)

KB_ROLLOVER_STATE = "/kbs/{kbid}/rollover/state"
KB_ROLLOVER_SHARDS = "/kbs/{kbid}/rollover/shards"
KB_ROLLOVER_RESOURCES_TO_INDEX = "/kbs/{kbid}/rollover/to-index/{resource}"
KB_ROLLOVER_RESOURCES_INDEXED = "/kbs/{kbid}/rollover/indexed/{resource}"


class RolloverState(BaseModel):
rollover_shards_created: bool = False
resources_scheduled: bool = False
resources_indexed: bool = False
cutover: bool = False
resources_validated: bool = False


class RolloverStateNotFoundError(Exception):
"""
Raised when the rollover state is not found.
"""

...


async def get_kb_rollover_shards(txn: Transaction, *, kbid: str) -> Optional[writer_pb2.Shards]:
key = KB_ROLLOVER_SHARDS.format(kbid=kbid)
return await get_kv_pb(txn, key, writer_pb2.Shards)
Expand Down Expand Up @@ -163,3 +181,21 @@ async def iterate_indexed_data(*, kbid: str) -> AsyncGenerator[tuple[str, tuple[
if len(batch) > 0:
for key, val in await _get_batch_indexed_data(kbid=kbid, batch=batch):
yield key, val


async def get_rollover_state(txn: Transaction, kbid: str) -> RolloverState:
key = KB_ROLLOVER_STATE.format(kbid=kbid)
val = await txn.get(key)
if not val:
raise RolloverStateNotFoundError(kbid)
return RolloverState.model_validate_json(val)


async def set_rollover_state(txn: Transaction, kbid: str, state: RolloverState) -> None:
key = KB_ROLLOVER_STATE.format(kbid=kbid)
await txn.set(key, state.model_dump_json().encode())


async def clear_rollover_state(txn: Transaction, kbid: str) -> None:
key = KB_ROLLOVER_STATE.format(kbid=kbid)
await txn.delete(key)
4 changes: 1 addition & 3 deletions nucliadb/src/nucliadb/writer/tus/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ async def initialize(
if self.json_credentials is not None and self.json_credentials.strip() != "":
self.json_credentials_file = os.path.join(tempfile.mkdtemp(), "gcs_credentials.json")
with open(self.json_credentials_file, "w") as file:
file.write(
base64.b64decode(self.json_credentials).decode("utf-8")
)
file.write(base64.b64decode(self.json_credentials).decode("utf-8"))
self._credentials = ServiceAccountCredentials.from_json_keyfile_name(
self.json_credentials_file, SCOPES
)
Expand Down
Loading
Loading