diff --git a/api/types/server.go b/api/types/server.go index 47ec7c92ee1d0..4ea0ec8e51318 100644 --- a/api/types/server.go +++ b/api/types/server.go @@ -103,6 +103,81 @@ type Server interface { GetAWSAccountID() string } +// ReadOnlyServer represents a Node, Proxy or Auth server in a Teleport cluster +type ReadOnlyServer interface { + // GetKind returns resource kind + GetKind() string + // GetSubKind returns resource subkind + GetSubKind() string + // GetVersion returns resource version + GetVersion() string + // GetName returns the name of the resource + GetName() string + // Expiry returns object expiry setting + Expiry() time.Time + // GetMetadata returns object metadata + GetMetadata() Metadata + // GetRevision returns the revision + GetRevision() string + // Origin returns the origin value of the resource. + Origin() string + // GetLabel retrieves the label with the provided key. + GetLabel(key string) (value string, ok bool) + // GetAllLabels returns all resource's labels. + GetAllLabels() map[string]string + // GetStaticLabels returns the resource's static labels. + GetStaticLabels() map[string]string + // SetStaticLabels sets the resource's static labels. + SetStaticLabels(sl map[string]string) + // MatchSearch goes through select field values of a resource + // and tries to match against the list of search values. + MatchSearch(searchValues []string) bool + // GetTeleportVersion returns the teleport version the server is running on + GetTeleportVersion() string + // GetAddr return server address + GetAddr() string + // GetHostname returns server hostname + GetHostname() string + // GetNamespace returns server namespace + GetNamespace() string + // GetLabels returns server's static label key pairs + GetLabels() map[string]string + // GetCmdLabels gets command labels + GetCmdLabels() map[string]CommandLabel + // GetPublicAddr returns a public address where this server can be reached. + GetPublicAddr() string + // GetPublicAddrs returns a list of public addresses where this server can be reached. + GetPublicAddrs() []string + // GetRotation gets the state of certificate authority rotation. + GetRotation() Rotation + // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. + GetUseTunnel() bool + // String returns string representation of the server + String() string + // GetPeerAddr returns the peer address of the server. + GetPeerAddr() string + // GetProxyIDs returns a list of proxy ids this service is connected to. + GetProxyIDs() []string + + // GetCloudMetadata gets the cloud metadata for the server. + GetCloudMetadata() *CloudMetadata + // GetAWSInfo returns the AWSInfo for the server. + GetAWSInfo() *AWSInfo + + // IsOpenSSHNode returns whether the connection to this Server must use OpenSSH. + // This returns true for SubKindOpenSSHNode and SubKindOpenSSHEICENode. + IsOpenSSHNode() bool + + // IsEICE returns whether the Node is an EICE instance. + // Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels). + IsEICE() bool + + // GetAWSInstanceID returns the AWS Instance ID if this node comes from an EC2 instance. + GetAWSInstanceID() string + // GetAWSAccountID returns the AWS Account ID if this node comes from an EC2 instance. + GetAWSAccountID() string +} + // NewServer creates an instance of Server. func NewServer(name, kind string, spec ServerSpecV2) (Server, error) { return NewServerWithLabels(name, kind, spec, map[string]string{}) diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index f03308f61060d..bb3cd352373b6 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -98,7 +98,7 @@ type TLSServerConfig struct { // kubernetes cluster name. Proxy uses this map to route requests to the correct // kubernetes_service. The servers are kept in memory to avoid making unnecessary // unmarshal calls followed by filtering and to improve memory usage. - KubernetesServersWatcher *services.GenericWatcher[types.KubeServer] + KubernetesServersWatcher *services.GenericWatcher[types.KubeServer, types.KubeServer] // PROXYProtocolMode controls behavior related to unsigned PROXY protocol headers. PROXYProtocolMode multiplexer.PROXYProtocolMode // InventoryHandle is used to send kube server heartbeats via the inventory control stream. @@ -170,7 +170,7 @@ type TLSServer struct { closeContext context.Context closeFunc context.CancelFunc // kubeClusterWatcher monitors changes to kube cluster resources. - kubeClusterWatcher *services.GenericWatcher[types.KubeCluster] + kubeClusterWatcher *services.GenericWatcher[types.KubeCluster, types.KubeCluster] // reconciler reconciles proxied kube clusters with kube_clusters resources. reconciler *services.Reconciler[types.KubeCluster] // monitoredKubeClusters contains all kube clusters the proxied kube_clusters are diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go index 1dafb4d96f4eb..ce3baf9e0ac23 100644 --- a/lib/kube/proxy/watcher.go +++ b/lib/kube/proxy/watcher.go @@ -87,7 +87,7 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) { // startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and // registers/unregisters the proxied Kube Cluster accordingly. -func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster], error) { +func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster, types.KubeCluster], error) { if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService { s.log.Debug("Not initializing Kube Cluster resource watcher.") return nil, nil diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 18e22adc798ee..99b216c35d6c4 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -383,7 +383,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check // site is the minimum interface needed to match servers // for a reversetunnelclient.RemoteSite. It makes testing easier. type site interface { - GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) + GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) } @@ -394,13 +394,13 @@ type remoteSite struct { } // GetNodes uses the wrapped sites NodeWatcher to filter nodes -func (r remoteSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) { +func (r remoteSite) GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error) { watcher, err := r.site.NodeWatcher() if err != nil { return nil, trace.Wrap(err) } - return watcher.GetNodes(ctx, fn), nil + return watcher.CurrentResourcesWithFilter(ctx, fn) } // GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig @@ -450,7 +450,7 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re var maxScore int scores := make(map[string]int) - matches, err := site.GetNodes(ctx, func(server services.Node) bool { + matches, err := site.GetNodes(ctx, func(server types.ReadOnlyServer) bool { score := routeMatcher.RouteToServerScore(server) if score < 1 { return false diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index 48268cf355961..77675930db343 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -37,7 +37,6 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" ) @@ -51,7 +50,7 @@ func (t testSite) GetClusterNetworkingConfig(ctx context.Context) (types.Cluster return t.cfg, nil } -func (t testSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) { +func (t testSite) GetNodes(ctx context.Context, fn func(n types.Server) bool) ([]types.Server, error) { var out []types.Server for _, s := range t.nodes { if fn(s) { diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 54e9d5db3a680..57a37e57490f2 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -180,7 +180,7 @@ func (s *localSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, err } // NodeWatcher returns a services.NodeWatcher for this cluster. -func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) { +func (s *localSite) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { return s.srv.NodeWatcher, nil } @@ -739,7 +739,11 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch return case <-proxyResyncTicker.Chan(): var req discoveryRequest - req.SetProxies(s.srv.proxyWatcher.GetCurrent()) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + req.SetProxies(proxies) if err := rconn.sendDiscoveryRequest(req); err != nil { logger.WithError(err).Debug("Marking connection invalid on error") @@ -764,9 +768,12 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - current := s.srv.proxyWatcher.GetCurrent() - if len(current) > 0 { - rconn.updateProxies(current) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + if len(proxies) > 0 { + rconn.updateProxies(proxies) } reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc() firstHeartbeat = false @@ -935,7 +942,7 @@ func (s *localSite) periodicFunctions() { // sshTunnelStats reports SSH tunnel statistics for the cluster. func (s *localSite) sshTunnelStats() error { - missing := s.srv.NodeWatcher.GetNodes(s.srv.ctx, func(server services.Node) bool { + missing, err := s.srv.NodeWatcher.CurrentResourcesWithFilter(s.srv.ctx, func(server types.ReadOnlyServer) bool { // Skip over any servers that have a TTL larger than announce TTL (10 // minutes) and are non-IoT SSH servers (they won't have tunnels). // @@ -967,6 +974,9 @@ func (s *localSite) sshTunnelStats() error { return err != nil }) + if err != nil { + return trace.Wrap(err) + } // Update Prometheus metrics and also log if any tunnels are missing. missingSSHTunnels.Set(float64(len(missing))) diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index fc16cbe11cefa..2460538954acf 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -90,7 +90,7 @@ func (p *clusterPeers) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, return peer.CachingAccessPoint() } -func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) { +func (p *clusterPeers) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { peer, err := p.pickPeer() if err != nil { return nil, trace.Wrap(err) @@ -202,7 +202,7 @@ func (s *clusterPeer) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, e return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s) } -func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) { +func (s *clusterPeer) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { return nil, trace.ConnectionProblem(nil, "unable to fetch node watcher, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 8e8b7e4c3fe79..d2c6e669b9442 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -85,7 +85,7 @@ type remoteSite struct { remoteAccessPoint authclient.RemoteProxyAccessPoint // nodeWatcher provides access the node set for the remote site - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // remoteCA is the last remote certificate authority recorded by the client. // It is used to detect CA rotation status changes. If the rotation @@ -164,7 +164,7 @@ func (s *remoteSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, er } // NodeWatcher returns the services.NodeWatcher for the remote cluster. -func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) { +func (s *remoteSite) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) { return s.nodeWatcher, nil } @@ -429,7 +429,11 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch return case <-proxyResyncTicker.Chan(): var req discoveryRequest - req.SetProxies(s.srv.proxyWatcher.GetCurrent()) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + req.SetProxies(proxies) if err := conn.sendDiscoveryRequest(req); err != nil { logger.WithError(err).Debug("Marking connection invalid on error") @@ -458,9 +462,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch if firstHeartbeat { // as soon as the agent connects and sends a first heartbeat // send it the list of current proxies back - current := s.srv.proxyWatcher.GetCurrent() - if len(current) > 0 { - conn.updateProxies(current) + proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx) + if err != nil { + logger.WithError(err).Warn("Failed to get proxy set") + } + if len(proxies) > 0 { + conn.updateProxies(proxies) } firstHeartbeat = false } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index c28241d3c3990..be125e075f5e5 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -114,7 +114,7 @@ type server struct { // proxyWatcher monitors changes to the proxies // and broadcasts updates - proxyWatcher *services.GenericWatcher[types.Server] + proxyWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // offlineThreshold is how long to wait for a keep alive message before // marking a reverse tunnel connection as invalid. @@ -201,7 +201,7 @@ type Config struct { LockWatcher *services.LockWatcher // NodeWatcher is a node watcher. - NodeWatcher *services.NodeWatcher + NodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // CertAuthorityWatcher is a cert authority watcher. CertAuthorityWatcher *services.CertAuthorityWatcher diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index f7e8dfb47ef63..4334c05dcafae 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -123,7 +123,7 @@ type RemoteSite interface { // but is resilient to auth server crashes CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) // NodeWatcher returns the node watcher that maintains the node set for the site - NodeWatcher() (*services.NodeWatcher, error) + NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/services/watcher.go b/lib/services/watcher.go index db5fd6db2f9ba..adae27d19bc6d 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -88,23 +88,23 @@ func watchKindsString(kinds []types.WatchKind) string { // ResourceWatcherConfig configures resource watcher. type ResourceWatcherConfig struct { - // Component is a component used in logs. - Component string // TODO(tross): remove this once e has been updated. Log logrus.FieldLogger + // Clock is used to control time. + Clock clockwork.Clock + // Client is used to create new watchers + Client types.Events // Logger emits log messages. Logger *slog.Logger + // ResetC is a channel to notify of internal watcher reset (used in tests). + ResetC chan time.Duration + // Component is a component used in logs. + Component string // MaxRetryPeriod is the maximum retry period on failed watchers. MaxRetryPeriod time.Duration - // Clock is used to control time. - Clock clockwork.Clock - // Client is used to create new watchers. - Client types.Events // MaxStaleness is a maximum acceptable staleness for the locally maintained // resources, zero implies no staleness detection. MaxStaleness time.Duration - // ResetC is a channel to notify of internal watcher reset (used in tests). - ResetC chan time.Duration // QueueSize is an optional queue size QueueSize int } @@ -168,28 +168,23 @@ func newResourceWatcher(ctx context.Context, collector resourceCollector, cfg Re // resourceWatcher monitors additions, updates and deletions // to a set of resources. type resourceWatcher struct { - ResourceWatcherConfig - collector resourceCollector - - // ctx is a context controlling the lifetime of this resourceWatcher - // instance. - ctx context.Context - cancel context.CancelFunc - - // retry is used to manage backoff logic for watchers. - retry retryutils.Retry - // failureStartedAt records when the current sync failures were first // detected, zero if there are no failures present. failureStartedAt time.Time - + collector resourceCollector + // ctx is a context controlling the lifetime of this resourceWatcher + // instance. + ctx context.Context + // retry is used to manage backoff logic for watchers. + retry retryutils.Retry + cancel context.CancelFunc // LoopC is a channel to check whether the watch loop is running // (used in tests). LoopC chan struct{} - // StaleC is a channel that can trigger the condition of resource staleness // (used in tests). StaleC chan struct{} + ResourceWatcherConfig } // Done returns a channel that signals resource watcher closure. @@ -383,7 +378,6 @@ func (p *resourceWatcher) watch() error { // ProxyWatcherConfig is a ProxyWatcher configuration. type ProxyWatcherConfig struct { - ResourceWatcherConfig // ProxyGetter is used to directly fetch the list of active proxies. ProxyGetter // ProxyDiffer is used to decide whether a put operation on an existing proxy should @@ -393,12 +387,13 @@ type ProxyWatcherConfig struct { // a fresh list at startup and subsequently a list of all known resourxes // whenever an addition or deletion is detected. ProxiesC chan []types.Server + ResourceWatcherConfig } // NewProxyWatcher returns a new instance of GenericWatcher that is configured // to watch for changes. -func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*GenericWatcher[types.Server], error) { - w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server]{ +func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*GenericWatcher[types.Server, types.ReadOnlyServer], error) { + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, types.ReadOnlyServer]{ ResourceWatcherConfig: cfg.ResourceWatcherConfig, ResourceKind: types.KindProxy, ResourceKey: types.Server.GetName, @@ -408,23 +403,24 @@ func NewProxyWatcher(ctx context.Context, cfg ProxyWatcherConfig) (*GenericWatch ResourcesC: cfg.ProxiesC, ResourceDiffer: cfg.ProxyDiffer, RequireResourcesForInitialBroadcast: true, + CloneFunc: types.Server.DeepCopy, }) return w, trace.Wrap(err) } // DatabaseWatcherConfig is a DatabaseWatcher configuration. type DatabaseWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // DatabaseGetter is responsible for fetching database resources. DatabaseGetter // DatabasesC receives up-to-date list of all database resources. DatabasesC chan []types.Database + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig } // NewDatabaseWatcher returns a new instance of DatabaseWatcher. -func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*GenericWatcher[types.Database], error) { - w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Database]{ +func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*GenericWatcher[types.Database, types.Database], error) { + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Database, types.Database]{ ResourceWatcherConfig: cfg.ResourceWatcherConfig, ResourceKind: types.KindDatabase, ResourceKey: types.Database.GetName, @@ -432,23 +428,26 @@ func NewDatabaseWatcher(ctx context.Context, cfg DatabaseWatcherConfig) (*Generi return cfg.DatabaseGetter.GetDatabases(ctx) }, ResourcesC: cfg.DatabasesC, + CloneFunc: func(resource types.Database) types.Database { + return resource.Copy() + }, }) return w, trace.Wrap(err) } // AppWatcherConfig is an AppWatcher configuration. type AppWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // AppGetter is responsible for fetching application resources. AppGetter // AppsC receives up-to-date list of all application resources. AppsC chan []types.Application + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig } // NewAppWatcher returns a new instance of AppWatcher. -func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*GenericWatcher[types.Application], error) { - w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Application]{ +func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*GenericWatcher[types.Application, types.Application], error) { + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Application, types.Application]{ ResourceWatcherConfig: cfg.ResourceWatcherConfig, ResourceKind: types.KindApp, ResourceKey: types.Application.GetName, @@ -456,6 +455,9 @@ func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*GenericWatcher[t return cfg.AppGetter.GetApps(ctx) }, ResourcesC: cfg.AppsC, + CloneFunc: func(resource types.Application) types.Application { + return resource.Copy() + }, }) return w, trace.Wrap(err) @@ -463,15 +465,15 @@ func NewAppWatcher(ctx context.Context, cfg AppWatcherConfig) (*GenericWatcher[t // KubeServerWatcherConfig is an KubeServerWatcher configuration. type KubeServerWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // KubernetesServerGetter is responsible for fetching kube_server resources. KubernetesServerGetter + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig } // NewKubeServerWatcher returns a new instance of KubeServerWatcher. -func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*GenericWatcher[types.KubeServer], error) { - w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeServer]{ +func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*GenericWatcher[types.KubeServer, types.KubeServer], error) { + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeServer, types.KubeServer]{ ResourceWatcherConfig: cfg.ResourceWatcherConfig, ResourceKind: types.KindKubeServer, ResourceGetter: func(ctx context.Context) ([]types.KubeServer, error) { @@ -481,23 +483,24 @@ func NewKubeServerWatcher(ctx context.Context, cfg KubeServerWatcherConfig) (*Ge return resource.GetHostID() + resource.GetName() }, DisableUpdateBroadcast: true, + CloneFunc: types.KubeServer.Copy, }) return w, trace.Wrap(err) } // KubeClusterWatcherConfig is an KubeClusterWatcher configuration. type KubeClusterWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // KubernetesGetter is responsible for fetching kube_cluster resources. KubernetesClusterGetter // KubeClustersC receives up-to-date list of all kube_cluster resources. KubeClustersC chan []types.KubeCluster + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig } // NewKubeClusterWatcher returns a new instance of KubeClusterWatcher. -func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (*GenericWatcher[types.KubeCluster], error) { - w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeCluster]{ +func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (*GenericWatcher[types.KubeCluster, types.KubeCluster], error) { + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.KubeCluster, types.KubeCluster]{ ResourceWatcherConfig: cfg.ResourceWatcherConfig, ResourceKind: types.KindKubernetesCluster, ResourceGetter: func(ctx context.Context) ([]types.KubeCluster, error) { @@ -505,22 +508,34 @@ func NewKubeClusterWatcher(ctx context.Context, cfg KubeClusterWatcherConfig) (* }, ResourceKey: types.KubeCluster.GetName, ResourcesC: cfg.KubeClustersC, + CloneFunc: func(resource types.KubeCluster) types.KubeCluster { + return resource.Copy() + }, }) return w, trace.Wrap(err) } // GenericWatcherConfig is a generic resource watcher configuration. -type GenericWatcherConfig[T any] struct { - ResourceWatcherConfig - // ResourceKind specifies the kind of resource the watcher is monitoring. - ResourceKind string +type GenericWatcherConfig[T any, R any] struct { // ResourceGetter is used to directly fetch the current set of resources. ResourceGetter func(context.Context) ([]T, error) // ResourceDiffer is used to decide whether a put operation on an existing ResourceGetter should - // trigger a event. + // trigger an event. ResourceDiffer func(old, new T) bool // ResourceKey defines how the resources should be keyed. ResourceKey func(resource T) string + // ResourcesC is a channel used to report the current resourxe set. It receives + // a fresh list at startup and subsequently a list of all known resourxes + // whenever an addition or deletion is detected. + ResourcesC chan []T + // CloneFunc defines how a resource is cloned. All resources provided via + // the broadcast mechanism, or retrieved via [GenericWatcer.CurrentResources] + // or [GenericWatcher.CurrentResourcesWithFilter] will be cloned by this + // mechanism before being provided to callers. + CloneFunc func(resource T) T + ResourceWatcherConfig + // ResourceKind specifies the kind of resource the watcher is monitoring. + ResourceKind string // RequireResourcesForInitialBroadcast indicates whether an update should be // performed if the initial set of resources is empty. RequireResourcesForInitialBroadcast bool @@ -529,14 +544,10 @@ type GenericWatcherConfig[T any] struct { // [GenericWatcher.CurrentResourcesWithFilter] manually to retrieve the active // resource set. DisableUpdateBroadcast bool - // ResourcesC is a channel used to report the current resourxe set. It receives - // a fresh list at startup and subsequently a list of all known resourxes - // whenever an addition or deletion is detected. - ResourcesC chan []T } // CheckAndSetDefaults checks parameters and sets default values. -func (cfg *GenericWatcherConfig[T]) CheckAndSetDefaults() error { +func (cfg *GenericWatcherConfig[T, R]) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } @@ -564,7 +575,7 @@ func (cfg *GenericWatcherConfig[T]) CheckAndSetDefaults() error { } // NewGenericResourceWatcher returns a new instance of resource watcher. -func NewGenericResourceWatcher[T any](ctx context.Context, cfg GenericWatcherConfig[T]) (*GenericWatcher[T], error) { +func NewGenericResourceWatcher[T any, R any](ctx context.Context, cfg GenericWatcherConfig[T, R]) (*GenericWatcher[T, R], error) { if err := cfg.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -578,7 +589,7 @@ func NewGenericResourceWatcher[T any](ctx context.Context, cfg GenericWatcherCon return nil, trace.Wrap(err) } - collector := &genericCollector[T]{ + collector := &genericCollector[T, R]{ GenericWatcherConfig: cfg, initializationC: make(chan struct{}), cache: cache, @@ -588,17 +599,25 @@ func NewGenericResourceWatcher[T any](ctx context.Context, cfg GenericWatcherCon if err != nil { return nil, trace.Wrap(err) } - return &GenericWatcher[T]{watcher, collector}, nil + return &GenericWatcher[T, R]{watcher, collector}, nil } // GenericWatcher is built on top of resourceWatcher to monitor additions // and deletions to the set of resources. -type GenericWatcher[T any] struct { +type GenericWatcher[T any, R any] struct { *resourceWatcher - *genericCollector[T] + *genericCollector[T, R] +} + +// ResourceCount returns the current number of resources known to the watcher. +func (g *GenericWatcher[T, R]) ResourceCount() int { + g.rw.RLock() + defer g.rw.RUnlock() + return len(g.current) } -func (g *GenericWatcher[T]) CurrentResources(ctx context.Context) ([]T, error) { +// CurrentResources returns a copy of the resources known to the watcher. +func (g *GenericWatcher[T, R]) CurrentResources(ctx context.Context) ([]T, error) { if err := g.refreshStaleResources(ctx); err != nil { return nil, trace.Wrap(err) } @@ -606,10 +625,12 @@ func (g *GenericWatcher[T]) CurrentResources(ctx context.Context) ([]T, error) { g.rw.RLock() defer g.rw.RUnlock() - return resourcesToSlice(g.current), nil + return resourcesToSlice(g.current, g.CloneFunc), nil } -func (g *GenericWatcher[T]) CurrentResourcesWithFilter(ctx context.Context, filter func(T) bool) ([]T, error) { +// CurrentResourcesWithFilter returns a copy of the resources known to the watcher +// that match the provided filter. +func (g *GenericWatcher[T, R]) CurrentResourcesWithFilter(ctx context.Context, filter func(R) bool) ([]T, error) { if err := g.refreshStaleResources(ctx); err != nil { return nil, trace.Wrap(err) } @@ -617,10 +638,14 @@ func (g *GenericWatcher[T]) CurrentResourcesWithFilter(ctx context.Context, filt g.rw.RLock() defer g.rw.RUnlock() + r := func(a any) R { + return a.(R) + } + var out []T for _, resource := range g.current { - if filter(resource) { - out = append(out, resource) + if filter(r(resource)) { + out = append(out, g.CloneFunc(resource)) } } @@ -628,36 +653,29 @@ func (g *GenericWatcher[T]) CurrentResourcesWithFilter(ctx context.Context, filt } // genericCollector accompanies resourceWatcher when monitoring proxies. -type genericCollector[T any] struct { - GenericWatcherConfig[T] +type genericCollector[T any, R any] struct { + GenericWatcherConfig[T, R] // current holds a map of the currently known resources (keyed by server name, // RWMutex protected). current map[string]T - rw sync.RWMutex initializationC chan struct{} - once sync.Once + // cache is a helper for temporarily storing the results of CurrentResources. + // It's used to limit the number of calls to the backend. + cache *utils.FnCache + rw sync.RWMutex + once sync.Once // stale is used to indicate that the watcher is stale and needs to be // refreshed. stale atomic.Bool - // cache is a helper for temporarily storing the results of CurrentResources. - // It's used to limit the amount of calls to the backend. - cache *utils.FnCache -} - -// GetCurrent returns the currently stored proxies. -func (p *genericCollector[T]) GetCurrent() []T { - p.rw.RLock() - defer p.rw.RUnlock() - return resourcesToSlice(p.current) } // resourceKinds specifies the resource kind to watch. -func (p *genericCollector[T]) resourceKinds() []types.WatchKind { +func (p *genericCollector[T, R]) resourceKinds() []types.WatchKind { return []types.WatchKind{{Kind: p.ResourceKind}} } -// getResourcesAndUpdateCurrent gets the list of current resources. -func (g *genericCollector[T]) getResources(ctx context.Context) (map[string]T, error) { +// getResources gets the list of current resources. +func (g *genericCollector[T, R]) getResources(ctx context.Context) (map[string]T, error) { resources, err := g.GenericWatcherConfig.ResourceGetter(ctx) if err != nil { return nil, trace.Wrap(err) @@ -670,7 +688,7 @@ func (g *genericCollector[T]) getResources(ctx context.Context) (map[string]T, e return current, nil } -func (g *genericCollector[T]) refreshStaleResources(ctx context.Context) error { +func (g *genericCollector[T, R]) refreshStaleResources(ctx context.Context) error { if !g.stale.Load() { return nil } @@ -697,7 +715,7 @@ func (g *genericCollector[T]) refreshStaleResources(ctx context.Context) error { // getResourcesAndUpdateCurrent is called when the resources should be // (re-)fetched directly. -func (p *genericCollector[T]) getResourcesAndUpdateCurrent(ctx context.Context) error { +func (p *genericCollector[T, R]) getResourcesAndUpdateCurrent(ctx context.Context) error { resources, err := p.ResourceGetter(ctx) if err != nil { return trace.Wrap(err) @@ -721,7 +739,7 @@ func (p *genericCollector[T]) getResourcesAndUpdateCurrent(ctx context.Context) return nil } -func (p *genericCollector[T]) defineCollectorAsInitialized() { +func (p *genericCollector[T, R]) defineCollectorAsInitialized() { p.once.Do(func() { // mark watcher as initialized. close(p.initializationC) @@ -729,7 +747,7 @@ func (p *genericCollector[T]) defineCollectorAsInitialized() { } // processEventsAndUpdateCurrent is called when a watcher event is received. -func (p *genericCollector[T]) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { +func (p *genericCollector[T, R]) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { p.rw.Lock() defer p.rw.Unlock() @@ -768,8 +786,8 @@ func (p *genericCollector[T]) processEventsAndUpdateCurrent(ctx context.Context, } } -// broadcastUpdate broadcasts information about updating the proxy set. -func (p *genericCollector[T]) broadcastUpdate(ctx context.Context) { +// broadcastUpdate broadcasts information about updating the resource set. +func (p *genericCollector[T, R]) broadcastUpdate(ctx context.Context) { if p.DisableUpdateBroadcast { return } @@ -781,18 +799,18 @@ func (p *genericCollector[T]) broadcastUpdate(ctx context.Context) { p.Logger.DebugContext(ctx, "List of known resources updated", "resources", names) select { - case p.ResourcesC <- resourcesToSlice(p.current): + case p.ResourcesC <- resourcesToSlice(p.current, p.CloneFunc): case <-ctx.Done(): } } // isInitialized is used to check that the cache has done its initial // sync -func (p *genericCollector[T]) initializationChan() <-chan struct{} { +func (p *genericCollector[T, R]) initializationChan() <-chan struct{} { return p.initializationC } -func (p *genericCollector[T]) isInitialized() bool { +func (p *genericCollector[T, R]) isInitialized() bool { select { case <-p.initializationC: return true @@ -801,14 +819,14 @@ func (p *genericCollector[T]) isInitialized() bool { } } -func (p *genericCollector[T]) notifyStale() { +func (p *genericCollector[T, R]) notifyStale() { p.stale.Store(true) } // LockWatcherConfig is a LockWatcher configuration. type LockWatcherConfig struct { - ResourceWatcherConfig LockGetter + ResourceWatcherConfig } // CheckAndSetDefaults checks parameters and sets default values. @@ -862,15 +880,15 @@ type lockCollector struct { LockWatcherConfig // current holds a map of the currently known locks (keyed by lock name). current map[string]types.Lock - // isStale indicates whether the local lock view (current) is stale. - isStale bool - // currentRW is a mutex protecting both current and isStale. - currentRW sync.RWMutex // fanout provides support for multiple subscribers to the lock updates. fanout *FanoutV2 // initializationC is used to check whether the initial sync has completed initializationC chan struct{} - once sync.Once + // currentRW is a mutex protecting both current and isStale. + currentRW sync.RWMutex + once sync.Once + // isStale indicates whether the local lock view (current) is stale. + isStale bool } // IsStale is used to check whether the lock watcher is stale. @@ -1057,9 +1075,9 @@ func lockMapValues(lockMap map[string]types.Lock) []types.Lock { return locks } -func resourcesToSlice[T any](resources map[string]T) (slice []T) { +func resourcesToSlice[T any](resources map[string]T, cloneFunc func(T) T) (slice []T) { for _, resource := range resources { - slice = append(slice, resource) + slice = append(slice, cloneFunc(resource)) } return slice } @@ -1131,17 +1149,15 @@ type CertAuthorityWatcher struct { // caCollector accompanies resourceWatcher when monitoring cert authority resources. type caCollector struct { - CertAuthorityWatcherConfig fanout *FanoutV2 - - // lock protects concurrent access to cas - lock sync.RWMutex - // cas maps ca type -> cluster -> ca - cas map[types.CertAuthType]map[string]types.CertAuthority + cas map[types.CertAuthType]map[string]types.CertAuthority // initializationC is used to check whether the initial sync has completed initializationC chan struct{} - once sync.Once filter types.CertAuthorityFilter + CertAuthorityWatcherConfig + // lock protects concurrent access to cas + lock sync.RWMutex + once sync.Once } // Subscribe is used to subscribe to the lock updates. @@ -1278,285 +1294,36 @@ func (c *caCollector) notifyStale() {} // NodeWatcherConfig is a NodeWatcher configuration. type NodeWatcherConfig struct { - ResourceWatcherConfig // NodesGetter is used to directly fetch the list of active nodes. NodesGetter -} - -// CheckAndSetDefaults checks parameters and sets default values. -func (cfg *NodeWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.NodesGetter == nil { - getter, ok := cfg.Client.(NodesGetter) - if !ok { - return trace.BadParameter("missing parameter NodesGetter and Client not usable as NodesGetter") - } - cfg.NodesGetter = getter - } - return nil + ResourceWatcherConfig } // NewNodeWatcher returns a new instance of NodeWatcher. -func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*NodeWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - cache, err := utils.NewFnCache(utils.FnCacheConfig{ - Context: ctx, - TTL: 3 * time.Second, - Clock: cfg.Clock, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - collector := &nodeCollector{ - NodeWatcherConfig: cfg, - current: map[string]types.Server{}, - initializationC: make(chan struct{}), - cache: cache, - stale: true, - } - - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - return &NodeWatcher{resourceWatcher: watcher, nodeCollector: collector}, nil -} - -// NodeWatcher is built on top of resourceWatcher to monitor additions -// and deletions to the set of nodes. -type NodeWatcher struct { - *resourceWatcher - *nodeCollector -} - -// nodeCollector accompanies resourceWatcher when monitoring nodes. -type nodeCollector struct { - NodeWatcherConfig - - // initializationC is used to check whether the initial sync has completed - initializationC chan struct{} - once sync.Once - - cache *utils.FnCache - - rw sync.RWMutex - // current holds a map of the currently known nodes keyed by server name - current map[string]types.Server - stale bool -} - -// Node is a readonly subset of the types.Server interface which -// users may filter by in GetNodes. -type Node interface { - // ResourceWithLabels provides common resource headers - types.ResourceWithLabels - // GetTeleportVersion returns the teleport version the server is running on - GetTeleportVersion() string - // GetAddr return server address - GetAddr() string - // GetPublicAddrs returns all public addresses where this server can be reached. - GetPublicAddrs() []string - // GetHostname returns server hostname - GetHostname() string - // GetNamespace returns server namespace - GetNamespace() string - // GetCmdLabels gets command labels - GetCmdLabels() map[string]types.CommandLabel - // GetRotation gets the state of certificate authority rotation. - GetRotation() types.Rotation - // GetUseTunnel gets if a reverse tunnel should be used to connect to this node. - GetUseTunnel() bool - // GetProxyIDs returns a list of proxy ids this server is connected to. - GetProxyIDs() []string - // IsEICE returns whether the Node is an EICE instance. - // Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels). - IsEICE() bool -} - -// GetNodes allows callers to retrieve a subset of nodes that match the filter provided. The -// returned servers are a copy and can be safely modified. It is intentionally hard to retrieve -// the full set of nodes to reduce the number of copies needed since the number of nodes can get -// quite large and doing so can be expensive. -func (n *nodeCollector) GetNodes(ctx context.Context, fn func(n Node) bool) []types.Server { - // Attempt to freshen our data first. - n.refreshStaleNodes(ctx) - - n.rw.RLock() - defer n.rw.RUnlock() - - var matched []types.Server - for _, server := range n.current { - if fn(server) { - matched = append(matched, server.DeepCopy()) - } - } - - return matched -} - -// GetNode allows callers to retrieve a node based on its name. The -// returned server are a copy and can be safely modified. -func (n *nodeCollector) GetNode(ctx context.Context, name string) (types.Server, error) { - // Attempt to freshen our data first. - n.refreshStaleNodes(ctx) - - n.rw.RLock() - defer n.rw.RUnlock() - - server, found := n.current[name] - if !found { - return nil, trace.NotFound("server does not exist") - } - return server.DeepCopy(), nil -} - -// refreshStaleNodes attempts to reload nodes from the NodeGetter if -// the collecter is stale. This ensures that no matter the health of -// the collecter callers will be returned the most up to date node -// set as possible. -func (n *nodeCollector) refreshStaleNodes(ctx context.Context) error { - n.rw.RLock() - if !n.stale { - n.rw.RUnlock() - return nil - } - n.rw.RUnlock() - - _, err := utils.FnCacheGet(ctx, n.cache, "nodes", func(ctx context.Context) (any, error) { - current, err := n.getNodes(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - n.rw.Lock() - defer n.rw.Unlock() - - // There is a chance that the watcher reinitialized while - // getting nodes happened above. Check if we are still stale - // now that the lock is held to ensure that the refresh is - // still necessary. - if !n.stale { - return nil, nil - } - - n.current = current - return nil, trace.Wrap(err) - }) - - return trace.Wrap(err) -} - -func (n *nodeCollector) NodeCount() int { - n.rw.RLock() - defer n.rw.RUnlock() - return len(n.current) -} - -// resourceKinds specifies the resource kind to watch. -func (n *nodeCollector) resourceKinds() []types.WatchKind { - return []types.WatchKind{{Kind: types.KindNode}} -} - -// getResourcesAndUpdateCurrent is called when the resources should be -// (re-)fetched directly. -func (n *nodeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - newCurrent, err := n.getNodes(ctx) - if err != nil { - return trace.Wrap(err) - } - defer n.defineCollectorAsInitialized() - - if len(newCurrent) == 0 { - return nil - } - - n.rw.Lock() - defer n.rw.Unlock() - n.current = newCurrent - n.stale = false - return nil -} - -func (n *nodeCollector) getNodes(ctx context.Context) (map[string]types.Server, error) { - nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) - if err != nil { - return nil, trace.Wrap(err) - } - - if len(nodes) == 0 { - return map[string]types.Server{}, nil - } - - current := make(map[string]types.Server, len(nodes)) - for _, node := range nodes { - current[node.GetName()] = node - } - - return current, nil -} - -func (n *nodeCollector) defineCollectorAsInitialized() { - n.once.Do(func() { - // mark watcher as initialized. - close(n.initializationC) +func NewNodeWatcher(ctx context.Context, cfg NodeWatcherConfig) (*GenericWatcher[types.Server, types.ReadOnlyServer], error) { + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, types.ReadOnlyServer]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindNode, + ResourceGetter: func(ctx context.Context) ([]types.Server, error) { + return cfg.NodesGetter.GetNodes(ctx, apidefaults.Namespace) + }, + ResourceKey: types.Server.GetName, + DisableUpdateBroadcast: true, + CloneFunc: types.Server.DeepCopy, }) -} - -// processEventsAndUpdateCurrent is called when a watcher event is received. -func (n *nodeCollector) processEventsAndUpdateCurrent(ctx context.Context, events []types.Event) { - n.rw.Lock() - defer n.rw.Unlock() - - for _, event := range events { - if event.Resource == nil || event.Resource.GetKind() != types.KindNode { - n.Logger.WarnContext(ctx, "Received unexpected event", "event", logutils.StringerAttr(event)) - continue - } - - switch event.Type { - case types.OpDelete: - delete(n.current, event.Resource.GetName()) - case types.OpPut: - server, ok := event.Resource.(types.Server) - if !ok { - n.Logger.WarnContext(ctx, "Received unexpected type", "resource", event.Resource.GetKind()) - continue - } - - n.current[server.GetName()] = server - default: - n.Logger.WarnContext(ctx, "Skipping unsupported event type", "event_type", event.Type) - } - } -} - -func (n *nodeCollector) initializationChan() <-chan struct{} { - return n.initializationC -} - -func (n *nodeCollector) notifyStale() { - n.rw.Lock() - defer n.rw.Unlock() - n.stale = true + return w, trace.Wrap(err) } // AccessRequestWatcherConfig is a AccessRequestWatcher configuration. type AccessRequestWatcherConfig struct { - // ResourceWatcherConfig is the resource watcher configuration. - ResourceWatcherConfig // AccessRequestGetter is responsible for fetching access request resources. AccessRequestGetter - // Filter is the filter to use to monitor access requests. - Filter types.AccessRequestFilter // AccessRequestsC receives up-to-date list of all access request resources. AccessRequestsC chan types.AccessRequests + // ResourceWatcherConfig is the resource watcher configuration. + ResourceWatcherConfig + // Filter is the filter to use to monitor access requests. + Filter types.AccessRequestFilter } // CheckAndSetDefaults checks parameters and sets default values. @@ -1605,11 +1372,11 @@ type accessRequestCollector struct { AccessRequestWatcherConfig // current holds a map of the currently known access request resources. current map[string]types.AccessRequest - // lock protects the "current" map. - lock sync.RWMutex // initializationC is used to check that the watcher has been initialized properly. initializationC chan struct{} - once sync.Once + // lock protects the "current" map. + lock sync.RWMutex + once sync.Once } // resourceKinds specifies the resource kind to watch. @@ -1669,7 +1436,7 @@ func (p *accessRequestCollector) processEventsAndUpdateCurrent(ctx context.Conte delete(p.current, event.Resource.GetName()) select { case <-ctx.Done(): - case p.AccessRequestsC <- resourcesToSlice(p.current): + case p.AccessRequestsC <- resourcesToSlice(p.current, types.AccessRequest.Copy): } case types.OpPut: accessRequest, ok := event.Resource.(types.AccessRequest) @@ -1680,7 +1447,7 @@ func (p *accessRequestCollector) processEventsAndUpdateCurrent(ctx context.Conte p.current[accessRequest.GetName()] = accessRequest select { case <-ctx.Done(): - case p.AccessRequestsC <- resourcesToSlice(p.current): + case p.AccessRequestsC <- resourcesToSlice(p.current, types.AccessRequest.Copy): } default: @@ -1693,14 +1460,14 @@ func (*accessRequestCollector) notifyStale() {} // OktaAssignmentWatcherConfig is a OktaAssignmentWatcher configuration. type OktaAssignmentWatcherConfig struct { - // RWCfg is the resource watcher configuration. - RWCfg ResourceWatcherConfig // OktaAssignments is responsible for fetching Okta assignments. OktaAssignments OktaAssignmentsGetter - // PageSize is the number of Okta assignments to list at a time. - PageSize int // OktaAssignmentsC receives up-to-date list of all Okta assignment resources. OktaAssignmentsC chan types.OktaAssignments + // RWCfg is the resource watcher configuration. + RWCfg ResourceWatcherConfig + // PageSize is the number of Okta assignments to list at a time. + PageSize int } // CheckAndSetDefaults checks parameters and sets default values. @@ -1765,16 +1532,16 @@ func (o *OktaAssignmentWatcher) Done() <-chan struct{} { // oktaAssignmentCollector accompanies resourceWatcher when monitoring Okta assignment resources. type oktaAssignmentCollector struct { - logger *slog.Logger // OktaAssignmentWatcherConfig is the watcher configuration. - cfg OktaAssignmentWatcherConfig - // mu guards "current" - mu sync.RWMutex + cfg OktaAssignmentWatcherConfig + logger *slog.Logger // current holds a map of the currently known Okta assignment resources. current map[string]types.OktaAssignment // initializationC is used to check that the watcher has been initialized properly. initializationC chan struct{} - once sync.Once + // mu guards "current" + mu sync.RWMutex + once sync.Once } // resourceKinds specifies the resource kind to watch. @@ -1842,7 +1609,7 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont switch event.Type { case types.OpDelete: delete(c.current, event.Resource.GetName()) - resources := resourcesToSlice(c.current) + resources := resourcesToSlice(c.current, types.OktaAssignment.Copy) select { case <-ctx.Done(): case c.cfg.OktaAssignmentsC <- resources: @@ -1854,7 +1621,7 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont continue } c.current[oktaAssignment.GetName()] = oktaAssignment - resources := resourcesToSlice(c.current) + resources := resourcesToSlice(c.current, types.OktaAssignment.Copy) select { case <-ctx.Done(): diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 966d19b5d7a18..6daf00c79c627 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -912,6 +912,7 @@ func TestNodeWatcherFallback(t *testing.T) { }, MaxStaleness: time.Minute, }, + NodesGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -925,15 +926,14 @@ func TestNodeWatcherFallback(t *testing.T) { nodes = append(nodes, node) } - require.Empty(t, w.NodeCount()) + require.Empty(t, w.ResourceCount()) require.False(t, w.IsInitialized()) - got := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) + got, err := w.CurrentResources(ctx) + require.NoError(t, err) require.Len(t, nodes, len(got)) - require.Len(t, nodes, w.NodeCount()) + require.Len(t, nodes, w.ResourceCount()) require.False(t, w.IsInitialized()) } @@ -964,6 +964,7 @@ func TestNodeWatcher(t *testing.T) { }, MaxStaleness: time.Minute, }, + NodesGetter: presence, }) require.NoError(t, err) t.Cleanup(w.Close) @@ -977,25 +978,27 @@ func TestNodeWatcher(t *testing.T) { nodes = append(nodes, node) } - require.Eventually(t, func() bool { - filtered := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) - return len(filtered) == len(nodes) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(nodes)) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") - require.Len(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetUseTunnel() }), 3) + filtered, err := w.CurrentResourcesWithFilter(ctx, func(n types.Server) bool { return n.GetUseTunnel() }) + require.NoError(t, err) + require.Len(t, filtered, 3) require.NoError(t, presence.DeleteNode(ctx, apidefaults.Namespace, nodes[0].GetName())) - require.Eventually(t, func() bool { - filtered := w.GetNodes(ctx, func(n services.Node) bool { - return true - }) - return len(filtered) == len(nodes)-1 + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(nodes)-1) }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") - require.Empty(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetName() == nodes[0].GetName() })) + filtered, err = w.CurrentResourcesWithFilter(ctx, func(n types.Server) bool { return n.GetName() == nodes[0].GetName() }) + require.NoError(t, err) + require.Empty(t, filtered) } func newNodeServer(t *testing.T, name, hostname, addr string, tunnel bool) types.Server { diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index fc59911e00a39..5e8807c0e60df 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -151,7 +151,7 @@ type Server struct { reconcileCh chan struct{} // watcher monitors changes to application resources. - watcher *services.GenericWatcher[types.Application] + watcher *services.GenericWatcher[types.Application, types.Application] } // monitoredApps is a collection of applications from different sources diff --git a/lib/srv/app/watcher.go b/lib/srv/app/watcher.go index 007e36fa4b35e..f64aef67be135 100644 --- a/lib/srv/app/watcher.go +++ b/lib/srv/app/watcher.go @@ -65,7 +65,7 @@ func (s *Server) startReconciler(ctx context.Context) error { // startResourceWatcher starts watching changes to application resources and // registers/unregisters the proxied applications accordingly. -func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Application], error) { +func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Application, types.Application], error) { if len(s.c.ResourceMatchers) == 0 { s.log.Debug("Not initializing application resource watcher.") return nil, nil diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 8cb65c84c1f93..e6a213501b765 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -314,7 +314,7 @@ type Server struct { // heartbeats holds heartbeats for database servers. heartbeats map[string]srv.HeartbeatI // watcher monitors changes to database resources. - watcher *services.GenericWatcher[types.Database] + watcher *services.GenericWatcher[types.Database, types.Database] // proxiedDatabases contains databases this server currently is proxying. // Proxied databases are reconciled against monitoredDatabases below. proxiedDatabases map[string]types.Database diff --git a/lib/srv/db/watcher.go b/lib/srv/db/watcher.go index 1a73ef292a9de..4891f523da395 100644 --- a/lib/srv/db/watcher.go +++ b/lib/srv/db/watcher.go @@ -69,7 +69,7 @@ func (s *Server) startReconciler(ctx context.Context) error { // startResourceWatcher starts watching changes to database resources and // registers/unregisters the proxied databases accordingly. -func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Database], error) { +func (s *Server) startResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.Database, types.Database], error) { if len(s.cfg.ResourceMatchers) == 0 { s.log.DebugContext(ctx, "Not starting database resource watcher.") return nil, nil diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index adb2c6727de8a..bdf0da5fc1068 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -266,7 +266,7 @@ type Server struct { // cancelfn is used with ctx when stopping the discovery server cancelfn context.CancelFunc // nodeWatcher is a node watcher. - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // ec2Watcher periodically retrieves EC2 instances. ec2Watcher *server.Watcher @@ -776,13 +776,16 @@ func (s *Server) initGCPWatchers(ctx context.Context, matchers []types.GCPMatche return nil } -func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n types.ReadOnlyServer) bool { labels := n.GetAllLabels() _, accountOK := labels[types.AWSAccountIDLabel] _, instanceOK := labels[types.AWSInstanceIDLabel] return accountOK && instanceOK }) + if err != nil { + return trace.Wrap(err) + } var filtered []server.EC2Instance outer: @@ -799,6 +802,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func genEC2InstancesLogStr(instances []server.EC2Instance) string { @@ -903,12 +907,21 @@ func (s *Server) heartbeatEICEInstance(instances *server.EC2Instances) { continue } - existingNode, err := s.nodeWatcher.GetNode(s.ctx, eiceNode.GetName()) + existingNodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(s types.ReadOnlyServer) bool { + return s.GetName() == eiceNode.GetName() + }) if err != nil && !trace.IsNotFound(err) { s.Log.Warnf("Error finding the existing node with name %q: %v", eiceNode.GetName(), err) continue } + if len(existingNodes) != 1 { + s.Log.Warnf("Found multiple matching nodes with name :q", eiceNode.GetName()) + continue + } + + existingNode := existingNodes[0] + // EICE Node's Name are deterministic (based on the Account and Instance ID). // // To reduce load, nodes are skipped if @@ -1061,7 +1074,7 @@ func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, err if err != nil { return nil, trace.Wrap(err) } - found := s.nodeWatcher.GetNodes(ctx, func(n services.Node) bool { + found, err := s.nodeWatcher.CurrentResourcesWithFilter(ctx, func(n types.ReadOnlyServer) bool { if n.GetSubKind() != types.SubKindOpenSSHNode { return false } @@ -1074,6 +1087,9 @@ func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, err return mostRecentCertRotation.After(n.GetRotation().LastRotated) }) + if err != nil { + return nil, trace.Wrap(err) + } if len(found) == 0 { return nil, trace.NotFound("no unrotated nodes found") @@ -1115,13 +1131,18 @@ func (s *Server) handleEC2Discovery() { } } -func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingAzureNodes(instances *server.AzureInstances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n types.ReadOnlyServer) bool { labels := n.GetAllLabels() _, subscriptionOK := labels[types.SubscriptionIDLabel] _, vmOK := labels[types.VMIDLabel] return subscriptionOK && vmOK }) + + if err != nil { + return trace.Wrap(err) + } + var filtered []*armcompute.VirtualMachine outer: for _, inst := range instances.Instances { @@ -1141,6 +1162,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func (s *Server) handleAzureInstances(instances *server.AzureInstances) error { @@ -1203,14 +1225,19 @@ func (s *Server) handleAzureDiscovery() { } } -func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) { - nodes := s.nodeWatcher.GetNodes(s.ctx, func(n services.Node) bool { +func (s *Server) filterExistingGCPNodes(instances *server.GCPInstances) error { + nodes, err := s.nodeWatcher.CurrentResourcesWithFilter(s.ctx, func(n types.ReadOnlyServer) bool { labels := n.GetAllLabels() _, projectIDOK := labels[types.ProjectIDLabelDiscovery] _, zoneOK := labels[types.ZoneLabelDiscovery] _, nameOK := labels[types.NameLabelDiscovery] return projectIDOK && zoneOK && nameOK }) + + if err != nil { + return trace.Wrap(err) + } + var filtered []*gcpimds.Instance outer: for _, inst := range instances.Instances { @@ -1227,6 +1254,7 @@ outer: filtered = append(filtered, inst) } instances.Instances = filtered + return nil } func (s *Server) handleGCPInstances(instances *server.GCPInstances) error { diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 42df4c9d4017c..fa6a4e9e6fa8f 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -2865,7 +2865,7 @@ func newLockWatcher(ctx context.Context, t testing.TB, client types.Events) *ser return lockWatcher } -func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.NodeWatcher { +func newNodeWatcher(ctx context.Context, t *testing.T, client types.Events) *services.GenericWatcher[types.Server] { nodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: "test", diff --git a/lib/utils/fncache.go b/lib/utils/fncache.go index e45a8b3a2d821..84f5be17478bb 100644 --- a/lib/utils/fncache.go +++ b/lib/utils/fncache.go @@ -245,6 +245,8 @@ func FnCacheGetWithTTL[T any](ctx context.Context, cache *FnCache, key any, ttl switch { case err != nil: return ret, err + case t == nil: + return ret, nil case !ok: return ret, trace.BadParameter("value retrieved was %T, expected %T", t, ret) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 8087b5e216194..babf7c34efd02 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -162,7 +162,7 @@ type Handler struct { // nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. - nodeWatcher *services.NodeWatcher + nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // tracer is used to create spans. tracer oteltrace.Tracer @@ -298,7 +298,7 @@ type Config struct { // NodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. - NodeWatcher *services.NodeWatcher + NodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer] // PresenceChecker periodically runs the mfa ceremony for moderated // sessions. @@ -3531,9 +3531,12 @@ func (h *Handler) siteNodeConnect( WebsocketConn: ws, SSHDialTimeout: dialTimeout, HostNameResolver: func(serverID string) (string, error) { - matches := nw.GetNodes(r.Context(), func(n services.Node) bool { + matches, err := nw.CurrentResourcesWithFilter(r.Context(), func(n types.ReadOnlyServer) bool { return n.GetName() == serverID }) + if err != nil { + return "", trace.Wrap(err) + } if len(matches) != 1 { return "", trace.NotFound("unable to resolve hostname for server %s", serverID)