diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000000..c3d322163e2 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [netbirdio] diff --git a/client/cmd/status.go b/client/cmd/status.go index ed3daa2b5fd..6db52a67795 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { statusEval := false ipEval := false - nameEval := false + nameEval := true if statusFilter != "" { lowerStatusFilter := strings.ToLower(statusFilter) @@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { if len(prefixNamesFilter) > 0 { for prefixNameFilter := range prefixNamesFilterMap { - if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { - nameEval = true + if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { + nameEval = false break } } + } else { + nameEval = false } return statusEval || ipEval || nameEval diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index c853548f841..076d08ec27b 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -3,6 +3,7 @@ package firewall import ( + "errors" "fmt" "os" @@ -37,62 +38,55 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, errFw := createNativeFirewall(iface) + fm, err := createNativeFirewall(iface, stateManager) - if fm != nil { - if err := fm.Init(stateManager); err != nil { - log.Errorf("failed to init nftables manager: %s", err) - } + if !iface.IsUserspaceBind() { + return fm, err + } + + if err != nil { + log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } + return createUserspaceFirewall(iface, fm) +} - if iface.IsUserspaceBind() { - return createUserspaceFirewall(iface, fm, errFw) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { + fm, err := createFW(iface) + if err != nil { + return nil, fmt.Errorf("create firewall: %s", err) } - return fm, errFw + if err = fm.Init(stateManager); err != nil { + return nil, fmt.Errorf("init firewall: %s", err) + } + + return fm, nil } -func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper) (firewall.Manager, error) { switch check() { case IPTABLES: - return createIptablesFirewall(iface) + log.Info("creating an iptables firewall manager") + return nbiptables.Create(iface) case NFTABLES: - return createNftablesFirewall(iface) + log.Info("creating an nftables firewall manager") + return nbnftables.Create(iface) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") - return nil, fmt.Errorf("no firewall manager found") + return nil, errors.New("no firewall manager found") } } -func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) { - log.Info("creating an iptables firewall manager") - fm, err := nbiptables.Create(iface) - if err != nil { - log.Errorf("failed to create iptables manager: %s", err) - } - return fm, err -} - -func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) { - log.Info("creating an nftables firewall manager") - fm, err := nbnftables.Create(iface) - if err != nil { - log.Errorf("failed to create nftables manager: %s", err) - } - return fm, err -} - -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { var errUsp error - if errFw == nil { + if fm != nil { fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) } else { fm, errUsp = uspfilter.Create(iface) } if errUsp != nil { - log.Debugf("failed to create userspace filtering firewall: %s", errUsp) - return nil, errUsp + return nil, fmt.Errorf("create userspace firewall: %s", errUsp) } if err := fm.AllowNetbird(); err != nil { diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 90811ae1182..9b75640b4b5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -296,6 +296,8 @@ func (r *router) RemoveAllLegacyRouteRules() error { } if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } else { + delete(r.rules, k) } } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a4650f3b626..ea8912f27f5 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -230,23 +230,7 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - oldLegacy := m.router.legacyManagement - - if oldLegacy != isLegacy { - m.router.legacyManagement = isLegacy - log.Debugf("Set legacy management to %v", isLegacy) - } - - // client reconnected to a newer mgmt, we need to cleanup the legacy rules - if !isLegacy && oldLegacy { - if err := m.router.RemoveAllLegacyRouteRules(); err != nil { - return fmt.Errorf("remove legacy routing rules: %v", err) - } - - log.Debugf("Legacy routing rules removed") - } - - return nil + return firewall.SetLegacyManagement(m.router, isLegacy) } // Reset firewall to the default state diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 9b28e4eb213..0e7ea71b774 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -551,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error { } if err := r.conn.DelRule(rule); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } else { + delete(r.rules, k) } + } return nberrors.FormatErrorOrNil(merr) } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 3829a9baffe..af5dc673393 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -237,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { } // SetLegacyManagement doesn't need to be implemented for this manager -func (m *Manager) SetLegacyManagement(_ bool) error { - return nil +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.SetLegacyManagement(isLegacy) } // Flush doesn't need to be implemented for this manager diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index e986d6d7b07..e0883715a99 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -104,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } }() - buf := make([]byte, 1500) for { + buf := make([]byte, 1500) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index ce2a12af16f..5bb0905d2a7 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -3,6 +3,7 @@ package acl import ( "crypto/md5" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -10,14 +11,18 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) +var ErrSourceRangesEmpty = errors.New("sources range is empty") + // Manager is a ACL rules manager type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap) @@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { - var newRouteRules = make(map[id.RuleID]struct{}) + newRouteRules := make(map[id.RuleID]struct{}, len(rules)) + var merr *multierror.Error + + // Apply new rules - firewall manager will return existing rule ID if already present for _, rule := range rules { id, err := d.applyRouteACL(rule) if err != nil { - return fmt.Errorf("apply route ACL: %w", err) + if errors.Is(err, ErrSourceRangesEmpty) { + log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) + } else { + merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) + } + continue } newRouteRules[id] = struct{}{} } + // Clean up old firewall rules for id := range d.routeRules { - if _, ok := newRouteRules[id]; !ok { + if _, exists := newRouteRules[id]; !exists { if err := d.firewall.DeleteRouteRule(id); err != nil { - log.Errorf("failed to delete route firewall rule: %v", err) - continue + merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err)) } - delete(d.routeRules, id) + // implicitly deleted from the map } } + d.routeRules = newRouteRules - return nil + return nberrors.FormatErrorOrNil(merr) } func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { if len(rule.SourceRanges) == 0 { - return "", fmt.Errorf("source ranges is empty") + return "", ErrSourceRangesEmpty } var sources []netip.Prefix diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 56b772759a2..84a8c221fa6 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -309,6 +309,11 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon return } + if remoteConnNil(conn.log, iceConnInfo.RemoteConn) { + conn.log.Errorf("remote ICE connection is nil") + return + } + conn.log.Debugf("ICE connection is ready") if conn.currentConnPriority > priority { diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index b2a9669367e..dc4750f243a 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,13 +1,14 @@ package ice import ( - "github.com/netbirdio/netbird/client/internal/stdnet" + "time" + "github.com/pion/ice/v3" "github.com/pion/randutil" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" - "runtime" - "time" + + "github.com/netbirdio/netbird/client/internal/stdnet" ) const ( @@ -77,10 +78,7 @@ func CandidateTypes() []ice.CandidateType { if hasICEForceRelayConn() { return []ice.CandidateType{ice.CandidateTypeRelay} } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} } diff --git a/client/internal/peer/nilcheck.go b/client/internal/peer/nilcheck.go new file mode 100644 index 00000000000..058fe9a2697 --- /dev/null +++ b/client/internal/peer/nilcheck.go @@ -0,0 +1,21 @@ +package peer + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +func remoteConnNil(log *log.Entry, conn net.Conn) bool { + if conn == nil { + log.Errorf("ice conn is nil") + return true + } + + if conn.RemoteAddr() == nil { + log.Errorf("ICE remote address is nil") + return true + } + + return false +} diff --git a/funding.json b/funding.json new file mode 100644 index 00000000000..6b509a9921b --- /dev/null +++ b/funding.json @@ -0,0 +1,126 @@ +{ + "version": "v1.0.0", + "entity": { + "type": "organisation", + "role": "owner", + "name": "NetBird GmbH", + "email": "hello@netbird.io", + "phone": "", + "description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.", + "webpageUrl": { + "url": "https://github.com/netbirdio" + } + }, + "projects": [ + { + "guid": "netbird", + "name": "NetBird", + "description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.", + "webpageUrl": { + "url": "https://github.com/netbirdio/netbird" + }, + "repositoryUrl": { + "url": "https://github.com/netbirdio/netbird" + }, + "licenses": [ + "BSD-3" + ], + "tags": [ + "network-security", + "vpn", + "developer-tools", + "ztna", + "zero-trust", + "remote-access", + "wireguard", + "peer-to-peer", + "private-networking", + "software-defined-networking" + ] + } + ], + "funding": { + "channels": [ + { + "guid": "github-sponsors", + "type": "payment-provider", + "address": "https://github.com/sponsors/netbirdio", + "description": "" + }, + { + "guid": "bank-transfer", + "type": "bank", + "address": "", + "description": "Contact us at hello@netbird.io for bank transfer details." + } + ], + "plans": [ + { + "guid": "support-yearly", + "status": "active", + "name": "Support Open Source Development and Maintenance - Yearly", + "description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.", + "amount": 100000, + "currency": "USD", + "frequency": "yearly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "support-one-time-year", + "status": "active", + "name": "Support Open Source Development and Maintenance - One Year", + "description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.", + "amount": 100000, + "currency": "USD", + "frequency": "one-time", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "support-one-time-monthly", + "status": "active", + "name": "Support Open Source Development and Maintenance - Monthly", + "description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.", + "amount": 10000, + "currency": "USD", + "frequency": "monthly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "support-monthly", + "status": "active", + "name": "Support Open Source Development and Maintenance - One Month", + "description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.", + "amount": 10000, + "currency": "USD", + "frequency": "monthly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "goodwill", + "status": "active", + "name": "Goodwill Plan", + "description": "Pay anything you wish to show your goodwill for the project.", + "amount": 0, + "currency": "USD", + "frequency": "monthly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + } + ], + "history": null + } +} diff --git a/go.mod b/go.mod index a6b83794dab..571b41abf19 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,6 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 - github.com/r3labs/diff/v3 v3.0.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -156,7 +155,7 @@ require ( github.com/go-text/typesetting v0.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/google/btree v1.0.1 // indirect + github.com/google/btree v1.1.2 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect @@ -211,8 +210,6 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect - github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect @@ -231,7 +228,7 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect - gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect + gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect k8s.io/apimachinery v0.26.2 // indirect ) diff --git a/go.sum b/go.sum index 412542d5eb9..217d27e0adf 100644 --- a/go.sum +++ b/go.sum @@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= -github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= -github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= -github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= -github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -1238,8 +1232,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4 gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 16b2364fb56..0b2b651429e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -873,7 +873,7 @@ services: zitadel: restart: 'always' networks: [netbird] - image: 'ghcr.io/zitadel/zitadel:v2.54.10' + image: 'ghcr.io/zitadel/zitadel:v2.64.1' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' env_file: - ./zitadel.env diff --git a/management/server/account.go b/management/server/account.go index a8a244bdf1f..1810c6b41ec 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -153,6 +153,7 @@ type AccountManager interface { FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error } type DefaultAccountManager struct { diff --git a/management/server/account_test.go b/management/server/account_test.go index 3c3fcebc67f..1cd4ae449db 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1010,7 +1010,6 @@ func TestAccountManager_AddPeer(t *testing.T) { return } expectedPeerKey := key.PublicKey().String() - expectedSetupKey := setupKey.Key peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, @@ -1035,10 +1034,6 @@ func TestAccountManager_AddPeer(t *testing.T) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } - if peer.SetupKey != expectedSetupKey { - t.Errorf("expecting just added peer to have SetupKey = %s, got %s", expectedSetupKey, peer.SetupKey) - } - if account.Network.CurrentSerial() != 1 { t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial()) } @@ -2367,7 +2362,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - SetupKey: "key", }, "peer-2": { Status: &nbpeer.PeerStatus{ @@ -2375,7 +2369,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - SetupKey: "key", }, }, expiration: time.Second, @@ -2529,7 +2522,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LoginExpired: false, }, InactivityExpirationEnabled: true, - SetupKey: "key", }, "peer-2": { Status: &nbpeer.PeerStatus{ @@ -2537,7 +2529,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LoginExpired: false, }, InactivityExpirationEnabled: true, - SetupKey: "key", }, }, expiration: time.Second, diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 188494241c6..603260dbcb2 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -146,6 +146,8 @@ const ( AccountPeerInactivityExpirationEnabled Activity = 65 AccountPeerInactivityExpirationDisabled Activity = 66 AccountPeerInactivityExpirationDurationUpdated Activity = 67 + + SetupKeyDeleted Activity = 68 ) var activityMap = map[Activity]Code{ @@ -219,6 +221,7 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, + SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go deleted file mode 100644 index de4aa334c17..00000000000 --- a/management/server/differs/netip.go +++ /dev/null @@ -1,82 +0,0 @@ -package differs - -import ( - "fmt" - "net/netip" - "reflect" - - "github.com/r3labs/diff/v3" -) - -// NetIPAddr is a custom differ for netip.Addr -type NetIPAddr struct { - DiffFunc func(path []string, a, b reflect.Value, p interface{}) error -} - -func (differ NetIPAddr) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) -} - -func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return nil - } - - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return nil - } - - fromAddr, ok1 := a.Interface().(netip.Addr) - toAddr, ok2 := b.Interface().(netip.Addr) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for netip.Addr") - } - - if fromAddr.String() != toAddr.String() { - cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) - } - - return nil -} - -func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { - differ.DiffFunc = dfunc //nolint -} - -// NetIPPrefix is a custom differ for netip.Prefix -type NetIPPrefix struct { - DiffFunc func(path []string, a, b reflect.Value, p interface{}) error -} - -func (differ NetIPPrefix) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) -} - -func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return nil - } - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return nil - } - - fromPrefix, ok1 := a.Interface().(netip.Prefix) - toPrefix, ok2 := b.Interface().(netip.Prefix) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for netip.Addr") - } - - if fromPrefix.String() != toPrefix.String() { - cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) - } - - return nil -} - -func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { - differ.DiffFunc = dfunc //nolint -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c675fc12c84..8a66da96c0f 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -521,35 +522,56 @@ func TestDNSAccountPeersUpdate(t *testing.T) { } }) - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + // Creating DNS settings with groups that have no peers should not update account peers or send peer update + t.Run("creating dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupB"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } }) - assert.NoError(t, err) - _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ - IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, - }}, - []string{"groupA"}, - true, []string{}, true, userID, false, - ) - assert.NoError(t, err) + // Creating DNS settings with groups that have peers should update account peers and send peer update + t.Run("creating dns setting with used groups", func(t *testing.T) { + err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) - // Saving DNS settings with groups that have peers should update account peers and send peer update - t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ - DisabledManagementGroups: []string{"groupA", "groupB"}, - }) + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) assert.NoError(t, err) select { @@ -559,12 +581,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged DNS settings with used groups should update account peers and not send peer update - // since there is no change in the network map - t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + // Saving DNS settings with groups that have peers should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -576,7 +597,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") + t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/group_test.go b/management/server/group_test.go index 1e59b74ef5b..89184e81927 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -8,12 +8,13 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -536,29 +537,6 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } }) - // Saving an unchanged group should trigger account peers update and not send peer update - // since there is no change in the network map - t.Run("saving unchanged group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID}, - }) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // adding peer to a used group should update account peers and send peer update t.Run("adding peer to linked group", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 9d51482481a..9b4592ccf10 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -530,10 +530,9 @@ components: type: string example: reusable expires_in: - description: Expiration time in seconds + description: Expiration time in seconds, 0 will mean the key never expires type: integer - minimum: 86400 - maximum: 31536000 + minimum: 0 example: 86400 revoked: description: Setup key revocation status @@ -2018,6 +2017,32 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Setup Key + description: Delete a Setup Key + tags: [ Setup Keys ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: keyId + required: true + schema: + type: string + description: The unique identifier of a setup key + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/groups: get: summary: List all Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index e2870d5d8ef..c1ef1ba2122 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1101,7 +1101,7 @@ type SetupKeyRequest struct { // Ephemeral Indicate that the peer will be ephemeral or not Ephemeral *bool `json:"ephemeral,omitempty"` - // ExpiresIn Expiration time in seconds + // ExpiresIn Expiration time in seconds, 0 will mean the key never expires ExpiresIn int `json:"expires_in"` // Name Setup Key name diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3f8a8554d07..c3928bff681 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -141,6 +141,7 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() { apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") + apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") } func (apiHandler *apiHandler) addPoliciesEndpoint() { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index f933eee1497..dd49c03b848 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -13,12 +13,13 @@ import ( "time" "github.com/gorilla/mux" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "golang.org/x/exp/maps" "github.com/stretchr/testify/assert" @@ -168,7 +169,6 @@ func TestGetPeers(t *testing.T) { peer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - SetupKey: "setupkey", IP: net.ParseIP("100.64.0.1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "PeerName", diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8514f0b556b..31859f59bf0 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -61,10 +61,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request expiresIn := time.Duration(req.ExpiresIn) * time.Second - day := time.Hour * 24 - year := day * 365 - if expiresIn < day || expiresIn > year { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) + if expiresIn < 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn can not be in the past"), w) return } @@ -76,6 +74,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { @@ -83,7 +82,11 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request return } - writeSuccess(r.Context(), w, setupKey) + apiSetupKeys := toResponseBody(setupKey) + // for the creation we need to send the plain key + apiSetupKeys.Key = setupKey.Key + + util.WriteJSONObject(r.Context(), w, apiSetupKeys) } // GetSetupKey is a GET request to get a SetupKey by ID @@ -98,7 +101,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) return } @@ -123,7 +126,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) return } @@ -181,6 +184,30 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } +func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + keyID := vars["keyId"] + if len(keyID) == 0 { + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) + return + } + + err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, emptyObject{}) +} + func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -206,7 +233,7 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey { return &api.SetupKey{ Id: key.Id, - Key: key.Key, + Key: key.KeySecret, Name: key.Name, Expires: key.ExpiresAt, Type: string(key.Type), diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 2d15287af25..09256d0ea5e 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -67,6 +67,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { return []*server.SetupKey{defaultKey}, nil }, + + DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { + if keyID == defaultKey.Id { + return nil + } + return status.Errorf(status.NotFound, "key %s not found", keyID) + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -81,18 +88,21 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := server.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID adminUser := server.NewAdminUser("test_user") - newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, + newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, server.SetupKeyUnlimitedUsage, true) + newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true + expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey.Key = plainKey tt := []struct { name string requestType string @@ -134,7 +144,7 @@ func TestSetupKeysHandlers(t *testing.T) { []byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(newSetupKey), + expectedSetupKey: expectedNewKey, }, { name: "Update Setup Key", @@ -150,6 +160,14 @@ func TestSetupKeysHandlers(t *testing.T) { expectedBody: true, expectedSetupKey: toResponseBody(updatedDefaultSetupKey), }, + { + name: "Delete Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/" + defaultSetupKey.Id, + requestBody: bytes.NewBuffer([]byte("")), + expectedStatus: http.StatusOK, + expectedBody: false, + }, } handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) @@ -164,6 +182,7 @@ func TestSetupKeysHandlers(t *testing.T) { router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index bdf744d211e..843fa575e83 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -267,7 +267,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { peersSSHEnabled++ } - if peer.SetupKey == "" { + if peer.UserID != "" { userPeers++ } diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 4c8baea5e87..6f12d94b401 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -2,13 +2,16 @@ package migration import ( "context" + "crypto/sha256" "database/sql" + b64 "encoding/base64" "encoding/gob" "encoding/json" "errors" "fmt" "net" "strings" + "unicode/utf8" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -205,3 +208,90 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi return nil } + +func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error { + oldColumnName := "key" + newColumnName := "key_secret" + + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if err := db.Transaction(func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&model, newColumnName) { + log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName) + if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil { + return fmt.Errorf("add column %s: %w", newColumnName, err) + } + } + + var rows []map[string]any + if err := tx.Table(tableName). + Select("id", oldColumnName, newColumnName). + Where(newColumnName + " IS NULL OR " + newColumnName + " = ''"). + Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'"). + Find(&rows).Error; err != nil { + return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err) + } + + if len(rows) == 0 { + log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName) + return nil + } + + for _, row := range rows { + var plainKey string + if columnValue := row[oldColumnName]; columnValue != nil { + value, ok := columnValue.(string) + if !ok { + return fmt.Errorf("type assertion failed") + } + plainKey = value + } + + secretKey := hiddenKey(plainKey, 4) + + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil { + return fmt.Errorf("update row with secret key: %w", err) + } + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil { + return fmt.Errorf("update row with hashed key: %w", err) + } + } + + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { + log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) + } + + return nil + }); err != nil { + return err + } + + log.Printf("Migration of plain setup key to hashed setup key completed") + return nil +} + +// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func hiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 5a192664169..51358c7ad67 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -160,3 +160,72 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr) assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged") } + +func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} + +func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + KeySecret: "EEFDA****", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} + +func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 681bf533ae4..d7139bb2a5f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -109,6 +109,14 @@ type MockAccountManager struct { GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error +} + +func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + if am.DeleteSetupKeyFunc != nil { + return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) + } + return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 96637cd39a0..846dbf02370 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1065,36 +1065,6 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { } }) - // saving unchanged nameserver group should update account peers and not send peer update - t.Run("saving unchanged nameserver group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - newNameServerGroupB.NameServers = []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - } - err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting a nameserver group should update account peers and send peer update t.Run("deleting nameserver group", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/network.go b/management/server/network.go index 8fb6a8b3c12..a5b188b4610 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -41,9 +41,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 `diff:"-"` + Serial uint64 - mu sync.Mutex `json:"-" gorm:"-" diff:"-"` + mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer.go b/management/server/peer.go index 80d43497a70..96ede151158 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "net" "slices" @@ -396,6 +398,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } upperKey := strings.ToUpper(setupKey) + hashedKey := sha256.Sum256([]byte(upperKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) var accountID string var err error addedByUser := false @@ -403,7 +407,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser = true accountID, err = am.Store.GetAccountIDByUserID(userID) } else { - accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey) + accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") @@ -448,7 +452,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -489,7 +493,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s ID: xid.New().String(), AccountID: accountID, Key: peer.Key, - SetupKey: upperKey, IP: freeIP, Meta: peer.Meta, Name: peer.Meta.Hostname, diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index ef96bce7dd8..34d7918446b 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -4,6 +4,7 @@ import ( "net" "net/netip" "slices" + "sort" "time" ) @@ -16,38 +17,36 @@ type Peer struct { AccountID string `json:"-" gorm:"index"` // WireGuard public key Key string `gorm:"index"` - // A setup key this peer was registered with - SetupKey string `diff:"-"` // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` // The user ID that registered the peer - UserID string `diff:"-"` + UserID string // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool `diff:"-"` + LoginExpirationEnabled bool - InactivityExpirationEnabled bool `diff:"-"` + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation - LastLogin time.Time `diff:"-"` + LastLogin time.Time // CreatedAt records the time the peer was created - CreatedAt time.Time `diff:"-"` + CreatedAt time.Time // Indicate ephemeral peer attribute - Ephemeral bool `diff:"-"` + Ephemeral bool // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` + Location Location `gorm:"embedded;embeddedPrefix:location_"` } type PeerStatus struct { //nolint:revive @@ -109,6 +108,12 @@ type PeerSystemMeta struct { //nolint:revive } func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { + sort.Slice(p.NetworkAddresses, func(i, j int) bool { + return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac + }) + sort.Slice(other.NetworkAddresses, func(i, j int) bool { + return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac + }) equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool { return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP }) @@ -116,6 +121,12 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { return false } + sort.Slice(p.Files, func(i, j int) bool { + return p.Files[i].Path < p.Files[j].Path + }) + sort.Slice(other.Files, func(i, j int) bool { + return other.Files[i].Path < other.Files[j].Path + }) equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool { return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning }) @@ -172,23 +183,22 @@ func (p *Peer) Copy() *Peer { peerStatus = p.Status.Copy() } return &Peer{ - ID: p.ID, - AccountID: p.AccountID, - Key: p.Key, - SetupKey: p.SetupKey, - IP: p.IP, - Meta: p.Meta, - Name: p.Name, - DNSLabel: p.DNSLabel, - Status: peerStatus, - UserID: p.UserID, - SSHKey: p.SSHKey, - SSHEnabled: p.SSHEnabled, - LoginExpirationEnabled: p.LoginExpirationEnabled, - LastLogin: p.LastLogin, - CreatedAt: p.CreatedAt, - Ephemeral: p.Ephemeral, - Location: p.Location, + ID: p.ID, + AccountID: p.AccountID, + Key: p.Key, + IP: p.IP, + Meta: p.Meta, + Name: p.Name, + DNSLabel: p.DNSLabel, + Status: peerStatus, + UserID: p.UserID, + SSHKey: p.SSHKey, + SSHEnabled: p.SSHEnabled, + LoginExpirationEnabled: p.LoginExpirationEnabled, + LastLogin: p.LastLogin, + CreatedAt: p.CreatedAt, + Ephemeral: p.Ephemeral, + Location: p.Location, InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } diff --git a/management/server/peer/peer_test.go b/management/server/peer/peer_test.go index 7b94f68c67f..3d3a2e31108 100644 --- a/management/server/peer/peer_test.go +++ b/management/server/peer/peer_test.go @@ -2,6 +2,7 @@ package peer import ( "fmt" + "net/netip" "testing" ) @@ -29,3 +30,56 @@ func BenchmarkFQDN(b *testing.B) { } }) } + +func TestIsEqual(t *testing.T) { + meta1 := PeerSystemMeta{ + NetworkAddresses: []NetworkAddress{{ + NetIP: netip.MustParsePrefix("192.168.1.2/24"), + Mac: "2", + }, + { + NetIP: netip.MustParsePrefix("192.168.1.0/24"), + Mac: "1", + }, + }, + Files: []File{ + { + Path: "/etc/hosts1", + Exist: true, + ProcessIsRunning: true, + }, + { + Path: "/etc/hosts2", + Exist: false, + ProcessIsRunning: false, + }, + }, + } + meta2 := PeerSystemMeta{ + NetworkAddresses: []NetworkAddress{ + { + NetIP: netip.MustParsePrefix("192.168.1.0/24"), + Mac: "1", + }, + { + NetIP: netip.MustParsePrefix("192.168.1.2/24"), + Mac: "2", + }, + }, + Files: []File{ + { + Path: "/etc/hosts2", + Exist: false, + ProcessIsRunning: false, + }, + { + Path: "/etc/hosts1", + Exist: true, + ProcessIsRunning: true, + }, + }, + } + if !meta1.isEqual(meta2) { + t.Error("meta1 should be equal to meta2") + } +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 7b2180bf019..5127f77fbe6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "io" "net" @@ -1090,7 +1092,6 @@ func Test_RegisterPeerByUser(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", @@ -1155,7 +1156,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "existingSetupKey", UserID: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ @@ -1175,7 +1175,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - assert.Equal(t, peer.SetupKey, existingSetupKeyID) account, err := store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) @@ -1187,8 +1186,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed) - assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) + + hashedKey := sha256.Sum256([]byte(existingSetupKeyID)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + assert.NotEqual(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed) + assert.Equal(t, 1, account.SetupKeys[encodedHashedKey].UsedTimes) } @@ -1221,7 +1223,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "existingSetupKey", UserID: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ @@ -1250,8 +1251,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) - assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) + + hashedKey := sha256.Sum256([]byte(faultyKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC()) + assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes) } func TestPeerAccountPeersUpdate(t *testing.T) { diff --git a/management/server/policy.go b/management/server/policy.go index 05554243032..43a925f8850 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - am.updateAccountPeers(ctx, account) + if anyGroupHasPeers(account, policy.ruleGroups()) { + am.updateAccountPeers(ctx, account) + } return nil } diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 5b1411702b2..e7f0f9cd2f1 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -854,16 +854,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) - t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) - }) - // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { policy := Policy{ @@ -883,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() @@ -918,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -953,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg2) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -987,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1021,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1056,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() @@ -1090,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1104,46 +1099,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged policy should trigger account peers update but not send peer update - t.Run("saving unchanged policy", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg1) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { policyID := "policy-source-destination-peers" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1164,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg2) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1180,10 +1142,10 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2 + policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7d31956f955..c63538b9d52 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" ) @@ -264,25 +265,6 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged posture check should not trigger account peers update and not send peer update - // since there is no change in the network map - t.Run("saving unchanged posture check", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Removing posture check from policy should trigger account peers update and send peer update t.Run("removing posture check from policy", func(t *testing.T) { done := make(chan struct{}) @@ -412,50 +394,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - // Updating linked posture check to policy where source has peers but destination does not, - // should not trigger account peers update or send peer update - t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupB"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: []string{postureCheck.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) - assert.NoError(t, err) - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - postureCheck.Checks = posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.29.0", - }, - } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update - t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) { + t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { policy = Policy{ ID: "policyB", Enabled: true, diff --git a/management/server/route_test.go b/management/server/route_test.go index a4b320c7ee2..4893e19b9f3 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1938,26 +1938,6 @@ func TestRouteAccountPeersUpdate(t *testing.T) { } }) - // Updating unchanged route should update account peers and not send peer update - t.Run("updating unchanged route", func(t *testing.T) { - baseRoute.Groups = []string{routeGroup1, routeGroup2} - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting the route should update account peers and send peer update t.Run("deleting route", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index e84f8fcd687..43b6e02c936 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,6 +2,9 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" + "fmt" "hash/fnv" "strconv" "strings" @@ -73,6 +76,7 @@ type SetupKey struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` Key string + KeySecret string Name string Type SetupKeyType CreatedAt time.Time @@ -104,6 +108,7 @@ func (key *SetupKey) Copy() *SetupKey { Id: key.Id, AccountID: key.AccountID, Key: key.Key, + KeySecret: key.KeySecret, Name: key.Name, Type: key.Type, CreatedAt: key.CreatedAt, @@ -120,19 +125,17 @@ func (key *SetupKey) Copy() *SetupKey { // EventMeta returns activity event meta related to the setup key func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.HiddenCopy(1).Key} + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} } -// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix. +// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. // E.g., "831F6*******************************" -func (key *SetupKey) HiddenCopy(length int) *SetupKey { - k := key.Copy() - prefix := k.Key[0:5] - if length > utf8.RuneCountInString(key.Key) { - length = utf8.RuneCountInString(key.Key) - len(prefix) - } - k.Key = prefix + strings.Repeat("*", length) - return k +func hiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) } // IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now @@ -155,6 +158,9 @@ func (key *SetupKey) IsRevoked() bool { // IsExpired if key was expired func (key *SetupKey) IsExpired() bool { + if key.ExpiresAt.IsZero() { + return false + } return time.Now().After(key.ExpiresAt) } @@ -169,30 +175,40 @@ func (key *SetupKey) IsOverUsed() bool { // GenerateSetupKey generates a new setup key func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) *SetupKey { + usageLimit int, ephemeral bool) (*SetupKey, string) { key := strings.ToUpper(uuid.New().String()) limit := usageLimit if t == SetupKeyOneOff { limit = 1 } + + expiresAt := time.Time{} + if validFor != 0 { + expiresAt = time.Now().UTC().Add(validFor) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + return &SetupKey{ Id: strconv.Itoa(int(Hash(key))), - Key: key, + Key: encodedHashedKey, + KeySecret: hiddenKey(key, 4), Name: name, Type: t, CreatedAt: time.Now().UTC(), - ExpiresAt: time.Now().UTC().Add(validFor), + ExpiresAt: expiresAt, UpdatedAt: time.Now().UTC(), Revoked: false, UsedTimes: 0, AutoGroups: autoGroups, UsageLimit: limit, Ephemeral: ephemeral, - } + }, key } // GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() *SetupKey { +func GenerateDefaultSetupKey() (*SetupKey, string) { return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, SetupKeyUnlimitedUsage, false) } @@ -213,11 +229,6 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - keyDuration := DefaultSetupKeyDuration - if expiresIn != 0 { - keyDuration = expiresIn - } - account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -227,7 +238,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, err } - setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) + setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) account.SetupKeys[setupKey.Key] = setupKey err = am.Store.SaveAccount(ctx, account) if err != nil { @@ -246,6 +257,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } } + // for the creation return the plain key to the caller + setupKey.Key = plainKey + return setupKey, nil } @@ -334,7 +348,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + return nil, status.NewUnauthorizedToViewSetupKeysError() } setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) @@ -342,18 +356,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - keys := make([]*SetupKey, 0, len(setupKeys)) - for _, key := range setupKeys { - var k *SetupKey - if !user.IsAdminOrServiceUser() { - k = key.HiddenCopy(999) - } else { - k = key.Copy() - } - keys = append(keys, k) - } - - return keys, nil + return setupKeys, nil } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -364,7 +367,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + return nil, status.NewUnauthorizedToViewSetupKeysError() } setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) @@ -377,11 +380,33 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use setupKey.UpdatedAt = setupKey.CreatedAt } - if !user.IsAdminOrServiceUser() { - setupKey = setupKey.HiddenCopy(999) + return setupKey, nil +} + +// DeleteSetupKey removes the setup key from the account +func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) } - return setupKey, nil + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return status.NewUnauthorizedToViewSetupKeysError() + } + + deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } + + err = am.Store.DeleteSetupKey(ctx, accountID, keyID) + if err != nil { + return fmt.Errorf("failed to delete setup key: %w", err) + } + + am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) + + return nil } func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 651b5401047..2ed8aef95c6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -2,8 +2,11 @@ package server import ( "context" + "crypto/sha256" + "encoding/base64" "fmt" "strconv" + "strings" "testing" "time" @@ -66,7 +69,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, - key.Id, time.Now().UTC(), autoGroups) + key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) @@ -183,7 +186,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), - tCase.expectedUpdatedAt, tCase.expectedGroups) + tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) @@ -239,10 +242,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key := GenerateDefaultSetupKey() + key, plainKey := GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) + expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -256,41 +259,41 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) + expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) @@ -299,7 +302,7 @@ func TestSetupKey_IsValid(t *testing.T) { func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, - expectedUpdatedAt time.Time, expectedAutoGroups []string) { + expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() if key.Name != expectedName { t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name) @@ -329,13 +332,23 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt) } - _, err := uuid.Parse(key.Key) - if err != nil { - t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) + if expectHashedKey { + if !isValidBase64SHA256(key.Key) { + t.Errorf("expected key to be hashed, got %v", key.Key) + } + } else { + _, err := uuid.Parse(key.Key) + if err != nil { + t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) + } } - if key.Id != strconv.Itoa(int(Hash(key.Key))) { - t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id) + if !strings.HasSuffix(key.KeySecret, "****") { + t.Errorf("expected key secret to be secure, got %v", key.Key) + } + + if key.Id != expectedID { + t.Errorf("expected key Id %v, got %v", expectedID, key.Id) } if len(key.AutoGroups) != len(expectedAutoGroups) { @@ -344,13 +357,26 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal") } +func isValidBase64SHA256(encodedKey string) bool { + decoded, err := base64.StdEncoding.DecodeString(encodedKey) + if err != nil { + return false + } + + if len(decoded) != sha256.Size { + return false + } + + return true +} + func TestSetupKey_Copy(t *testing.T) { - key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, - key.UpdatedAt, key.AutoGroups) + key.UpdatedAt, key.AutoGroups, true) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 0098217c00e..b1b8330ba3b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -516,7 +516,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* startTime := time.Now() var key SetupKey - result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) + result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -842,7 +842,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) startTime := time.Now() var accountID string - result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) + result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -1111,7 +1111,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&setupKey, keyQueryCondition, strings.ToUpper(key)) + First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") @@ -1414,6 +1414,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) } +func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { + return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID) +} + // getRecords retrieves records from the database based on the account ID. func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { var record []T @@ -1446,3 +1450,21 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// deleteRecordByID deletes a record by its ID and account ID from the database. +func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { + var record T + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "record not found") + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 000eb1b11b2..b371e231319 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "math/rand" "net" @@ -71,7 +73,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -81,7 +83,6 @@ func runLargeTest(t *testing.T, store Store) { peer := &nbpeer.Peer{ ID: peerID, Key: peerID, - SetupKey: "", IP: netIP, Name: peerID, DNSLabel: peerID, @@ -133,7 +134,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -215,30 +216,28 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey = GenerateDefaultSetupKey() + setupKey, _ = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - SetupKey: "peerkeysetupkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account2) @@ -297,15 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user @@ -394,13 +392,12 @@ func TestSqlite_SavePeer(t *testing.T) { // save status of non-existing peer peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() err = store.SavePeer(ctx, account.Id, peer) @@ -453,13 +450,12 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -720,15 +716,14 @@ func newSqliteStore(t *testing.T) *SqlStore { func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ - Key: "peerkey" + str, - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey" + str, + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } return store.SaveAccount(context.Background(), account) @@ -760,30 +755,28 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey = GenerateDefaultSetupKey() + setupKey, _ = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - SetupKey: "peerkeysetupkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account2) @@ -842,15 +835,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user @@ -921,13 +913,12 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -1118,12 +1109,17 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) - assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) + assert.Equal(t, encodedHashedKey, setupKey.Key) + assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } @@ -1138,24 +1134,28 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 0, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 1, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } @@ -1264,3 +1264,32 @@ func TestSqlite_GetGroupByName(t *testing.T) { require.NoError(t, err) require.Equal(t, "All", group.Name) } + +func Test_DeleteSetupKeySuccessfully(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + + err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + require.NoError(t, err) + + _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) + require.Error(t, err) +} + +func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nonExistingKeyID := "non-existing-key-id" + + err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + require.Error(t, err) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 0188b5feea0..f5befa096d4 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -119,4 +119,12 @@ func NewGetUserFromStoreError() error { // NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context func NewStoreContextCanceledError(duration time.Duration) error { return Errorf(Internal, "store access: context canceled after %v", duration) + + // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key +func NewInvalidKeyIDError() error { + return Errorf(InvalidArgument, "invalid key ID") } + +// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key +func NewUnauthorizedToViewSetupKeysError() error { + return Errorf(Unauthorized, "only users with admin power can view setup keys") diff --git a/management/server/store.go b/management/server/store.go index 131fd8aaab6..087c9884763 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -124,6 +124,7 @@ type Store interface { // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + DeleteSetupKey(ctx context.Context, accountID, keyID string) error } type StoreEngine string @@ -241,6 +242,9 @@ func getMigrations(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, + func(db *gorm.DB) error { + return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + }, } } diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 6fb96c97124..59b6fd09492 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,13 +2,9 @@ package server import ( "context" - "fmt" - "runtime/debug" "sync" "time" - "github.com/netbirdio/netbird/management/server/differs" - "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -25,8 +21,6 @@ type UpdateMessage struct { type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage - // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. - peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics @@ -36,10 +30,9 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - peerUpdateMessage: make(map[string]*UpdateMessage), - channelsMux: &sync.RWMutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -48,15 +41,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool - // skip sending sync update to the peer if there is no change in update message, - // it will not check on turn credential refresh as we do not send network map or client posture checks - if update.NetworkMap != nil { - updated := p.handlePeerMessageUpdate(ctx, peerID, update) - if !updated { - return - } - } - p.channelsMux.Lock() defer func() { @@ -66,16 +50,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() - if update.NetworkMap != nil { - lastSentUpdate := p.peerUpdateMessage[peerID] - if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() { - log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", - peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) - return - } - p.peerUpdateMessage[peerID] = update - } - if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -108,7 +82,6 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) - delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -123,7 +96,6 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) - delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -200,72 +172,3 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } - -// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. -func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { - p.channelsMux.RLock() - lastSentUpdate := p.peerUpdateMessage[peerID] - p.channelsMux.RUnlock() - - if lastSentUpdate != nil { - updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update) - if err != nil { - log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) - return false - } - if !updated { - log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) - return false - } - } - - return true -} - -// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. -func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) { - defer func() { - if r := recover(); r != nil { - log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) - isNew, err = true, nil - } - }() - - if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { - return false, nil - } - - differ, err := diff.NewDiffer( - diff.CustomValueDiffers(&differs.NetIPAddr{}), - diff.CustomValueDiffers(&differs.NetIPPrefix{}), - ) - if err != nil { - return false, fmt.Errorf("failed to create differ: %v", err) - } - - lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks) - currFiles := getChecksFiles(currUpdateToSend.Update.Checks) - - changelog, err := differ.Diff(lastSentFiles, currFiles) - if err != nil { - return false, fmt.Errorf("failed to diff checks: %v", err) - } - if len(changelog) > 0 { - return true, nil - } - - changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) - if err != nil { - return false, fmt.Errorf("failed to diff network map: %v", err) - } - return len(changelog) > 0, nil -} - -// getChecksFiles returns a list of files from the given checks. -func getChecksFiles(checks []*proto.Checks) []string { - files := make([]string, 0, len(checks)) - for _, check := range checks { - files = append(files, check.GetFiles()...) - } - return files -} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 52b715e9503..69f5b895c45 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -2,19 +2,10 @@ package server import ( "context" - "net" - "net/netip" "testing" "time" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - nbroute "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/util" - "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -86,470 +77,3 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } - -func TestHandlePeerMessageUpdate(t *testing.T) { - tests := []struct { - name string - peerID string - existingUpdate *UpdateMessage - newUpdate *UpdateMessage - expectedResult bool - }{ - { - name: "update message with turn credentials update", - peerID: "peer", - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - WiretrusteeConfig: &proto.WiretrusteeConfig{}, - }, - }, - expectedResult: true, - }, - { - name: "update message for peer without existing update", - peerID: "peer1", - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - expectedResult: true, - }, - { - name: "update message with no changes in update", - peerID: "peer2", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - expectedResult: false, - }, - { - name: "update message with changes in checks", - peerID: "peer3", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 2}, - Checks: []*proto.Checks{ - { - Files: []string{"/usr/bin/netbird"}, - }, - }, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - expectedResult: true, - }, - { - name: "update message with lower serial number", - peerID: "peer4", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 2}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - expectedResult: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := NewPeersUpdateManager(nil) - ctx := context.Background() - - if tt.existingUpdate != nil { - p.peerUpdateMessage[tt.peerID] = tt.existingUpdate - } - - result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) - assert.Equal(t, tt.expectedResult, result) - }) - } -} - -func TestIsNewPeerUpdateMessage(t *testing.T) { - t.Run("Unchanged value", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.False(t, message) - }) - - t.Run("Unchanged value with serial incremented", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.False(t, message) - }) - - t.Run("Updating routes network", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - - }) - - t.Run("Updating routes groups", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating network map peers", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newPeer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.4"), - SSHEnabled: true, - Key: "peer4-key", - DNSLabel: "peer4", - SSHKey: "peer4-ssh-key", - } - newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating process check", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - - newUpdateMessage2 := createMockUpdateMessage(t) - newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.False(t, message) - - newUpdateMessage3 := createMockUpdateMessage(t) - newUpdateMessage3.Update.Checks = []*proto.Checks{} - newUpdateMessage3.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3) - assert.NoError(t, err) - assert.True(t, message) - - newUpdateMessage4 := createMockUpdateMessage(t) - check := &posture.Checks{ - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/local/netbird", - MacPath: "/usr/bin/netbird", - }, - }, - }, - }, - } - newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} - newUpdateMessage4.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4) - assert.NoError(t, err) - assert.True(t, message) - - newUpdateMessage5 := createMockUpdateMessage(t) - check = &posture.Checks{ - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/bin/netbird", - WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", - MacPath: "/usr/local/netbird", - }, - }, - }, - }, - } - newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} - newUpdateMessage5.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating DNS configuration", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newDomain := "newexample.com" - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append( - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains, - newDomain, - ) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating peer IP", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating firewall rule", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Add new firewall rule", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newRule := &FirewallRule{ - PeerIP: "192.168.1.3", - Direction: firewallRuleDirectionOUT, - Action: string(PolicyTrafficActionDrop), - Protocol: string(PolicyRuleProtocolUDP), - Port: "53", - } - newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Removing nameserver", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating name server IP", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating custom DNS zone", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) - assert.NoError(t, err) - assert.True(t, message) - }) - -} - -func createMockUpdateMessage(t *testing.T) *UpdateMessage { - t.Helper() - - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") - if err != nil { - t.Fatal(err) - } - domainList, err := domain.FromStringList([]string{"example.com"}) - if err != nil { - t.Fatal(err) - } - - config := &Config{ - Signal: &Host{ - Proto: "https", - URI: "signal.uri", - Username: "", - Password: "", - }, - Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, - TURNConfig: &TURNConfig{ - Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, - }, - } - peer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.1"), - SSHEnabled: true, - Key: "peer-key", - DNSLabel: "peer1", - SSHKey: "peer1-ssh-key", - } - - secretManager := NewTimeBasedAuthSecretsManager( - NewPeersUpdateManager(nil), - &TURNConfig{ - TimeBasedCredentials: false, - CredentialsTTL: util.Duration{ - Duration: defaultDuration, - }, - Secret: "secret", - Turns: []*Host{TurnTestHost}, - }, - &Relay{ - Addresses: []string{"localhost:0"}, - CredentialsTTL: util.Duration{Duration: time.Hour}, - Secret: "secret", - }, - ) - - networkMap := &NetworkMap{ - Network: &Network{Net: *ipNet, Serial: 1000}, - Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, - OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, - Routes: []*nbroute.Route{ - { - ID: "route1", - Network: netip.MustParsePrefix("10.0.0.0/24"), - KeepRoute: true, - NetID: "route1", - Peer: "peer1", - NetworkType: 1, - Masquerade: true, - Metric: 9999, - Enabled: true, - Groups: []string{"test1", "test2"}, - }, - { - ID: "route2", - Domains: domainList, - KeepRoute: true, - NetID: "route2", - Peer: "peer1", - NetworkType: 1, - Masquerade: true, - Metric: 9999, - Enabled: true, - Groups: []string{"test1", "test2"}, - }, - }, - DNSConfig: nbdns.Config{ - ServiceEnable: true, - NameServerGroups: []*nbdns.NameServerGroup{ - { - NameServers: []nbdns.NameServer{{ - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }}, - Primary: true, - Domains: []string{"example.com"}, - Enabled: true, - SearchDomainsEnabled: true, - }, - { - ID: "ns1", - NameServers: []nbdns.NameServer{{ - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }}, - Groups: []string{"group1"}, - Primary: true, - Domains: []string{"example.com"}, - Enabled: true, - SearchDomainsEnabled: true, - }, - }, - CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, - }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, - }, - } - dnsName := "example.com" - checks := []*posture.Checks{ - { - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/bin/netbird", - WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", - MacPath: "/usr/bin/netbird", - }, - }, - }, - }, - }, - } - dnsCache := &DNSConfigCache{} - - turnToken, err := secretManager.GenerateTurnToken() - if err != nil { - t.Fatal(err) - } - - relayToken, err := secretManager.GenerateRelayToken() - if err != nil { - t.Fatal(err) - } - - return &UpdateMessage{ - Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), - NetworkMap: networkMap, - } -}