Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Adding configurable audience property for flyte clients (#329)
Browse files Browse the repository at this point in the history
* Adding configurable audience property for flyte clients

Signed-off-by: pmahindrakar-oss <[email protected]>

* changed the const audience to audienceKey

Signed-off-by: pmahindrakar-oss <[email protected]>

* fixed unit tests

Signed-off-by: pmahindrakar-oss <[email protected]>

* fixed unit test

Signed-off-by: pmahindrakar-oss <[email protected]>

* nit

Signed-off-by: pmahindrakar-oss <[email protected]>

* feedback

Signed-off-by: pmahindrakar-oss <[email protected]>

* refactored unit tests

Signed-off-by: pmahindrakar-oss <[email protected]>

* Added UseAudienceFromAdmin property to force pull audience from admin config. Default is false and expects clients to pass it

Signed-off-by: pmahindrakar-oss <[email protected]>

* Added test for expected number of calls to the public admin endpoint

Signed-off-by: pmahindrakar-oss <[email protected]>

* fixed the tests

Signed-off-by: pmahindrakar-oss <[email protected]>

Signed-off-by: pmahindrakar-oss <[email protected]>
  • Loading branch information
pmahindrakar-oss authored Jan 18, 2023
1 parent 9fbac98 commit c1c892a
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 36 deletions.
51 changes: 24 additions & 27 deletions clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,18 @@ import (
"sync"
"testing"

"github.com/flyteorg/flytestdlib/logger"

"k8s.io/apimachinery/pkg/util/rand"

mocks2 "github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/stretchr/testify/mock"

service2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/config"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/apimachinery/pkg/util/rand"

"github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"
)

// authMetadataServer is a fake AuthMetadataServer that takes in an AuthMetadataServer implementation (usually one
Expand All @@ -39,15 +35,15 @@ type authMetadataServer struct {
port int
grpcServer *grpc.Server
netListener net.Listener
impl service2.AuthMetadataServiceServer
impl service.AuthMetadataServiceServer
lck *sync.RWMutex
}

func (s authMetadataServer) GetOAuth2Metadata(ctx context.Context, in *service2.OAuth2MetadataRequest) (*service2.OAuth2MetadataResponse, error) {
func (s authMetadataServer) GetOAuth2Metadata(ctx context.Context, in *service.OAuth2MetadataRequest) (*service.OAuth2MetadataResponse, error) {
return s.impl.GetOAuth2Metadata(ctx, in)
}

func (s authMetadataServer) GetPublicClientConfig(ctx context.Context, in *service2.PublicClientAuthConfigRequest) (*service2.PublicClientAuthConfigResponse, error) {
func (s authMetadataServer) GetPublicClientConfig(ctx context.Context, in *service.PublicClientAuthConfigRequest) (*service.PublicClientAuthConfigResponse, error) {
return s.impl.GetPublicClientConfig(ctx, in)
}

Expand Down Expand Up @@ -84,7 +80,7 @@ func (s *authMetadataServer) Start(_ context.Context) error {
}

grpcS := grpc.NewServer()
service2.RegisterAuthMetadataServiceServer(grpcS, s)
service.RegisterAuthMetadataServiceServer(grpcS, s)
go func() {
_ = grpcS.Serve(lis)
//assert.NoError(s.t, err)
Expand All @@ -106,7 +102,7 @@ func (s *authMetadataServer) Close() {
s.s.Close()
}

func newAuthMetadataServer(t testing.TB, port int, impl service2.AuthMetadataServiceServer) *authMetadataServer {
func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServiceServer) *authMetadataServer {
return &authMetadataServer{
port: port,
t: t,
Expand All @@ -132,13 +128,13 @@ func Test_newAuthInterceptor(t *testing.T) {
}))

port := rand.IntnRange(10000, 60000)
m := &mocks2.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service2.OAuth2MetadataResponse{
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{
AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port),
TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port),
JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port),
}, nil)
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service2.PublicClientAuthConfigResponse{
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{
Scopes: []string{"all"},
}, nil)
s := newAuthMetadataServer(t, port, m)
Expand Down Expand Up @@ -171,7 +167,7 @@ func Test_newAuthInterceptor(t *testing.T) {
}))

port := rand.IntnRange(10000, 60000)
m := &mocks2.AuthMetadataServiceServer{}
m := &adminMocks.AuthMetadataServiceServer{}
s := newAuthMetadataServer(t, port, m)
ctx := context.Background()
assert.NoError(t, s.Start(ctx))
Expand Down Expand Up @@ -201,13 +197,13 @@ func Test_newAuthInterceptor(t *testing.T) {
}))

port := rand.IntnRange(10000, 60000)
m := &mocks2.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service2.OAuth2MetadataResponse{
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{
AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port),
TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port),
JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port),
}, nil)
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service2.PublicClientAuthConfigResponse{
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{
Scopes: []string{"all"},
}, nil)

Expand Down Expand Up @@ -237,8 +233,8 @@ func Test_newAuthInterceptor(t *testing.T) {

func TestMaterializeCredentials(t *testing.T) {
port := rand.IntnRange(10000, 60000)
t.Run("No public client config or oauth2 metadata endpoint lookup", func(t *testing.T) {
m := &mocks2.AuthMetadataServiceServer{}
t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) {
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config"))
s := newAuthMetadataServer(t, port, m)
Expand All @@ -256,12 +252,13 @@ func TestMaterializeCredentials(t *testing.T) {
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port),
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, &mocks.TokenCache{}, f)
assert.NoError(t, err)
})
t.Run("Failed to fetch client metadata", func(t *testing.T) {
m := &mocks2.AuthMetadataServiceServer{}
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
failedPublicClientConfigLookup := errors.New("expected err")
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup)
Expand Down
2 changes: 2 additions & 0 deletions clients/go/admin/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ type Config struct {
ClientSecretLocation string `json:"clientSecretLocation" pflag:",File containing the client secret"`
ClientSecretEnvVar string `json:"clientSecretEnvVar" pflag:",Environment variable containing the client secret"`
Scopes []string `json:"scopes" pflag:",List of scopes to request"`
UseAudienceFromAdmin bool `json:"useAudienceFromAdmin" pflag:",Use Audience configured from admins public endpoint config."`
Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."`

// There are two ways to get the token URL. If the authorization server url is provided, the client will try to use RFC 8414 to
// try to get the token URL. Or it can be specified directly through TokenURL config.
Expand Down
2 changes: 2 additions & 0 deletions clients/go/admin/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions clients/go/admin/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 29 additions & 9 deletions clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
Expand All @@ -22,6 +23,10 @@ import (
"github.com/flyteorg/flytestdlib/logger"
)

const (
audienceKey = "audience"
)

// TokenSourceProvider defines the interface needed to provide a TokenSource that is used to
// create a client with authentication enabled.
type TokenSourceProvider interface {
Expand All @@ -46,15 +51,24 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T
}

scopes := cfg.Scopes
if len(scopes) == 0 {
clientMetadata, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
audienceValue := cfg.Audience

if len(scopes) == 0 || cfg.UseAudienceFromAdmin {
publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
scopes = clientMetadata.Scopes
// Update scopes from publicClientConfig
if len(scopes) == 0 {
scopes = publicClientConfig.Scopes
}
// Update audience from publicClientConfig
if cfg.UseAudienceFromAdmin {
audienceValue = publicClientConfig.Audience
}
}

tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL)
tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, audienceValue)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -152,7 +166,7 @@ type ClientCredentialsTokenSourceProvider struct {
TokenRefreshWindow time.Duration
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string) (TokenSourceProvider, error) {
func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, audience string) (TokenSourceProvider, error) {
var secret string
if len(cfg.ClientSecretEnvVar) > 0 {
secret = os.Getenv(cfg.ClientSecretEnvVar)
Expand All @@ -164,13 +178,19 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
}
secret = string(secretBytes)
}
endpointParams := url.Values{}
if len(audience) > 0 {
endpointParams = url.Values{audienceKey: {audience}}
}
secret = strings.TrimSpace(secret)
return ClientCredentialsTokenSourceProvider{
ccConfig: clientcredentials.Config{
ClientID: cfg.ClientID,
ClientSecret: secret,
TokenURL: tokenURL,
Scopes: scopes},
ClientID: cfg.ClientID,
ClientSecret: secret,
TokenURL: tokenURL,
Scopes: scopes,
EndpointParams: endpointParams,
},
TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil
}

Expand Down
70 changes: 70 additions & 0 deletions clients/go/admin/token_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package admin

import (
"context"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"

tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
)

type DummyTestTokenSource struct {
Expand All @@ -25,3 +31,67 @@ func TestNewTokenSource(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "Bearer abc", metadata["test"])
}

func TestNewTokenSourceProvider(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
audienceCfg string
scopesCfg []string
useAudienceFromAdmin bool
clientConfigResponse service.PublicClientAuthConfigResponse
expectedAudience string
expectedScopes []string
expectedCallsPubEndpoint int
}{
{
name: "audience from client config",
audienceCfg: "clientConfiguredAud",
scopesCfg: []string{"all"},
clientConfigResponse: service.PublicClientAuthConfigResponse{},
expectedAudience: "clientConfiguredAud",
expectedScopes: []string{"all"},
expectedCallsPubEndpoint: 0,
},
{
name: "audience from public client response",
audienceCfg: "clientConfiguredAud",
useAudienceFromAdmin: true,
scopesCfg: []string{"all"},
clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}},
expectedAudience: "AdminConfiguredAud",
expectedScopes: []string{"all"},
expectedCallsPubEndpoint: 1,
},

{
name: "audience from client with useAudience from admin false",
audienceCfg: "clientConfiguredAud",
useAudienceFromAdmin: false,
scopesCfg: []string{"all"},
clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}},
expectedAudience: "clientConfiguredAud",
expectedScopes: []string{"all"},
expectedCallsPubEndpoint: 0,
},
}
for _, test := range tests {
cfg := GetConfig(ctx)
tokenCache := &tokenCacheMocks.TokenCache{}
metadataClient := &adminMocks.AuthMetadataServiceClient{}
metadataClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{}, nil)
metadataClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&test.clientConfigResponse, nil)
cfg.AuthType = AuthTypeClientSecret
cfg.Audience = test.audienceCfg
cfg.Scopes = test.scopesCfg
cfg.UseAudienceFromAdmin = test.useAudienceFromAdmin
flyteTokenSource, err := NewTokenSourceProvider(ctx, cfg, tokenCache, metadataClient)
assert.True(t, metadataClient.AssertNumberOfCalls(t, "GetPublicClientConfig", test.expectedCallsPubEndpoint))
assert.NoError(t, err)
assert.NotNil(t, flyteTokenSource)
clientCredSourceProvider, ok := flyteTokenSource.(ClientCredentialsTokenSourceProvider)
assert.True(t, ok)
assert.Equal(t, test.expectedScopes, clientCredSourceProvider.ccConfig.Scopes)
assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams)
}
}

0 comments on commit c1c892a

Please sign in to comment.