Skip to content

Commit

Permalink
fix: Address flakiness and revisit TestTeleportClient_Login_local (#4…
Browse files Browse the repository at this point in the history
…7406) (#47438)

* Replace prompt.Stdin() per TeleportClient instance

* Remove hasTouchIDCredentials global

* Remove secondFactor from test table

* Cancel OTP goroutine on PromptPIN

* Make sub-tests parallel

* Make context timeout apply strictly to tc.Login
  • Loading branch information
codingllama authored Oct 15, 2024
1 parent 5950338 commit f00b430
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 167 deletions.
15 changes: 13 additions & 2 deletions lib/auth/webauthncli/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type DefaultPrompt struct {
AcknowledgeTouchMessage string
PromptCredentialMessage string

// StdinFunc allows tests to override prompt.Stdin().
// If nil prompt.Stdin() is used.
StdinFunc func() prompt.StdinReader

ctx context.Context
out io.Writer

Expand All @@ -61,9 +65,16 @@ func NewDefaultPrompt(ctx context.Context, out io.Writer) *DefaultPrompt {
}
}

func (p *DefaultPrompt) stdin() prompt.StdinReader {
if p.StdinFunc == nil {
return prompt.Stdin()
}
return p.StdinFunc()
}

// PromptPIN prompts the user for a PIN.
func (p *DefaultPrompt) PromptPIN() (string, error) {
return prompt.Password(p.ctx, p.out, prompt.Stdin(), p.PINMessage)
return prompt.Password(p.ctx, p.out, p.stdin(), p.PINMessage)
}

// PromptTouch prompts the user for a security key touch, using different
Expand Down Expand Up @@ -105,7 +116,7 @@ func (p *DefaultPrompt) PromptCredential(creds []*CredentialInfo) (*CredentialIn
}

for {
numOrName, err := prompt.Input(p.ctx, p.out, prompt.Stdin(), p.PromptCredentialMessage)
numOrName, err := prompt.Input(p.ctx, p.out, p.stdin(), p.PromptCredentialMessage)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
37 changes: 28 additions & 9 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,14 @@ type Config struct {

// SSHDialTimeout is the timeout value that should be used for SSH connections.
SSHDialTimeout time.Duration

// StdinFunc allows tests to override prompt.Stdin().
// If nil prompt.Stdin() is used.
StdinFunc func() prompt.StdinReader

// HasTouchIDCredentialsFunc allows tests to override touchid.HasCredentials.
// If nil touchid.HasCredentials is used.
HasTouchIDCredentialsFunc func(rpID, user string) bool
}

// CachePolicy defines cache policy for local clients
Expand Down Expand Up @@ -1241,6 +1249,20 @@ func NewClient(c *Config) (tc *TeleportClient, err error) {
return tc, nil
}

func (tc *TeleportClient) stdin() prompt.StdinReader {
if tc.StdinFunc == nil {
return prompt.Stdin()
}
return tc.StdinFunc()
}

func (tc *TeleportClient) hasTouchIDCredentials(rpID, user string) bool {
if tc.HasTouchIDCredentialsFunc == nil {
return touchid.HasCredentials(rpID, user)
}
return tc.HasTouchIDCredentialsFunc(rpID, user)
}

func (tc *TeleportClient) ProfileStatus() (*ProfileStatus, error) {
status, err := tc.ClientStore.ReadProfileStatus(tc.WebProxyAddr)
if err != nil {
Expand Down Expand Up @@ -3694,9 +3716,6 @@ func (tc *TeleportClient) mfaLocalLoginWeb(ctx context.Context, priv *keys.Priva
return clt, session, trace.Wrap(err)
}

// hasTouchIDCredentials provides indirection for tests.
var hasTouchIDCredentials = touchid.HasCredentials

// canDefaultToPasswordless checks without user interaction
// if there is any registered passwordless login.
func (tc *TeleportClient) canDefaultToPasswordless(pr *webclient.PingResponse) bool {
Expand All @@ -3719,7 +3738,7 @@ func (tc *TeleportClient) canDefaultToPasswordless(pr *webclient.PingResponse) b
user = tc.Username
}

return hasTouchIDCredentials(pr.Auth.Webauthn.RPID, user)
return tc.hasTouchIDCredentials(pr.Auth.Webauthn.RPID, user)
}

// SSHLoginFunc is a function which carries out authn with an auth server and returns an auth response.
Expand Down Expand Up @@ -4269,7 +4288,7 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
fmt.Fprintln(tc.Stderr, motd.Text)

// If possible, prompt the user for acknowledement before continuing.
if stdin := prompt.Stdin(); stdin.IsTerminal() {
if stdin := tc.stdin(); stdin.IsTerminal() {
// We're re-using the password reader for user acknowledgment for
// aesthetic purposes, because we want to hide any garbage the
// user might enter at the prompt. Whatever the user enters will
Expand Down Expand Up @@ -4650,11 +4669,11 @@ func (tc *TeleportClient) AskOTP(ctx context.Context) (token string, err error)
)
defer span.End()

stdin := prompt.Stdin()
stdin := tc.stdin()
if !stdin.IsTerminal() {
return "", trace.Wrap(prompt.ErrNotTerminal, "cannot perform OTP login without a terminal")
}
return prompt.Password(ctx, tc.Stderr, prompt.Stdin(), "Enter your OTP token")
return prompt.Password(ctx, tc.Stderr, stdin, "Enter your OTP token")
}

// AskPassword prompts the user to enter the password
Expand All @@ -4666,7 +4685,7 @@ func (tc *TeleportClient) AskPassword(ctx context.Context) (pwd string, err erro
)
defer span.End()

stdin := prompt.Stdin()
stdin := tc.stdin()
if !stdin.IsTerminal() {
return "", trace.Wrap(prompt.ErrNotTerminal, "cannot perform password login without a terminal")
}
Expand Down Expand Up @@ -5116,7 +5135,7 @@ func (tc *TeleportClient) HeadlessApprove(ctx context.Context, headlessAuthentic
fmt.Fprintf(tc.Stdout, "Headless login attempt from IP address %q requires approval.\nContact your administrator if you didn't initiate this login attempt.\n", headlessAuthn.ClientIpAddress)

if confirm {
ok, err := prompt.Confirmation(ctx, tc.Stdout, prompt.Stdin(), "Approve?")
ok, err := prompt.Confirmation(ctx, tc.Stdout, tc.stdin(), "Approve?")
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit f00b430

Please sign in to comment.