Skip to content

Commit

Permalink
Merge pull request #512 from lidofinance/upload-state
Browse files Browse the repository at this point in the history
feat: State dump CID in tree dump
  • Loading branch information
F4ever authored Sep 4, 2024
2 parents 7263717 + 2f9b09e commit dbd4a62
Show file tree
Hide file tree
Showing 25 changed files with 386 additions and 169 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM python:3.12.4-slim as base
RUN apt-get update && apt-get install -y --no-install-recommends -qq \
libffi-dev=3.4.4-1 \
g++=4:12.2.0-3 \
curl=7.88.1-10+deb12u6 \
curl=7.88.1-10+deb12u7 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

Expand Down
48 changes: 24 additions & 24 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ types-setuptools = "^67.6.0.0"
types-urllib3 = "^1.26.25.8"
# }}}
hypothesis = "^6.68.2"
black = "^23.3.0"
black = "^24.8"
pylint = "^3.2.3"
mypy = "^1.10.0"

Expand Down
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Iterable, cast
from typing import Iterator, cast

from packaging.version import Version
from prometheus_client import start_http_server
Expand Down Expand Up @@ -152,7 +152,7 @@ def check_providers_chain_ids(web3: Web3, cc: ConsensusClientModule, kac: KeysAP
)


def ipfs_providers() -> Iterable[IPFSProvider]:
def ipfs_providers() -> Iterator[IPFSProvider]:
if variables.GW3_ACCESS_KEY and variables.GW3_SECRET_KEY:
yield GW3(
variables.GW3_ACCESS_KEY,
Expand Down
2 changes: 1 addition & 1 deletion src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _check_duty(
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
self.state.add_processed_epoch(duty_epoch)
self.state.commit()
self.state.log_status()
self.state.log_progress()
unprocessed_epochs = self.state.unprocessed_epochs
CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs))
CSM_MIN_UNPROCESSED_EPOCH.set(min(unprocessed_epochs or {EpochNumber(-1)}))
Expand Down
98 changes: 66 additions & 32 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from collections import defaultdict
from typing import Iterable
from typing import Iterator

from hexbytes import HexBytes

from src.constants import TOTAL_BASIS_POINTS, UINT64_MAX
from src.metrics.prometheus.business import CONTRACT_ON_PAUSE
Expand All @@ -10,22 +12,23 @@
)
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
from src.modules.csm.log import FramePerfLog
from src.modules.csm.state import State
from src.modules.csm.tree import Tree
from src.modules.csm.types import ReportData, Shares
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.providers.execution.contracts.cs_fee_oracle import CSFeeOracleContract
from src.providers.execution.exceptions import InconsistentData
from src.providers.ipfs.cid import CID
from src.providers.ipfs import CID
from src.types import (
BlockStamp,
EpochNumber,
ReferenceBlockStamp,
SlotNumber,
StakingModuleAddress,
ValidatorIndex,
StakingModuleId,
ValidatorIndex,
)
from src.utils.blockstamp import build_blockstamp
from src.utils.cache import global_lru_cache as lru_cache
Expand Down Expand Up @@ -55,6 +58,7 @@ class CSOracle(BaseModule, ConsensusModule):
2. Calculate the performance of each validator based on the attestations.
3. Calculate the share of each CSM node operator excluding underperforming validators.
"""

COMPATIBLE_ONCHAIN_VERSIONS = [(1, 1)]

report_contract: CSFeeOracleContract
Expand Down Expand Up @@ -89,15 +93,9 @@ def execute_module(self, last_finalized_blockstamp: BlockStamp) -> ModuleExecute
@lru_cache(maxsize=1)
@duration_meter()
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
# NOTE: We cannot use `r_epoch` from the `current_frame_range` call because the `blockstamp` is a
# `ReferenceBlockStamp`, hence it's a block the frame ends at. We use `ref_epoch` instead.
l_epoch, _ = self.current_frame_range(blockstamp)
r_epoch = blockstamp.ref_epoch

self.state.validate(l_epoch, r_epoch)
self.state.log_status()
self.validate_state(blockstamp)

distributed, shares = self.calculate_distribution(blockstamp)
distributed, shares, log = self.calculate_distribution(blockstamp)
if not distributed:
logger.info({"msg": "No shares distributed in the current frame"})

Expand All @@ -106,29 +104,23 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
prev_cid = self.w3.csm.get_csm_tree_cid(blockstamp)

if prev_cid:
logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(prev_cid)})
ipfs_tree = Tree.decode(self.w3.ipfs.fetch(prev_cid))

if ipfs_tree.root != prev_root:
raise ValueError("Unexpected tree root got from IPFS dump")

logger.info({"msg": "Restored tree from IPFS dump", "root": repr(prev_root)})
# Update cumulative amount of shares for all operators.
for v in ipfs_tree.tree.values:
no_id, amount = v["value"]
shares[no_id] += amount
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
shares[no_id] += acc_shares
else:
logger.info({"msg": "No previous CID available"})

tree = self.make_tree(shares)
cid: CID | None = None
tree_cid: CID | None = None

if tree:
cid = self.w3.ipfs.publish(tree.encode())
tree_cid = self.publish_tree(tree, log)

return ReportData(
self.report_contract.get_consensus_version(blockstamp.block_hash),
blockstamp.ref_slot,
tree_root=tree.root if tree else prev_root,
tree_cid=cid or prev_cid,
tree_cid=tree_cid or prev_cid or "",
distributed=distributed,
).as_tuple()

Expand All @@ -151,6 +143,14 @@ def module_validators_by_node_operators(self, blockstamp: BlockStamp) -> Validat
StakingModuleAddress(self.w3.csm.module.address), blockstamp
)

def validate_state(self, blockstamp: ReferenceBlockStamp) -> None:
# NOTE: We cannot use `r_epoch` from the `current_frame_range` call because the `blockstamp` is a
# `ReferenceBlockStamp`, hence it's a block the frame ends at. We use `ref_epoch` instead.
l_epoch, _ = self.current_frame_range(blockstamp)
r_epoch = blockstamp.ref_epoch

self.state.validate(l_epoch, r_epoch)

def collect_data(self, blockstamp: BlockStamp) -> bool:
"""Ongoing report data collection for the estimated reference slot"""

Expand All @@ -175,11 +175,11 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
return False

self.state.migrate(l_epoch, r_epoch)
self.state.log_status()
self.state.log_progress()

if done := self.state.is_fulfilled:
if self.state.is_fulfilled:
logger.info({"msg": "All epochs are already processed. Nothing to collect"})
return done
return True

try:
checkpoints = FrameCheckpointsIterator(
Expand All @@ -198,17 +198,23 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:

return self.state.is_fulfilled

def calculate_distribution(self, blockstamp: ReferenceBlockStamp) -> tuple[int, defaultdict[NodeOperatorId, int]]:
def calculate_distribution(
self, blockstamp: ReferenceBlockStamp
) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]:
"""Computes distribution of fee shares at the given timestamp"""

threshold = self.state.avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
network_avg_perf = self.state.get_network_aggr().perf
threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
operators_to_validators = self.module_validators_by_node_operators(blockstamp)

# Build the map of the current distribution operators.
distribution: dict[NodeOperatorId, int] = defaultdict(int)
stuck_operators = self.stuck_operators(blockstamp)
log = FramePerfLog(self.state.frame, threshold)

for (_, no_id), validators in operators_to_validators.items():
if no_id in stuck_operators:
log.operators[no_id].stuck = True
continue

for v in validators:
Expand All @@ -218,30 +224,51 @@ def calculate_distribution(self, blockstamp: ReferenceBlockStamp) -> tuple[int,
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
continue
else:
if v.validator.slashed is True:
# It means that validator was active during the frame and got slashed and didn't meet the exit
# epoch, so we should not count such validator for operator's share.
log.operators[no_id].validators[v.index].slashed = True
continue

if aggr.perf > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
distribution[no_id] += aggr.assigned

log.operators[no_id].validators[v.index].perf = aggr

# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
total = sum(p for p in distribution.values())

if not total:
return 0, shares
return 0, shares, log

to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash)
for no_id, no_share in distribution.items():
if no_share:
shares[no_id] = to_distribute * no_share // total
log.operators[no_id].distributed = shares[no_id]

distributed = sum(s for s in shares.values())
if distributed > to_distribute:
raise CSMError(f"Invalid distribution: {distributed=} > {to_distribute=}")
return distributed, shares
return distributed, shares, log

def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]:
logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)})
tree = Tree.decode(self.w3.ipfs.fetch(cid))

logger.info({"msg": "Restored tree from IPFS dump", "root": repr(tree.root)})

def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> Iterable[NodeOperatorId]:
if tree.root != root:
raise ValueError("Unexpected tree root got from IPFS dump")

for v in tree.tree.values:
yield v["value"]

def stuck_operators(self, blockstamp: ReferenceBlockStamp) -> set[NodeOperatorId]:
stuck: set[NodeOperatorId] = set()
l_epoch, _ = self.current_frame_range(blockstamp)
l_ref_slot = self.converter(blockstamp).get_epoch_first_slot(l_epoch)
Expand Down Expand Up @@ -284,6 +311,13 @@ def make_tree(self, shares: dict[NodeOperatorId, Shares]) -> Tree | None:
logger.info({"msg": "New tree built for the report", "root": repr(tree.root)})
return tree

def publish_tree(self, tree: Tree, log: FramePerfLog) -> CID:
log_cid = self.w3.ipfs.publish(log.encode())
logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)})
tree_cid = self.w3.ipfs.publish(tree.encode({"logCID": log_cid}))
logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)})
return tree_cid

@lru_cache(maxsize=1)
def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, EpochNumber]:
converter = self.converter(blockstamp)
Expand Down
48 changes: 48 additions & 0 deletions src/modules/csm/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from collections import defaultdict
from dataclasses import asdict, dataclass, field

from src.modules.csm.state import AttestationsAccumulator
from src.types import EpochNumber, NodeOperatorId


class LogJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, AttestationsAccumulator):
return asdict(o)
return super().default(o)


@dataclass
class ValidatorFrameSummary:
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
slashed: bool = False


@dataclass
class OperatorFrameSummary:
distributed: int = 0
validators: dict[str, ValidatorFrameSummary] = field(default_factory=lambda: defaultdict(ValidatorFrameSummary))
stuck: bool = False


@dataclass
class FramePerfLog:
"""A log of performance assessed per operator in the given frame"""

frame: tuple[EpochNumber, EpochNumber]
threshold: float = 0.0
operators: dict[NodeOperatorId, OperatorFrameSummary] = field(
default_factory=lambda: defaultdict(OperatorFrameSummary)
)

def encode(self) -> bytes:
return (
LogJSONEncoder(
indent=None,
separators=(',', ':'),
sort_keys=True,
)
.encode(asdict(self))
.encode()
)
Loading

0 comments on commit dbd4a62

Please sign in to comment.