From 204c76e6cf729c9bb6e7366ac98d23c9a994d346 Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 15 Oct 2024 13:53:17 -0700 Subject: [PATCH] Get SSO mfa device concurrently. --- lib/services/local/users.go | 49 +++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/lib/services/local/users.go b/lib/services/local/users.go index 5bab8d0eb8187..e36c4b2199c1f 100644 --- a/lib/services/local/users.go +++ b/lib/services/local/users.go @@ -1379,6 +1379,45 @@ func (s *IdentityService) GetMFADevices(ctx context.Context, user string, withSe return nil, trace.BadParameter("missing parameter user") } + // get normal MFA devices and SSO mfa device concurrently, returning the first error we get. + errC := make(chan error) + + var devices []*types.MFADevice + go func() { + var err error + devices, err = s.getMFADevices(ctx, user, withSecrets) + if err != nil { + errC <- trace.Wrap(err) + return + } + errC <- nil + }() + + var ssoDev *types.MFADevice + go func() { + var err error + ssoDev, err = s.getSSOMFADevice(ctx, user) + if !trace.IsNotFound(err) { + errC <- trace.Wrap(err) + return + } + errC <- nil + }() + + for i := 0; i < 2; i++ { + if err := <-errC; err != nil { + return nil, trace.Wrap(err) + } + } + + if ssoDev != nil { + devices = append(devices, ssoDev) + } + + return devices, nil +} + +func (s *IdentityService) getMFADevices(ctx context.Context, user string, withSecrets bool) ([]*types.MFADevice, error) { startKey := backend.ExactKey(webPrefix, usersPrefix, user, mfaDevicePrefix) result, err := s.GetRange(ctx, startKey, backend.RangeEnd(startKey), backend.NoLimit) if err != nil { @@ -1399,16 +1438,6 @@ func (s *IdentityService) GetMFADevices(ctx context.Context, user string, withSe } devices = append(devices, &d) } - - // If the user originated from an SSO connector with MFA enabled, - // append the corresponding SSO MFA device for the user. - ssoDev, err := s.getSSOMFADevice(ctx, user) - if err == nil { - devices = append(devices, ssoDev) - } else if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - return devices, nil }