Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

close wallet connections methods #5

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions connectionprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
type ConnectionProvider interface {
// Connection returns a connection and release function.
Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error)
Close(endpoints []*Endpoint)
}

// PuddleConnectionProvider provides connections using the Puddle connection pooler.
Expand All @@ -44,9 +45,30 @@ type PuddleConnectionProvider struct {
credentials credentials.TransportCredentials
}

// Close closes connections to the given endpoints and specific connection provider.
func (c *PuddleConnectionProvider) Close(endpoints []*Endpoint) {
connectionPoolsMu.Lock()
defer connectionPoolsMu.Unlock()
for i := range endpoints {
key := c.getConnectionKey(endpoints[i])
if pool, exists := connectionPools[key]; exists {
pool.Close()
delete(connectionPools, key)
}
}
}

func (c *PuddleConnectionProvider) getConnectionKey(endpoint *Endpoint) string {
key := fmt.Sprintf("%s:%d", endpoint.host, endpoint.port)
if c.name != "" {
key += fmt.Sprintf("-%s", c.name)
}
return key
}

// Connection returns a connection and release function.
func (c *PuddleConnectionProvider) Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) {
pool := c.obtainOrCreatePool(fmt.Sprintf("%s:%d", endpoint.host, endpoint.port))
pool := c.obtainOrCreatePool(c.getConnectionKey(endpoint), fmt.Sprintf("%s:%d", endpoint.host, endpoint.port))

res, err := pool.Acquire(ctx)
if err != nil {
Expand All @@ -56,9 +78,9 @@ func (c *PuddleConnectionProvider) Connection(ctx context.Context, endpoint *End
return res.Value(), res.Release, nil
}

func (c *PuddleConnectionProvider) obtainOrCreatePool(address string) *puddle.Pool[*grpc.ClientConn] {
func (c *PuddleConnectionProvider) obtainOrCreatePool(connectionKey, address string) *puddle.Pool[*grpc.ClientConn] {
connectionPoolsMu.Lock()
pool, exists := connectionPools[address]
pool, exists := connectionPools[connectionKey]
connectionPoolsMu.Unlock()
if !exists {
constructor := func(ctx context.Context) (*grpc.ClientConn, error) {
Expand Down Expand Up @@ -93,7 +115,7 @@ func (c *PuddleConnectionProvider) obtainOrCreatePool(address string) *puddle.Po
MaxSize: c.poolConnections,
})
connectionPoolsMu.Lock()
connectionPools[address] = pool
connectionPools[connectionKey] = pool
connectionPoolsMu.Unlock()
}

Expand Down
17 changes: 17 additions & 0 deletions grpc_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type ErroringConnectionProvider struct {
pb.UnimplementedListerServer
}

func (c *ErroringConnectionProvider) Close(_ []*Endpoint) {}

// Connection returns a connection and release function.
func (c *ErroringConnectionProvider) Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) {
return nil, nil, errors.New("mock error")
Expand Down Expand Up @@ -72,6 +74,21 @@ func (c *BufConnectionProvider) bufDialer(ctx context.Context, in string) (net.C
return c.listeners[in].Dial()
}

func (c *BufConnectionProvider) Close(endpoints []*Endpoint) {
c.mutex.Lock()
defer c.mutex.Unlock()
for i := range endpoints {
if server, exists := c.servers[endpoints[i].String()]; exists {
server.Stop()
delete(c.servers, endpoints[i].String())
}
if listener, exists := c.listeners[endpoints[i].String()]; exists {
listener.Close()
delete(c.listeners, endpoints[i].String())
}
}
}

// Connection returns a connection and release function.
func (c *BufConnectionProvider) Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) {
serverAddress := fmt.Sprintf("%s:%d", endpoint.host, endpoint.port)
Expand Down
9 changes: 9 additions & 0 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type parameters struct {
monitor Metrics
timeout time.Duration
name string
connectionName string
credentials credentials.TransportCredentials
endpoints []*Endpoint
poolConnections int32
Expand Down Expand Up @@ -82,6 +83,14 @@ func WithPoolConnections(connections int32) Parameter {
})
}

// WithConnectionProviderName sets the name for the connection provider.
// This is used to distinguish between different connection providers.
func WithConnectionProviderName(name string) Parameter {
return parameterFunc(func(p *parameters) {
p.connectionName = name
})
}

// parseAndCheckParameters parses and checks parameters to ensure that mandatory parameters are present and correct.
func parseAndCheckParameters(params ...Parameter) (*parameters, error) {
parameters := parameters{
Expand Down
7 changes: 6 additions & 1 deletion wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Open(ctx context.Context,
wallet.timeout = parameters.timeout
wallet.endpoints = make([]*Endpoint, len(parameters.endpoints))
wallet.connectionProvider = &PuddleConnectionProvider{
name: parameters.name,
name: parameters.connectionName,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is mandatory but connectionname is not. So I think that this should use connectionName if present, otherwise it should fall back to using name.

poolConnections: parameters.poolConnections,
credentials: parameters.credentials.Clone(),
}
Expand Down Expand Up @@ -186,6 +186,11 @@ func (w *wallet) CreateDistributedAccount(ctx context.Context, name string, part
return w.GenerateDistributedAccount(ctx, name, participants, signingThreshold, passphrase)
}

// Close closes connections for wallet.
func (w *wallet) Close() {
w.connectionProvider.Close(w.endpoints)
}

// SetConnectionProvider sets a connection provider for the wallet.
// This should, in general, only be used for testing.
func (w *wallet) SetConnectionProvider(connectionProvider ConnectionProvider) {
Expand Down