Skip to content

Commit

Permalink
Enable caching of agent readiness
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton-Kalpakchiev committed Nov 21, 2024
1 parent 059a132 commit 07ebb6f
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 14 deletions.
17 changes: 16 additions & 1 deletion agent/agentserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"os"
"strings"
"sync"
"time"

"github.com/uber/kraken/build-index/tagclient"
"github.com/uber/kraken/core"
Expand All @@ -39,7 +40,10 @@ import (
)

// Config defines Server configuration.
type Config struct{}
type Config struct {
// How long a successful readiness check is valid for. If 0, disable caching successful readiness.
readinessCacheTTL time.Duration `yaml:"readiness_cache_ttl"`
}

// Server defines the agent HTTP server.
type Server struct {
Expand All @@ -50,6 +54,7 @@ type Server struct {
tags tagclient.Client
ac announceclient.Client
containerRuntime containerruntime.Factory
lastReady time.Time
}

// New creates a new Server.
Expand Down Expand Up @@ -208,6 +213,14 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) error {
}

func (s *Server) readinessCheckHandler(w http.ResponseWriter, r *http.Request) error {
if s.config.readinessCacheTTL != 0 {
rCacheValid := s.lastReady.Add(s.config.readinessCacheTTL).After(time.Now())
if rCacheValid {
io.WriteString(w, "OK")
return nil
}
}

var schedErr, buildIndexErr, trackerErr error
var wg sync.WaitGroup

Expand Down Expand Up @@ -236,6 +249,8 @@ func (s *Server) readinessCheckHandler(w http.ResponseWriter, r *http.Request) e
if len(errMsgs) != 0 {
return handler.Errorf("agent not ready: %v", strings.Join(errMsgs, "\n")).Status(http.StatusServiceUnavailable)
}

s.lastReady = time.Now()
io.WriteString(w, "OK")
return nil
}
Expand Down
95 changes: 82 additions & 13 deletions agent/agentserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/url"
"strings"
Expand Down Expand Up @@ -78,8 +79,8 @@ func newServerMocks(t *testing.T) (*serverMocks, func()) {
containerruntime, &cleanup}, cleanup.Run
}

func (m *serverMocks) startServer() string {
s := New(Config{}, tally.NoopScope, m.cads, m.sched, m.tags, m.ac, m.containerRuntime)
func (m *serverMocks) startServer(c Config) string {
s := New(c, tally.NoopScope, m.cads, m.sched, m.tags, m.ac, m.containerRuntime)
addr, stop := testutil.StartServer(s.Handler())
m.cleanup.Add(stop)
return addr
Expand All @@ -96,7 +97,7 @@ func TestGetTag(t *testing.T) {

mocks.tags.EXPECT().Get(tag).Return(d, nil)

c := agentclient.New(mocks.startServer())
c := agentclient.New(mocks.startServer(Config{}))

result, err := c.GetTag(tag)
require.NoError(err)
Expand All @@ -113,7 +114,7 @@ func TestGetTagNotFound(t *testing.T) {

mocks.tags.EXPECT().Get(tag).Return(core.Digest{}, tagclient.ErrTagNotFound)

c := agentclient.New(mocks.startServer())
c := agentclient.New(mocks.startServer(Config{}))

_, err := c.GetTag(tag)
require.Error(err)
Expand All @@ -134,7 +135,7 @@ func TestDownload(t *testing.T) {
return store.RunDownload(mocks.cads, d, blob.Content)
})

addr := mocks.startServer()
addr := mocks.startServer(Config{})
c := agentclient.New(addr)

r, err := c.Download(namespace, blob.Digest)
Expand All @@ -155,7 +156,7 @@ func TestDownloadNotFound(t *testing.T) {

mocks.sched.EXPECT().Download(namespace, blob.Digest).Return(scheduler.ErrTorrentNotFound)

addr := mocks.startServer()
addr := mocks.startServer(Config{})
c := agentclient.New(addr)

_, err := c.Download(namespace, blob.Digest)
Expand All @@ -174,7 +175,7 @@ func TestDownloadUnknownError(t *testing.T) {

mocks.sched.EXPECT().Download(namespace, blob.Digest).Return(fmt.Errorf("test error"))

addr := mocks.startServer()
addr := mocks.startServer(Config{})
c := agentclient.New(addr)

_, err := c.Download(namespace, blob.Digest)
Expand All @@ -199,7 +200,7 @@ func TestHealthHandler(t *testing.T) {

mocks.sched.EXPECT().Probe().Return(test.probeErr)

addr := mocks.startServer()
addr := mocks.startServer(Config{})

_, err := httputil.Get(fmt.Sprintf("http://%s/health", addr))
if test.probeErr != nil {
Expand Down Expand Up @@ -265,7 +266,7 @@ func TestReadinessCheckHandler(t *testing.T) {
mocks.tags.EXPECT().CheckReadiness().Return(tc.buildIndexErr)
mocks.ac.EXPECT().CheckReadiness().Return(tc.trackerErr)

addr := mocks.startServer()
addr := mocks.startServer(Config{})
_, err := httputil.Get(fmt.Sprintf("http://%s/readiness", addr))
if tc.wantErr == "" {
require.Nil(err)
Expand All @@ -276,13 +277,81 @@ func TestReadinessCheckHandler(t *testing.T) {
}
}

func TestReadinessCheckHandlerCache(t *testing.T) {
for _, tc := range []struct {
desc string
firstCallErr error
readinessCacheTTL time.Duration
waitInvalidation bool
}{
{
desc: "call 1 succeeds, so second call automatically succeeds",
firstCallErr: nil,
readinessCacheTTL: 10 * time.Minute,
waitInvalidation: false,
},
{
desc: "call 1 fails, so second call performs checks",
firstCallErr: errors.New("test-error"),
readinessCacheTTL: 10 * time.Minute,
waitInvalidation: false,
},
{
desc: "call 1 succeeds, but cache becomes invalid, so second call performs checks",
firstCallErr: nil,
readinessCacheTTL: 50 * time.Millisecond,
waitInvalidation: true,
},
{
desc: "call 1 succeeds, but caching is disabled, so second call performs checks",
firstCallErr: nil,
readinessCacheTTL: 0,
waitInvalidation: false,
},
} {
t.Run(tc.desc, func(t *testing.T) {
require := require.New(t)
mocks, cleanup := newServerMocks(t)
defer cleanup()

mocks.sched.EXPECT().Probe().Return(tc.firstCallErr).Times(1)
mocks.tags.EXPECT().CheckReadiness().Return(tc.firstCallErr).Times(1)
mocks.ac.EXPECT().CheckReadiness().Return(tc.firstCallErr).Times(1)
if tc.firstCallErr != nil || tc.waitInvalidation || tc.readinessCacheTTL == 0 {
mocks.sched.EXPECT().Probe().Return(nil).Times(1)
mocks.tags.EXPECT().CheckReadiness().Return(nil).Times(1)
mocks.ac.EXPECT().CheckReadiness().Return(nil).Times(1)
}

addr := mocks.startServer(Config{readinessCacheTTL: tc.readinessCacheTTL})
r, err := httputil.Get(fmt.Sprintf("http://%s/readiness", addr))
if tc.firstCallErr == nil {
require.Nil(err)
respBody, _ := io.ReadAll(r.Body)

Check failure on line 330 in agent/agentserver/server_test.go

View workflow job for this annotation

GitHub Actions / build (1.14)

undefined: io.ReadAll

Check failure on line 330 in agent/agentserver/server_test.go

View workflow job for this annotation

GitHub Actions / build (1.14)

undefined: io.ReadAll
require.Equal("OK", string(respBody))
} else {
require.Error(err)
}

if tc.waitInvalidation {
time.Sleep(tc.readinessCacheTTL)
}

r, err = httputil.Get(fmt.Sprintf("http://%s/readiness", addr))
require.Nil(err)
respBody, _ := io.ReadAll(r.Body)

Check failure on line 342 in agent/agentserver/server_test.go

View workflow job for this annotation

GitHub Actions / build (1.14)

undefined: io.ReadAll

Check failure on line 342 in agent/agentserver/server_test.go

View workflow job for this annotation

GitHub Actions / build (1.14)

undefined: io.ReadAll
require.Equal("OK", string(respBody))
})
}
}

func TestPatchSchedulerConfigHandler(t *testing.T) {
require := require.New(t)

mocks, cleanup := newServerMocks(t)
defer cleanup()

addr := mocks.startServer()
addr := mocks.startServer(Config{})

config := scheduler.Config{
ConnTTI: time.Minute,
Expand Down Expand Up @@ -311,7 +380,7 @@ func TestGetBlacklistHandler(t *testing.T) {
}}
mocks.sched.EXPECT().BlacklistSnapshot().Return(blacklist, nil)

addr := mocks.startServer()
addr := mocks.startServer(Config{})

resp, err := httputil.Get(fmt.Sprintf("http://%s/x/blacklist", addr))
require.NoError(err)
Expand All @@ -329,7 +398,7 @@ func TestDeleteBlobHandler(t *testing.T) {

d := core.DigestFixture()

addr := mocks.startServer()
addr := mocks.startServer(Config{})

mocks.sched.EXPECT().RemoveTorrent(d).Return(nil)

Expand Down Expand Up @@ -381,7 +450,7 @@ func TestPreloadHandler(t *testing.T) {
defer cleanup()

tt.setup(mocks)
addr := mocks.startServer()
addr := mocks.startServer(Config{})

_, err := httputil.Get(fmt.Sprintf("http://%s%s", addr, tt.url))
if tt.expectedError != "" {
Expand Down

0 comments on commit 07ebb6f

Please sign in to comment.