Skip to content

Commit

Permalink
Fixed connection map for tcp proxy, battle-tested on M*CTF
Browse files Browse the repository at this point in the history
  • Loading branch information
pomo-mondreganto committed Dec 19, 2020
1 parent a936292 commit f9409f0
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
64 changes: 64 additions & 0 deletions internal/proxy/tcp/connmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package tcp

import (
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
"net"
"sync"
)

func newConnMap() *connMap {
return &connMap{
conns: make(map[string]net.Conn),
mu: new(sync.RWMutex),
seq: new(atomic.Int32),
}
}

type connMap struct {
conns map[string]net.Conn
mu *sync.RWMutex
seq *atomic.Int32
}

func (m *connMap) add(conn net.Conn) string {
id := genConnID(conn, int(m.seq.Inc()))
m.mu.Lock()
defer m.mu.Unlock()
m.conns[id] = conn
return id
}

func (m *connMap) remove(connID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.conns, connID)
}

func (m *connMap) length() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.conns)
}

func (m *connMap) get(connID string) net.Conn {
m.mu.RLock()
defer m.mu.RUnlock()
return m.conns[connID]
}

func (m *connMap) closeAll(logger *logrus.Entry) {
m.mu.Lock()
defer m.mu.Unlock()
logger.Debugf("Closing %d connections", len(m.conns))
for id, c := range m.conns {
cl := logger.WithField("conn", id)
if err := c.Close(); err != nil {
if isConnectionClosedErr(err) {
cl.Debugf("Connection already closed: %v", err)
} else {
cl.Errorf("Error closing connection: %v", err)
}
}
}
}
32 changes: 9 additions & 23 deletions internal/proxy/tcp/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewProxy(cfg common.ServiceConfig, rs *filters.RuleSet) (*Proxy, error) {
serviceConfig: cfg,
logger: logger,
filters: fts,
conns: make(map[string]net.Conn),
conns: newConnMap(),
wg: new(sync.WaitGroup),
}
return p, nil
Expand All @@ -61,8 +61,8 @@ type Proxy struct {
serviceConfig common.ServiceConfig
closing bool
listening atomic.Bool
conns map[string]net.Conn
connSeq atomic.Int32
conns *connMap
connSeq *atomic.Int32
wg *sync.WaitGroup
listener net.Listener
logger *logrus.Entry
Expand Down Expand Up @@ -107,18 +107,7 @@ func (p *Proxy) Shutdown(ctx context.Context) error {

done := make(chan interface{}, 1)
go func() {
p.logger.Debugf("Closing %d connections", len(p.conns))
for id, c := range p.conns {
logger := p.logger.WithField("conn", id)
if err := c.Close(); err != nil {
if isConnectionClosedErr(err) {
logger.Debugf("Connection already closed: %v", err)
} else {
logger.Errorf("Error closing connection: %v", err)
}
}
}

p.conns.closeAll(p.logger)
p.wg.Wait()
done <- nil
}()
Expand Down Expand Up @@ -222,13 +211,13 @@ func (p Proxy) oneSideHandler(conn *Connection, logger *logrus.Entry, ingress bo
func (p Proxy) handleConnection(id string) {
defer p.wg.Done()

conn := p.conns[id]
conn := p.conns.get(id)
connLogger := p.logger.WithField("conn", id)
defer func() {
if err := conn.Close(); err != nil && !isConnectionClosedErr(err) {
connLogger.Warningf("Error closing connection: %v", err)
}
delete(p.conns, id)
p.conns.remove(id)
}()

connLogger.Debugf("Connection received")
Expand Down Expand Up @@ -283,18 +272,15 @@ func (p *Proxy) serve() {
return
}

connID := genConnID(conn, int(p.connSeq.Inc()))
logger := p.logger.WithField("conn", connID)

if !p.GetListening() {
logger.Debugf("Proxy closed, dropping the connection")
p.logger.Debugf("Proxy closed, dropping the connection")
if err := conn.Close(); err != nil {
logger.Errorf("Error dropping the connection: %v", err)
p.logger.Errorf("Error dropping the connection: %v", err)
}
continue
}

p.conns[connID] = conn
connID := p.conns.add(conn)
p.wg.Add(1)
go p.handleConnection(connID)
}
Expand Down

0 comments on commit f9409f0

Please sign in to comment.