Skip to content

Commit

Permalink
[v16] SSO Ceremony refactor (#47303)
Browse files Browse the repository at this point in the history
* Fix Hardware Key Login in Teleport Connect (#46793)

* Reuse tc.NewSSHLogin in Teleport Connect for consistency.

* reuse sso login function.

* SSO Ceremony refactor (#46983)

* Refactor SSO redirector for customization; add SSO MFA ceremony to redirector.

* Add customizable SSO ceremony.

* Create SSO login ceremony from client; Remove ProxySupportsKeyPolicyMessage for v17.

* Resolve comments.

* Refactor redirector closing.

* Add redirector tests.

* Add ceremony and redirector tests.

* Fix new copyright headers.

* Use slog.

* Fix lint.

* Fix lint.
  • Loading branch information
Joerger authored Oct 16, 2024
1 parent 4a59807 commit c9f7f5f
Show file tree
Hide file tree
Showing 13 changed files with 1,112 additions and 631 deletions.
106 changes: 43 additions & 63 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"time"
"unicode/utf8"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -76,6 +75,7 @@ import (
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
"github.com/gravitational/teleport/lib/authz"
libmfa "github.com/gravitational/teleport/lib/client/mfa"
"github.com/gravitational/teleport/lib/client/sso"
"github.com/gravitational/teleport/lib/client/terminal"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/devicetrust"
Expand Down Expand Up @@ -3566,19 +3566,13 @@ func (tc *TeleportClient) getSSHLoginFunc(pr *webclient.PingResponse) (SSHLoginF
}
case constants.OIDC:
oidc := pr.Auth.OIDC
return func(ctx context.Context, priv *keys.PrivateKey) (*authclient.SSHLoginResponse, error) {
return tc.ssoLogin(ctx, priv, oidc.Name, oidc.Display, constants.OIDC)
}, nil
return tc.SSOLoginFn(oidc.Name, oidc.Display, constants.OIDC), nil
case constants.SAML:
saml := pr.Auth.SAML
return func(ctx context.Context, priv *keys.PrivateKey) (*authclient.SSHLoginResponse, error) {
return tc.ssoLogin(ctx, priv, saml.Name, saml.Display, constants.SAML)
}, nil
return tc.SSOLoginFn(saml.Name, saml.Display, constants.SAML), nil
case constants.Github:
github := pr.Auth.Github
return func(ctx context.Context, priv *keys.PrivateKey) (*authclient.SSHLoginResponse, error) {
return tc.ssoLogin(ctx, priv, github.Name, github.Display, constants.Github)
}, nil
return tc.SSOLoginFn(github.Name, github.Display, constants.Github), nil
default:
return nil, trace.BadParameter("unsupported authentication type: %q", pr.Auth.Type)
}
Expand Down Expand Up @@ -3632,7 +3626,7 @@ func (tc *TeleportClient) pwdlessLoginWeb(ctx context.Context, priv *keys.Privat
user = tc.Username
}

sshLogin, err := tc.newSSHLogin(priv)
sshLogin, err := tc.NewSSHLogin(priv)
if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -3680,7 +3674,7 @@ func (tc *TeleportClient) directLoginWeb(ctx context.Context, secondFactorType c
}
}

sshLogin, err := tc.newSSHLogin(priv)
sshLogin, err := tc.NewSSHLogin(priv)
if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand All @@ -3702,7 +3696,7 @@ func (tc *TeleportClient) mfaLocalLoginWeb(ctx context.Context, priv *keys.Priva
return nil, nil, trace.Wrap(err)
}

sshLogin, err := tc.newSSHLogin(priv)
sshLogin, err := tc.NewSSHLogin(priv)
if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -3867,8 +3861,8 @@ func (tc *TeleportClient) GetNewLoginKey(ctx context.Context) (priv *keys.Privat
return priv, nil
}

// new SSHLogin generates a new SSHLogin using the given login key.
func (tc *TeleportClient) newSSHLogin(priv *keys.PrivateKey) (SSHLogin, error) {
// NewSSHLogin generates a new SSHLogin using the given login key.
func (tc *TeleportClient) NewSSHLogin(priv *keys.PrivateKey) (SSHLogin, error) {
return SSHLogin{
ProxyAddr: tc.WebProxyAddr,
PubKey: priv.MarshalSSHPublicKey(),
Expand Down Expand Up @@ -3898,7 +3892,7 @@ func (tc *TeleportClient) pwdlessLogin(ctx context.Context, priv *keys.PrivateKe
user = tc.Username
}

sshLogin, err := tc.newSSHLogin(priv)
sshLogin, err := tc.NewSSHLogin(priv)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -3964,7 +3958,7 @@ func (tc *TeleportClient) directLogin(ctx context.Context, secondFactorType cons
}
}

sshLogin, err := tc.newSSHLogin(priv)
sshLogin, err := tc.NewSSHLogin(priv)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -3994,7 +3988,7 @@ func (tc *TeleportClient) mfaLocalLogin(ctx context.Context, priv *keys.PrivateK
return nil, trace.Wrap(err)
}

sshLogin, err := tc.newSSHLogin(priv)
sshLogin, err := tc.NewSSHLogin(priv)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -4047,55 +4041,41 @@ func (tc *TeleportClient) headlessLogin(ctx context.Context, priv *keys.PrivateK
// SSOLoginFunc is a function used in tests to mock SSO logins.
type SSOLoginFunc func(ctx context.Context, connectorID string, priv *keys.PrivateKey, protocol string) (*authclient.SSHLoginResponse, error)

// TODO(atburke): DELETE in v17.0.0
func versionSupportsKeyPolicyMessage(proxyVersion *semver.Version) bool {
switch proxyVersion.Major {
case 15:
return !proxyVersion.LessThan(*semver.New("15.2.5"))
case 14:
return !proxyVersion.LessThan(*semver.New("14.3.17"))
case 13:
return !proxyVersion.LessThan(*semver.New("13.4.22"))
default:
return proxyVersion.Major > 15
}
}
// SSOLoginFn returns a function that will carry out SSO login. A browser window will be opened
// for the user to authenticate through SSO. On completion they will be redirected to a success
// page and the resulting login session will be captured and returned.
func (tc *TeleportClient) SSOLoginFn(connectorID, connectorName, connectorType string) SSHLoginFunc {
return func(ctx context.Context, priv *keys.PrivateKey) (*authclient.SSHLoginResponse, error) {
if tc.MockSSOLogin != nil {
// sso login response is being mocked for testing purposes
return tc.MockSSOLogin(ctx, connectorID, priv, connectorType)
}

// samlLogin opens browser window and uses OIDC or SAML redirect cycle with browser
func (tc *TeleportClient) ssoLogin(ctx context.Context, priv *keys.PrivateKey, connectorID, connectorName, protocol string) (*authclient.SSHLoginResponse, error) {
if tc.MockSSOLogin != nil {
// sso login response is being mocked for testing purposes
return tc.MockSSOLogin(ctx, connectorID, priv, protocol)
}
// Set SAMLSingleLogoutEnabled from server settings.
pr, err := tc.Ping(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if connectorType == constants.SAML && pr.Auth.SAML != nil {
tc.SAMLSingleLogoutEnabled = pr.Auth.SAML.SingleLogoutEnabled
}

sshLogin, err := tc.newSSHLogin(priv)
if err != nil {
return nil, trace.Wrap(err)
}
rdConfig, err := tc.ssoRedirectorConfig(ctx, connectorName)
if err != nil {
return nil, trace.Wrap(err)
}

pr, err := tc.Ping(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
proxyVersion := semver.New(pr.ServerVersion)
rd, err := sso.NewRedirector(rdConfig)
if err != nil {
return nil, trace.Wrap(err)
}
defer rd.Close()

if protocol == constants.SAML && pr.Auth.SAML != nil {
tc.SAMLSingleLogoutEnabled = pr.Auth.SAML.SingleLogoutEnabled
}
ssoCeremony := sso.NewCLICeremony(rd, tc.ssoLoginInitFn(priv, connectorID, connectorType))

// ask the CA (via proxy) to sign our public key:
response, err := SSHAgentSSOLogin(ctx, SSHLoginSSO{
SSHLogin: sshLogin,
ConnectorID: connectorID,
ConnectorName: connectorName,
Protocol: protocol,
BindAddr: tc.BindAddr,
CallbackAddr: tc.CallbackAddr,
Browser: tc.Browser,
PrivateKeyPolicy: tc.PrivateKeyPolicy,
ProxySupportsKeyPolicyMessage: versionSupportsKeyPolicyMessage(proxyVersion),
}, nil)
return response, trace.Wrap(err)
resp, err := ssoCeremony.Run(ctx)
return resp, trace.Wrap(err)
}
}

func (tc *TeleportClient) GetSAMLSingleLogoutURL(ctx context.Context, clt *ClusterClient, profile *ProfileStatus) (string, error) {
Expand All @@ -4121,7 +4101,7 @@ func (tc *TeleportClient) SAMLSingleLogout(ctx context.Context, SAMLSingleLogout
relayState := parsed.Query().Get("RelayState")
_, connectorName, _ := strings.Cut(relayState, ",")

err = OpenURLInBrowser(tc.Browser, SAMLSingleLogoutURL)
err = sso.OpenURLInBrowser(tc.Browser, SAMLSingleLogoutURL)
// If no browser was opened.
if err != nil || tc.Browser == teleport.BrowserNone {
fmt.Fprintf(os.Stderr, "Open the following link to log out of %s: %v\n", connectorName, SAMLSingleLogoutURL)
Expand Down
Loading

0 comments on commit c9f7f5f

Please sign in to comment.