diff --git a/lib/auth/storage/storage.go b/lib/auth/storage/storage.go index 76db71182e98..dacf42179b8a 100644 --- a/lib/auth/storage/storage.go +++ b/lib/auth/storage/storage.go @@ -27,7 +27,9 @@ package storage import ( "context" "encoding/json" + "strconv" "strings" + "time" "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" @@ -233,6 +235,31 @@ func (p *ProcessStorage) WriteTeleportVersion(ctx context.Context, version *semv return nil } +func (p *ProcessStorage) rdpLicenseKey(majorVersion, minorVersion uint16, issuer, company, productID string) backend.Key { + return backend.NewKey("rdplicense", issuer, strconv.Itoa(int(majorVersion)), strconv.Itoa(int(minorVersion)), company, productID) +} + +// WriteRDPLicense writes an RDP license to local storage. +func (p *ProcessStorage) WriteRDPLicense(ctx context.Context, majorVersion, minorVersion uint16, issuer, company, productID string, license []byte) error { + item := backend.Item{ + Key: p.rdpLicenseKey(majorVersion, minorVersion, issuer, company, productID), + Value: license, + Expires: p.BackendStorage.Clock().Now().Add(28 * 24 * time.Hour), + } + _, err := p.stateStorage.Put(ctx, item) + return trace.Wrap(err) +} + +// ReadRDPLicense reads a previously obtained license from storage. +func (p *ProcessStorage) ReadRDPLicense(ctx context.Context, majorVersion, minorVersion uint16, issuer, company, productID string) ([]byte, error) { + item, err := p.stateStorage.Get(ctx, p.rdpLicenseKey(majorVersion, minorVersion, issuer, company, productID)) + if err != nil { + return nil, trace.Wrap(err) + } + + return item.Value, nil +} + // ReadLocalIdentity reads, parses and returns the given pub/pri key + cert from the // key storage (dataDir). func ReadLocalIdentity(dataDir string, id state.IdentityID) (*state.Identity, error) { diff --git a/lib/auth/webauthnwin/webauthn_windows.go b/lib/auth/webauthnwin/webauthn_windows.go index 681120739aa9..cdad0964ef7f 100644 --- a/lib/auth/webauthnwin/webauthn_windows.go +++ b/lib/auth/webauthnwin/webauthn_windows.go @@ -220,6 +220,8 @@ func isUVPlatformAuthenticatorAvailable() (bool, error) { // bytesFromCBytes gets slice of bytes from C type and copies it to new slice // so that it won't interfere when main objects is free. +// +// TODO(codingllama): can we use C.GoBytes() here instead? func bytesFromCBytes(size uint32, p *byte) []byte { if p == nil { return nil diff --git a/lib/service/desktop.go b/lib/service/desktop.go index 5ebc52af55a8..a49c0de404d1 100644 --- a/lib/service/desktop.go +++ b/lib/service/desktop.go @@ -210,6 +210,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(logger *slog srv, err := desktop.NewWindowsService(desktop.WindowsServiceConfig{ DataDir: process.Config.DataDir, + LicenseStore: process.storage, Logger: process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentWindowsDesktop, process.id)), Clock: process.Clock, Authorizer: authorizer, diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index 5b8a48a7eb9e..056a374efa4a 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -743,6 +743,62 @@ func toClient(handle C.uintptr_t) (value *Client, err error) { return cgo.Handle(handle).Value().(*Client), nil } +//export cgo_read_rdp_license +func cgo_read_rdp_license(handle C.uintptr_t, req *C.CGOLicenseRequest, data_out **C.uint8_t, len_out *C.size_t) C.CGOErrCode { + *data_out = nil + *len_out = 0 + + client, err := toClient(handle) + if err != nil { + return C.ErrCodeFailure + } + + issuer := C.GoString(req.issuer) + company := C.GoString(req.company) + productID := C.GoString(req.product_id) + + license, err := client.readRDPLicense( + uint16(req.major_version), uint16(req.minor_version), + issuer, company, productID) + if err != nil { + return C.ErrCodeFailure + } + + // in this case, we expect the caller to use cgo_free_rdp_license + // when the data is no longer needed + *data_out = (*C.uint8_t)(C.CBytes(license)) + *len_out = C.size_t(len(license)) + return C.ErrCodeSuccess +} + +//export cgo_free_rdp_license +func cgo_free_rdp_license(p unsafe.Pointer) { + C.free(p) +} + +//export cgo_write_rdp_license +func cgo_write_rdp_license(handle C.uintptr_t, req *C.CGOLicenseRequest, data *C.uint8_t, length C.size_t) C.CGOErrCode { + client, err := toClient(handle) + if err != nil { + return C.ErrCodeFailure + } + + issuer := C.GoString(req.issuer) + company := C.GoString(req.company) + productID := C.GoString(req.product_id) + + licenseData := C.GoBytes(unsafe.Pointer(data), C.int(length)) + + err = client.writeRDPLicense( + uint16(req.major_version), uint16(req.minor_version), + issuer, company, productID, licenseData) + if err != nil { + return C.ErrCodeFailure + } + + return C.ErrCodeSuccess +} + //export cgo_handle_fastpath_pdu func cgo_handle_fastpath_pdu(handle C.uintptr_t, data *C.uint8_t, length C.uint32_t) C.CGOErrCode { goData := asRustBackedSlice(data, int(length)) @@ -753,6 +809,44 @@ func cgo_handle_fastpath_pdu(handle C.uintptr_t, data *C.uint8_t, length C.uint3 return client.handleRDPFastPathPDU(goData) } +func (c *Client) readRDPLicense(majorVersion, minorVersion uint16, issuer, company, productID string) ([]byte, error) { + log := c.cfg.Logger.With( + "issuer", issuer, + "company", company, + "version", fmt.Sprintf("%d.%d", majorVersion, minorVersion), + "product", productID, + ) + + license, err := c.cfg.LicenseStore.ReadRDPLicense(context.Background(), majorVersion, minorVersion, issuer, company, productID) + switch { + case trace.IsNotFound(err): + log.InfoContext(context.Background(), "existing RDP license not found") + case err != nil: + log.ErrorContext(context.Background(), "could not look up existing RDP license", "error", err) + case len(license) > 0: + log.InfoContext(context.Background(), "found existing RDP license") + } + + return license, trace.Wrap(err) +} + +func (c *Client) writeRDPLicense(majorVersion, minorVersion uint16, issuer, company, productID string, license []byte) error { + c.cfg.Logger.InfoContext(context.Background(), "writing RDP license to storage", + "issuer", issuer, + "company", company, + "version", fmt.Sprintf("%d.%d", majorVersion, minorVersion), + "product", productID, + ) + return trace.Wrap(c.cfg.LicenseStore.WriteRDPLicense( + context.Background(), + majorVersion, minorVersion, + issuer, + company, + productID, + license, + )) +} + func (c *Client) handleRDPFastPathPDU(data []byte) C.CGOErrCode { // Notify the input forwarding goroutine that we're ready for input. // Input can only be sent after connection was established, which we infer diff --git a/lib/srv/desktop/rdp/rdpclient/client_common.go b/lib/srv/desktop/rdp/rdpclient/client_common.go index 80e192e37427..4c6b4371d045 100644 --- a/lib/srv/desktop/rdp/rdpclient/client_common.go +++ b/lib/srv/desktop/rdp/rdpclient/client_common.go @@ -30,11 +30,20 @@ import ( "github.com/gravitational/teleport/lib/srv/desktop/tdp" ) +// LicenseStore implements client-side license storage for Microsoft +// Remote Desktop Services (RDS) licenses. +type LicenseStore interface { + WriteRDPLicense(ctx context.Context, majorVersion, minorVersion uint16, issuer, company, productID string, license []byte) error + ReadRDPLicense(ctx context.Context, majorVersion, minorVersion uint16, issuer, company, productID string) ([]byte, error) +} + // Config for creating a new Client. type Config struct { // Addr is the network address of the RDP server, in the form host:port. Addr string + LicenseStore LicenseStore + // UserCertGenerator generates user certificates for RDP authentication. GenerateUserCert GenerateUserCertFn CertTTL time.Duration diff --git a/lib/srv/desktop/rdp/rdpclient/src/lib.rs b/lib/srv/desktop/rdp/rdpclient/src/lib.rs index 05b16ca8fe6a..1d1cfec49d1f 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/lib.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/lib.rs @@ -464,6 +464,15 @@ pub unsafe extern "C" fn client_write_screen_resize( ) } +#[repr(C)] +pub struct CGOLicenseRequest { + major_version: u16, + minor_version: u16, + issuer: *const c_char, + company: *const c_char, + product_id: *const c_char, +} + #[repr(C)] pub struct CGOConnectParams { ad: bool, @@ -705,6 +714,18 @@ pub type CGOSharedDirectoryTruncateResponse = SharedDirectoryTruncateResponse; // These functions are defined on the Go side. // Look for functions with '//export funcname' comments. extern "C" { + fn cgo_read_rdp_license( + cgo_handle: CgoHandle, + req: *mut CGOLicenseRequest, + data_out: *mut *mut u8, + len_out: *mut usize, + ) -> CGOErrCode; + fn cgo_write_rdp_license( + cgo_handle: CgoHandle, + req: *mut CGOLicenseRequest, + data: *mut u8, + length: usize, + ) -> CGOErrCode; fn cgo_handle_remote_copy(cgo_handle: CgoHandle, data: *mut u8, len: u32) -> CGOErrCode; fn cgo_handle_fastpath_pdu(cgo_handle: CgoHandle, data: *mut u8, len: u32) -> CGOErrCode; fn cgo_handle_rdp_connection_activated( diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index 40c19a2ada1a..2d86732d1cb7 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -159,8 +159,9 @@ type WindowsServiceConfig struct { // Logger is the logger for the service. Logger *slog.Logger // Clock provides current time. - Clock clockwork.Clock - DataDir string + Clock clockwork.Clock + DataDir string + LicenseStore rdpclient.LicenseStore // Authorizer is used to authorize requests. Authorizer authz.Authorizer // LockWatcher is used to monitor for new locks. @@ -906,7 +907,8 @@ func (s *WindowsService) connectRDP(ctx context.Context, log *slog.Logger, tdpCo //nolint:staticcheck // SA4023. False positive, depends on build tags. rdpc, err := rdpclient.New(rdpclient.Config{ - Logger: log, + Logger: log, + LicenseStore: s.cfg.LicenseStore, GenerateUserCert: func(ctx context.Context, username string, ttl time.Duration) (certDER, keyDER []byte, err error) { return s.generateUserCert(ctx, username, ttl, desktop, createUsers, groups) },