From 53cf6685c619a185508f7d48b4048cb3c321eee6 Mon Sep 17 00:00:00 2001 From: Maximilian Langenfeld <15726643+ezdac@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:37:29 -0400 Subject: [PATCH] chore: add new chainsync --- .../medley/chainsync/chainsegment/chain.go | 316 +++++++++++++++ rolling-shutter/medley/chainsync/chainsync.go | 46 +++ .../medley/chainsync/client/client.go | 52 +++ .../medley/chainsync/client/test.go | 317 +++++++++++++++ rolling-shutter/medley/chainsync/options.go | 127 ++++++ .../medley/chainsync/syncer/chaincache.go | 78 ++++ .../chainsync/syncer/dynamically_typed.go | 68 ++++ .../medley/chainsync/syncer/fetch.go | 247 ++++++++++++ .../medley/chainsync/syncer/log_topics.go | 50 +++ .../medley/chainsync/syncer/loop.go | 192 +++++++++ .../medley/chainsync/syncer/types.go | 70 ++++ .../syncer_test/chainsegment_test.go | 366 ++++++++++++++++++ .../chainsync/syncer_test/fetcher_test.go | 162 ++++++++ .../medley/chainsync/syncer_test/util.go | 191 +++++++++ 14 files changed, 2282 insertions(+) create mode 100644 rolling-shutter/medley/chainsync/chainsegment/chain.go create mode 100644 rolling-shutter/medley/chainsync/chainsync.go create mode 100644 rolling-shutter/medley/chainsync/client/client.go create mode 100644 rolling-shutter/medley/chainsync/client/test.go create mode 100644 rolling-shutter/medley/chainsync/options.go create mode 100644 rolling-shutter/medley/chainsync/syncer/chaincache.go create mode 100644 rolling-shutter/medley/chainsync/syncer/dynamically_typed.go create mode 100644 rolling-shutter/medley/chainsync/syncer/fetch.go create mode 100644 rolling-shutter/medley/chainsync/syncer/log_topics.go create mode 100644 rolling-shutter/medley/chainsync/syncer/loop.go create mode 100644 rolling-shutter/medley/chainsync/syncer/types.go create mode 100644 rolling-shutter/medley/chainsync/syncer_test/chainsegment_test.go create mode 100644 rolling-shutter/medley/chainsync/syncer_test/fetcher_test.go create mode 100644 rolling-shutter/medley/chainsync/syncer_test/util.go diff --git a/rolling-shutter/medley/chainsync/chainsegment/chain.go b/rolling-shutter/medley/chainsync/chainsegment/chain.go new file mode 100644 index 00000000..15c60da2 --- /dev/null +++ b/rolling-shutter/medley/chainsync/chainsegment/chain.go @@ -0,0 +1,316 @@ +package chainsegment + +import ( + "context" + "errors" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/client" +) + +const MaxNumPollBlocks = 50 + +var ( + ErrReorg = errors.New("detected reorg in updated chain-segment") + ErrEmpty = errors.New("empty chain-segment") + ErrUpdateBlockTooFarInPast = errors.New("the updated block reaches too far in the past for the chain-segment") + ErrOverlapTooBig = errors.New("chain-segment overlap too big") +) + +type UpdateLatestResult struct { + // the full new segment with the reorg applied + FullSegment *ChainSegment + // the removed segment that is not part of the new full segment anymore + // (reorged blocks) + RemovedSegment *ChainSegment + // the updated segment of new blocks that were not part of the old chain + // (new blocks including the replacement blocks from a reorg) + UpdatedSegment *ChainSegment +} + +// capNumPollBlocks is a pipeline function +// that restricts the number of blocks to be +// polled, e.g. during filling gaps between +// two chain-segments. +func capNumPollBlocks(num int) int { + if num > MaxNumPollBlocks { + return MaxNumPollBlocks + } else if num < 1 { + return 1 + } + return num +} + +type ChainSegment struct { + chain []*types.Header +} + +func NewChainSegment(chain ...*types.Header) *ChainSegment { + bc := &ChainSegment{ + chain: chain, + } + return bc +} + +func (bc *ChainSegment) GetHeaderByHash(h common.Hash) *types.Header { + // OPTIM: this should be implemented more efficiently + // with a hash-map + for _, header := range bc.chain { + if header.Hash().Cmp(h) == 0 { + return header + } + } + return nil +} + +func (bc *ChainSegment) Len() int { + return len(bc.chain) +} + +func (bc *ChainSegment) Earliest() *types.Header { + if len(bc.chain) == 0 { + return nil + } + return bc.chain[0] +} + +func (bc *ChainSegment) Latest() *types.Header { + if len(bc.chain) == 0 { + return nil + } + return bc.chain[len(bc.chain)-1] +} + +func (bc *ChainSegment) Get() []*types.Header { + return bc.chain +} + +func (bc *ChainSegment) Copy() *ChainSegment { + return NewChainSegment(bc.chain...) +} + +// UpdateLatest incorporates a new chainsegment `update` into it's existing +// chain-segment. +// For this it backtracks the new chain-segment until it finds the common ancestor +// with it's current chain-segment. If there is no ancestor because of a block-number +// gap between the old segments "latest" block and the new segments "earliest" block, +// it will incrementally batch-augment the 'update' chain-segment with blocks older than +// it's "earliest" block, and call the UpdateLatest latest method recursively +// until the algorithm finds a common ancestor. +// The outcome of this process is an `UpdateLatestResult`, which +// communicates to the caller what part of the previous chain-segment had to be removed, +// and what part of the `update` chain-segment was appended to the previous chain-segment +// after removal of out-of-date blocks, in addition to the full newly updated chain-segment. +// This is a pointer method that updates the internal state of it's chain-segment! +func (bc *ChainSegment) UpdateLatest(ctx context.Context, c client.Sync, update *ChainSegment) (UpdateLatestResult, error) { + update = update.Copy() + if bc.Len() == 0 { + // We can't compare anything - instead of silently absorbing the + // whole new segment, communicate this to the caller with a specific error. + return UpdateLatestResult{}, ErrEmpty + } + + if bc.Earliest().Number.Cmp(update.Earliest().Number) == 1 { + // We don't reach so far in the past for the old chain-segment. + // This happens when there is a large reorg, while the chain-segment + // of the cache is still small. + return UpdateLatestResult{}, fmt.Errorf( + "segment earliest=%d, update earliest=%d: %w", + bc.Earliest().Number.Int64(), update.Earliest().Number.Int64(), + ErrUpdateBlockTooFarInPast, + ) + } + overlapBig := new(big.Int).Add( + new(big.Int).Sub(bc.Latest().Number, update.Earliest().Number), + // both being the same height means one block overlap, so add 1 + big.NewInt(1), + ) + if !overlapBig.IsInt64() { + // this should never happen, this would be too large of a gap + return UpdateLatestResult{}, ErrOverlapTooBig + } + + overlap := int(overlapBig.Int64()) + if overlap < 0 { + // overlap is negative, this means we have a gap: + extendedUpdate, err := update.ExtendLeft(ctx, c, capNumPollBlocks(-overlap)) + if err != nil { + return UpdateLatestResult{}, fmt.Errorf("failed to extend left gap: %w", err) + } + return bc.UpdateLatest(ctx, c, extendedUpdate) + } else if overlap == 0 { + if update.Earliest().ParentHash.Cmp(bc.Latest().Hash()) == 0 { + // the new segment extends the old one perfectly + return UpdateLatestResult{ + FullSegment: bc.Copy().AddRight(update), + RemovedSegment: nil, + UpdatedSegment: update, + }, nil + } + // the block-numbers align, but the new segment + // seems to be from a reorg that branches off within the old segment + _, err := update.ExtendLeft(ctx, c, capNumPollBlocks(bc.Len())) + if err != nil { + return UpdateLatestResult{}, fmt.Errorf("failed to extend into reorg: %w", err) + } + return bc.UpdateLatest(ctx, c, update) + } + // implicit case - overlap > 0: + // now we can compare the segments and find the common ancestor + // Return the segment of the overlap from the current segment + // and compute the diff of the whole new update segment. + removed, updated := bc.GetLatest(overlap).DiffLeftAligned(update) + // don't copy, but use the method's struct, + // that way we modify in-place + full := bc + if removed != nil { + // cut the reorged section that has to be removed + // so that we only have the "left" section up until the + // common ancestor + full = full.GetEarliest(full.Len() - removed.Len()) + } + if updated != nil { + // and now append the update section + // to the right, effectively removing the reorged section + full.AddRight(updated) + } + return UpdateLatestResult{ + FullSegment: full, + RemovedSegment: removed, + UpdatedSegment: updated, + }, nil +} + +// AddRight adds the `add` chain-segment to the "right" of the +// original chain-segment, and thus assumes that the `add` segments +// Earliest() block is the child-block of the original segments +// Latest() block. This condition is *not* checked, +// so callers have to guarantee for it. +func (bc *ChainSegment) AddRight(add *ChainSegment) *ChainSegment { + bc.chain = append(bc.chain, add.chain...) + return bc +} + +// DiffLeftAligned compares the ChainSegment to another chain-segment that +// starts at the same Earliest() block-number. +// It walks both segments from earliest to latest header simultaneously +// and compares the block-hashes. As soon as there is a mismatch +// in block-hashes, a consecutive difference from that point on is assumed. +// All diff blocks from the `other` chain-segment will be appended to the returned `update` +// chain-segment, and all diff blocks from the original chain-segment +// will be appended to the `remove` chain-segment. +// If there is no overlap in the diff, but the `other` chain-segment is longer than +// the original segment, the `remove` segment will be nil, and the `update` segment +// will consist of the non-overlapping blocks of the `other` segment. +// If both segments are identical, both `update` and `remove` segments will be nil. +func (bc *ChainSegment) DiffLeftAligned(other *ChainSegment) (remove, update *ChainSegment) { + // 1) assumes both segments start at the same block height (earliest block at index 0 with same blocknum) + // 2) assumes the other.Len() >= bc.Len() + + // Compare the two and see if we have to reorg based on the hashes + removed := []*types.Header{} + updated := []*types.Header{} + oldChain := bc.Get() + newChain := other.Get() + + for i := 0; i < len(newChain); i++ { + var oldHeader *types.Header + newHeader := newChain[i] + if len(oldChain) > i { + oldHeader = oldChain[i] + } + if oldHeader == nil { + updated = append(updated, newHeader) + // TODO: sanity check also the blocknum + parent hash chain + // so that we are sure that we have consecutive segments. + } else if oldHeader.Hash().Cmp(newHeader.Hash()) != 0 { + removed = append(removed, oldHeader) + updated = append(updated, newHeader) + } + } + var removedSegment, updatedSegment *ChainSegment + if len(removed) > 0 { + removedSegment = NewChainSegment(removed...) + } + if len(updated) > 0 { + updatedSegment = NewChainSegment(updated...) + } + return removedSegment, updatedSegment +} + +// GetLatest retrieves the "n" latest blocks from this +// ChainSegment. +// If the segment is shorter than n, the whole segment gets returned. +func (bc *ChainSegment) GetLatest(n int) *ChainSegment { + if n > bc.Len() { + n = bc.Len() + } + return NewChainSegment(bc.chain[len(bc.chain)-n : len(bc.chain)]...) +} + +// GetLatest retrieves the "n" earliest blocks from this +// ChainSegment. +// If the segment is shorter than n, the whole segment gets returned. +func (bc *ChainSegment) GetEarliest(n int) *ChainSegment { + if n > bc.Len() { + n = bc.Len() + } + return NewChainSegment(bc.chain[:n]...) +} + +func (bc *ChainSegment) NewSegmentRight(ctx context.Context, c client.Sync, num int) (*ChainSegment, error) { + rightMost := bc.Latest() + if rightMost == nil { + return nil, ErrEmpty + } + chain := []*types.Header{} + previous := rightMost + for i := 1; i <= num; i++ { + blockNum := new(big.Int).Sub(rightMost.Number, big.NewInt(int64(i))) + h, err := c.HeaderByNumber(ctx, blockNum) + if err != nil { + return nil, err + } + if h.Hash().Cmp(previous.ParentHash) != 0 { + // the server has a different chain state than this segment, + // so it is part of a reorged away chain-segment + return nil, ErrReorg + } + chain = append(chain, h) + previous = h + } + return NewChainSegment(chain...), nil +} + +func (bc *ChainSegment) ExtendLeft(ctx context.Context, c client.Sync, num int) (*ChainSegment, error) { + leftMost := bc.Earliest() + if leftMost == nil { + return nil, ErrEmpty + } + for num > 0 { + blockNum := new(big.Int).Sub(leftMost.Number, big.NewInt(int64(1))) + //OPTIM: we do cap the max poll number when calling this method, + // but then we make one request per block anyways. + // This doesn't make sense, but there currently is no batching + // for retrieving ranges of headers. + h, err := c.HeaderByNumber(ctx, blockNum) + if err != nil { + return nil, fmt.Errorf("failed to retrieve header by number (#%d): %w", blockNum.Uint64(), err) + } + if h.Hash().Cmp(leftMost.ParentHash) != 0 { + // The server has a different chain state than this segment, + // so it is part of a reorged away chain-segment. + // This can also happen when the server reorged during this loop + // and we now polled the parent with an unexpected hash. + return nil, ErrReorg + } + bc.chain = append([]*types.Header{h}, bc.chain...) + leftMost = h + num-- + } + return bc, nil +} diff --git a/rolling-shutter/medley/chainsync/chainsync.go b/rolling-shutter/medley/chainsync/chainsync.go new file mode 100644 index 00000000..8365e29b --- /dev/null +++ b/rolling-shutter/medley/chainsync/chainsync.go @@ -0,0 +1,46 @@ +package chainsync + +import ( + "context" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/syncer" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service" +) + +type Chainsync struct { + options *options + fetcher *syncer.Fetcher +} + +func New(options ...Option) (*Chainsync, error) { + opts := defaultOptions() + for _, o := range options { + if err := o(opts); err != nil { + return nil, fmt.Errorf("error applying option to Chainsync: %w", err) + } + } + + if err := opts.verify(); err != nil { + return nil, fmt.Errorf("error verifying options to Chainsync: %w", err) + } + return &Chainsync{ + options: opts, + }, nil +} + +func (c *Chainsync) Start(ctx context.Context, runner service.Runner) error { + var err error + c.fetcher, err = c.options.initFetcher(ctx) + if err != nil { + return fmt.Errorf("error initializing Chainsync: %w", err) + } + return c.fetcher.Start(ctx, runner) +} + +func (c *Chainsync) GetHeaderByHash(ctx context.Context, h common.Hash) (*types.Header, error) { + return c.fetcher.GetHeaderByHash(ctx, h) +} diff --git a/rolling-shutter/medley/chainsync/client/client.go b/rolling-shutter/medley/chainsync/client/client.go new file mode 100644 index 00000000..47944f06 --- /dev/null +++ b/rolling-shutter/medley/chainsync/client/client.go @@ -0,0 +1,52 @@ +package client + +import ( + "context" + "math/big" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +type Client interface { + Close() + ChainID(ctx context.Context) (*big.Int, error) + BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) + BlockByNumber(ctx context.Context, number *big.Int) (*types.Block, error) + BlockNumber(ctx context.Context) (uint64, error) + PeerCount(ctx context.Context) (uint64, error) + HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) + HeaderByNumber(ctx context.Context, number *big.Int) (*types.Header, error) + TransactionByHash(ctx context.Context, hash common.Hash) (tx *types.Transaction, isPending bool, err error) + TransactionSender(ctx context.Context, tx *types.Transaction, block common.Hash, index uint) (common.Address, error) + TransactionCount(ctx context.Context, blockHash common.Hash) (uint, error) + TransactionInBlock(ctx context.Context, blockHash common.Hash, index uint) (*types.Transaction, error) + TransactionReceipt(ctx context.Context, txHash common.Hash) (*types.Receipt, error) + SyncProgress(ctx context.Context) (*ethereum.SyncProgress, error) + SubscribeNewHead(ctx context.Context, ch chan<- *types.Header) (ethereum.Subscription, error) + NetworkID(ctx context.Context) (*big.Int, error) + BalanceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (*big.Int, error) + StorageAt(ctx context.Context, account common.Address, key common.Hash, blockNumber *big.Int) ([]byte, error) + CodeAt(ctx context.Context, account common.Address, blockNumber *big.Int) ([]byte, error) + NonceAt(ctx context.Context, account common.Address, blockNumber *big.Int) (uint64, error) + FilterLogs(ctx context.Context, q ethereum.FilterQuery) ([]types.Log, error) + SubscribeFilterLogs(ctx context.Context, q ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) + PendingBalanceAt(ctx context.Context, account common.Address) (*big.Int, error) + PendingStorageAt(ctx context.Context, account common.Address, key common.Hash) ([]byte, error) + PendingCodeAt(ctx context.Context, account common.Address) ([]byte, error) + PendingNonceAt(ctx context.Context, account common.Address) (uint64, error) + PendingTransactionCount(ctx context.Context) (uint, error) + CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) + CallContractAtHash(ctx context.Context, msg ethereum.CallMsg, blockHash common.Hash) ([]byte, error) + PendingCallContract(ctx context.Context, msg ethereum.CallMsg) ([]byte, error) + SuggestGasPrice(ctx context.Context) (*big.Int, error) + SuggestGasTipCap(ctx context.Context) (*big.Int, error) + EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64, error) + SendTransaction(ctx context.Context, tx *types.Transaction) error +} + +type Sync interface { + ethereum.LogFilterer + ethereum.ChainReader +} diff --git a/rolling-shutter/medley/chainsync/client/test.go b/rolling-shutter/medley/chainsync/client/test.go new file mode 100644 index 00000000..303c78a7 --- /dev/null +++ b/rolling-shutter/medley/chainsync/client/test.go @@ -0,0 +1,317 @@ +package client + +import ( + "context" + "errors" + "fmt" + "math/big" + "sync" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + gethLog "github.com/ethereum/go-ethereum/log" +) + +var ErrNotImplemented = errors.New("not implemented") + +var _ Sync = &TestClient{} + +type TestClient struct { + log gethLog.Logger + mux *sync.RWMutex + subsMux *sync.RWMutex + headerChain []*types.Header + logs map[common.Hash][]types.Log + latestHeadIndex int + initialProgress bool + latestHeadEmitter []chan<- *types.Header + latestHeadSubscription []*Subscription +} + +func NewSubscription(idx int) *Subscription { + return &Subscription{ + idx: idx, + err: make(chan error, 1), + } +} + +type Subscription struct { + idx int + err chan error +} + +func (su *Subscription) Unsubscribe() { + // TODO: not implemented yet, but we don't want to panic +} + +func (su *Subscription) Err() <-chan error { + return su.err +} + +type TestClientController struct { + c *TestClient +} + +func NewTestClient(logger gethLog.Logger) (*TestClient, *TestClientController) { + c := &TestClient{ + log: log, + mux: &sync.RWMutex{}, + subsMux: &sync.RWMutex{}, + headerChain: []*types.Header{}, + logs: map[common.Hash][]types.Log{}, + latestHeadIndex: 0, + initialProgress: false, + latestHeadEmitter: []chan<- *types.Header{}, + latestHeadSubscription: []*Subscription{}, + } + ctrl := &TestClientController{c} + return c, ctrl +} + +// progresses the internal state of the latest head +// until no more information is available in the +// internal header chain. +func (c *TestClientController) ProgressAllHeads() { + for c.ProgressHead() { + } +} + +// updates the internal state of the latest +// head one block. This will iterate over the +// internal headerChain and thus also includes reorging +// and decreasing the latest-head number. +func (c *TestClientController) ProgressHead() bool { + c.c.mux.Lock() + defer c.c.mux.Unlock() + + if c.c.latestHeadIndex >= len(c.c.headerChain)-1 { + return false + } + c.c.latestHeadIndex++ + return true +} + +func (c *TestClientController) WaitSubscribed(ctx context.Context) { + for { + c.c.subsMux.RLock() + if len(c.c.latestHeadEmitter) > 0 { + c.c.subsMux.RUnlock() + break + } + c.c.subsMux.RUnlock() + if ctx.Err() != nil { + return + } + time.After(50 * time.Millisecond) + } +} +func (c *TestClientController) EmitLatestHead(ctx context.Context) error { + c.c.subsMux.RLock() + defer c.c.subsMux.RUnlock() + + c.c.mux.RLock() + if len(c.c.latestHeadEmitter) == 0 { + c.c.mux.RUnlock() + return nil + } + h := c.c.getLatestHeader() + c.c.mux.RUnlock() + for _, em := range c.c.latestHeadEmitter { + select { + case em <- h: + case <-ctx.Done(): + return ctx.Err() + } + } + return nil +} + +func (c *TestClientController) AppendNextHeader(h *types.Header, events ...types.Log) { + c.c.mux.Lock() + defer c.c.mux.Unlock() + + c.c.headerChain = append(c.c.headerChain, h) + _, ok := c.c.logs[h.Hash()] + if ok { + return + } + c.c.logs[h.Hash()] = events +} + +func (t *TestClient) ChainID(_ context.Context) (*big.Int, error) { //nolint: unparam + return big.NewInt(42), nil +} + +func (t *TestClient) Close() { + // TODO: cleanup +} + +func (t *TestClient) getLatestHeader() *types.Header { + if len(t.headerChain) == 0 { + return nil + } + return t.headerChain[t.latestHeadIndex] +} + +func (t *TestClient) searchBlock(f func(*types.Header) bool) *types.Header { + for i := t.latestHeadIndex; i >= 0; i-- { + h := t.headerChain[i] + if f(h) { + return h + } + } + return nil +} + +func (t *TestClient) searchBlockByNumber(number *big.Int) *types.Header { + return t.searchBlock( + func(h *types.Header) bool { + return h.Number.Cmp(number) == 0 + }) +} + +func (t *TestClient) searchBlockByHash(hash common.Hash) *types.Header { + return t.searchBlock( + func(h *types.Header) bool { + return hash.Cmp(h.Hash()) == 0 + }) +} + +func (t *TestClient) BlockNumber(_ context.Context) (uint64, error) { //nolint: unparam + t.mux.RLock() + defer t.mux.RUnlock() + + return t.getLatestHeader().Nonce.Uint64(), nil +} + +func (t *TestClient) HeaderByHash(_ context.Context, hash common.Hash) (*types.Header, error) { + t.mux.RLock() + defer t.mux.RUnlock() + + h := t.searchBlockByHash(hash) + if h == nil { + return nil, errors.New("header not found") + } + return h, nil +} + +func (t *TestClient) HeaderByNumber(_ context.Context, number *big.Int) (*types.Header, error) { + t.mux.RLock() + defer t.mux.RUnlock() + + if number == nil { + return t.getLatestHeader(), nil + } + if number.Cmp(big.NewInt(-2)) == 0 { + return t.getLatestHeader(), nil + } + h := t.searchBlockByNumber(number) + if h == nil { + return nil, errors.New("not found") + } + return h, nil +} + +func (t *TestClient) SubscribeNewHead(_ context.Context, ch chan<- *types.Header) (ethereum.Subscription, error) { + t.subsMux.Lock() + defer t.subsMux.Unlock() + + t.latestHeadEmitter = append(t.latestHeadEmitter, ch) + su := NewSubscription(len(t.latestHeadSubscription) - 1) + t.latestHeadSubscription = append(t.latestHeadSubscription, su) + // TODO: unsubscribe and deleting from the array + // TODO: filling error promise in the subscription + return su, nil +} + +func (t *TestClient) getLogs(ctx context.Context, query ethereum.FilterQuery) ([]types.Log, error) { + logs := []types.Log{} + if query.BlockHash != nil { + log, ok := t.logs[*query.BlockHash] + if !ok { + // TODO: if possible return the same error as the client + return logs, fmt.Errorf("no logs found") + } + return log, nil + } + if query.FromBlock != nil { + current := query.FromBlock + toBlock := query.ToBlock + if toBlock == nil { + latest := t.getLatestHeader() + toBlock = latest.Number + } + for current.Cmp(toBlock) != +1 { + h := t.searchBlockByNumber(current) + + current = new(big.Int).Add(current, big.NewInt(1)) + log, ok := t.logs[h.Hash()] + if !ok { + continue + } + logs = append(logs, log...) + } + } + //FIXME: also return no logs found if empty? + return logs, nil +} + +func (t *TestClient) FilterLogs(ctx context.Context, query ethereum.FilterQuery) ([]types.Log, error) { + t.mux.RLock() + defer t.mux.RUnlock() + + logs, err := t.getLogs(ctx, query) + if len(logs) > 0 { + t.log.Info("logs found in FilterLogs", "logs", logs) + } + if err != nil { + return logs, err + } + filtered := []types.Log{} + + addrs := map[common.Address]struct{}{} + for _, a := range query.Addresses { + addrs[a] = struct{}{} + } + t.log.Info("query Addresses FilterLogs", "addresses", query.Addresses) + + for _, log := range logs { + if _, ok := addrs[log.Address]; !ok { + continue + } + filtered = append(filtered, log) + } + // OPTIM: filter by the topics, but this gets complex + // since it's position based as well. + // It's not strictly needed for the tests, since the downstream + // caller should also ignore wrong log types upon parsing. + return filtered, nil +} + +func (t *TestClient) SubscribeFilterLogs(_ context.Context, _ ethereum.FilterQuery, _ chan<- types.Log) (ethereum.Subscription, error) { + panic(ErrNotImplemented) +} + +func (t *TestClient) CodeAt(_ context.Context, _ common.Address, _ *big.Int) ([]byte, error) { + panic(ErrNotImplemented) +} + +func (t *TestClient) TransactionReceipt(_ context.Context, _ common.Hash) (*types.Receipt, error) { + panic(ErrNotImplemented) +} + +func (t *TestClient) BlockByHash(_ context.Context, _ common.Hash) (*types.Block, error) { + panic(ErrNotImplemented) +} + +func (t *TestClient) TransactionCount(_ context.Context, _ common.Hash) (uint, error) { + panic(ErrNotImplemented) + +} + +func (t *TestClient) TransactionInBlock(_ context.Context, _ common.Hash, _ uint) (*types.Transaction, error) { + panic(ErrNotImplemented) + +} diff --git a/rolling-shutter/medley/chainsync/options.go b/rolling-shutter/medley/chainsync/options.go new file mode 100644 index 00000000..301174a7 --- /dev/null +++ b/rolling-shutter/medley/chainsync/options.go @@ -0,0 +1,127 @@ +package chainsync + +import ( + "context" + + "github.com/ethereum/go-ethereum/ethclient" + "github.com/pkg/errors" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/client" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/syncer" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/encodeable/number" +) + +const defaultMemoryBlockCacheSize = 50 + +type Option func(*options) error + +type options struct { + clientURL string + ethClient client.Sync + syncStart *number.BlockNumber + chainCache syncer.ChainCache + eventHandler []syncer.ContractEventHandler + chainHandler []syncer.ChainUpdateHandler +} + +func (o *options) verify() error { + if o.clientURL != "" && o.ethClient != nil { + return errors.New("'WithClient' and 'WithClientURL' options are mutually exclusive") + } + if o.clientURL == "" && o.ethClient == nil { + return errors.New("either 'WithClient' or 'WithClientURL' options are expected") + } + return nil +} + +// initFetcher applies the options and initializes the fetcher. +// The context is only the initialisation context, +// and should not be considered to handle the lifecycle +// of shutter clients background workers. +func (o *options) initFetcher(ctx context.Context) (*syncer.Fetcher, error) { + var err error + if o.clientURL != "" { + o.ethClient, err = ethclient.DialContext(ctx, o.clientURL) + if err != nil { + return nil, err + } + } + + if o.chainCache == nil { + o.chainCache = syncer.NewMemoryChainCache(int(defaultMemoryBlockCacheSize), nil) + } + f := syncer.NewFetcher(o.ethClient, o.chainCache) + + for _, h := range o.chainHandler { + f.RegisterChainUpdateHandler(h) + } + for _, h := range o.eventHandler { + f.RegisterContractEventHandler(h) + } + return f, nil +} + +func defaultOptions() *options { + return &options{ + syncStart: number.NewBlockNumber(nil), + eventHandler: []syncer.ContractEventHandler{}, + chainHandler: []syncer.ChainUpdateHandler{}, + } +} + +func WithSyncStartBlock( + blockNumber *number.BlockNumber, +) Option { + if blockNumber == nil { + blockNumber = number.NewBlockNumber(nil) + } + return func(o *options) error { + o.syncStart = blockNumber + return nil + } +} + +func WithClientURL(url string) Option { + return func(o *options) error { + o.clientURL = url + return nil + } +} + +// NOTE: The Latest() of the chaincache determines what is the starting +// point of the chainsync. +// In case of an empty chaincache, we will initialize the cache +// with the current latest block. +// If we have a very old (persistent) chaincache, we will sync EVERY block +// since the latest known block of the cache due to consistency considerations. +// If that is unfeasible, the cache has to be emptied beforehand and the +// gap in state-updates has to be dealt with or accepted. +// If NO chaincache is passed with this option, an empty in-memory +// chain-cache with a capped cachesize will be used. +func WithChainCache(c syncer.ChainCache) Option { + return func(o *options) error { + o.chainCache = c + return nil + } +} + +func WithClient(c client.Sync) Option { + return func(o *options) error { + o.ethClient = c + return nil + } +} + +func WithContractEventHandler(h syncer.ContractEventHandler) Option { + return func(o *options) error { + o.eventHandler = append(o.eventHandler, h) + return nil + } +} + +func WithChainUpdateHandler(h syncer.ChainUpdateHandler) Option { + return func(o *options) error { + o.chainHandler = append(o.chainHandler, h) + return nil + } +} diff --git a/rolling-shutter/medley/chainsync/syncer/chaincache.go b/rolling-shutter/medley/chainsync/syncer/chaincache.go new file mode 100644 index 00000000..eb019365 --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer/chaincache.go @@ -0,0 +1,78 @@ +package syncer + +import ( + "context" + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/chainsegment" +) + +type ChainCache interface { + Get(context.Context) (*chainsegment.ChainSegment, error) + Update(context.Context, ChainUpdateContext) error + GetHeaderByHash(context.Context, common.Hash) (*types.Header, error) +} + +var ErrEmpy = errors.New("chain-cache empty") + +var _ ChainCache = &MemoryChainCache{} + +func NewMemoryChainCache(maxSize int, chain *chainsegment.ChainSegment) *MemoryChainCache { + // chain can be nil + return &MemoryChainCache{ + chain: chain, + maxSize: maxSize, + } +} + +type MemoryChainCache struct { + chain *chainsegment.ChainSegment + maxSize int +} + +func (mcc *MemoryChainCache) Get(_ context.Context) (*chainsegment.ChainSegment, error) { + if mcc.chain == nil { + return nil, ErrEmpy + } + return mcc.chain, nil +} + +func (mcc *MemoryChainCache) GetHeaderByHash(_ context.Context, h common.Hash) (*types.Header, error) { + return mcc.chain.GetHeaderByHash(h), nil +} + +func (mcc *MemoryChainCache) Update(_ context.Context, update ChainUpdateContext) error { + newSegment := []*types.Header{} + if mcc.chain != nil { + // OPTIM: can be implemented more efficient, but mainly used for testing + removeHashes := map[common.Hash]struct{}{} + if update.Remove != nil { + for _, header := range update.Remove.Get() { + removeHashes[header.Hash()] = struct{}{} + } + } + for _, header := range mcc.chain.Get() { + _, remove := removeHashes[header.Hash()] + if !remove { + newSegment = append(newSegment, header) + } + } + if update.Append != nil { + newSegment = append(newSegment, update.Append.Get()...) + } + if len(newSegment) > mcc.maxSize { + //TODO: check for oneoff. + newSegment = newSegment[len(newSegment)-mcc.maxSize:] + } + } else { + if update.Append == nil { + return nil + } + newSegment = update.Append.Get() + } + mcc.chain = chainsegment.NewChainSegment(newSegment...) + return nil +} diff --git a/rolling-shutter/medley/chainsync/syncer/dynamically_typed.go b/rolling-shutter/medley/chainsync/syncer/dynamically_typed.go new file mode 100644 index 00000000..0eba8ed7 --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer/dynamically_typed.go @@ -0,0 +1,68 @@ +package syncer + +import ( + "context" + "reflect" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +type contractEventHandler[T any] struct { + h IContractEventHandler[T] +} + +func (gh contractEventHandler[T]) Address() common.Address { + return gh.h.Address() +} + +func (gh contractEventHandler[T]) Topic() common.Hash { + return gh.h.ABI().Events[gh.h.Event()].ID +} + +func (gh contractEventHandler[T]) Parse(logger types.Log) (any, error) { + var event T + + if err := UnpackLog(gh.h.ABI(), &event, gh.h.Event(), logger); err != nil { + return nil, err + } + // Set the log to the Raw field + f := reflect.ValueOf(&event).Elem().FieldByName("Raw") + if f.CanSet() { + f.Set(reflect.ValueOf(logger)) + } + return event, nil +} + +func (gh contractEventHandler[T]) Accept(ctx context.Context, h types.Header, ev any) (bool, error) { + switch t := ev.(type) { + case T: + return gh.h.Accept(ctx, h, t) + default: + return false, nil + } +} + +func (gh contractEventHandler[T]) Handle(ctx context.Context, update ChainUpdateContext, events []any) error { + tList := []T{} + for _, ev := range events { + switch t := ev.(type) { + case T: + tList = append(tList, t) + default: + } + } + if len(tList) == 0 { + return nil + } + return gh.h.Handle(ctx, update, tList) +} + +func (gh contractEventHandler[T]) Logger() zerolog.Logger { + return log.With(). + Str("contract-event-handler", gh.h.Event()). + Str("contract-address", gh.Address().String()). + Logger() +} diff --git a/rolling-shutter/medley/chainsync/syncer/fetch.go b/rolling-shutter/medley/chainsync/syncer/fetch.go new file mode 100644 index 00000000..a06580bf --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer/fetch.go @@ -0,0 +1,247 @@ +package syncer + +import ( + "context" + "errors" + "fmt" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/hashicorp/go-multierror" + "github.com/rs/zerolog/log" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/chainsegment" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service" +) + +const MaxRequestBlockRange = 1000 +const MaxSyncedBlockCacheSize = 1000 + +var ErrServerStateInconsistent = errors.New("server's chain-state differs from assumed local state") + +type ChainSyncEthClient interface { + ethereum.LogFilterer + ethereum.ChainReader +} + +type Fetcher struct { + ethClient ChainSyncEthClient + chainCache ChainCache + syncMux sync.RWMutex + + chainUpdate *chainsegment.ChainSegment + + contractEventHandlers []ContractEventHandler + topics [][]common.Hash + addresses []common.Address + + chainUpdateHandlers []ChainUpdateHandler + + inChan chan *types.Header + processingTrig chan struct{} +} + +func NewFetcher(c ChainSyncEthClient, chainCache ChainCache) *Fetcher { + return &Fetcher{ + chainCache: chainCache, + ethClient: c, + syncMux: sync.RWMutex{}, + chainUpdate: &chainsegment.ChainSegment{}, + contractEventHandlers: []ContractEventHandler{}, + chainUpdateHandlers: []ChainUpdateHandler{}, + topics: [][]common.Hash{}, + addresses: []common.Address{}, + inChan: make(chan *types.Header), + processingTrig: make(chan struct{}, 1), + } +} +func (f *Fetcher) GetHeaderByHash(ctx context.Context, h common.Hash) (*types.Header, error) { + log.Error().Err(err).Msg("failed to query header from chain-cache") + if err != nil { + log.Error("failed to query header from chain-cache", "error", err) + err = nil + header, err = f.ethClient.HeaderByHash(ctx, h) + if header == nil { + header, err = f.client.HeaderByHash(ctx, h) + if err != nil { + err = fmt.Errorf("failed to query header from RPC client: %w", err) + } + } + return header, err +} + +func (f *Fetcher) Start(ctx context.Context, runner service.Runner) error { + var err error + for _, h := range f.contractEventHandlers { + f.addresses = append(f.addresses, h.Address()) + } + f.topics, err = topics(f.contractEventHandlers) + if err != nil { + return fmt.Errorf("can't construct topics for handler: %w", err) + } + + // TODO: retry + latest, err := f.client.HeaderByNumber(ctx, big.NewInt(-2)) + if err != nil { + return fmt.Errorf("can't get header by number: %w", err) + } + + f.chainUpdate = chainsegment.NewChainSegment(latest) + + subs, err := f.ethClient.SubscribeNewHead(ctx, f.inChan) + if err != nil { + return fmt.Errorf("can't subscribe to new head: %w", err) + } + runner.Defer(subs.Unsubscribe) + runner.Defer(func() { + close(f.inChan) + close(f.processingTrig) + }) + runner.Go(func() error { + err := f.loop(ctx) + if err != nil { + return fmt.Errorf("fetcher loop errored: %w", err) + } + return nil + }) + return nil +} + +// This method has to be called before starting the Fetcher. +func (f *Fetcher) RegisterContractEventHandler(h ContractEventHandler) { + f.syncMux.Lock() + defer f.syncMux.Unlock() + + f.contractEventHandlers = append(f.contractEventHandlers, h) +} + +// This method has to be called before starting the Fetcher. +func (f *Fetcher) RegisterChainUpdateHandler(h ChainUpdateHandler) { + f.syncMux.Lock() + defer f.syncMux.Unlock() + + f.chainUpdateHandlers = append(f.chainUpdateHandlers, h) +} + +func (f *Fetcher) processChainUpdateHandler(ctx context.Context, update ChainUpdateContext, h ChainUpdateHandler) error { + return h.Handle(ctx, update) +} + +func (f *Fetcher) processContractEventHandler( + ctx context.Context, + update ChainUpdateContext, + h ContractEventHandler, + logs []types.Log, +) error { + var result error + events := []any{} + for _, l := range logs { + // don't process logs from a different contract + if h.Address().Cmp(l.Address) != 0 { + continue + } + // don't process logs with non-matching topics + topicMatch := false + for _, t := range l.Topics { + if h.Topic().Cmp(t) == 0 { + topicMatch = true + break + } + } + if !topicMatch { + continue + } + + a, err := h.Parse(l) + // error here means we skip processing for this handler + if err != nil { + // TODO: we could log some errors here if they are not "wrong topic" + continue + } + header := update.Append.GetHeaderByHash(l.BlockHash) + if header == nil { + log.Error().Err(ErrServerStateInconsistent).Str("log-block-hash", l.BlockHash.String()) + result = multierror.Append(result, err) + continue + } + accept, err := h.Accept(ctx, *header, a) + if err != nil { + result = multierror.Append(result, err) + continue + } + if accept { + events = append(events, a) + } + } + if errors.Is(result, ErrCritical) { + if errors.Is(result, errs.ErrCritical) { + } + return h.Handle(ctx, update, events) +} + +func (f *Fetcher) FetchAndHandle(ctx context.Context, update ChainUpdateContext) error { + query := ethereum.FilterQuery{ + Addresses: f.addresses, + Topics: f.topics, + FromBlock: update.Append.Earliest().Number, + ToBlock: update.Append.Latest().Number, + } + + logs, err := f.ethClient.FilterLogs(ctx, query) + if err != nil { + return err + } + for _, l := range logs { + if update.Append.GetHeaderByHash(l.BlockHash) == nil { + // The API only allows filtering by blocknumber. + // If the retrieved log's block-hash is not present in the + // update query-context, + // this means the server is operating on a different + // chain-state (e.g. reorged). + return ErrServerStateInconsistent + } + } + + wg := sync.WaitGroup{} + var result error + f.syncMux.RLock() + for _, h := range f.contractEventHandlers { + handler := h + wg.Add(1) + go func() { + defer wg.Done() + + err := f.processContractEventHandler(ctx, update, handler, logs) + if err != nil { + err = fmt.Errorf("contract-event-handler error: %w", err) + log.Error().Err(err).Msg("handler processing errored") + result = multierror.Append(result, err) + } + }() + } + f.syncMux.RUnlock() + // run the chain-update handlers after the contract event handlers did run. + for _, h := range f.chainUpdateHandlers { + wg.Wait() + for i, h := range f.chainUpdateHandlers { + handler := h + f.log.Info("spawning chain update handler", "num", i) + wg.Add(1) + go func() { + defer wg.Done() + + err := f.processChainUpdateHandler(ctx, update, handler) + if err != nil { + err = fmt.Errorf("chain-update-handler error: %w", err) + log.Error().Err(err).Msg("handler processing errored") + result = multierror.Append(result, err) + } + }() + } + f.syncMux.RUnlock() + wg.Wait() + return result +} diff --git a/rolling-shutter/medley/chainsync/syncer/log_topics.go b/rolling-shutter/medley/chainsync/syncer/log_topics.go new file mode 100644 index 00000000..f3dd8c7e --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer/log_topics.go @@ -0,0 +1,50 @@ +package syncer + +import ( + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +var errTopicMismatch = errors.New("passed in topic does not match handler") + +func topics(handler []ContractEventHandler) ([][]common.Hash, error) { + var query [][]any + for _, h := range handler { + // Append the event selector to the query parameters and construct the topic set + query = append([][]any{{h.Topic()}}, query...) + } + topics, err := abi.MakeTopics(query...) + if err != nil { + return nil, err + } + return topics, nil +} + +// UnpackLog unpacks a retrieved log into the provided output structure. +func UnpackLog(a abi.ABI, out interface{}, event string, log types.Log) error { + // Copy of bind.BoundContract.UnpackLog + + // Anonymous events are not supported. + if len(log.Topics) == 0 { + return errTopicMismatch + } + if log.Topics[0] != a.Events[event].ID { + return errTopicMismatch + } + if len(log.Data) > 0 { + if err := a.UnpackIntoInterface(out, event, log.Data); err != nil { + return fmt.Errorf("error marshaling into `out` value: %w", err) + } + } + var indexed abi.Arguments + for _, arg := range a.Events[event].Inputs { + if arg.Indexed { + indexed = append(indexed, arg) + } + } + return abi.ParseTopics(out, indexed, log.Topics[1:]) +} diff --git a/rolling-shutter/medley/chainsync/syncer/loop.go b/rolling-shutter/medley/chainsync/syncer/loop.go new file mode 100644 index 00000000..13c71748 --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer/loop.go @@ -0,0 +1,192 @@ +package syncer + +import ( + "context" + "errors" + "fmt" + "math/big" + + "github.com/rs/zerolog/log" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/chainsegment" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/errs" +) + +var errUint64Overflow = errors.New("uint64 overflow in conversion from math.Big") + +func (f *Fetcher) triggerHandlerProcessing() { + // nonblocking setter of updates, + // in case this has already been triggered + select { + case f.processingTrig <- struct{}{}: + default: + } +} + +// success will be True, when we successfully applied the updated chain segment +// to the old chain. If there remains a gap, this has to return false. +func (f *Fetcher) handlerSync(ctx context.Context) (success bool, err error) { //nolint: funlen + var syncedChain, removedSegment, updatedSegment *chainsegment.ChainSegment + + syncedChain, err = f.chainCache.Get(ctx) + if f.chainUpdate == nil { + // nothing to update + success = true + return success, err + } + if errors.Is(err, ErrEmpy) { + // no chain-cache present yet, just set the chain-update without + // checking for reorgs + removedSegment = nil + updatedSegment = f.chainUpdate + log.Trace().Msg("internal chain cache empty, setting updated chain segment") + } else if err != nil { + success = false + return success, err + } else { + if new(big.Int).Add(syncedChain.Latest().Number, big.NewInt(1)). + Cmp(f.chainUpdate.Earliest().Number) == -1 { + //FIXME: overflow + diff := new(big.Int).Sub(f.chainUpdate.Earliest().Number, syncedChain.Latest().Number).Uint64() + queryBlocks := MaxRequestBlockRange + // cap the extend range at the diff to the update to not overshoot + if diff < uint64(queryBlocks) { + + diffBig := new(big.Int).Sub(f.chainUpdate.Earliest().Number, syncedChain.Latest().Number) + if !diffBig.IsUint64() { + success = false + return success, fmt.Errorf("chain-update difference too big: %w", errUint64Overflow) + } + diff := diffBig.Uint64() + queryBlocks := MaxRequestBlockRange + // cap the extend range at the diff to the update to not overshoot + if diff < uint64(queryBlocks) { + queryBlocks = int(diff) + } + + // we are not synced to the chain-update + // so first construct an update to the right of the synced chain + log.Trace(). + Uint64("synced-latest-blocknum", syncedChain.Latest().Number.Uint64()). + Uint64("update-earliest-blocknum", f.chainUpdate.Earliest().Number.Uint64()). + Int("num-query-blocks", queryBlocks). + Msg("chain update ahead of synced chain, fetching gap blocks") + updatedSegment, err = syncedChain.NewSegmentRight(ctx, f.ethClient, queryBlocks) + if errors.Is(err, chainsegment.ErrReorg) { + // this means we reorged the old chain segment. + f.chainUpdate.ExtendLeft(ctx, f.ethClient, queryBlocks) + f.chainUpdate.ExtendLeft(ctx, f.client, queryBlocks) + err = nil + f.chainUpdate, err = f.chainUpdate.ExtendLeft(ctx, f.ethClient, queryBlocks) + if err != nil { + err = fmt.Errorf("error while querying older blocks from reorg update: %w", err) + } + success = false + return success, err + } + removedSegment = nil + success = false + } else { + result, updateErr := syncedChain.UpdateLatest(ctx, f.ethClient, f.chainUpdate) + success = true + if updateErr != nil { + log.Error().Err(err).Msg("error updating chain with latest segment") + if errors.Is(err, chainsegment.ErrUpdateBlockTooFarInPast) { + // TODO: what should we do on 'ErrUpdateBlockTooFarInPast'? + // We can't provide handler calls with the same accuracy of + // information on the potentially "removed" chain-segment, + // since our chain-cache does not have the full old chain segment + // in it's storage anymore, and especially the block-hash + // of the reorged away chain is not present anymore. + // The client should probably panic with a critical log error. + // In general this is very unlikely when the chain-cache capacity is + // larger than the most unlikely, still realistic reorg-size. + // However the described condition might currently occur during + // initial syncing, when the block-cache is not filled to capacity + // yet. + log.Warn().Err(err).Msg("received a reorg that pre-dates the internal chain-cache." + + " ignoring chain-update for now, but this condition might be irrecoverable") + } + err = updateErr + // Now as long as the chain-cache eviction policy is not aggressive (easily doable) + removedSegment = result.RemovedSegment + updatedSegment = result.UpdatedSegment + // we will process the whole segment of the chain update + f.chainUpdate = nil + } + } + + update := ChainUpdateContext{ + Remove: removedSegment, + Append: updatedSegment, + } + if update.Append == nil { + return success, err + } + + // blocking call, until all handlers are done processing the + // new chain segment + err = f.FetchAndHandle(ctx, update) + if err != nil { + return false, err + } + err = f.chainCache.Update(ctx, update) + if err != nil { + return false, err + } + return success, err +} + +func (f *Fetcher) loop(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case newHeader, ok := <-f.inChan: + log.Debug().Msg("latest head stream closed, exiting handler loop") + f.log.Info("latest head stream closed, exiting handler loop") + return nil + log.Debug().Uint64("block-number", newHeader.Number.Uint64()).Msg("new latest head from l2 ws-stream") + f.log.Debug("new latest head from l2 ws-stream", "block-number", newHeader.Number.Uint64()) + newSegment := chainsegment.NewChainSegment(newHeader) + if f.chainUpdate != nil { + // apply the updates to the chain-update buffer that hasn't been processed + result, err := f.chainUpdate.Copy().UpdateLatest(ctx, f.ethClient, newSegment) + result, err := f.chainUpdate.Copy().UpdateLatest(ctx, f.client, newSegment) + fullUpdated := result.FullSegment + removed := result.RemovedSegment + if err != nil { + if errors.Is(err, chainsegment.ErrUpdateBlockTooFarInPast) { + // reorg beyond the chain-update segment, just set the new header + removed = f.chainUpdate + fullUpdated = newSegment + } + if errors.Is(err, errs.ErrCritical) { + return err + log.Error().Err(err).Msg("error updating chain segment") + f.log.Error("error updating chain segment", "error", err) + } + log.Info().Uint64("block-number", newHeader.Number.Uint64()).Msg("received a new reorg block") + f.log.Info("received a new reorg block", "block-number", newHeader.Number.Uint64()) + } + f.chainUpdate = fullUpdated + } else { + f.chainUpdate = newSegment + } + f.triggerHandlerProcessing() + + f.log.Trace("fetcher loop: received handler sync trigger") + success, err := f.handlerSync(ctx) + if err != nil { + if errors.Is(err, errs.ErrCritical) { + return err + log.Error().Err(err).Msg("error during handler-sync") + f.log.Error("error during handler-sync", "error", err) + } + if !success { + // keep processing the handler without waiting for updates + f.triggerHandlerProcessing() + } + } + } +} diff --git a/rolling-shutter/medley/chainsync/syncer/types.go b/rolling-shutter/medley/chainsync/syncer/types.go new file mode 100644 index 00000000..f454d555 --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer/types.go @@ -0,0 +1,70 @@ +package syncer + +import ( + "context" + "fmt" + "reflect" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/rs/zerolog" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/chainsegment" +) + +type ChainUpdateContext struct { + // a previously applied chainsegment that has to be + // removed from the state first + Remove *chainsegment.ChainSegment + // the chainsegment the passed in events are part of + Append *chainsegment.ChainSegment +} + +type ChainUpdateHandler interface { + Handle(ctx context.Context, update ChainUpdateContext) error +} + +// IContractEventHandler is the generic interface +// that should be implemented. +// This allows more narrowly typed implementations +// on a per-contracts-event basis, while offloading +// the dynamic typing to a single implementation +// (`contractEventHandler[T]`, complying to the +// ContractEventHandler interface). +type IContractEventHandler[T any] interface { + Address() common.Address + Event() string + ABI() abi.ABI + + Accept(context.Context, types.Header, T) (bool, error) + Handle(context.Context, ChainUpdateContext, []T) error +} + +// WrapHandler wraps the generic implementation into +// a dynamically typed handler complying to the +// `ContractEventHandler` interface. +func WrapHandler[T any](h IContractEventHandler[T]) (ContractEventHandler, error) { + var t T + if reflect.TypeOf(t).Kind() == reflect.Pointer { + return nil, fmt.Errorf("Handler must not receive pointer values for the event types.") + return nil, fmt.Errorf("handler must not receive pointer values for the event types") + return contractEventHandler[T]{ + h: h, + }, nil +} + +// ContractEventHandler is the dynamically typed +// interface that is accepted by the chainsync. +// Ideally this doesn't have to be implemented, +// but should be result of wrapping the more +// narrowly typed IContractEventHandler implementations. +type ContractEventHandler interface { + Topic() common.Hash + Address() common.Address + + Parse(log types.Log) (any, bool, error) + Accept(ctx context.Context, h types.Header, ev any) (bool, error) + Handle(ctx context.Context, update ChainUpdateContext, events []any) error + Logger() zerolog.Logger +} diff --git a/rolling-shutter/medley/chainsync/syncer_test/chainsegment_test.go b/rolling-shutter/medley/chainsync/syncer_test/chainsegment_test.go new file mode 100644 index 00000000..c62882db --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer_test/chainsegment_test.go @@ -0,0 +1,366 @@ +package tester + +import ( + "context" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + gethLog "github.com/ethereum/go-ethereum/log" + "golang.org/x/exp/slog" + "gotest.tools/v3/assert" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/chainsegment" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/client" +) + +func TestExtendLeft(t *testing.T) { + headers := MakeChain(1, common.BigToHash(big.NewInt(0)), 10, 42) + log := gethLog.NewLogger(slog.Default().Handler()) + clnt, ctl := client.NewTestClient(log) + + for _, h := range headers { + ctl.AppendNextHeader(h) + ctl.ProgressHead() + } + latest := chainsegment.NewChainSegment(headers[9]) + _, err := latest.ExtendLeft(context.Background(), clnt, 9) + assert.NilError(t, err) + assert.Equal(t, len(latest.Get()), 10) +} + +func TestUpdateLatest(t *testing.T) { //nolint: funlen + tests := map[string]struct { + chainArgs MakeChainSegmentsArgs + + expectedFullLength int + + expectedUpdatedLength int + expectedUpdatedEarliestNum int + + expectedRemovedLength int + expectedRemovedEarliestNum int + + expectedErrorString string + }{ + "close gap and reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 10, + }, + Update: MakeChainSegmentsChain{ + Length: 15, + }, + BranchOffBlock: 5, + UpdateSegmentLength: 1, + Reorg: true, + }, + + expectedFullLength: 20, + expectedUpdatedLength: 15, + expectedUpdatedEarliestNum: 5, + + expectedRemovedLength: 5, + expectedRemovedEarliestNum: 5, + expectedErrorString: "", + }, + "no gap and reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 10, + }, + Update: MakeChainSegmentsChain{ + Length: 15, + }, + BranchOffBlock: 5, + UpdateSegmentLength: 15, + Reorg: true, + }, + + expectedFullLength: 20, + + expectedUpdatedLength: 15, + expectedUpdatedEarliestNum: 5, + + expectedRemovedLength: 5, + expectedRemovedEarliestNum: 5, + expectedErrorString: "", + }, + "overlap and reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 10, + }, + Update: MakeChainSegmentsChain{ + Length: 15, + }, + BranchOffBlock: 5, + UpdateSegmentLength: 18, + Reorg: true, + }, + expectedFullLength: 20, + + expectedUpdatedLength: 15, + expectedUpdatedEarliestNum: 5, + + expectedRemovedLength: 5, + expectedRemovedEarliestNum: 5, + expectedErrorString: "", + }, + "append no reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 5, + }, + Update: MakeChainSegmentsChain{ + Length: 10, + }, + BranchOffBlock: 0, + // no gap, perfect alignment + UpdateSegmentLength: 5, + Reorg: false, + }, + + expectedFullLength: 10, + expectedUpdatedLength: 5, + expectedUpdatedEarliestNum: 5, + expectedRemovedLength: 0, + expectedRemovedEarliestNum: -1, + expectedErrorString: "", + }, + "close gap no reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 5, + }, + Update: MakeChainSegmentsChain{ + Length: 10, + }, + BranchOffBlock: 0, + // gap of 3 + UpdateSegmentLength: 2, + Reorg: false, + }, + expectedFullLength: 10, + expectedUpdatedLength: 5, + expectedUpdatedEarliestNum: 5, + expectedRemovedLength: 0, + expectedRemovedEarliestNum: -1, + expectedErrorString: "", + }, + "overlap no reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 5, + }, + Update: MakeChainSegmentsChain{ + Length: 10, + }, + BranchOffBlock: 0, + // overlap of 3 + UpdateSegmentLength: 8, + Reorg: false, + }, + + expectedFullLength: 10, + // overlap shouldn't be updated + expectedUpdatedLength: 5, + expectedUpdatedEarliestNum: 5, + expectedRemovedLength: 0, + expectedRemovedEarliestNum: -1, + expectedErrorString: "", + }, + "full overlap no reorg": { + chainArgs: MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 10, + }, + Update: MakeChainSegmentsChain{ + Length: 10, + }, + BranchOffBlock: 0, + // full overlap + UpdateSegmentLength: 10, + Reorg: false, + }, + + expectedFullLength: 10, + // overlap shouldn't be updated + expectedUpdatedLength: 0, + expectedUpdatedEarliestNum: -1, + expectedRemovedLength: 0, + expectedRemovedEarliestNum: -1, + expectedErrorString: "", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + chain := MakeChainSegments(t, test.chainArgs) + // this should poll all the header and detect the reorg + assert.Assert(t, chain.UpdateSegment.Len() > 0) + assert.Assert(t, chain.OriginalSegment.Len() > 0) + result, err := chain.OriginalSegment.UpdateLatest(context.Background(), chain.Client, chain.UpdateSegment) + assert.NilError(t, err) + full := result.FullSegment + removed := result.RemovedSegment + updated := result.UpdatedSegment + assert.Assert(t, full != nil) + assert.Assert(t, full.Len() > 0) + if test.expectedErrorString != "" { + assert.ErrorContains(t, err, test.expectedErrorString) + return + } + assert.NilError(t, err) + assert.Assert(t, full != nil) + assert.Assert(t, chain.OriginalSegment != nil) + + assert.Equal(t, full.Earliest().Number.Cmp(big.NewInt(0)), 0) + for i, h := range full.Get() { + t.Logf("full: index=%d, num=%d", i, h.Number.Uint64()) + } + assert.Equal(t, full.Latest().Number.Cmp(big.NewInt(int64(test.expectedFullLength)-1)), 0) + + if updated != nil { + assert.Assert(t, updated.Len() > 0) + assert.Equal(t, updated.Len(), test.expectedUpdatedLength) + assert.Equal(t, updated.Earliest().Number.Cmp(big.NewInt(int64(test.expectedUpdatedEarliestNum))), 0) + assert.Equal(t, updated.Latest().Number.Cmp(big.NewInt(int64(test.expectedUpdatedEarliestNum+test.expectedUpdatedLength)-1)), 0) + } else { + assert.Equal(t, 0, test.expectedUpdatedLength) + } + if removed != nil { + assert.Assert(t, removed.Len() > 0) + assert.Equal(t, removed.Len(), test.expectedRemovedLength) + assert.Equal(t, removed.Earliest().Number.Cmp(big.NewInt(int64(test.expectedRemovedEarliestNum))), 0) + assert.Equal(t, removed.Latest().Number.Cmp(big.NewInt(int64(test.expectedRemovedEarliestNum+test.expectedRemovedLength)-1)), 0) + } else { + assert.Equal(t, 0, test.expectedRemovedLength) + } + }) + } +} + +type LogFactory func(relativeIndex int, header *types.Header) ([]types.Log, error) + +type MakeChainSegmentsChain struct { + Length int + LogFactory LogFactory +} + +type MakeChainSegmentsArgs struct { + Original MakeChainSegmentsChain + Update MakeChainSegmentsChain + + BranchOffBlock int + // This will cut the updated chain segment + // so that it only has headers from + // `latest update header - UpdateSegmentLength` + // until + // `latest update header`, while the client + // still knows the state of all update headers + UpdateSegmentLength int + // uses a different mixin-seed value in the update chain, + // and induces a reorg + Reorg bool + // Will not set the internal state of the client + // to the latest head of the updated chain. + // If this is true, the Controller.ProgressHead() + // or Controller.ProgressAllHeads() have to + // be called manually so that blocks can + // be queried from the client. + ClientNoProgressHeads bool +} + +type MakeChainSegmentsResult struct { + Client client.Sync + Controller *client.TestClientController + OriginalSegment *chainsegment.ChainSegment + UpdateSegment *chainsegment.ChainSegment +} + +func MakeChainSegments(t *testing.T, args MakeChainSegmentsArgs) *MakeChainSegmentsResult { + t.Helper() + + var oldHeaders []*types.Header + newHeaders := []*types.Header{} + assert.Assert(t, args.BranchOffBlock < args.Original.Length) + newChainlength := args.Update.Length + args.BranchOffBlock + assert.Assert(t, args.UpdateSegmentLength <= newChainlength) + oldHeaders = MakeChain(0, common.BigToHash(big.NewInt(0)), uint(args.Original.Length), 42) + // TODO: header events + + // use different seed for the reorg chain to change the hashes + parentHash := common.BigToHash(big.NewInt(0)) + if args.BranchOffBlock != 0 { + parentHash = oldHeaders[args.BranchOffBlock-1].Hash() + } + var seed int64 = 42 + if args.Reorg { + seed = 442 + } + reorgHeaders := MakeChain(int64(args.BranchOffBlock), parentHash, uint(args.Update.Length), seed) + newHeaders = append(newHeaders, oldHeaders[:args.BranchOffBlock]...) + newHeaders = append(newHeaders, reorgHeaders...) + + // Make some assertions about the constructed chains + assert.Equal(t, len(oldHeaders), args.Original.Length) + assert.Equal(t, len(reorgHeaders), args.Update.Length) + assert.Equal(t, len(newHeaders), newChainlength) + assert.Assert(t, oldHeaders[len(oldHeaders)-1].Number.Cmp(big.NewInt(int64(len(oldHeaders)-1))) == 0) + assert.Assert(t, reorgHeaders[0].Number.Cmp(big.NewInt(int64(args.BranchOffBlock))) == 0) + + assert.Assert(t, reorgHeaders[len(reorgHeaders)-1].Number.Cmp(big.NewInt(int64(args.BranchOffBlock+args.Update.Length-1))) == 0) + + log := gethLog.NewLogger(slog.Default().Handler()) + testClient, testClientController := client.NewTestClient(log) + + for i, h := range oldHeaders { + var logs []types.Log + if args.Original.LogFactory != nil { + var err error + logs, err = args.Original.LogFactory(i, h) + assert.NilError(t, err) + } + testClientController.AppendNextHeader(h, logs...) + } + for i, h := range reorgHeaders { + var logs []types.Log + if args.Update.LogFactory != nil { + var err error + logs, err = args.Update.LogFactory(i, h) + assert.NilError(t, err) + } + testClientController.AppendNextHeader(h, logs...) + } + if !args.ClientNoProgressHeads { + testClientController.ProgressAllHeads() + } + original := chainsegment.NewChainSegment(oldHeaders...) + updateHeaders := newHeaders[len(newHeaders)-args.UpdateSegmentLength:] + update := chainsegment.NewChainSegment(updateHeaders...) + assert.Assert(t, update.Len() == args.UpdateSegmentLength) + assert.Assert(t, update.Len() > 0) + assert.Assert(t, original.Len() > 0) + assert.Assert(t, update.Len() > 0) + + return &MakeChainSegmentsResult{ + Client: testClient, + Controller: testClientController, + OriginalSegment: original, + UpdateSegment: update, + } +} + +func TestReplaceWholeSegment(t *testing.T) { + headers := MakeChain(1, common.BigToHash(big.NewInt(0)), 5, 42) + reorg := MakeChain(1, common.BigToHash(big.NewInt(0)), 5, 422) + + cs := chainsegment.NewChainSegment(headers...) + rcs := chainsegment.NewChainSegment(reorg...) + remove, update := cs.DiffLeftAligned(rcs) + + assert.Equal(t, len(remove.Get()), 5) + assert.Equal(t, len(update.Get()), 5) +} diff --git a/rolling-shutter/medley/chainsync/syncer_test/fetcher_test.go b/rolling-shutter/medley/chainsync/syncer_test/fetcher_test.go new file mode 100644 index 00000000..90601f8f --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer_test/fetcher_test.go @@ -0,0 +1,162 @@ +package tester + +import ( + "context" + "fmt" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + gethLog "github.com/ethereum/go-ethereum/log" + "github.com/shutter-network/shop-contracts/bindings" + "golang.org/x/exp/slog" + "gotest.tools/v3/assert" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/syncer" + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/service" +) + +// This serves as a smoketest for the home-baked PackKeyBroadcast function +// in the test and the handler parsing. +func TestPackAndParseEvent(t *testing.T) { + headers := MakeChain(1, common.BigToHash(big.NewInt(0)), 1, 42) + header := *headers[0] + eon := uint64(42) + key := []byte("thisisasecretbytekey") + + eventLog, err := PackKeyBroadcast(eon, key, header) + assert.NilError(t, err) + + log := gethLog.NewLogger(slog.Default().Handler()) + testHandler, err := NewTestKeyBroadcastHandler(log) + assert.NilError(t, err) + h, err := syncer.WrapHandler(testHandler) + assert.NilError(t, err) + + evAny, err := h.Parse(*eventLog) + assert.NilError(t, err) + + ev, ok := evAny.(bindings.KeyBroadcastContractEonKeyBroadcast) + assert.Assert(t, ok) + assert.Equal(t, ev.Eon, eon) + assert.DeepEqual(t, ev.Key, key) + assert.DeepEqual(t, ev.Raw.BlockHash, header.Hash()) +} + +func TestReorg(t *testing.T) { //nolint: funlen,gocyclo + log := gethLog.NewLogger(slog.Default().Handler()) + + var originalEvents LogFactory = func(relativeIndex int, header *types.Header) ([]types.Log, error) { + if relativeIndex == 1 { + // shouldn't be removed, is in non-reorged chainsegment + log, err := PackKeyBroadcast(1, []byte("key1"), *header) + if err != nil || log == nil { + return nil, err + } + return []types.Log{*log}, err + } + if relativeIndex == 7 { + // should be removed, is in reorged chainsegment + log, err := PackKeyBroadcast(2, []byte("key2"), *header) + if err != nil || log == nil { + return nil, err + } + return []types.Log{*log}, err + } + return nil, nil + } + var updateEvents LogFactory = func(relativeIndex int, header *types.Header) ([]types.Log, error) { + if relativeIndex == 1 { + // shouldn't be removed, is in non-reorged chainsegment + log, err := PackKeyBroadcast(3, []byte("key3"), *header) + if err != nil || log == nil { + return nil, err + } + return []types.Log{*log}, err + } + if relativeIndex == 7 { + // shouldn't be removed, is in non-reorged chainsegment + log, err := PackKeyBroadcast(4, []byte("key4"), *header) + if err != nil || log == nil { + return nil, err + } + return []types.Log{*log}, err + } + return nil, nil + } + chain := MakeChainSegments(t, + MakeChainSegmentsArgs{ + Original: MakeChainSegmentsChain{ + Length: 10, + LogFactory: originalEvents, + }, + Update: MakeChainSegmentsChain{ + Length: 10, + LogFactory: updateEvents, + }, + BranchOffBlock: 5, + UpdateSegmentLength: 1, + Reorg: true, + ClientNoProgressHeads: true, + }, + ) + + f := syncer.NewFetcher(chain.Client, syncer.NewMemoryChainCache(50, nil)) + + keyBroadcastHandler, err := NewTestKeyBroadcastHandler(log) + assert.NilError(t, err) + h, err := syncer.WrapHandler(keyBroadcastHandler) + assert.NilError(t, err) + + chainUpdateHandler, chainUpdateHandlerChannel, err := NewTestChainUpdateHandler(log) + assert.NilError(t, err) + + f.RegisterContractEventHandler(h) + f.RegisterChainUpdateHandler(chainUpdateHandler) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // we have to be able to query a latest head + ok := chain.Controller.ProgressHead() + assert.Assert(t, ok) + group, deferFn := service.RunBackground(ctx, f) + defer deferFn() + defer func() { + if ctx.Err() != nil { + // only wait for the error when the deadline + // raised + err := group.Wait() + if err != nil { + err = fmt.Errorf("Fetcher failed during test: %w", err) + } + assert.NilError(t, err) + } + }() + chain.Controller.WaitSubscribed(ctx) + + for { + ok := chain.Controller.ProgressHead() + if !ok { + break + } + err := chain.Controller.EmitLatestHead(ctx) + assert.NilError(t, err) + // Wait for the handler to be finished with processing + select { + case <-chainUpdateHandlerChannel: + case <-ctx.Done(): + t.FailNow() + } + } + uptodateEons := keyBroadcastHandler.GetEons() + t.Logf("eons: %v", uptodateEons) + for _, eon := range []uint64{1, 3, 4} { + _, ok := uptodateEons[eon] + assert.Assert(t, ok) + } + _ = group + // group.Wait() +} diff --git a/rolling-shutter/medley/chainsync/syncer_test/util.go b/rolling-shutter/medley/chainsync/syncer_test/util.go new file mode 100644 index 00000000..274e8e35 --- /dev/null +++ b/rolling-shutter/medley/chainsync/syncer_test/util.go @@ -0,0 +1,191 @@ +package tester + +import ( + "context" + "math/big" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/shutter-network/shop-contracts/bindings" + + "github.com/shutter-network/rolling-shutter/rolling-shutter/medley/chainsync/syncer" +) + +func init() { + var err error + KeyBroadcastContractABI, err = bindings.KeyBroadcastContractMetaData.GetAbi() + if err != nil { + panic(err) + } +} + +var KeyBroadcastContractAddress = common.BigToAddress(big.NewInt(42)) +var KeyBroadcastContractABI *abi.ABI + +func MakeChain(start int64, startParent common.Hash, numHeader uint, seed int64) []*types.Header { + n := numHeader + parent := startParent + num := big.NewInt(start) + h := []*types.Header{} + + // change the hashes for different seeds + mixinh := common.BigToHash(big.NewInt(seed)) + for n > 0 { + head := &types.Header{ + ParentHash: parent, + Number: num, + MixDigest: mixinh, + } + h = append(h, head) + num = new(big.Int).Add(num, big.NewInt(1)) + parent = head.Hash() + n-- + } + return h +} + +func NewTestKeyBroadcastHandler(logger log.Logger) (*TestKeyBroadcastHandler, error) { //nolint: unparam + return &TestKeyBroadcastHandler{ + log: logger, + evABI: KeyBroadcastContractABI, + address: KeyBroadcastContractAddress, + eons: map[common.Hash]uint64{}, + }, nil +} + +type TestKeyBroadcastHandler struct { + log log.Logger + evABI *abi.ABI + address common.Address + eons map[common.Hash]uint64 +} + +func (tkbh *TestKeyBroadcastHandler) Address() common.Address { + return tkbh.address +} + +func (tkbh *TestKeyBroadcastHandler) Log(msg string, ctx ...any) { + tkbh.log.Info(msg, ctx) +} + +func (tkbh *TestKeyBroadcastHandler) Event() string { + return "EonKeyBroadcast" //nolint: goconst +} + +func (tkbh *TestKeyBroadcastHandler) ABI() abi.ABI { + return *tkbh.evABI +} + +func (tkbh *TestKeyBroadcastHandler) Accept( + _ context.Context, + _ types.Header, + _ bindings.KeyBroadcastContractEonKeyBroadcast, +) (bool, error) { + return true, nil +} + +func (tkbh *TestKeyBroadcastHandler) Handle( + _ context.Context, + update syncer.ChainUpdateContext, + evs []bindings.KeyBroadcastContractEonKeyBroadcast, +) error { + if update.Remove != nil { + for _, h := range update.Remove.Get() { + _, ok := tkbh.eons[h.Hash()] + if ok { + delete(tkbh.eons, h.Hash()) + } + } + } + if update.Append != nil { + for _, ev := range evs { + tkbh.eons[ev.Raw.BlockHash] = ev.Eon + } + } + return nil +} + +func (tkbh *TestKeyBroadcastHandler) GetEons() map[uint64]struct{} { + m := map[uint64]struct{}{} + for _, v := range tkbh.eons { + m[v] = struct{}{} + } + return m +} +func (tkbh *TestKeyBroadcastHandler) GetBlockHashes() map[common.Hash]struct{} { + m := map[common.Hash]struct{}{} + for hsh := range tkbh.eons { + m[hsh] = struct{}{} + } + return m +} + +func NewTestChainUpdateHandler(logger log.Logger) (*TestChainUpdateHandler, chan syncer.ChainUpdateContext, error) { //nolint: unparam + querySyncChan := make(chan syncer.ChainUpdateContext) + return &TestChainUpdateHandler{ + log: logger, + querySyncChan: querySyncChan, + chainCache: syncer.NewMemoryChainCache(100, nil), + }, querySyncChan, nil +} + +type TestChainUpdateHandler struct { + log log.Logger + querySyncChan chan syncer.ChainUpdateContext + chainCache syncer.ChainCache +} + +func (tkbh *TestChainUpdateHandler) Handle( + ctx context.Context, + update syncer.ChainUpdateContext, +) error { + err := tkbh.chainCache.Update(ctx, update) + tkbh.querySyncChan <- update + return err +} + +func (tkbh *TestChainUpdateHandler) GetBlockHashes(ctx context.Context) (map[common.Hash]struct{}, error) { + m := map[common.Hash]struct{}{} + chain, err := tkbh.chainCache.Get(ctx) + if err != nil { + return m, err + } + for _, h := range chain.Get() { + m[h.Hash()] = struct{}{} + } + return m, nil +} + +func MustPackKeyBroadcast(eon uint64, key []byte, header types.Header) *types.Log { + l, err := PackKeyBroadcast(eon, key, header) + if err != nil { + panic("can't pack key broadcast event") + } + return l +} + +// This roughly emulates what the EVM does +// and packs a EonKeyBroadcast log. +func PackKeyBroadcast(eon uint64, key []byte, header types.Header) (*types.Log, error) { + event := "EonKeyBroadcast" + address := KeyBroadcastContractAddress + evABI := KeyBroadcastContractABI.Events[event] + + data, err := evABI.Inputs.Pack(eon, key) + if err != nil { + return nil, err + } + topics := []common.Hash{KeyBroadcastContractABI.Events[event].ID} + return &types.Log{ + Address: address, + Data: data, + Topics: topics, + BlockNumber: header.Number.Uint64(), + BlockHash: header.Hash(), + // NOTE: we don't set all the values here, make + // sure no reader relies on them when writing test handler + // (e.g. TxHash, TxIndex, ...) + }, nil +}