From 4e918e55ba0fe9cd87a3b3eccc658d1a7deeda0f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 24 Oct 2024 11:43:14 +0200 Subject: [PATCH] [client] Fix controller re-connection (#2758) Rethink the peer reconnection implementation --- .github/workflows/golang-test-linux.yml | 7 +- client/iface/bind/ice_bind.go | 1 - client/iface/wgproxy/ebpf/portlookup.go | 4 +- client/iface/wgproxy/ebpf/portlookup_test.go | 3 + client/iface/wgproxy/factory_kernel.go | 2 + .../iface/wgproxy/factory_kernel_freebsd.go | 3 + client/iface/wgproxy/factory_usp.go | 3 + client/iface/wgproxy/udp/proxy.go | 28 ++- client/internal/engine.go | 23 +- client/internal/engine_test.go | 3 + client/internal/peer/conn.go | 226 +++++------------- client/internal/peer/conn_monitor.go | 212 ---------------- client/internal/peer/conn_test.go | 16 +- client/internal/peer/guard/guard.go | 194 +++++++++++++++ client/internal/peer/guard/ice_monitor.go | 135 +++++++++++ client/internal/peer/guard/sr_watcher.go | 119 +++++++++ client/internal/peer/ice/agent.go | 89 +++++++ client/internal/peer/ice/config.go | 22 ++ .../peer/{env_config.go => ice/env.go} | 22 +- client/internal/peer/{ => ice}/stdnet.go | 2 +- .../internal/peer/{ => ice}/stdnet_android.go | 2 +- client/internal/peer/worker_ice.go | 107 +-------- client/internal/peer/worker_relay.go | 11 +- relay/client/client.go | 26 +- relay/client/guard.go | 20 ++ relay/client/manager.go | 21 ++ signal/client/client.go | 1 + signal/client/grpc.go | 14 ++ signal/client/mock.go | 22 +- 29 files changed, 814 insertions(+), 524 deletions(-) delete mode 100644 client/internal/peer/conn_monitor.go create mode 100644 client/internal/peer/guard/guard.go create mode 100644 client/internal/peer/guard/ice_monitor.go create mode 100644 client/internal/peer/guard/sr_watcher.go create mode 100644 client/internal/peer/ice/agent.go create mode 100644 client/internal/peer/ice/config.go rename client/internal/peer/{env_config.go => ice/env.go} (80%) rename client/internal/peer/{ => ice}/stdnet.go (94%) rename client/internal/peer/{ => ice}/stdnet_android.go (94%) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 9457d3a6621..b584f0ff68c 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -79,9 +79,6 @@ jobs: - name: check git status run: git --no-pager diff --exit-code - - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock @@ -98,7 +95,7 @@ jobs: run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - name: Generate Peer Test bin - run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... + run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ - run: chmod +x *testing.bin @@ -106,7 +103,7 @@ jobs: run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - name: Run RouteManager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index ccdcc2cda30..a9c25950d00 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -143,7 +143,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { conn, ok := b.endpoints[ep.DstIP()] b.endpointsMu.Unlock() if !ok { - log.Infof("failed to find endpoint for %s", ep.DstIP()) return b.StdNetBind.Send(bufs, ep) } diff --git a/client/iface/wgproxy/ebpf/portlookup.go b/client/iface/wgproxy/ebpf/portlookup.go index 0e2c20c9911..fce8f1507c6 100644 --- a/client/iface/wgproxy/ebpf/portlookup.go +++ b/client/iface/wgproxy/ebpf/portlookup.go @@ -5,9 +5,9 @@ import ( "net" ) -const ( +var ( portRangeStart = 3128 - portRangeEnd = 3228 + portRangeEnd = portRangeStart + 100 ) type portLookup struct { diff --git a/client/iface/wgproxy/ebpf/portlookup_test.go b/client/iface/wgproxy/ebpf/portlookup_test.go index 92f4b8eee9f..a2e92fc7926 100644 --- a/client/iface/wgproxy/ebpf/portlookup_test.go +++ b/client/iface/wgproxy/ebpf/portlookup_test.go @@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) { func Test_portLookup_on_allocated(t *testing.T) { pl := portLookup{} + portRangeStart = 4128 + portRangeEnd = portRangeStart + 100 + allocatedPort, err := allocatePort(portRangeStart) if err != nil { t.Fatal(err) diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 32e96e34f2d..3ad7dc59dd9 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -22,9 +22,11 @@ func NewKernelFactory(wgPort int) *KernelFactory { ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) if err := ebpfProxy.Listen(); err != nil { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) return f } + log.Infof("WireGuard Proxy Factory will produce eBPF proxy") f.ebpfProxy = ebpfProxy return f } diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go index 7ac2f99a882..736944229fc 100644 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ b/client/iface/wgproxy/factory_kernel_freebsd.go @@ -1,6 +1,8 @@ package wgproxy import ( + log "github.com/sirupsen/logrus" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) @@ -10,6 +12,7 @@ type KernelFactory struct { } func NewKernelFactory(wgPort int) *KernelFactory { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") f := &KernelFactory{ wgPort: wgPort, } diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index 99f5ada017a..e2d479331b7 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -1,6 +1,8 @@ package wgproxy import ( + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" ) @@ -10,6 +12,7 @@ type USPFactory struct { } func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { + log.Infof("WireGuard Proxy Factory will produce bind proxy") f := &USPFactory{ bind: iceBind, } diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 8bee099014e..200d961f3c8 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -2,14 +2,16 @@ package udp import ( "context" + "errors" "fmt" + "io" "net" "sync" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" + cerrors "github.com/netbirdio/netbird/client/errors" ) // WGUDPProxy proxies @@ -121,7 +123,7 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } - return errors.FormatErrorOrNil(result) + return cerrors.FormatErrorOrNil(result) } // proxyToRemote proxies from Wireguard to the RemoteKey @@ -160,18 +162,16 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { - log.Warnf("error in proxy to local loop: %s", err) + if !errors.Is(err, io.EOF) { + log.Warnf("error in proxy to local loop: %s", err) + } } }() buf := make([]byte, 1500) for { - n, err := p.remoteConn.Read(buf) + n, err := p.remoteConnRead(ctx, buf) if err != nil { - if ctx.Err() != nil { - return - } - log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } @@ -193,3 +193,15 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { } } } + +func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) { + n, err = p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err) + return + } + return +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 22dd1f584a7..af2817e6ed3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -30,6 +30,8 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -168,6 +170,7 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -263,6 +266,10 @@ func (e *Engine) Stop() error { e.routeManager.Stop(e.stateManager) } + if e.srWatcher != nil { + e.srWatcher.Close() + } + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) @@ -389,6 +396,18 @@ func (e *Engine) Start() error { return fmt.Errorf("initialize dns server: %w", err) } + iceCfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.UDPMuxDefault, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + + e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) + e.srWatcher.Start() + e.receiveSignalEvents() e.receiveManagementEvents() e.receiveProbeEvents() @@ -971,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e LocalWgPort: e.config.WgPort, RosenpassPubKey: e.getRosenpassPubKey(), RosenpassAddr: e.getRosenpassAddr(), - ICEConfig: peer.ICEConfig{ + ICEConfig: icemaker.Config{ StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, @@ -981,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d0ba1fffcf1..0018af6df8f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -29,6 +29,8 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.ctx = ctx + engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) type testCase struct { name string diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 99acfde314e..56b772759a2 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/cenkalti/backoff/v4" "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -18,6 +17,8 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -32,8 +33,6 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 - - reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -63,7 +62,7 @@ type ConnConfig struct { RosenpassAddr string // ICEConfig ICE protocol configuration - ICEConfig ICEConfig + ICEConfig icemaker.Config } type WorkerCallbacks struct { @@ -106,16 +105,12 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - // for reconnection operations - iCEDisconnected chan bool - relayDisconnected chan bool - connMonitor *ConnMonitor - reconnectCh <-chan struct{} + guard *guard.Guard } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) @@ -126,28 +121,18 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - signaler: signaler, - relayManager: relayManager, - allowedIP: allowedIP, - allowedNet: allowedNet.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), - iCEDisconnected: make(chan bool, 1), - relayDisconnected: make(chan bool, 1), - } - - conn.connMonitor, conn.reconnectCh = NewConnMonitor( - signaler, - iFaceDiscover, - config, - conn.relayDisconnected, - conn.iCEDisconnected, - ) + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + signaler: signaler, + relayManager: relayManager, + allowedIP: allowedIP, + allowedNet: allowedNet.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + } rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, @@ -159,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu OnStatusChanged: conn.onWorkerICEStateDisconnected, } - conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns) + ctrl := isController(config) + conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) @@ -174,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } + conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher) + go conn.handshaker.Listen() return conn, nil @@ -184,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // be used. func (conn *Conn) Open() { conn.log.Debugf("open connection to peer") + conn.mu.Lock() defer conn.mu.Unlock() conn.opened = true @@ -200,24 +189,19 @@ func (conn *Conn) Open() { conn.log.Warnf("error while updating the state err: %v", err) } - go conn.startHandshakeAndReconnect() + go conn.startHandshakeAndReconnect(conn.ctx) } -func (conn *Conn) startHandshakeAndReconnect() { - conn.waitInitialRandomSleepTime() +func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() if err != nil { conn.log.Errorf("failed to send initial offer: %v", err) } - go conn.connMonitor.Start(conn.ctx) - - if conn.workerRelay.IsController() { - conn.reconnectLoopWithRetry() - } else { - conn.reconnectLoopForOnDisconnectedEvent() - } + go conn.guard.Start(ctx) + go conn.listenGuardEvent(ctx) } // Close closes this peer Conn issuing a close event to the Conn closeCh @@ -316,104 +300,6 @@ func (conn *Conn) GetKey() string { return conn.config.Key } -func (conn *Conn) reconnectLoopWithRetry() { - // Give chance to the peer to establish the initial connection. - // With it, we can decrease to send necessary offer - select { - case <-conn.ctx.Done(): - return - case <-time.After(3 * time.Second): - } - - ticker := conn.prepareExponentTicker() - defer ticker.Stop() - time.Sleep(1 * time.Second) - - for { - select { - case t := <-ticker.C: - if t.IsZero() { - // in case if the ticker has been canceled by context then avoid the temporary loop - return - } - - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) - } - } else { - if conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) - } - } - - // checks if there is peer connection is established via relay or ice - if conn.isConnected() { - continue - } - - err := conn.handshaker.sendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - - case <-conn.reconnectCh: - ticker.Stop() - ticker = conn.prepareExponentTicker() - - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - } -} - -func (conn *Conn) prepareExponentTicker() *backoff.Ticker { - bo := backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.1, - Multiplier: 2, - MaxInterval: conn.config.Timeout, - MaxElapsedTime: reconnectMaxElapsedTime, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, conn.ctx) - - ticker := backoff.NewTicker(bo) - <-ticker.C // consume the initial tick what is happening right after the ticker has been created - - return ticker -} - -// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer -// when the connection is lost. It will try to establish a connection only once time if before the connection was established -// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not -// mean that to switch to it. We always force to use the higher priority connection. -func (conn *Conn) reconnectLoopForOnDisconnectedEvent() { - for { - select { - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, try to send new offer") - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, try to send new offer") - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - - err := conn.handshaker.SendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - } -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { conn.mu.Lock() @@ -513,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - conn.notifyReconnectLoopICEDisconnected(changed) + conn.guard.SetICEConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -604,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - conn.notifyReconnectLoopRelayDisconnected(changed) + conn.guard.SetRelayedConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -617,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { } } +func (conn *Conn) listenGuardEvent(ctx context.Context) { + for { + select { + case <-conn.guard.Reconnect: + conn.log.Debugf("send offer to peer") + if err := conn.handshaker.SendOffer(); err != nil { + conn.log.Errorf("failed to send offer: %v", err) + } + case <-ctx.Done(): + return + } + } +} + func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { return conn.config.WgConfig.WgInterface.UpdatePeer( conn.config.WgConfig.RemoteKey, @@ -693,7 +593,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } } -func (conn *Conn) waitInitialRandomSleepTime() { +func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { minWait := 100 maxWait := 800 duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond @@ -702,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() { defer timeout.Stop() select { - case <-conn.ctx.Done(): + case <-ctx.Done(): case <-timeout.C: } } @@ -731,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusDisconnected } -func (conn *Conn) isConnected() bool { +func (conn *Conn) isConnectedOnAllWay() (connected bool) { conn.mu.Lock() defer conn.mu.Unlock() - if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { + defer func() { + if !connected { + conn.logTraceConnState() + } + }() + + if conn.statusICE.Get() == StatusDisconnected { return false } @@ -805,20 +711,6 @@ func (conn *Conn) removeWgPeer() error { return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) } -func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { - select { - case conn.relayDisconnected <- changed: - default: - } -} - -func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { - select { - case conn.iCEDisconnected <- changed: - default: - } -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { @@ -831,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { } } +func (conn *Conn) logTraceConnState() { + if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { + conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) + } else { + conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE) + } +} + +func isController(config ConnConfig) bool { + return config.LocalKey > config.Key +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go deleted file mode 100644 index 75722c99011..00000000000 --- a/client/internal/peer/conn_monitor.go +++ /dev/null @@ -1,212 +0,0 @@ -package peer - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/pion/ice/v3" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/stdnet" -) - -const ( - signalerMonitorPeriod = 5 * time.Second - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second -) - -type ConnMonitor struct { - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - config ConnConfig - relayDisconnected chan bool - iCEDisconnected chan bool - reconnectCh chan struct{} - currentCandidates []ice.Candidate - candidatesMu sync.Mutex -} - -func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { - reconnectCh := make(chan struct{}, 1) - cm := &ConnMonitor{ - signaler: signaler, - iFaceDiscover: iFaceDiscover, - config: config, - relayDisconnected: relayDisconnected, - iCEDisconnected: iCEDisconnected, - reconnectCh: reconnectCh, - } - return cm, reconnectCh -} - -func (cm *ConnMonitor) Start(ctx context.Context) { - signalerReady := make(chan struct{}, 1) - go cm.monitorSignalerReady(ctx, signalerReady) - - localCandidatesChanged := make(chan struct{}, 1) - go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) - - for { - select { - case changed := <-cm.relayDisconnected: - if !changed { - continue - } - log.Debugf("Relay state changed, triggering reconnect") - cm.triggerReconnect() - - case changed := <-cm.iCEDisconnected: - if !changed { - continue - } - log.Debugf("ICE state changed, triggering reconnect") - cm.triggerReconnect() - - case <-signalerReady: - log.Debugf("Signaler became ready, triggering reconnect") - cm.triggerReconnect() - - case <-localCandidatesChanged: - log.Debugf("Local candidates changed, triggering reconnect") - cm.triggerReconnect() - - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { - if cm.signaler == nil { - return - } - - ticker := time.NewTicker(signalerMonitorPeriod) - defer ticker.Stop() - - lastReady := true - for { - select { - case <-ticker.C: - currentReady := cm.signaler.Ready() - if !lastReady && currentReady { - select { - case signalerReady <- struct{}{}: - default: - } - } - lastReady = currentReady - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { - ufrag, pwd, err := generateICECredentials() - if err != nil { - log.Warnf("Failed to generate ICE credentials: %v", err) - return - } - - ticker := time.NewTicker(candidatesMonitorPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { - log.Warnf("Failed to handle candidate tick: %v", err) - } - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { - log.Debugf("Gathering ICE candidates") - - transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) - if err != nil { - log.Errorf("failed to create pion's stdnet: %s", err) - } - - agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) - if err != nil { - return fmt.Errorf("create ICE agent: %w", err) - } - defer func() { - if err := agent.Close(); err != nil { - log.Warnf("Failed to close ICE agent: %v", err) - } - }() - - gatherDone := make(chan struct{}) - err = agent.OnCandidate(func(c ice.Candidate) { - log.Tracef("Got candidate: %v", c) - if c == nil { - close(gatherDone) - } - }) - if err != nil { - return fmt.Errorf("set ICE candidate handler: %w", err) - } - - if err := agent.GatherCandidates(); err != nil { - return fmt.Errorf("gather ICE candidates: %w", err) - } - - ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) - defer cancel() - - select { - case <-ctx.Done(): - return fmt.Errorf("wait for gathering: %w", ctx.Err()) - case <-gatherDone: - } - - candidates, err := agent.GetLocalCandidates() - if err != nil { - return fmt.Errorf("get local candidates: %w", err) - } - log.Tracef("Got candidates: %v", candidates) - - if changed := cm.updateCandidates(candidates); changed { - select { - case localCandidatesChanged <- struct{}{}: - default: - } - } - - return nil -} - -func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { - cm.candidatesMu.Lock() - defer cm.candidatesMu.Unlock() - - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates - return true - } - - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } - } - - return false -} - -func (cm *ConnMonitor) triggerReconnect() { - select { - case cm.reconnectCh <- struct{}{}: - default: - } -} diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index e68861c5f04..039952588d8 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -10,6 +10,8 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/peer/guard" + "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" ) @@ -19,7 +21,7 @@ var connConf = ConnConfig{ LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Timeout: time.Second, LocalWgPort: 51820, - ICEConfig: ICEConfig{ + ICEConfig: ice.Config{ InterfaceBlackList: nil, }, } @@ -43,7 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) if err != nil { return } @@ -54,7 +57,8 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } @@ -87,7 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } @@ -119,7 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go new file mode 100644 index 00000000000..bf3527a6264 --- /dev/null +++ b/client/internal/peer/guard/guard.go @@ -0,0 +1,194 @@ +package guard + +import ( + "context" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" +) + +const ( + reconnectMaxElapsedTime = 30 * time.Minute +) + +type isConnectedFunc func() bool + +// Guard is responsible for the reconnection logic. +// It will trigger to send an offer to the peer then has connection issues. +// Watch these events: +// - Relay client reconnected to home server +// - Signal server connection state changed +// - ICE connection disconnected +// - Relayed connection disconnected +// - ICE candidate changes +type Guard struct { + Reconnect chan struct{} + log *log.Entry + isController bool + isConnectedOnAllWay isConnectedFunc + timeout time.Duration + srWatcher *SRWatcher + relayedConnDisconnected chan bool + iCEConnDisconnected chan bool +} + +func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { + return &Guard{ + Reconnect: make(chan struct{}, 1), + log: log, + isController: isController, + isConnectedOnAllWay: isConnectedFn, + timeout: timeout, + srWatcher: srWatcher, + relayedConnDisconnected: make(chan bool, 1), + iCEConnDisconnected: make(chan bool, 1), + } +} + +func (g *Guard) Start(ctx context.Context) { + if g.isController { + g.reconnectLoopWithRetry(ctx) + } else { + g.listenForDisconnectEvents(ctx) + } +} + +func (g *Guard) SetRelayedConnDisconnected(changed bool) { + select { + case g.relayedConnDisconnected <- changed: + default: + } +} + +func (g *Guard) SetICEConnDisconnected(changed bool) { + select { + case g.iCEConnDisconnected <- changed: + default: + } +} + +// reconnectLoopWithRetry periodically check (max 30 min) the connection status. +// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { + waitForInitialConnectionTry(ctx) + + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + ticker := g.prepareExponentTicker(ctx) + defer ticker.Stop() + + tickerChannel := ticker.C + + g.log.Infof("start reconnect loop...") + for { + select { + case t := <-tickerChannel: + if t.IsZero() { + g.log.Infof("retry timed out, stop periodic offer sending") + // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop + tickerChannel = make(<-chan time.Time) + continue + } + + if !g.isConnectedOnAllWay() { + g.triggerOfferSending() + } + + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE connection changed, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case <-srReconnectedChan: + g.log.Debugf("has network changes, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer +// when the connection is lost. It will try to establish a connection only once time if before the connection was established +// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not +// mean that to switch to it. We always force to use the higher priority connection. +func (g *Guard) listenForDisconnectEvents(ctx context.Context) { + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + g.log.Infof("start listen for reconnect events...") + for { + select { + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, triggering reconnect") + g.triggerOfferSending() + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE state changed, try to send new offer") + g.triggerOfferSending() + case <-srReconnectedChan: + g.triggerOfferSending() + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 0.1, + Multiplier: 2, + MaxInterval: g.timeout, + MaxElapsedTime: reconnectMaxElapsedTime, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + ticker := backoff.NewTicker(bo) + <-ticker.C // consume the initial tick what is happening right after the ticker has been created + + return ticker +} + +func (g *Guard) triggerOfferSending() { + select { + case g.Reconnect <- struct{}{}: + default: + } +} + +// Give chance to the peer to establish the initial connection. +// With it, we can decrease to send necessary offer +func waitForInitialConnectionTry(ctx context.Context) { + select { + case <-ctx.Done(): + return + case <-time.After(3 * time.Second): + } +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go new file mode 100644 index 00000000000..b9c9aa1345c --- /dev/null +++ b/client/internal/peer/guard/ice_monitor.go @@ -0,0 +1,135 @@ +package guard + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ICEMonitor struct { + ReconnectCh chan struct{} + + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig icemaker.Config + + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { + cm := &ICEMonitor{ + ReconnectCh: make(chan struct{}, 1), + iFaceDiscover: iFaceDiscover, + iceConfig: config, + } + return cm +} + +func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { + ufrag, pwd, err := icemaker.GenerateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + changed, err := cm.handleCandidateTick(ctx, ufrag, pwd) + if err != nil { + log.Warnf("Failed to check ICE changes: %v", err) + continue + } + + if changed { + onChanged() + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { + log.Debugf("Gathering ICE candidates") + + agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return false, fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return false, fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return false, fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return false, fmt.Errorf("wait for gathering timed out") + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return false, fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + return cm.updateCandidates(candidates), nil +} + +func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func candidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go new file mode 100644 index 00000000000..90e45426f78 --- /dev/null +++ b/client/internal/peer/guard/sr_watcher.go @@ -0,0 +1,119 @@ +package guard + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +type chNotifier interface { + SetOnReconnectedListener(func()) + Ready() bool +} + +type SRWatcher struct { + signalClient chNotifier + relayManager chNotifier + + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config + + cancelIceMonitor context.CancelFunc +} + +// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the +// Relay connection is reconnected or the Signal client reconnected. +func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher { + srw := &SRWatcher{ + signalClient: signalClient, + relayManager: relayManager, + iFaceDiscover: iFaceDiscover, + iceConfig: iceConfig, + listeners: make(map[chan struct{}]struct{}), + } + return srw +} + +func (w *SRWatcher) Start() { + w.mu.Lock() + defer w.mu.Unlock() + + if w.cancelIceMonitor != nil { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + w.cancelIceMonitor = cancel + + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + go iceMonitor.Start(ctx, w.onICEChanged) + w.signalClient.SetOnReconnectedListener(w.onReconnected) + w.relayManager.SetOnReconnectedListener(w.onReconnected) + +} + +func (w *SRWatcher) Close() { + w.mu.Lock() + defer w.mu.Unlock() + + if w.cancelIceMonitor == nil { + return + } + w.cancelIceMonitor() + w.signalClient.SetOnReconnectedListener(nil) + w.relayManager.SetOnReconnectedListener(nil) +} + +func (w *SRWatcher) NewListener() chan struct{} { + w.mu.Lock() + defer w.mu.Unlock() + + listenerChan := make(chan struct{}, 1) + w.listeners[listenerChan] = struct{}{} + return listenerChan +} + +func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.listeners, listenerChan) + close(listenerChan) +} + +func (w *SRWatcher) onICEChanged() { + if !w.signalClient.Ready() { + return + } + + log.Infof("network changes detected by ICE agent") + w.notify() +} + +func (w *SRWatcher) onReconnected() { + if !w.signalClient.Ready() { + return + } + if !w.relayManager.Ready() { + return + } + + log.Infof("reconnected to Signal or Relay server") + w.notify() +} + +func (w *SRWatcher) notify() { + w.mu.Lock() + defer w.mu.Unlock() + for listener := range w.listeners { + select { + case listener <- struct{}{}: + default: + } + } +} diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go new file mode 100644 index 00000000000..b2a9669367e --- /dev/null +++ b/client/internal/peer/ice/agent.go @@ -0,0 +1,89 @@ +package ice + +import ( + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/pion/ice/v3" + "github.com/pion/randutil" + "github.com/pion/stun/v2" + log "github.com/sirupsen/logrus" + "runtime" + "time" +) + +const ( + lenUFrag = 16 + lenPwd = 32 + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + iceKeepAliveDefault = 4 * time.Second + iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second +) + +var ( + failedTimeout = 6 * time.Second +) + +func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList), + UDPMux: config.UDPMux, + UDPMuxSrflx: config.UDPMuxSrflx, + NAT1To1IPs: config.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + +func GenerateICECredentials() (string, string, error) { + ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) + if err != nil { + return "", "", err + } + + pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) + if err != nil { + return "", "", err + } + return ufrag, pwd, nil +} + +func CandidateTypes() []ice.CandidateType { + if hasICEForceRelayConn() { + return []ice.CandidateType{ice.CandidateTypeRelay} + } + // TODO: remove this once we have refactored userspace proxy into the bind package + if runtime.GOOS == "ios" { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} + } + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} +} + +func CandidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go new file mode 100644 index 00000000000..8abc842f0d2 --- /dev/null +++ b/client/internal/peer/ice/config.go @@ -0,0 +1,22 @@ +package ice + +import ( + "sync/atomic" + + "github.com/pion/ice/v3" +) + +type Config struct { + // StunTurn is a list of STUN and TURN URLs + StunTurn *atomic.Value // []*stun.URI + + // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering + // (e.g. if eth0 is in the list, host candidate of this interface won't be used) + InterfaceBlackList []string + DisableIPv6Discovery bool + + UDPMux ice.UDPMux + UDPMuxSrflx ice.UniversalUDPMux + + NATExternalIPs []string +} diff --git a/client/internal/peer/env_config.go b/client/internal/peer/ice/env.go similarity index 80% rename from client/internal/peer/env_config.go rename to client/internal/peer/ice/env.go index 87b626df763..3b0cb74ad2a 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/ice/env.go @@ -1,4 +1,4 @@ -package peer +package ice import ( "os" @@ -10,12 +10,19 @@ import ( ) const ( + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" + + msgWarnInvalidValue = "invalid value %s set for %s, using default %v" ) +func hasICEForceRelayConn() bool { + disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) + return strings.ToLower(disconnectedTimeoutEnv) == "true" +} + func iceKeepAlive() time.Duration { keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec) if keepAliveEnv == "" { @@ -25,7 +32,7 @@ func iceKeepAlive() time.Duration { log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) + log.Warnf(msgWarnInvalidValue, keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) return iceKeepAliveDefault } @@ -41,7 +48,7 @@ func iceDisconnectedTimeout() time.Duration { log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) + log.Warnf(msgWarnInvalidValue, disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) return iceDisconnectedTimeoutDefault } @@ -57,14 +64,9 @@ func iceRelayAcceptanceMinWait() time.Duration { log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) + log.Warnf(msgWarnInvalidValue, iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) return iceRelayAcceptanceMinWaitDefault } return time.Duration(disconnectedTimeoutSec) * time.Second } - -func hasICEForceRelayConn() bool { - disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) - return strings.ToLower(disconnectedTimeoutEnv) == "true" -} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/ice/stdnet.go similarity index 94% rename from client/internal/peer/stdnet.go rename to client/internal/peer/ice/stdnet.go index 96d211dbc77..3ce83727e6e 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -1,6 +1,6 @@ //go:build !android -package peer +package ice import ( "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go similarity index 94% rename from client/internal/peer/stdnet_android.go rename to client/internal/peer/ice/stdnet_android.go index a39a03b1c83..84c665e6f40 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,4 +1,4 @@ -package peer +package ice import "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c86c1858fdc..55894218d73 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,52 +5,20 @@ import ( "fmt" "net" "net/netip" - "runtime" "sync" - "sync/atomic" "time" "github.com/pion/ice/v3" - "github.com/pion/randutil" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) -const ( - iceKeepAliveDefault = 4 * time.Second - iceDisconnectedTimeoutDefault = 6 * time.Second - // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package - iceRelayAcceptanceMinWaitDefault = 2 * time.Second - - lenUFrag = 16 - lenPwd = 32 - runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -) - -var ( - failedTimeout = 6 * time.Second -) - -type ICEConfig struct { - // StunTurn is a list of STUN and TURN URLs - StunTurn *atomic.Value // []*stun.URI - - // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering - // (e.g. if eth0 is in the list, host candidate of this interface won't be used) - InterfaceBlackList []string - DisableIPv6Discovery bool - - UDPMux ice.UDPMux - UDPMuxSrflx ice.UniversalUDPMux - - NATExternalIPs []string -} - type ICEConnInfo struct { RemoteConn net.Conn RosenpassPubKey []byte @@ -103,7 +71,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal conn: callBacks, } - localUfrag, localPwd, err := generateICECredentials() + localUfrag, localPwd, err := icemaker.GenerateICECredentials() if err != nil { return nil, err } @@ -125,10 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { w.selectedPriority = connPriorityICEP2P - preferredCandidateTypes = candidateTypesP2P() + preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { w.selectedPriority = connPriorityICETurn - preferredCandidateTypes = candidateTypes() + preferredCandidateTypes = icemaker.CandidateTypes() } w.log.Debugf("recreate ICE agent") @@ -232,15 +200,10 @@ func (w *WorkerICE) Close() { } } -func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) - if err != nil { - w.log.Errorf("failed to create pion's stdnet: %s", err) - } - +func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { w.sentExtraSrflx = false - agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } @@ -365,36 +328,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } -func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: candidateTypes, - InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), - UDPMux: config.ICEConfig.UDPMux, - UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: ufrag, - LocalPwd: pwd, - } - - if config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - - return ice.NewAgent(agentConfig) -} - func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -435,21 +368,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool return false } -func candidateTypes() []ice.CandidateType { - if hasICEForceRelayConn() { - return []ice.CandidateType{ice.CandidateTypeRelay} - } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} -} - -func candidateTypesP2P() []ice.CandidateType { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} -} - func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } @@ -460,16 +378,3 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } - -func generateICECredentials() (string, string, error) { - ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) - if err != nil { - return "", "", err - } - - pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) - if err != nil { - return "", "", err - } - return ufrag, pwd, nil -} diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c02fccebc47..c22dcdeda5d 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -31,6 +31,7 @@ type WorkerRelayCallbacks struct { type WorkerRelay struct { log *log.Entry + isController bool config ConnConfig relayManager relayClient.ManagerService callBacks WorkerRelayCallbacks @@ -44,9 +45,10 @@ type WorkerRelay struct { relaySupportedOnRemotePeer atomic.Bool } -func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { +func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { r := &WorkerRelay{ log: log, + isController: ctrl, config: config, relayManager: relayManager, callBacks: callbacks, @@ -80,6 +82,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Errorf("failed to open connection via Relay: %s", err) return } + w.relayLock.Lock() w.relayedConn = relayedConn w.relayLock.Unlock() @@ -136,10 +139,6 @@ func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool { return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally() } -func (w *WorkerRelay) IsController() bool { - return w.config.LocalKey > w.config.Key -} - func (w *WorkerRelay) RelayIsSupportedLocally() bool { return w.relayManager.HasRelayAddress() } @@ -212,7 +211,7 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { } func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { - if w.IsController() { + if w.isController { return myRelayAddress } return remoteRelayAddress diff --git a/relay/client/client.go b/relay/client/client.go index 20a73f4b343..a82a75453bf 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -142,6 +142,7 @@ type Client struct { muInstanceURL sync.Mutex onDisconnectListener func() + onConnectedListener func() listenerMutex sync.Mutex } @@ -191,6 +192,7 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) + go c.notifyConnected() return nil } @@ -238,6 +240,12 @@ func (c *Client) SetOnDisconnectListener(fn func()) { c.onDisconnectListener = fn } +func (c *Client) SetOnConnectedListener(fn func()) { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + c.onConnectedListener = fn +} + // HasConns returns true if there are connections. func (c *Client) HasConns() bool { c.mu.Lock() @@ -245,6 +253,12 @@ func (c *Client) HasConns() bool { return len(c.conns) > 0 } +func (c *Client) Ready() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.serviceIsRunning +} + // Close closes the connection to the relay server and all connections to other peers. func (c *Client) Close() error { return c.close(true) @@ -363,9 +377,9 @@ func (c *Client) readLoop(relayConn net.Conn) { c.instanceURL = nil c.muInstanceURL.Unlock() - c.notifyDisconnected() c.wgReadLoop.Done() _ = c.close(false) + c.notifyDisconnected() } func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { @@ -544,6 +558,16 @@ func (c *Client) notifyDisconnected() { go c.onDisconnectListener() } +func (c *Client) notifyConnected() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onConnectedListener == nil { + return + } + go c.onConnectedListener() +} + func (c *Client) writeCloseMsg() { msg := messages.MarshalCloseMsg() _, err := c.relayConn.Write(msg) diff --git a/relay/client/guard.go b/relay/client/guard.go index f826cf1b600..d6b6b0da509 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -29,6 +29,10 @@ func NewGuard(context context.Context, relayClient *Client) *Guard { // OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent func (g *Guard) OnDisconnected() { + if g.quickReconnect() { + return + } + ticker := time.NewTicker(reconnectingTimeout) defer ticker.Stop() @@ -46,3 +50,19 @@ func (g *Guard) OnDisconnected() { } } } + +func (g *Guard) quickReconnect() bool { + ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) + defer cancel() + <-ctx.Done() + + if g.ctx.Err() != nil { + return false + } + + if err := g.relayClient.Connect(); err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + return false + } + return true +} diff --git a/relay/client/manager.go b/relay/client/manager.go index 4554c7c0f6e..3981415fcd4 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -65,6 +65,7 @@ type Manager struct { relayClientsMutex sync.RWMutex onDisconnectedListeners map[string]*list.List + onReconnectedListenerFn func() listenerLock sync.Mutex } @@ -101,6 +102,7 @@ func (m *Manager) Serve() error { m.relayClient = client m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnConnectedListener(m.onServerConnected) m.relayClient.SetOnDisconnectListener(func() { m.onServerDisconnected(client.connectionURL) }) @@ -138,6 +140,18 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { return netConn, err } +// Ready returns true if the home Relay client is connected to the relay server. +func (m *Manager) Ready() bool { + if m.relayClient == nil { + return false + } + return m.relayClient.Ready() +} + +func (m *Manager) SetOnReconnectedListener(f func()) { + m.onReconnectedListenerFn = f +} + // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { @@ -240,6 +254,13 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return conn, nil } +func (m *Manager) onServerConnected() { + if m.onReconnectedListenerFn == nil { + return + } + go m.onReconnectedListenerFn() +} + func (m *Manager) onServerDisconnected(serverAddress string) { if serverAddress == m.relayClient.connectionURL { go m.reconnectGuard.OnDisconnected() diff --git a/signal/client/client.go b/signal/client/client.go index ced3fb7d0eb..eff1ccb8794 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -35,6 +35,7 @@ type Client interface { WaitStreamConnected() SendToStream(msg *proto.EncryptedMessage) error Send(msg *proto.Message) error + SetOnReconnectedListener(func()) } // UnMarshalCredential parses the credentials from the message and returns a Credential instance diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7a3b502ffc6..2ff84e46075 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -43,6 +43,8 @@ type GrpcClient struct { connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + + onReconnectedListenerFn func() } func (c *GrpcClient) StreamConnected() bool { @@ -181,12 +183,17 @@ func (c *GrpcClient) notifyStreamDisconnected() { func (c *GrpcClient) notifyStreamConnected() { c.mux.Lock() defer c.mux.Unlock() + c.status = StreamConnected if c.connectedCh != nil { // there are goroutines waiting on this channel -> release them close(c.connectedCh) c.connectedCh = nil } + + if c.onReconnectedListenerFn != nil { + c.onReconnectedListenerFn() + } } func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { @@ -271,6 +278,13 @@ func (c *GrpcClient) WaitStreamConnected() { } } +func (c *GrpcClient) SetOnReconnectedListener(fn func()) { + c.mux.Lock() + defer c.mux.Unlock() + + c.onReconnectedListenerFn = fn +} + // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // GrpcClient.connWg can be used to wait diff --git a/signal/client/mock.go b/signal/client/mock.go index 70ecea9eda2..32236c82c09 100644 --- a/signal/client/mock.go +++ b/signal/client/mock.go @@ -7,14 +7,20 @@ import ( ) type MockClient struct { - CloseFunc func() error - GetStatusFunc func() Status - StreamConnectedFunc func() bool - ReadyFunc func() bool - WaitStreamConnectedFunc func() - ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error - SendToStreamFunc func(msg *proto.EncryptedMessage) error - SendFunc func(msg *proto.Message) error + CloseFunc func() error + GetStatusFunc func() Status + StreamConnectedFunc func() bool + ReadyFunc func() bool + WaitStreamConnectedFunc func() + ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error + SendToStreamFunc func(msg *proto.EncryptedMessage) error + SendFunc func(msg *proto.Message) error + SetOnReconnectedListenerFunc func(f func()) +} + +// SetOnReconnectedListener sets the function to be called when the client reconnects. +func (sm *MockClient) SetOnReconnectedListener(_ func()) { + // Do nothing } func (sm *MockClient) IsHealthy() bool {