From 83c37af312a9b5606e950f3e0ca76bfc91dc86df Mon Sep 17 00:00:00 2001 From: Trent Clarke Date: Wed, 16 Oct 2024 10:26:02 +1100 Subject: [PATCH] Address review feedback --- lib/auth/authclient/clt.go | 8 -------- lib/cache/cache_test.go | 14 +++++++------- lib/cache/provisioning.go | 4 ++-- lib/cache/provisioning_test.go | 14 +++++++------- lib/services/local/provisioningstates.go | 8 ++++++-- lib/services/provisioningstates.go | 10 +++++----- 6 files changed, 27 insertions(+), 31 deletions(-) diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index bd61608ead62..02c00a7502e2 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -625,10 +625,6 @@ func (c *Client) OktaClient() services.Okta { return c.APIClient.OktaClient() } -func (c *Client) ProvisioningStatesClient() services.ProvisioningStates { - return nil -} - func (c *Client) SCIMClient() services.SCIM { return c.APIClient.SCIMClient() } @@ -1877,8 +1873,4 @@ type ClientI interface { // GenerateAppToken creates a JWT token with application access. GenerateAppToken(ctx context.Context, req types.GenerateAppTokenRequest) (string, error) - - // ProvisioningStatesClient manages access to the downstream user and group - // provisioning state storage service - ProvisioningStatesClient() services.ProvisioningStates } diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index 7340eae35416..7a9546bc1762 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -138,7 +138,7 @@ type testPack struct { spiffeFederations *local.SPIFFEFederationService staticHostUsers services.StaticHostUser autoUpdateService services.AutoUpdateService - provisioningStatesS services.ProvisioningStates + provisioningStates services.ProvisioningStates } // testFuncs are functions to support testing an object in a cache. @@ -382,7 +382,7 @@ func newPackWithoutCache(dir string, opts ...packOption) (*testPack, error) { return nil, trace.Wrap(err) } - p.provisioningStatesS, err = local.NewProvisioningStateService(p.backend) + p.provisioningStates, err = local.NewProvisioningStateService(p.backend) if err != nil { return nil, trace.Wrap(err) } @@ -437,7 +437,7 @@ func newPack(dir string, setupConfig func(c Config) Config, opts ...packOption) DatabaseObjects: p.databaseObjects, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, - ProvisioningStates: p.provisioningStatesS, + ProvisioningStates: p.provisioningStates, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -846,7 +846,7 @@ func TestCompletenessInit(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, - ProvisioningStates: p.provisioningStatesS, + ProvisioningStates: p.provisioningStates, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -928,7 +928,7 @@ func TestCompletenessReset(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, - ProvisioningStates: p.provisioningStatesS, + ProvisioningStates: p.provisioningStates, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) @@ -1136,7 +1136,7 @@ func TestListResources_NodesTTLVariant(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, - ProvisioningStates: p.provisioningStatesS, + ProvisioningStates: p.provisioningStates, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, neverOK: true, // ensure reads are never healthy @@ -1229,7 +1229,7 @@ func initStrategy(t *testing.T) { SPIFFEFederations: p.spiffeFederations, StaticHostUsers: p.staticHostUsers, AutoUpdateService: p.autoUpdateService, - ProvisioningStates: p.provisioningStatesS, + ProvisioningStates: p.provisioningStates, MaxRetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, })) diff --git a/lib/cache/provisioning.go b/lib/cache/provisioning.go index 69ac3e8e767d..883fb8cd69fb 100644 --- a/lib/cache/provisioning.go +++ b/lib/cache/provisioning.go @@ -29,7 +29,7 @@ import ( type provisioningStateGetter interface { GetProvisioningState(context.Context, services.DownstreamID, services.ProvisioningStateID) (*provisioningv1.PrincipalState, error) - ListAllProvisioningStates(context.Context, int, *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) + ListProvisioningStatesForAllDownstreams(context.Context, int, *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) } type provisioningStateExecutor struct{} @@ -49,7 +49,7 @@ func (provisioningStateExecutor) getAll(ctx context.Context, cache *Cache, loadS var resourcesPage []*provisioningv1.PrincipalState var err error - resourcesPage, nextPage, err := cache.ProvisioningStates.ListAllProvisioningStates(ctx, 0, &page) + resourcesPage, nextPage, err := cache.ProvisioningStates.ListProvisioningStatesForAllDownstreams(ctx, 0, &page) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/cache/provisioning_test.go b/lib/cache/provisioning_test.go index 84a771f70b58..e13c363f7391 100644 --- a/lib/cache/provisioning_test.go +++ b/lib/cache/provisioning_test.go @@ -62,7 +62,7 @@ func TestProvisioningPrincipalState(t *testing.T) { var result []*provisioningv1.PrincipalState var pageToken pagination.PageRequestToken for { - page, nextPage, err := src.ListAllProvisioningStates(ctx, 0, &pageToken) + page, nextPage, err := src.ListProvisioningStatesForAllDownstreams(ctx, 0, &pageToken) if err != nil { return nil, trace.Wrap(err) } @@ -82,28 +82,28 @@ func TestProvisioningPrincipalState(t *testing.T) { return newProvisioningPrincipalState(s), nil }, create: func(ctx context.Context, item *provisioningv1.PrincipalState) error { - _, err := fixturePack.provisioningStatesS.CreateProvisioningState(ctx, item) + _, err := fixturePack.provisioningStates.CreateProvisioningState(ctx, item) return trace.Wrap(err) }, update: func(ctx context.Context, item *provisioningv1.PrincipalState) error { - _, err := fixturePack.provisioningStatesS.UpdateProvisioningState(ctx, item) + _, err := fixturePack.provisioningStates.UpdateProvisioningState(ctx, item) return trace.Wrap(err) }, list: func(ctx context.Context) ([]*provisioningv1.PrincipalState, error) { - return collect(ctx, fixturePack.provisioningStatesS) + return collect(ctx, fixturePack.provisioningStates) }, delete: func(ctx context.Context, id string) error { - return trace.Wrap(fixturePack.provisioningStatesS.DeleteProvisioningState( + return trace.Wrap(fixturePack.provisioningStates.DeleteProvisioningState( ctx, testDownstreamID, services.ProvisioningStateID(id))) }, deleteAll: func(ctx context.Context) error { - return trace.Wrap(fixturePack.provisioningStatesS.DeleteAllProvisioningStates(ctx)) + return trace.Wrap(fixturePack.provisioningStates.DeleteAllProvisioningStates(ctx)) }, cacheList: func(ctx context.Context) ([]*provisioningv1.PrincipalState, error) { return collect(ctx, fixturePack.cache.provisioningStatesCache) }, cacheGet: func(ctx context.Context, id string) (*provisioningv1.PrincipalState, error) { - r, err := fixturePack.provisioningStatesS.GetProvisioningState( + r, err := fixturePack.provisioningStates.GetProvisioningState( ctx, testDownstreamID, services.ProvisioningStateID(id)) return r, trace.Wrap(err) }, diff --git a/lib/services/local/provisioningstates.go b/lib/services/local/provisioningstates.go index 3910321e84ca..407c964fa385 100644 --- a/lib/services/local/provisioningstates.go +++ b/lib/services/local/provisioningstates.go @@ -140,11 +140,15 @@ func (ss *ProvisioningStateService) ListProvisioningStates(ctx context.Context, return resp, pagination.NextPageToken(nextPage), nil } -// ListAllProvisioningStates lists all provisioning state records for all +// ListProvisioningStatesForAllDownstreams lists all provisioning state records for all // downstream receivers. Note that the returned record names may not be unique // across all downstream receivers. Check the records' `DownstreamID` field // to disambiguate. -func (ss *ProvisioningStateService) ListAllProvisioningStates(ctx context.Context, pageSize int, page *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) { +func (ss *ProvisioningStateService) ListProvisioningStatesForAllDownstreams( + ctx context.Context, + pageSize int, + page *pagination.PageRequestToken, +) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) { if pageSize == 0 { pageSize = provisioningStatePageSize } diff --git a/lib/services/provisioningstates.go b/lib/services/provisioningstates.go index 0e686e6d987a..e51e4b9b782f 100644 --- a/lib/services/provisioningstates.go +++ b/lib/services/provisioningstates.go @@ -73,11 +73,11 @@ type DownstreamProvisioningStates interface { type ProvisioningStates interface { DownstreamProvisioningStates - // ListProvisioningStates lists all provisioning state records for all - // downstream receivers. Note that the returned record names may not be unique - // across all downstream receivers. Check the records' `DownstreamID` field - // to disambiguate. - ListAllProvisioningStates(context.Context, int, *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) + // ListProvisioningStatesForAllDownstreams lists all provisioning state + // records for all downstream receivers. Note that the returned record names + // may not be unique across all downstream receivers. Check the records' + // `DownstreamID` field to disambiguate. + ListProvisioningStatesForAllDownstreams(context.Context, int, *pagination.PageRequestToken) ([]*provisioningv1.PrincipalState, pagination.NextPageToken, error) // DeleteAllProvisioningStates deletes all provisioning state records for // all downstream receivers.