From 4c758c6e526a5ba8ee0088263d5883a2c817a190 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:24:15 +0100 Subject: [PATCH 1/6] [management] remove network map diff calculations (#2820) --- go.mod | 3 - go.sum | 6 - management/server/differs/netip.go | 82 --- management/server/dns_test.go | 73 ++- management/server/group_test.go | 28 +- management/server/nameserver_test.go | 30 -- management/server/network.go | 4 +- management/server/peer/peer.go | 18 +- management/server/policy.go | 4 +- management/server/policy_test.go | 62 +-- management/server/posture_checks_test.go | 65 +-- management/server/route_test.go | 20 - .../server/telemetry/updatechannel_metrics.go | 12 - management/server/updatechannel.go | 110 +--- management/server/updatechannel_test.go | 482 ------------------ 15 files changed, 82 insertions(+), 917 deletions(-) delete mode 100644 management/server/differs/netip.go diff --git a/go.mod b/go.mod index 7223a446bb1..571b41abf19 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,6 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 - github.com/r3labs/diff/v3 v3.0.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -211,8 +210,6 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect - github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go.sum b/go.sum index 5cd703bc894..217d27e0adf 100644 --- a/go.sum +++ b/go.sum @@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= -github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= -github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= -github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go deleted file mode 100644 index de4aa334c17..00000000000 --- a/management/server/differs/netip.go +++ /dev/null @@ -1,82 +0,0 @@ -package differs - -import ( - "fmt" - "net/netip" - "reflect" - - "github.com/r3labs/diff/v3" -) - -// NetIPAddr is a custom differ for netip.Addr -type NetIPAddr struct { - DiffFunc func(path []string, a, b reflect.Value, p interface{}) error -} - -func (differ NetIPAddr) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) -} - -func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return nil - } - - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return nil - } - - fromAddr, ok1 := a.Interface().(netip.Addr) - toAddr, ok2 := b.Interface().(netip.Addr) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for netip.Addr") - } - - if fromAddr.String() != toAddr.String() { - cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) - } - - return nil -} - -func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { - differ.DiffFunc = dfunc //nolint -} - -// NetIPPrefix is a custom differ for netip.Prefix -type NetIPPrefix struct { - DiffFunc func(path []string, a, b reflect.Value, p interface{}) error -} - -func (differ NetIPPrefix) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) -} - -func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return nil - } - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return nil - } - - fromPrefix, ok1 := a.Interface().(netip.Prefix) - toPrefix, ok2 := b.Interface().(netip.Prefix) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for netip.Addr") - } - - if fromPrefix.String() != toPrefix.String() { - cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) - } - - return nil -} - -func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { - differ.DiffFunc = dfunc //nolint -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c675fc12c84..8a66da96c0f 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -521,35 +522,56 @@ func TestDNSAccountPeersUpdate(t *testing.T) { } }) - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + // Creating DNS settings with groups that have no peers should not update account peers or send peer update + t.Run("creating dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupB"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } }) - assert.NoError(t, err) - _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ - IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, - }}, - []string{"groupA"}, - true, []string{}, true, userID, false, - ) - assert.NoError(t, err) + // Creating DNS settings with groups that have peers should update account peers and send peer update + t.Run("creating dns setting with used groups", func(t *testing.T) { + err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) - // Saving DNS settings with groups that have peers should update account peers and send peer update - t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ - DisabledManagementGroups: []string{"groupA", "groupB"}, - }) + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) assert.NoError(t, err) select { @@ -559,12 +581,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged DNS settings with used groups should update account peers and not send peer update - // since there is no change in the network map - t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + // Saving DNS settings with groups that have peers should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -576,7 +597,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") + t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/group_test.go b/management/server/group_test.go index 1e59b74ef5b..89184e81927 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -8,12 +8,13 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -536,29 +537,6 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } }) - // Saving an unchanged group should trigger account peers update and not send peer update - // since there is no change in the network map - t.Run("saving unchanged group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID}, - }) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // adding peer to a used group should update account peers and send peer update t.Run("adding peer to linked group", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 96637cd39a0..846dbf02370 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1065,36 +1065,6 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { } }) - // saving unchanged nameserver group should update account peers and not send peer update - t.Run("saving unchanged nameserver group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - newNameServerGroupB.NameServers = []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - } - err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting a nameserver group should update account peers and send peer update t.Run("deleting nameserver group", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/network.go b/management/server/network.go index 8fb6a8b3c12..a5b188b4610 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -41,9 +41,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 `diff:"-"` + Serial uint64 - mu sync.Mutex `json:"-" gorm:"-" diff:"-"` + mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 82e0acf3ade..34d7918446b 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -20,33 +20,33 @@ type Peer struct { // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` // The user ID that registered the peer - UserID string `diff:"-"` + UserID string // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool `diff:"-"` + LoginExpirationEnabled bool - InactivityExpirationEnabled bool `diff:"-"` + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation - LastLogin time.Time `diff:"-"` + LastLogin time.Time // CreatedAt records the time the peer was created - CreatedAt time.Time `diff:"-"` + CreatedAt time.Time // Indicate ephemeral peer attribute - Ephemeral bool `diff:"-"` + Ephemeral bool // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` + Location Location `gorm:"embedded;embeddedPrefix:location_"` } type PeerStatus struct { //nolint:revive diff --git a/management/server/policy.go b/management/server/policy.go index 05554243032..43a925f8850 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - am.updateAccountPeers(ctx, account) + if anyGroupHasPeers(account, policy.ruleGroups()) { + am.updateAccountPeers(ctx, account) + } return nil } diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 5b1411702b2..e7f0f9cd2f1 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -854,16 +854,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) - t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) - }) - // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { policy := Policy{ @@ -883,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() @@ -918,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -953,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg2) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -987,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1021,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1056,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() @@ -1090,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1104,46 +1099,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged policy should trigger account peers update but not send peer update - t.Run("saving unchanged policy", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg1) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { policyID := "policy-source-destination-peers" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1164,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg2) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1180,10 +1142,10 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2 + policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7d31956f955..c63538b9d52 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" ) @@ -264,25 +265,6 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged posture check should not trigger account peers update and not send peer update - // since there is no change in the network map - t.Run("saving unchanged posture check", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Removing posture check from policy should trigger account peers update and send peer update t.Run("removing posture check from policy", func(t *testing.T) { done := make(chan struct{}) @@ -412,50 +394,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - // Updating linked posture check to policy where source has peers but destination does not, - // should not trigger account peers update or send peer update - t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupB"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: []string{postureCheck.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) - assert.NoError(t, err) - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - postureCheck.Checks = posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.29.0", - }, - } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update - t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) { + t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { policy = Policy{ ID: "policyB", Enabled: true, diff --git a/management/server/route_test.go b/management/server/route_test.go index a4b320c7ee2..4893e19b9f3 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1938,26 +1938,6 @@ func TestRouteAccountPeersUpdate(t *testing.T) { } }) - // Updating unchanged route should update account peers and not send peer update - t.Run("updating unchanged route", func(t *testing.T) { - baseRoute.Groups = []string{routeGroup1, routeGroup2} - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting the route should update account peers and send peer update t.Run("deleting route", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/telemetry/updatechannel_metrics.go b/management/server/telemetry/updatechannel_metrics.go index fb33b663c62..2582006e517 100644 --- a/management/server/telemetry/updatechannel_metrics.go +++ b/management/server/telemetry/updatechannel_metrics.go @@ -18,7 +18,6 @@ type UpdateChannelMetrics struct { getAllConnectedPeersDurationMicro metric.Int64Histogram getAllConnectedPeers metric.Int64Histogram hasChannelDurationMicro metric.Int64Histogram - networkMapDiffDurationMicro metric.Int64Histogram ctx context.Context } @@ -64,11 +63,6 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh return nil, err } - networkMapDiffDurationMicro, err := meter.Int64Histogram("management.updatechannel.networkmap.diff.duration.micro") - if err != nil { - return nil, err - } - return &UpdateChannelMetrics{ createChannelDurationMicro: createChannelDurationMicro, closeChannelDurationMicro: closeChannelDurationMicro, @@ -78,7 +72,6 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro, getAllConnectedPeers: getAllConnectedPeers, hasChannelDurationMicro: hasChannelDurationMicro, - networkMapDiffDurationMicro: networkMapDiffDurationMicro, ctx: ctx, }, nil } @@ -118,8 +111,3 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) { metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds()) } - -// CountNetworkMapDiffDurationMicro counts the duration of the NetworkMapDiff method -func (metrics *UpdateChannelMetrics) CountNetworkMapDiffDurationMicro(duration time.Duration) { - metrics.networkMapDiffDurationMicro.Record(metrics.ctx, duration.Microseconds()) -} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 7c73002223b..59b6fd09492 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,16 +2,12 @@ package server import ( "context" - "fmt" - "runtime/debug" "sync" "time" - "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" - "github.com/netbirdio/netbird/management/server/differs" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -25,8 +21,6 @@ type UpdateMessage struct { type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage - // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. - peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics @@ -36,10 +30,9 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - peerUpdateMessage: make(map[string]*UpdateMessage), - channelsMux: &sync.RWMutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -48,15 +41,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool - // skip sending sync update to the peer if there is no change in update message, - // it will not check on turn credential refresh as we do not send network map or client posture checks - if update.NetworkMap != nil { - updated := p.handlePeerMessageUpdate(ctx, peerID, update) - if !updated { - return - } - } - p.channelsMux.Lock() defer func() { @@ -66,16 +50,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() - if update.NetworkMap != nil { - lastSentUpdate := p.peerUpdateMessage[peerID] - if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() { - log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", - peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) - return - } - p.peerUpdateMessage[peerID] = update - } - if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -108,7 +82,6 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) - delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -123,7 +96,6 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) - delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -200,79 +172,3 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } - -// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. -func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { - p.channelsMux.RLock() - lastSentUpdate := p.peerUpdateMessage[peerID] - p.channelsMux.RUnlock() - - if lastSentUpdate != nil { - updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update, p.metrics) - if err != nil { - log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) - return true - } - if !updated { - log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) - return false - } - } - - return true -} - -// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. -func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage, metric telemetry.AppMetrics) (isNew bool, err error) { - startTime := time.Now() - - defer func() { - if r := recover(); r != nil { - log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) - isNew, err = true, nil - } - }() - - if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { - return false, nil - } - - differ, err := diff.NewDiffer( - diff.CustomValueDiffers(&differs.NetIPAddr{}), - diff.CustomValueDiffers(&differs.NetIPPrefix{}), - ) - if err != nil { - return false, fmt.Errorf("failed to create differ: %v", err) - } - - lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks) - currFiles := getChecksFiles(currUpdateToSend.Update.Checks) - - changelog, err := differ.Diff(lastSentFiles, currFiles) - if err != nil { - return false, fmt.Errorf("failed to diff checks: %v", err) - } - if len(changelog) > 0 { - return true, nil - } - - changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) - if err != nil { - return false, fmt.Errorf("failed to diff network map: %v", err) - } - - if metric != nil { - metric.UpdateChannelMetrics().CountNetworkMapDiffDurationMicro(time.Since(startTime)) - } - - return len(changelog) > 0, nil -} - -// getChecksFiles returns a list of files from the given checks. -func getChecksFiles(checks []*proto.Checks) []string { - files := make([]string, 0, len(checks)) - for _, check := range checks { - files = append(files, check.GetFiles()...) - } - return files -} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index b8a0ce45f73..69f5b895c45 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -2,21 +2,10 @@ package server import ( "context" - "net" - "net/netip" "testing" "time" - "github.com/stretchr/testify/assert" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/telemetry" - nbroute "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/util" ) // var peersUpdater *PeersUpdateManager @@ -88,474 +77,3 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } - -func TestHandlePeerMessageUpdate(t *testing.T) { - tests := []struct { - name string - peerID string - existingUpdate *UpdateMessage - newUpdate *UpdateMessage - expectedResult bool - }{ - { - name: "update message with turn credentials update", - peerID: "peer", - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - WiretrusteeConfig: &proto.WiretrusteeConfig{}, - }, - }, - expectedResult: true, - }, - { - name: "update message for peer without existing update", - peerID: "peer1", - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - expectedResult: true, - }, - { - name: "update message with no changes in update", - peerID: "peer2", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - expectedResult: false, - }, - { - name: "update message with changes in checks", - peerID: "peer3", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 2}, - Checks: []*proto.Checks{ - { - Files: []string{"/usr/bin/netbird"}, - }, - }, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - expectedResult: true, - }, - { - name: "update message with lower serial number", - peerID: "peer4", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 2}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - expectedResult: false, - }, - } - - for _, tt := range tests { - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - if err != nil { - t.Fatal(err) - } - t.Run(tt.name, func(t *testing.T) { - p := NewPeersUpdateManager(metrics) - ctx := context.Background() - - if tt.existingUpdate != nil { - p.peerUpdateMessage[tt.peerID] = tt.existingUpdate - } - - result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) - assert.Equal(t, tt.expectedResult, result) - }) - } -} - -func TestIsNewPeerUpdateMessage(t *testing.T) { - t.Run("Unchanged value", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.False(t, message) - }) - - t.Run("Unchanged value with serial incremented", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.False(t, message) - }) - - t.Run("Updating routes network", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - - }) - - t.Run("Updating routes groups", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating network map peers", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newPeer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.4"), - SSHEnabled: true, - Key: "peer4-key", - DNSLabel: "peer4", - SSHKey: "peer4-ssh-key", - } - newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating process check", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - - newUpdateMessage2 := createMockUpdateMessage(t) - newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.False(t, message) - - newUpdateMessage3 := createMockUpdateMessage(t) - newUpdateMessage3.Update.Checks = []*proto.Checks{} - newUpdateMessage3.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3, nil) - assert.NoError(t, err) - assert.True(t, message) - - newUpdateMessage4 := createMockUpdateMessage(t) - check := &posture.Checks{ - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/local/netbird", - MacPath: "/usr/bin/netbird", - }, - }, - }, - }, - } - newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} - newUpdateMessage4.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4, nil) - assert.NoError(t, err) - assert.True(t, message) - - newUpdateMessage5 := createMockUpdateMessage(t) - check = &posture.Checks{ - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/bin/netbird", - WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", - MacPath: "/usr/local/netbird", - }, - }, - }, - }, - } - newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} - newUpdateMessage5.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating DNS configuration", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newDomain := "newexample.com" - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append( - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains, - newDomain, - ) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating peer IP", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating firewall rule", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Add new firewall rule", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newRule := &FirewallRule{ - PeerIP: "192.168.1.3", - Direction: firewallRuleDirectionOUT, - Action: string(PolicyTrafficActionDrop), - Protocol: string(PolicyRuleProtocolUDP), - Port: "53", - } - newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Removing nameserver", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating name server IP", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating custom DNS zone", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - -} - -func createMockUpdateMessage(t *testing.T) *UpdateMessage { - t.Helper() - - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") - if err != nil { - t.Fatal(err) - } - domainList, err := domain.FromStringList([]string{"example.com"}) - if err != nil { - t.Fatal(err) - } - - config := &Config{ - Signal: &Host{ - Proto: "https", - URI: "signal.uri", - Username: "", - Password: "", - }, - Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, - TURNConfig: &TURNConfig{ - Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, - }, - } - peer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.1"), - SSHEnabled: true, - Key: "peer-key", - DNSLabel: "peer1", - SSHKey: "peer1-ssh-key", - } - - secretManager := NewTimeBasedAuthSecretsManager( - NewPeersUpdateManager(nil), - &TURNConfig{ - TimeBasedCredentials: false, - CredentialsTTL: util.Duration{ - Duration: defaultDuration, - }, - Secret: "secret", - Turns: []*Host{TurnTestHost}, - }, - &Relay{ - Addresses: []string{"localhost:0"}, - CredentialsTTL: util.Duration{Duration: time.Hour}, - Secret: "secret", - }, - ) - - networkMap := &NetworkMap{ - Network: &Network{Net: *ipNet, Serial: 1000}, - Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, - OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, - Routes: []*nbroute.Route{ - { - ID: "route1", - Network: netip.MustParsePrefix("10.0.0.0/24"), - KeepRoute: true, - NetID: "route1", - Peer: "peer1", - NetworkType: 1, - Masquerade: true, - Metric: 9999, - Enabled: true, - Groups: []string{"test1", "test2"}, - }, - { - ID: "route2", - Domains: domainList, - KeepRoute: true, - NetID: "route2", - Peer: "peer1", - NetworkType: 1, - Masquerade: true, - Metric: 9999, - Enabled: true, - Groups: []string{"test1", "test2"}, - }, - }, - DNSConfig: nbdns.Config{ - ServiceEnable: true, - NameServerGroups: []*nbdns.NameServerGroup{ - { - NameServers: []nbdns.NameServer{{ - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }}, - Primary: true, - Domains: []string{"example.com"}, - Enabled: true, - SearchDomainsEnabled: true, - }, - { - ID: "ns1", - NameServers: []nbdns.NameServer{{ - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }}, - Groups: []string{"group1"}, - Primary: true, - Domains: []string{"example.com"}, - Enabled: true, - SearchDomainsEnabled: true, - }, - }, - CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, - }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, - }, - } - dnsName := "example.com" - checks := []*posture.Checks{ - { - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/bin/netbird", - WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", - MacPath: "/usr/bin/netbird", - }, - }, - }, - }, - }, - } - dnsCache := &DNSConfigCache{} - - turnToken, err := secretManager.GenerateTurnToken() - if err != nil { - t.Fatal(err) - } - - relayToken, err := secretManager.GenerateRelayToken() - if err != nil { - t.Fatal(err) - } - - return &UpdateMessage{ - Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), - NetworkMap: networkMap, - } -} From ad4f0a6fdfae53a569cf20b78610567a78f83f02 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 31 Oct 2024 23:18:35 +0100 Subject: [PATCH 2/6] [client] Nil check on ICE remote conn (#2806) --- client/internal/peer/conn.go | 5 +++++ client/internal/peer/nilcheck.go | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 client/internal/peer/nilcheck.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 56b772759a2..84a8c221fa6 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -309,6 +309,11 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon return } + if remoteConnNil(conn.log, iceConnInfo.RemoteConn) { + conn.log.Errorf("remote ICE connection is nil") + return + } + conn.log.Debugf("ICE connection is ready") if conn.currentConnPriority > priority { diff --git a/client/internal/peer/nilcheck.go b/client/internal/peer/nilcheck.go new file mode 100644 index 00000000000..058fe9a2697 --- /dev/null +++ b/client/internal/peer/nilcheck.go @@ -0,0 +1,21 @@ +package peer + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +func remoteConnNil(log *log.Entry, conn net.Conn) bool { + if conn == nil { + log.Errorf("ice conn is nil") + return true + } + + if conn.RemoteAddr() == nil { + log.Errorf("ICE remote address is nil") + return true + } + + return false +} From 9812de853bac5c61ac38f6497014e7db1e4e290e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 1 Nov 2024 00:33:25 +0100 Subject: [PATCH 3/6] Allocate new buffer for every package (#2823) --- client/iface/wgproxy/bind/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index e986d6d7b07..e0883715a99 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -104,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } }() - buf := make([]byte, 1500) for { + buf := make([]byte, 1500) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { From bac95ace1820d89f5642a528f877aad4e96539d9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:58:39 +0100 Subject: [PATCH 4/6] [management] Add DB access duration to logs for context cancel (#2781) --- management/server/sql_store.go | 202 ++++++++++++++++++++++++++++-- management/server/status/error.go | 6 + 2 files changed, 198 insertions(+), 10 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 27238d28e8a..b1b8330ba3b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -292,6 +292,8 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + startTime := time.Now() + // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -317,6 +319,9 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. }) if err != nil { + if errors.Is(err, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return err } @@ -324,6 +329,8 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + startTime := time.Now() + accountCopy := Account{ Domain: domain, DomainCategory: category, @@ -336,6 +343,9 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return result.Error } @@ -347,6 +357,8 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID } func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { + startTime := time.Now() + var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -359,6 +371,9 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return result.Error } @@ -370,6 +385,8 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe } func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { + startTime := time.Now() + // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, @@ -381,6 +398,9 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P Updates(peerCopy) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return result.Error } @@ -394,6 +414,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P // SaveUsers saves the given list of users to the database. // It updates existing users if a conflict occurs. func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { + startTime := time.Now() + usersToSave := make([]User, 0, len(users)) for _, user := range users { user.AccountID = accountID @@ -403,15 +425,28 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { } usersToSave = append(usersToSave, *user) } - return s.db.Session(&gorm.Session{FullSaveAssociations: true}). + err := s.db.Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.OnConflict{UpdateAll: true}). Create(&usersToSave).Error + if err != nil { + if errors.Is(err, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } + return status.Errorf(status.Internal, "failed to save users to store: %v", err) + } + + return nil } // SaveUser saves the given user to the database. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + startTime := time.Now() + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } return nil @@ -419,12 +454,17 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u // SaveGroups saves the given list of groups to the database. func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + startTime := time.Now() + if len(groups) == 0 { return nil } result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } return nil @@ -451,6 +491,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) } func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + startTime := time.Now() + var accountID string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", @@ -460,6 +502,9 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } @@ -468,12 +513,17 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { + startTime := time.Now() + var key SetupKey result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewSetupKeyNotFoundError(result.Error) } @@ -485,12 +535,17 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { + startTime := time.Now() + var token PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } @@ -499,12 +554,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri } func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { + startTime := time.Now() + var token PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -528,6 +588,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { + startTime := time.Now() + var user User result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).First(&user, idQueryCondition, userID) @@ -535,6 +597,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetUserFromStoreError() } @@ -542,12 +607,17 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + startTime := time.Now() + var users []*User result := s.db.Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting users from store") } @@ -556,12 +626,17 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us } func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + startTime := time.Now() + var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting groups from store") } @@ -661,12 +736,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { + startTime := time.Now() + var user User result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -678,12 +758,17 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun } func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { + startTime := time.Now() + var peer nbpeer.Peer result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -695,13 +780,17 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco } func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - var peer nbpeer.Peer + startTime := time.Now() + var peer nbpeer.Peer result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -713,6 +802,8 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( } func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { + startTime := time.Now() + var peer nbpeer.Peer var accountID string result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) @@ -720,6 +811,9 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -727,12 +821,17 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { + startTime := time.Now() + var accountID string result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -740,12 +839,17 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { } func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { + startTime := time.Now() + var accountID string result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", status.NewSetupKeyNotFoundError(result.Error) } @@ -757,6 +861,8 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) } func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + startTime := time.Now() + var ipJSONStrings []string // Fetch the IP addresses as JSON strings @@ -767,6 +873,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } @@ -784,8 +893,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - var labels []string + startTime := time.Now() + var labels []string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("dns_label", &labels) @@ -794,6 +904,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } @@ -802,24 +915,33 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - var accountNetwork AccountNetwork + startTime := time.Now() + var accountNetwork AccountNetwork if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } + if errors.Is(err, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { + startTime := time.Now() + var peer nbpeer.Peer result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } @@ -827,11 +949,16 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { + startTime := time.Now() + var accountSettings AccountSettings if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } + if errors.Is(err, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil @@ -839,13 +966,17 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - var user User + startTime := time.Now() + var user User result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) } + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.NewGetUserFromStoreError() } user.LastLogin = lastLogin @@ -854,6 +985,8 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { + startTime := time.Now() + definitionJSON, err := json.Marshal(checks) if err != nil { return nil, err @@ -862,6 +995,9 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p var postureCheck posture.Checks err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error if err != nil { + if errors.Is(err, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, err } @@ -971,6 +1107,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + startTime := time.Now() + var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, keyQueryCondition, key) @@ -978,12 +1116,17 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + startTime := time.Now() + result := s.db.WithContext(ctx).Model(&SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ @@ -992,6 +1135,9 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string }) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) } @@ -1003,13 +1149,17 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group + startTime := time.Now() + var group nbgroup.Group result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } @@ -1022,6 +1172,9 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } @@ -1029,13 +1182,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer } func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group + startTime := time.Now() + var group nbgroup.Group result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } @@ -1048,6 +1205,9 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1060,7 +1220,12 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + startTime := time.Now() + if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + if errors.Is(err, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1068,8 +1233,13 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + startTime := time.Now() + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil @@ -1100,14 +1270,18 @@ func (s *SqlStore) GetDB() *gorm.DB { } func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - var accountDNSSettings AccountDNSSettings + startTime := time.Now() + var accountDNSSettings AccountDNSSettings result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "dns settings not found") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) } return &accountDNSSettings.DNSSettings, nil @@ -1115,14 +1289,18 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { - var accountID string + startTime := time.Now() + var accountID string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return false, nil } + if errors.Is(result.Error, context.Canceled) { + return false, status.NewStoreContextCanceledError(time.Since(startTime)) + } return false, result.Error } @@ -1131,14 +1309,18 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - var account Account + startTime := time.Now() + var account Account result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", "", status.Errorf(status.NotFound, "account not found") } + if errors.Is(result.Error, context.Canceled) { + return "", "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) } diff --git a/management/server/status/error.go b/management/server/status/error.go index e9fc8c15ef9..a145edf8002 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,6 +3,7 @@ package status import ( "errors" "fmt" + "time" ) const ( @@ -115,6 +116,11 @@ func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } +// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context +func NewStoreContextCanceledError(duration time.Duration) error { + return Errorf(Internal, "store access: context canceled after %v", duration) +} + // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") From 0eb99c266affccaa03d9c363862655edd8798b22 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 1 Nov 2024 12:33:29 +0100 Subject: [PATCH 5/6] Fix unused servers cleanup (#2826) The cleanup loop did not manage those situations well when a connection failed or the connection success but the code did not add a peer connection to it yet. - in the cleanup loop check if a connection failed to a server - after adding a foreign server connection force to keep it a minimum 5 sec --- relay/client/manager.go | 18 +++++++++++++++++- relay/client/manager_test.go | 5 +++-- relay/client/picker_test.go | 3 +-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/relay/client/manager.go b/relay/client/manager.go index 3981415fcd4..b14a7701bfb 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -16,6 +16,7 @@ import ( var ( relayCleanupInterval = 60 * time.Second + keepUnusedServerTime = 5 * time.Second ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -27,10 +28,13 @@ type RelayTrack struct { sync.RWMutex relayClient *Client err error + created time.Time } func NewRelayTrack() *RelayTrack { - return &RelayTrack{} + return &RelayTrack{ + created: time.Now(), + } } type OnServerCloseListener func() @@ -302,6 +306,18 @@ func (m *Manager) cleanUpUnusedRelays() { for addr, rt := range m.relayClients { rt.Lock() + // if the connection failed to the server the relay client will be nil + // but the instance will be kept in the relayClients until the next locking + if rt.err != nil { + rt.Unlock() + continue + } + + if time.Since(rt.created) <= keepUnusedServerTime { + rt.Unlock() + continue + } + if rt.relayClient.HasConns() { rt.Unlock() continue diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index e9cc2c58154..bfc342f25f7 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -288,8 +288,9 @@ func TestForeginAutoClose(t *testing.T) { t.Fatalf("failed to close connection: %s", err) } - t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second) - time.Sleep(relayCleanupInterval + 1*time.Second) + timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second + t.Logf("waiting for relay cleanup: %s", timeout) + time.Sleep(timeout) if len(mgr.relayClients) != 0 { t.Errorf("expected 0, got %d", len(mgr.relayClients)) } diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index eb14581e067..4800e05ba29 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "testing" - "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { @@ -13,7 +12,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) { PeerID: "test", } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { From 5f06b202c364dd66e57b8a58f178a8647a6ddfce Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:08:22 +0100 Subject: [PATCH 6/6] [client] Log windows panics (#2829) --- client/server/panic_generic.go | 7 +++ client/server/panic_windows.go | 83 ++++++++++++++++++++++++++++++++++ client/server/server.go | 4 ++ 3 files changed, 94 insertions(+) create mode 100644 client/server/panic_generic.go create mode 100644 client/server/panic_windows.go diff --git a/client/server/panic_generic.go b/client/server/panic_generic.go new file mode 100644 index 00000000000..f027b954b34 --- /dev/null +++ b/client/server/panic_generic.go @@ -0,0 +1,7 @@ +//go:build !windows + +package server + +func handlePanicLog() error { + return nil +} diff --git a/client/server/panic_windows.go b/client/server/panic_windows.go new file mode 100644 index 00000000000..1d4ba4b756f --- /dev/null +++ b/client/server/panic_windows.go @@ -0,0 +1,83 @@ +package server + +import ( + "fmt" + "os" + "path/filepath" + "syscall" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +const ( + windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG" + // STD_ERROR_HANDLE ((DWORD)-12) = 4294967284 + stdErrorHandle = ^uintptr(11) +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + + // https://learn.microsoft.com/en-us/windows/console/setstdhandle + setStdHandleFn = kernel32.NewProc("SetStdHandle") +) + +func handlePanicLog() error { + logPath := os.Getenv(windowsPanicLogEnvVar) + if logPath == "" { + return nil + } + + // Ensure the directory exists + logDir := filepath.Dir(logPath) + if err := os.MkdirAll(logDir, 0750); err != nil { + return fmt.Errorf("create panic log directory: %w", err) + } + if err := util.EnforcePermission(logPath); err != nil { + return fmt.Errorf("enforce permission on panic log file: %w", err) + } + + // Open log file with append mode + f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + return fmt.Errorf("open panic log file: %w", err) + } + + // Redirect stderr to the file + if err = redirectStderr(f); err != nil { + if closeErr := f.Close(); closeErr != nil { + log.Warnf("failed to close file after redirect error: %v", closeErr) + } + return fmt.Errorf("redirect stderr: %w", err) + } + + log.Infof("successfully configured panic logging to: %s", logPath) + return nil +} + +// redirectStderr redirects stderr to the provided file +func redirectStderr(f *os.File) error { + // Get the current process's stderr handle + if err := setStdHandle(f); err != nil { + return fmt.Errorf("failed to set stderr handle: %w", err) + } + + // Also set os.Stderr for Go's standard library + os.Stderr = f + + return nil +} + +func setStdHandle(f *os.File) error { + handle := f.Fd() + r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle) + if r0 == 0 { + if e1 != nil { + return e1 + } + return syscall.EINVAL + } + return nil +} diff --git a/client/server/server.go b/client/server/server.go index a0332208194..4d921851f94 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -97,6 +97,10 @@ func (s *Server) Start() error { defer s.mutex.Unlock() state := internal.CtxGetState(s.rootCtx) + if err := handlePanicLog(); err != nil { + log.Warnf("failed to redirect stderr: %v", err) + } + if err := restoreResidualState(s.rootCtx); err != nil { log.Warnf(errRestoreResidualState, err) }