diff --git a/internal/proxy/tcp/connmap.go b/internal/proxy/tcp/connmap.go new file mode 100644 index 0000000..b5cf364 --- /dev/null +++ b/internal/proxy/tcp/connmap.go @@ -0,0 +1,64 @@ +package tcp + +import ( + "github.com/sirupsen/logrus" + "go.uber.org/atomic" + "net" + "sync" +) + +func newConnMap() *connMap { + return &connMap{ + conns: make(map[string]net.Conn), + mu: new(sync.RWMutex), + seq: new(atomic.Int32), + } +} + +type connMap struct { + conns map[string]net.Conn + mu *sync.RWMutex + seq *atomic.Int32 +} + +func (m *connMap) add(conn net.Conn) string { + id := genConnID(conn, int(m.seq.Inc())) + m.mu.Lock() + defer m.mu.Unlock() + m.conns[id] = conn + return id +} + +func (m *connMap) remove(connID string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.conns, connID) +} + +func (m *connMap) length() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.conns) +} + +func (m *connMap) get(connID string) net.Conn { + m.mu.RLock() + defer m.mu.RUnlock() + return m.conns[connID] +} + +func (m *connMap) closeAll(logger *logrus.Entry) { + m.mu.Lock() + defer m.mu.Unlock() + logger.Debugf("Closing %d connections", len(m.conns)) + for id, c := range m.conns { + cl := logger.WithField("conn", id) + if err := c.Close(); err != nil { + if isConnectionClosedErr(err) { + cl.Debugf("Connection already closed: %v", err) + } else { + cl.Errorf("Error closing connection: %v", err) + } + } + } +} diff --git a/internal/proxy/tcp/proxy.go b/internal/proxy/tcp/proxy.go index e1a5342..0bb14f5 100644 --- a/internal/proxy/tcp/proxy.go +++ b/internal/proxy/tcp/proxy.go @@ -48,7 +48,7 @@ func NewProxy(cfg common.ServiceConfig, rs *filters.RuleSet) (*Proxy, error) { serviceConfig: cfg, logger: logger, filters: fts, - conns: make(map[string]net.Conn), + conns: newConnMap(), wg: new(sync.WaitGroup), } return p, nil @@ -61,8 +61,8 @@ type Proxy struct { serviceConfig common.ServiceConfig closing bool listening atomic.Bool - conns map[string]net.Conn - connSeq atomic.Int32 + conns *connMap + connSeq *atomic.Int32 wg *sync.WaitGroup listener net.Listener logger *logrus.Entry @@ -107,18 +107,7 @@ func (p *Proxy) Shutdown(ctx context.Context) error { done := make(chan interface{}, 1) go func() { - p.logger.Debugf("Closing %d connections", len(p.conns)) - for id, c := range p.conns { - logger := p.logger.WithField("conn", id) - if err := c.Close(); err != nil { - if isConnectionClosedErr(err) { - logger.Debugf("Connection already closed: %v", err) - } else { - logger.Errorf("Error closing connection: %v", err) - } - } - } - + p.conns.closeAll(p.logger) p.wg.Wait() done <- nil }() @@ -222,13 +211,13 @@ func (p Proxy) oneSideHandler(conn *Connection, logger *logrus.Entry, ingress bo func (p Proxy) handleConnection(id string) { defer p.wg.Done() - conn := p.conns[id] + conn := p.conns.get(id) connLogger := p.logger.WithField("conn", id) defer func() { if err := conn.Close(); err != nil && !isConnectionClosedErr(err) { connLogger.Warningf("Error closing connection: %v", err) } - delete(p.conns, id) + p.conns.remove(id) }() connLogger.Debugf("Connection received") @@ -283,18 +272,15 @@ func (p *Proxy) serve() { return } - connID := genConnID(conn, int(p.connSeq.Inc())) - logger := p.logger.WithField("conn", connID) - if !p.GetListening() { - logger.Debugf("Proxy closed, dropping the connection") + p.logger.Debugf("Proxy closed, dropping the connection") if err := conn.Close(); err != nil { - logger.Errorf("Error dropping the connection: %v", err) + p.logger.Errorf("Error dropping the connection: %v", err) } continue } - p.conns[connID] = conn + connID := p.conns.add(conn) p.wg.Add(1) go p.handleConnection(connID) }