diff --git a/connectionprovider.go b/connectionprovider.go index e98562f..e7c8c8c 100644 --- a/connectionprovider.go +++ b/connectionprovider.go @@ -35,6 +35,7 @@ var ( type ConnectionProvider interface { // Connection returns a connection and release function. Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) + Close(endpoints []*Endpoint) } // PuddleConnectionProvider provides connections using the Puddle connection pooler. @@ -44,9 +45,30 @@ type PuddleConnectionProvider struct { credentials credentials.TransportCredentials } +// Close closes connections to the given endpoints and specific connection provider. +func (c *PuddleConnectionProvider) Close(endpoints []*Endpoint) { + connectionPoolsMu.Lock() + defer connectionPoolsMu.Unlock() + for i := range endpoints { + key := c.getConnectionKey(endpoints[i]) + if pool, exists := connectionPools[key]; exists { + pool.Close() + delete(connectionPools, key) + } + } +} + +func (c *PuddleConnectionProvider) getConnectionKey(endpoint *Endpoint) string { + key := fmt.Sprintf("%s:%d", endpoint.host, endpoint.port) + if c.name != "" { + key += fmt.Sprintf("-%s", c.name) + } + return key +} + // Connection returns a connection and release function. func (c *PuddleConnectionProvider) Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) { - pool := c.obtainOrCreatePool(fmt.Sprintf("%s:%d", endpoint.host, endpoint.port)) + pool := c.obtainOrCreatePool(c.getConnectionKey(endpoint), fmt.Sprintf("%s:%d", endpoint.host, endpoint.port)) res, err := pool.Acquire(ctx) if err != nil { @@ -56,9 +78,9 @@ func (c *PuddleConnectionProvider) Connection(ctx context.Context, endpoint *End return res.Value(), res.Release, nil } -func (c *PuddleConnectionProvider) obtainOrCreatePool(address string) *puddle.Pool[*grpc.ClientConn] { +func (c *PuddleConnectionProvider) obtainOrCreatePool(connectionKey, address string) *puddle.Pool[*grpc.ClientConn] { connectionPoolsMu.Lock() - pool, exists := connectionPools[address] + pool, exists := connectionPools[connectionKey] connectionPoolsMu.Unlock() if !exists { constructor := func(ctx context.Context) (*grpc.ClientConn, error) { @@ -93,7 +115,7 @@ func (c *PuddleConnectionProvider) obtainOrCreatePool(address string) *puddle.Po MaxSize: c.poolConnections, }) connectionPoolsMu.Lock() - connectionPools[address] = pool + connectionPools[connectionKey] = pool connectionPoolsMu.Unlock() } diff --git a/grpc_internal_test.go b/grpc_internal_test.go index 6dd9c9f..43e729c 100644 --- a/grpc_internal_test.go +++ b/grpc_internal_test.go @@ -37,6 +37,8 @@ type ErroringConnectionProvider struct { pb.UnimplementedListerServer } +func (c *ErroringConnectionProvider) Close(_ []*Endpoint) {} + // Connection returns a connection and release function. func (c *ErroringConnectionProvider) Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) { return nil, nil, errors.New("mock error") @@ -72,6 +74,21 @@ func (c *BufConnectionProvider) bufDialer(ctx context.Context, in string) (net.C return c.listeners[in].Dial() } +func (c *BufConnectionProvider) Close(endpoints []*Endpoint) { + c.mutex.Lock() + defer c.mutex.Unlock() + for i := range endpoints { + if server, exists := c.servers[endpoints[i].String()]; exists { + server.Stop() + delete(c.servers, endpoints[i].String()) + } + if listener, exists := c.listeners[endpoints[i].String()]; exists { + listener.Close() + delete(c.listeners, endpoints[i].String()) + } + } +} + // Connection returns a connection and release function. func (c *BufConnectionProvider) Connection(ctx context.Context, endpoint *Endpoint) (*grpc.ClientConn, func(), error) { serverAddress := fmt.Sprintf("%s:%d", endpoint.host, endpoint.port) diff --git a/parameters.go b/parameters.go index d8a7d6d..077ee51 100644 --- a/parameters.go +++ b/parameters.go @@ -24,6 +24,7 @@ type parameters struct { monitor Metrics timeout time.Duration name string + connectionName string credentials credentials.TransportCredentials endpoints []*Endpoint poolConnections int32 @@ -82,6 +83,14 @@ func WithPoolConnections(connections int32) Parameter { }) } +// WithConnectionProviderName sets the name for the connection provider. +// This is used to distinguish between different connection providers. +func WithConnectionProviderName(name string) Parameter { + return parameterFunc(func(p *parameters) { + p.connectionName = name + }) +} + // parseAndCheckParameters parses and checks parameters to ensure that mandatory parameters are present and correct. func parseAndCheckParameters(params ...Parameter) (*parameters, error) { parameters := parameters{ diff --git a/wallet.go b/wallet.go index c478e9c..7293fea 100644 --- a/wallet.go +++ b/wallet.go @@ -67,7 +67,7 @@ func Open(ctx context.Context, wallet.timeout = parameters.timeout wallet.endpoints = make([]*Endpoint, len(parameters.endpoints)) wallet.connectionProvider = &PuddleConnectionProvider{ - name: parameters.name, + name: parameters.connectionName, poolConnections: parameters.poolConnections, credentials: parameters.credentials.Clone(), } @@ -186,6 +186,11 @@ func (w *wallet) CreateDistributedAccount(ctx context.Context, name string, part return w.GenerateDistributedAccount(ctx, name, participants, signingThreshold, passphrase) } +// Close closes connections for wallet. +func (w *wallet) Close() { + w.connectionProvider.Close(w.endpoints) +} + // SetConnectionProvider sets a connection provider for the wallet. // This should, in general, only be used for testing. func (w *wallet) SetConnectionProvider(connectionProvider ConnectionProvider) {