diff --git a/registry/storage/shares.go b/registry/storage/shares.go index d075e1fa27..cf02f0d3fc 100644 --- a/registry/storage/shares.go +++ b/registry/storage/shares.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/gob" "encoding/hex" + "errors" "fmt" "sync" @@ -57,8 +58,8 @@ type sharesStorage struct { prefix []byte shares map[string]*types.SSVShare validatorStore *validatorStore + dbmu sync.Mutex // parent lock for in-memory mutex mu sync.RWMutex - dbmu sync.Mutex } type storageOperator struct { @@ -133,12 +134,12 @@ func NewSharesStorage(logger *zap.Logger, db basedb.Database, prefix []byte) (Sh // load reads all shares from db. func (s *sharesStorage) load() error { - s.mu.Lock() - defer s.mu.Unlock() - s.dbmu.Lock() defer s.dbmu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() + return s.db.GetAll(append(s.prefix, sharesPrefix...), func(i int, obj basedb.Obj) error { val := &storageShare{} if err := val.Decode(obj.Value); err != nil { @@ -210,50 +211,55 @@ func (s *sharesStorage) Save(rw basedb.ReadWriter, shares ...*types.SSVShare) er } } + s.dbmu.Lock() + defer s.dbmu.Unlock() + + // Update in-memory. err := func() error { - s.dbmu.Lock() - defer s.dbmu.Unlock() - - return s.db.Using(rw).SetMany(s.prefix, len(shares), func(i int) (basedb.Obj, error) { - share := specShareToStorageShare(shares[i]) - value, err := share.Encode() - if err != nil { - return basedb.Obj{}, fmt.Errorf("failed to serialize share: %w", err) - } - return basedb.Obj{Key: s.storageKey(share.ValidatorPubKey[:]), Value: value}, nil - }) - }() - if err != nil { - return err - } + s.mu.Lock() + defer s.mu.Unlock() - s.mu.Lock() - defer s.mu.Unlock() + updateShares := make([]*types.SSVShare, 0, len(shares)) + addShares := make([]*types.SSVShare, 0, len(shares)) - updateShares := make([]*types.SSVShare, 0, len(shares)) - addShares := make([]*types.SSVShare, 0, len(shares)) + for _, share := range shares { + key := hex.EncodeToString(share.ValidatorPubKey[:]) - for _, share := range shares { - key := hex.EncodeToString(share.ValidatorPubKey[:]) + // Update validatorStore indices. + if _, ok := s.shares[key]; ok { + updateShares = append(updateShares, share) + } else { + addShares = append(addShares, share) + } + s.shares[key] = share + } - // Update validatorStore indices. - if _, ok := s.shares[key]; ok { - updateShares = append(updateShares, share) - } else { - addShares = append(addShares, share) + if err := s.validatorStore.handleSharesUpdated(updateShares...); err != nil { + return err } - s.shares[key] = share - } - if err := s.validatorStore.handleSharesUpdated(updateShares...); err != nil { - return err - } + if err := s.validatorStore.handleSharesAdded(addShares...); err != nil { + return err + } - if err := s.validatorStore.handleSharesAdded(addShares...); err != nil { + return nil + }() + if err != nil { return err } - return nil + return s.unsafeSave(rw, shares...) +} + +func (s *sharesStorage) unsafeSave(rw basedb.ReadWriter, shares ...*types.SSVShare) error { + return s.db.Using(rw).SetMany(s.prefix, len(shares), func(i int) (basedb.Obj, error) { + share := specShareToStorageShare(shares[i]) + value, err := share.Encode() + if err != nil { + return basedb.Obj{}, fmt.Errorf("failed to serialize share: %w", err) + } + return basedb.Obj{Key: s.storageKey(share.ValidatorPubKey[:]), Value: value}, nil + }) } func specShareToStorageShare(share *types.SSVShare) *storageShare { @@ -317,39 +323,46 @@ func (s *sharesStorage) storageShareToSpecShare(share *storageShare) (*types.SSV return specShare, nil } -func (s *sharesStorage) Delete(rw basedb.ReadWriter, pubKey []byte) error { - s.mu.Lock() - defer s.mu.Unlock() +var errShareNotFound = errors.New("share not found") +func (s *sharesStorage) Delete(rw basedb.ReadWriter, pubKey []byte) error { s.dbmu.Lock() defer s.dbmu.Unlock() - // Delete the share from the database - if err := s.db.Using(rw).Delete(s.prefix, s.storageKey(pubKey)); err != nil { - return err - } + err := func() error { + s.mu.Lock() + defer s.mu.Unlock() - share := s.shares[hex.EncodeToString(pubKey)] - if share == nil { - return nil - } + share, found := s.shares[hex.EncodeToString(pubKey)] + if !found { + return errShareNotFound + } - // Remove the share from local storage map - delete(s.shares, hex.EncodeToString(pubKey)) + // Remove the share from local storage map + delete(s.shares, hex.EncodeToString(pubKey)) - // Remove the share from the validator store. This method will handle its own locking. - if err := s.validatorStore.handleShareRemoved(share); err != nil { + // Remove the share from the validator store. This method will handle its own locking. + return s.validatorStore.handleShareRemoved(share) + }() + if errors.Is(err, errShareNotFound) { + return nil + } + if err != nil { return err } - return nil + // Delete the share from the database + return s.db.Using(rw).Delete(s.prefix, s.storageKey(pubKey)) } // UpdateValidatorsMetadata updates the metadata of the given validator func (s *sharesStorage) UpdateValidatorsMetadata(data map[spectypes.ValidatorPK]*beaconprotocol.ValidatorMetadata) error { var shares []*types.SSVShare - func() { + s.dbmu.Lock() + defer s.dbmu.Unlock() + + err := func() error { s.mu.RLock() defer s.mu.RUnlock() @@ -367,30 +380,33 @@ func (s *sharesStorage) UpdateValidatorsMetadata(data map[spectypes.ValidatorPK] share.Share.ValidatorIndex = metadata.Index shares = append(shares, share) } + + return s.validatorStore.handleSharesUpdated(shares...) }() + if err != nil { + return err + } - return s.Save(nil, shares...) + return s.unsafeSave(nil, shares...) } // Drop deletes all shares. func (s *sharesStorage) Drop() error { - s.mu.Lock() - defer s.mu.Unlock() - s.dbmu.Lock() defer s.dbmu.Unlock() - err := s.db.DropPrefix(bytes.Join( + func() { + s.mu.Lock() + defer s.mu.Unlock() + + s.shares = make(map[string]*types.SSVShare) + s.validatorStore.handleDrop() + }() + + return s.db.DropPrefix(bytes.Join( [][]byte{s.prefix, sharesPrefix, []byte("/")}, nil, )) - if err != nil { - return err - } - - s.shares = make(map[string]*types.SSVShare) - s.validatorStore.handleDrop() - return nil } // storageKey builds share key using sharesPrefix & validator public key, e.g. "shares/0x00..01"