diff --git a/udp_mux.go b/udp_mux.go index 47084f83..141ccb20 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -36,7 +36,7 @@ type UDPMuxDefault struct { connsIPv4, connsIPv6 map[string]*udpMuxedConn addressMapMu sync.RWMutex - addressMap map[string]*udpMuxedConn + addressMap map[udpMuxedConnAddr]*udpMuxedConn // Buffer pool to recycle buffers for net.UDPAddr encodes/decodes pool *sync.Pool @@ -105,7 +105,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } m := &UDPMuxDefault{ - addressMap: map[string]*udpMuxedConn{}, + addressMap: map[udpMuxedConnAddr]*udpMuxedConn{}, params: params, connsIPv4: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn), @@ -246,7 +246,7 @@ func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr udpMuxedConnAddr) { if m.IsClosed() { return } @@ -304,7 +304,7 @@ func (m *UDPMuxDefault) connWorker() { // If we have already seen this address dispatch to the appropriate destination m.addressMapMu.Lock() - destinationConn := m.addressMap[addr.String()] + destinationConn := m.addressMap[newUDPMuxedConnAddr(udpAddr)] m.addressMapMu.Unlock() // If we haven't seen this address before but is a STUN packet lookup by ufrag diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index 09e4b3a8..7c9b26e3 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -14,6 +14,17 @@ import ( "github.com/pion/transport/v2/packetio" ) +type udpMuxedConnAddr struct { + ip [16]byte + port uint16 +} + +func newUDPMuxedConnAddr(addr *net.UDPAddr) (a udpMuxedConnAddr) { + copy(a.ip[:], addr.IP.To16()) + a.port = uint16(addr.Port) + return a +} + type udpMuxedConnParams struct { Mux *UDPMuxDefault AddrPool *sync.Pool @@ -26,7 +37,7 @@ type udpMuxedConnParams struct { type udpMuxedConn struct { params *udpMuxedConnParams // Remote addresses that we have sent to on this conn - addresses []string + addresses []udpMuxedConnAddr // Channel holding incoming packets buf *packetio.Buffer @@ -81,7 +92,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { return 0, io.ErrClosedPipe } // Each time we write to a new address, we'll register it with the mux - addr := rAddr.String() + addr := newUDPMuxedConnAddr(rAddr.(*net.UDPAddr)) if !c.containsAddress(addr) { c.addAddress(addr) } @@ -127,15 +138,15 @@ func (c *udpMuxedConn) isClosed() bool { } } -func (c *udpMuxedConn) getAddresses() []string { +func (c *udpMuxedConn) getAddresses() []udpMuxedConnAddr { c.mu.Lock() defer c.mu.Unlock() - addresses := make([]string, len(c.addresses)) + addresses := make([]udpMuxedConnAddr, len(c.addresses)) copy(addresses, c.addresses) return addresses } -func (c *udpMuxedConn) addAddress(addr string) { +func (c *udpMuxedConn) addAddress(addr udpMuxedConnAddr) { c.mu.Lock() c.addresses = append(c.addresses, addr) c.mu.Unlock() @@ -144,11 +155,11 @@ func (c *udpMuxedConn) addAddress(addr string) { c.params.Mux.registerConnForAddress(c, addr) } -func (c *udpMuxedConn) removeAddress(addr string) { +func (c *udpMuxedConn) removeAddress(addr udpMuxedConnAddr) { c.mu.Lock() defer c.mu.Unlock() - newAddresses := make([]string, 0, len(c.addresses)) + newAddresses := make([]udpMuxedConnAddr, 0, len(c.addresses)) for _, a := range c.addresses { if a != addr { newAddresses = append(newAddresses, a) @@ -158,7 +169,7 @@ func (c *udpMuxedConn) removeAddress(addr string) { c.addresses = newAddresses } -func (c *udpMuxedConn) containsAddress(addr string) bool { +func (c *udpMuxedConn) containsAddress(addr udpMuxedConnAddr) bool { c.mu.Lock() defer c.mu.Unlock() for _, a := range c.addresses {