Skip to content

Commit

Permalink
Fixed Node/ManagedNode concurrency issues (#808)
Browse files Browse the repository at this point in the history
Signed-off-by: Antonio Mindov <[email protected]>
  • Loading branch information
rokn authored Sep 20, 2023
1 parent 80b8a94 commit f7e9c4b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
20 changes: 20 additions & 0 deletions managed_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package hedera
*/

import (
"sync"
"time"
)

Expand Down Expand Up @@ -57,9 +58,12 @@ type _ManagedNode struct {
maxBackoff time.Duration
badGrpcStatusCount int64
readmitTime *time.Time
mutex sync.RWMutex
}

func (node *_ManagedNode) _GetAttempts() int64 {
node.mutex.RLock()
defer node.mutex.RUnlock()
return node.badGrpcStatusCount
}

Expand All @@ -72,6 +76,8 @@ func (node *_ManagedNode) _GetAddress() string {
}

func (node *_ManagedNode) _GetReadmitTime() *time.Time {
node.mutex.RLock()
defer node.mutex.RUnlock()
return node.readmitTime
}

Expand Down Expand Up @@ -109,11 +115,17 @@ func (node *_ManagedNode) _GetMaxBackoff() time.Duration {
}

func (node *_ManagedNode) _InUse() {
node.mutex.Lock()
defer node.mutex.Unlock()

node.useCount++
node.lastUsed = time.Now()
}

func (node *_ManagedNode) _IsHealthy() bool {
node.mutex.RLock()
defer node.mutex.RUnlock()

if node.readmitTime == nil {
return true
}
Expand All @@ -122,6 +134,9 @@ func (node *_ManagedNode) _IsHealthy() bool {
}

func (node *_ManagedNode) _IncreaseBackoff() {
node.mutex.Lock()
defer node.mutex.Unlock()

node.badGrpcStatusCount++
node.currentBackoff *= 2
if node.currentBackoff > node.maxBackoff {
Expand All @@ -132,13 +147,18 @@ func (node *_ManagedNode) _IncreaseBackoff() {
}

func (node *_ManagedNode) _DecreaseBackoff() {
node.mutex.Lock()
defer node.mutex.Unlock()

node.currentBackoff /= 2
if node.currentBackoff < node.minBackoff {
node.currentBackoff = node.minBackoff
}
}

func (node *_ManagedNode) _Wait() time.Duration {
node.mutex.RLock()
defer node.mutex.RUnlock()
return node.readmitTime.Sub(node.lastUsed)
}

Expand Down
61 changes: 61 additions & 0 deletions network_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,64 @@ func TestUnitConcurrentGetNodeReadmit(t *testing.T) {
network._ReadmitNodes()
require.Equal(t, len(nodes), len(network.healthyNodes))
}

func TestUnitConcurrentNodeAccess(t *testing.T) {
t.Parallel()

network := _NewNetwork()
nodes := newNetworkMockNodes()
err := network.SetNetwork(nodes)
network._SetMinNodeReadmitPeriod(0)
network._SetMaxNodeReadmitPeriod(0)
require.NoError(t, err)

for _, node := range network.nodes {
node._SetMaxBackoff(-1 * time.Minute)
}

numThreads := 3
var wg sync.WaitGroup
node := network._GetNode()
wg.Add(numThreads)
for i := 0; i < numThreads; i++ {
go func() {
for i := 0; i < 20; i++ {
network._GetNode()
network._IncreaseBackoff(node)
node._IsHealthy()
node._GetAttempts()
node._GetReadmitTime()
node._Wait()
node._InUse()
}
wg.Done()
}()
}
wg.Wait()
network._ReadmitNodes()
require.Equal(t, len(nodes), len(network.healthyNodes))
}

func TestUnitConcurrentNodeGetChannel(t *testing.T) {
t.Parallel()

network := _NewNetwork()
nodes := newNetworkMockNodes()
err := network.SetNetwork(nodes)
require.NoError(t, err)

numThreads := 20
var wg sync.WaitGroup
node := network._GetNode()
wg.Add(numThreads)
logger := NewLogger("", LoggerLevelError)
for i := 0; i < numThreads; i++ {
go func() {
node._GetChannel(logger)
wg.Done()
}()
}
wg.Wait()
network._ReadmitNodes()
require.Equal(t, len(nodes), len(network.healthyNodes))
}
8 changes: 8 additions & 0 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"crypto/x509"
"encoding/hex"
"encoding/pem"
"sync"
"time"

"context"
Expand All @@ -45,6 +46,7 @@ type _Node struct {
channel *_Channel
addressBook *NodeAddress
verifyCertificate bool
channelMutex sync.Mutex
}

func _NewNode(accountID AccountID, address string, minBackoff time.Duration) (node *_Node, err error) {
Expand Down Expand Up @@ -121,6 +123,9 @@ func (node *_Node) _GetReadmitTime() *time.Time {
}

func (node *_Node) _GetChannel(logger Logger) (*_Channel, error) {
node.channelMutex.Lock()
defer node.channelMutex.Unlock()

if node.channel != nil {
return node.channel, nil
}
Expand Down Expand Up @@ -196,6 +201,9 @@ func (node *_Node) _GetChannel(logger Logger) (*_Channel, error) {
}

func (node *_Node) _Close() error {
node.channelMutex.Lock()
defer node.channelMutex.Unlock()

if node.channel != nil {
err := node.channel.client.Close()
node.channel = nil
Expand Down

0 comments on commit f7e9c4b

Please sign in to comment.