Skip to content

Commit

Permalink
add support for nodes, cloning, and readonly filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Oct 16, 2024
1 parent 00868c0 commit c75114a
Show file tree
Hide file tree
Showing 20 changed files with 337 additions and 443 deletions.
75 changes: 75 additions & 0 deletions api/types/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,81 @@ type Server interface {
GetAWSAccountID() string
}

// ReadOnlyServer represents a Node, Proxy or Auth server in a Teleport cluster
type ReadOnlyServer interface {
// GetKind returns resource kind
GetKind() string
// GetSubKind returns resource subkind
GetSubKind() string
// GetVersion returns resource version
GetVersion() string
// GetName returns the name of the resource
GetName() string
// Expiry returns object expiry setting
Expiry() time.Time
// GetMetadata returns object metadata
GetMetadata() Metadata
// GetRevision returns the revision
GetRevision() string
// Origin returns the origin value of the resource.
Origin() string
// GetLabel retrieves the label with the provided key.
GetLabel(key string) (value string, ok bool)
// GetAllLabels returns all resource's labels.
GetAllLabels() map[string]string
// GetStaticLabels returns the resource's static labels.
GetStaticLabels() map[string]string
// SetStaticLabels sets the resource's static labels.
SetStaticLabels(sl map[string]string)
// MatchSearch goes through select field values of a resource
// and tries to match against the list of search values.
MatchSearch(searchValues []string) bool
// GetTeleportVersion returns the teleport version the server is running on
GetTeleportVersion() string
// GetAddr return server address
GetAddr() string
// GetHostname returns server hostname
GetHostname() string
// GetNamespace returns server namespace
GetNamespace() string
// GetLabels returns server's static label key pairs
GetLabels() map[string]string
// GetCmdLabels gets command labels
GetCmdLabels() map[string]CommandLabel
// GetPublicAddr returns a public address where this server can be reached.
GetPublicAddr() string
// GetPublicAddrs returns a list of public addresses where this server can be reached.
GetPublicAddrs() []string
// GetRotation gets the state of certificate authority rotation.
GetRotation() Rotation
// GetUseTunnel gets if a reverse tunnel should be used to connect to this node.
GetUseTunnel() bool
// String returns string representation of the server
String() string
// GetPeerAddr returns the peer address of the server.
GetPeerAddr() string
// GetProxyIDs returns a list of proxy ids this service is connected to.
GetProxyIDs() []string

// GetCloudMetadata gets the cloud metadata for the server.
GetCloudMetadata() *CloudMetadata
// GetAWSInfo returns the AWSInfo for the server.
GetAWSInfo() *AWSInfo

// IsOpenSSHNode returns whether the connection to this Server must use OpenSSH.
// This returns true for SubKindOpenSSHNode and SubKindOpenSSHEICENode.
IsOpenSSHNode() bool

// IsEICE returns whether the Node is an EICE instance.
// Must be `openssh-ec2-ice` subkind and have the AccountID and InstanceID information (AWS Metadata or Labels).
IsEICE() bool

// GetAWSInstanceID returns the AWS Instance ID if this node comes from an EC2 instance.
GetAWSInstanceID() string
// GetAWSAccountID returns the AWS Account ID if this node comes from an EC2 instance.
GetAWSAccountID() string
}

// NewServer creates an instance of Server.
func NewServer(name, kind string, spec ServerSpecV2) (Server, error) {
return NewServerWithLabels(name, kind, spec, map[string]string{})
Expand Down
4 changes: 2 additions & 2 deletions lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ type TLSServerConfig struct {
// kubernetes cluster name. Proxy uses this map to route requests to the correct
// kubernetes_service. The servers are kept in memory to avoid making unnecessary
// unmarshal calls followed by filtering and to improve memory usage.
KubernetesServersWatcher *services.GenericWatcher[types.KubeServer]
KubernetesServersWatcher *services.GenericWatcher[types.KubeServer, types.KubeServer]
// PROXYProtocolMode controls behavior related to unsigned PROXY protocol headers.
PROXYProtocolMode multiplexer.PROXYProtocolMode
// InventoryHandle is used to send kube server heartbeats via the inventory control stream.
Expand Down Expand Up @@ -170,7 +170,7 @@ type TLSServer struct {
closeContext context.Context
closeFunc context.CancelFunc
// kubeClusterWatcher monitors changes to kube cluster resources.
kubeClusterWatcher *services.GenericWatcher[types.KubeCluster]
kubeClusterWatcher *services.GenericWatcher[types.KubeCluster, types.KubeCluster]
// reconciler reconciles proxied kube clusters with kube_clusters resources.
reconciler *services.Reconciler[types.KubeCluster]
// monitoredKubeClusters contains all kube clusters the proxied kube_clusters are
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) {

// startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and
// registers/unregisters the proxied Kube Cluster accordingly.
func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster], error) {
func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.GenericWatcher[types.KubeCluster, types.KubeCluster], error) {
if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService {
s.log.Debug("Not initializing Kube Cluster resource watcher.")
return nil, nil
Expand Down
8 changes: 4 additions & 4 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check
// site is the minimum interface needed to match servers
// for a reversetunnelclient.RemoteSite. It makes testing easier.
type site interface {
GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error)
GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error)
GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error)
}

Expand All @@ -394,13 +394,13 @@ type remoteSite struct {
}

// GetNodes uses the wrapped sites NodeWatcher to filter nodes
func (r remoteSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) {
func (r remoteSite) GetNodes(ctx context.Context, fn func(n types.ReadOnlyServer) bool) ([]types.Server, error) {
watcher, err := r.site.NodeWatcher()
if err != nil {
return nil, trace.Wrap(err)
}

return watcher.GetNodes(ctx, fn), nil
return watcher.CurrentResourcesWithFilter(ctx, fn)
}

// GetClusterNetworkingConfig uses the wrapped sites cache to retrieve the ClusterNetworkingConfig
Expand Down Expand Up @@ -450,7 +450,7 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re

var maxScore int
scores := make(map[string]int)
matches, err := site.GetNodes(ctx, func(server services.Node) bool {
matches, err := site.GetNodes(ctx, func(server types.ReadOnlyServer) bool {
score := routeMatcher.RouteToServerScore(server)
if score < 1 {
return false
Expand Down
3 changes: 1 addition & 2 deletions lib/proxy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/observability/tracing"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
)
Expand All @@ -51,7 +50,7 @@ func (t testSite) GetClusterNetworkingConfig(ctx context.Context) (types.Cluster
return t.cfg, nil
}

func (t testSite) GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) {
func (t testSite) GetNodes(ctx context.Context, fn func(n types.Server) bool) ([]types.Server, error) {
var out []types.Server
for _, s := range t.nodes {
if fn(s) {
Expand Down
22 changes: 16 additions & 6 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (s *localSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, err
}

// NodeWatcher returns a services.NodeWatcher for this cluster.
func (s *localSite) NodeWatcher() (*services.NodeWatcher, error) {
func (s *localSite) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) {
return s.srv.NodeWatcher, nil
}

Expand Down Expand Up @@ -739,7 +739,11 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
return
case <-proxyResyncTicker.Chan():
var req discoveryRequest
req.SetProxies(s.srv.proxyWatcher.GetCurrent())
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
req.SetProxies(proxies)

if err := rconn.sendDiscoveryRequest(req); err != nil {
logger.WithError(err).Debug("Marking connection invalid on error")
Expand All @@ -764,9 +768,12 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
if firstHeartbeat {
// as soon as the agent connects and sends a first heartbeat
// send it the list of current proxies back
current := s.srv.proxyWatcher.GetCurrent()
if len(current) > 0 {
rconn.updateProxies(current)
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
if len(proxies) > 0 {
rconn.updateProxies(proxies)
}
reverseSSHTunnels.WithLabelValues(rconn.tunnelType).Inc()
firstHeartbeat = false
Expand Down Expand Up @@ -935,7 +942,7 @@ func (s *localSite) periodicFunctions() {

// sshTunnelStats reports SSH tunnel statistics for the cluster.
func (s *localSite) sshTunnelStats() error {
missing := s.srv.NodeWatcher.GetNodes(s.srv.ctx, func(server services.Node) bool {
missing, err := s.srv.NodeWatcher.CurrentResourcesWithFilter(s.srv.ctx, func(server types.ReadOnlyServer) bool {
// Skip over any servers that have a TTL larger than announce TTL (10
// minutes) and are non-IoT SSH servers (they won't have tunnels).
//
Expand Down Expand Up @@ -967,6 +974,9 @@ func (s *localSite) sshTunnelStats() error {

return err != nil
})
if err != nil {
return trace.Wrap(err)
}

// Update Prometheus metrics and also log if any tunnels are missing.
missingSSHTunnels.Set(float64(len(missing)))
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (p *clusterPeers) CachingAccessPoint() (authclient.RemoteProxyAccessPoint,
return peer.CachingAccessPoint()
}

func (p *clusterPeers) NodeWatcher() (*services.NodeWatcher, error) {
func (p *clusterPeers) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) {
peer, err := p.pickPeer()
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -202,7 +202,7 @@ func (s *clusterPeer) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, e
return nil, trace.ConnectionProblem(nil, "unable to fetch access point, this proxy %v has not been discovered yet, try again later", s)
}

func (s *clusterPeer) NodeWatcher() (*services.NodeWatcher, error) {
func (s *clusterPeer) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) {
return nil, trace.ConnectionProblem(nil, "unable to fetch node watcher, this proxy %v has not been discovered yet, try again later", s)
}

Expand Down
19 changes: 13 additions & 6 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type remoteSite struct {
remoteAccessPoint authclient.RemoteProxyAccessPoint

// nodeWatcher provides access the node set for the remote site
nodeWatcher *services.NodeWatcher
nodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer]

// remoteCA is the last remote certificate authority recorded by the client.
// It is used to detect CA rotation status changes. If the rotation
Expand Down Expand Up @@ -164,7 +164,7 @@ func (s *remoteSite) CachingAccessPoint() (authclient.RemoteProxyAccessPoint, er
}

// NodeWatcher returns the services.NodeWatcher for the remote cluster.
func (s *remoteSite) NodeWatcher() (*services.NodeWatcher, error) {
func (s *remoteSite) NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error) {
return s.nodeWatcher, nil
}

Expand Down Expand Up @@ -429,7 +429,11 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
return
case <-proxyResyncTicker.Chan():
var req discoveryRequest
req.SetProxies(s.srv.proxyWatcher.GetCurrent())
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
req.SetProxies(proxies)

if err := conn.sendDiscoveryRequest(req); err != nil {
logger.WithError(err).Debug("Marking connection invalid on error")
Expand Down Expand Up @@ -458,9 +462,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch
if firstHeartbeat {
// as soon as the agent connects and sends a first heartbeat
// send it the list of current proxies back
current := s.srv.proxyWatcher.GetCurrent()
if len(current) > 0 {
conn.updateProxies(current)
proxies, err := s.srv.proxyWatcher.CurrentResources(s.srv.ctx)
if err != nil {
logger.WithError(err).Warn("Failed to get proxy set")
}
if len(proxies) > 0 {
conn.updateProxies(proxies)
}
firstHeartbeat = false
}
Expand Down
4 changes: 2 additions & 2 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ type server struct {

// proxyWatcher monitors changes to the proxies
// and broadcasts updates
proxyWatcher *services.GenericWatcher[types.Server]
proxyWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer]

// offlineThreshold is how long to wait for a keep alive message before
// marking a reverse tunnel connection as invalid.
Expand Down Expand Up @@ -201,7 +201,7 @@ type Config struct {
LockWatcher *services.LockWatcher

// NodeWatcher is a node watcher.
NodeWatcher *services.NodeWatcher
NodeWatcher *services.GenericWatcher[types.Server, types.ReadOnlyServer]

// CertAuthorityWatcher is a cert authority watcher.
CertAuthorityWatcher *services.CertAuthorityWatcher
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnelclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ type RemoteSite interface {
// but is resilient to auth server crashes
CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error)
// NodeWatcher returns the node watcher that maintains the node set for the site
NodeWatcher() (*services.NodeWatcher, error)
NodeWatcher() (*services.GenericWatcher[types.Server, types.ReadOnlyServer], error)
// GetTunnelsCount returns the amount of active inbound tunnels
// from the remote cluster
GetTunnelsCount() int
Expand Down
Loading

0 comments on commit c75114a

Please sign in to comment.