Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalidate worker cache on ACL update #1605

Closed
wants to merge 13 commits into from
Closed
24 changes: 24 additions & 0 deletions api/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
)

const (
ModuleACL = "acl"
ModuleConsensus = "consensus"
ModuleContract = "contract"
ModuleContractSet = "contract_set"
Expand All @@ -28,6 +29,12 @@ var (
)

type (
EventACLUpdate struct {
Allowlist []types.PublicKey `json:"allowlist"`
Blocklist []string `json:"blocklist"`
Timestamp time.Time `json:"timestamp"`
}

EventConsensusUpdate struct {
ConsensusState
TransactionFee types.Currency `json:"transactionFee"`
Expand Down Expand Up @@ -73,6 +80,15 @@ type (
)

var (
WebhookACLUpdate = func(url string, headers map[string]string) webhooks.Webhook {
return webhooks.Webhook{
Event: EventUpdate,
Headers: headers,
Module: ModuleACL,
URL: url,
}
}

WebhookConsensusUpdate = func(url string, headers map[string]string) webhooks.Webhook {
return webhooks.Webhook{
Event: EventUpdate,
Expand Down Expand Up @@ -143,6 +159,14 @@ func ParseEventWebhook(event webhooks.Event) (interface{}, error) {
return nil, err
}
switch event.Module {
case ModuleACL:
if event.Event == EventUpdate {
var e EventACLUpdate
if err := json.Unmarshal(bytes, &e); err != nil {
return nil, err
}
return e, nil
}
case ModuleContract:
switch event.Event {
case EventAdd:
Expand Down
1 change: 1 addition & 0 deletions api/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
const (
ContractFilterModeAll = "all"
ContractFilterModeActive = "active"
ContractFilterModeDownload = "download"
ContractFilterModeArchived = "archived"

HostFilterModeAll = "all"
Expand Down
29 changes: 29 additions & 0 deletions bus/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,20 @@ func (b *Bus) hostsAllowlistHandlerPUT(jc jape.Context) {
return
} else if jc.Check("couldn't update allowlist entries", b.store.UpdateHostAllowlistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil {
return
} else {
if allowlist, err := b.store.HostAllowlist(ctx); jc.Check("couldn't fetch allowlist", err) == nil {
if blocklist, err := b.store.HostBlocklist(ctx); jc.Check("couldn't fetch blocklist", err) == nil {
b.broadcastAction(webhooks.Event{
Module: api.ModuleACL,
Event: api.EventUpdate,
Payload: api.EventACLUpdate{
Allowlist: allowlist,
Blocklist: blocklist,
Timestamp: time.Now().UTC(),
},
})
}
}
}
}
}
Expand All @@ -658,6 +672,20 @@ func (b *Bus) hostsBlocklistHandlerPUT(jc jape.Context) {
return
} else if jc.Check("couldn't update blocklist entries", b.store.UpdateHostBlocklistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil {
return
} else {
if allowlist, err := b.store.HostAllowlist(ctx); jc.Check("couldn't fetch allowlist", err) == nil {
if blocklist, err := b.store.HostBlocklist(ctx); jc.Check("couldn't fetch blocklist", err) == nil {
b.broadcastAction(webhooks.Event{
Module: api.ModuleACL,
Event: api.EventUpdate,
Payload: api.EventACLUpdate{
Allowlist: allowlist,
Blocklist: blocklist,
Timestamp: time.Now().UTC(),
},
})
}
}
}
}
}
Expand All @@ -676,6 +704,7 @@ func (b *Bus) contractsHandlerGET(jc jape.Context) {
case api.ContractFilterModeAll:
case api.ContractFilterModeActive:
case api.ContractFilterModeArchived:
case api.ContractFilterModeDownload:
default:
jc.Error(fmt.Errorf("invalid filter mode: '%v'", filterMode), http.StatusBadRequest)
return
Expand Down
113 changes: 106 additions & 7 deletions internal/test/e2e/blocklist_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package e2e

import (
"bytes"
"context"
"fmt"
"testing"
Expand All @@ -9,9 +10,15 @@ import (
"go.sia.tech/core/types"
"go.sia.tech/renterd/api"
"go.sia.tech/renterd/internal/test"
"go.uber.org/zap"
"lukechampine.com/frand"
)

func TestBlocklist(t *testing.T) {
if testing.Short() {
t.SkipNow()
}

ctx := context.Background()

// create a new test cluster
Expand All @@ -21,10 +28,17 @@ func TestBlocklist(t *testing.T) {
defer cluster.Shutdown()
b := cluster.Bus
tt := cluster.tt
cs := test.AutopilotConfig.Contracts.Set

// fetch contracts
opts := api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set}
contracts, err := b.Contracts(ctx, opts)
contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: cs})
tt.OK(err)
if len(contracts) != 3 {
t.Fatalf("unexpected number of contracts, %v != 3", len(contracts))
}

// fetch again using filter mode
contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
tt.OK(err)
if len(contracts) != 3 {
t.Fatalf("unexpected number of contracts, %v != 3", len(contracts))
Expand All @@ -37,12 +51,21 @@ func TestBlocklist(t *testing.T) {
err = b.UpdateHostAllowlist(ctx, []types.PublicKey{hk1, hk2}, nil, false)
tt.OK(err)

// assert the contract of h3 can't be used for downloads
contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
tt.OK(err)
if len(contracts) != 2 {
t.Fatalf("unexpected number of contracts, %v != 2", len(contracts))
} else if contracts[0].HostKey == hk3 || contracts[1].HostKey == hk3 {
t.Fatal("unexpected download contract")
}

// assert h3 is no longer in the contract set
tt.Retry(100, 100*time.Millisecond, func() error {
contracts, err := b.Contracts(ctx, opts)
contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: cs})
tt.OK(err)
if len(contracts) != 2 {
return fmt.Errorf("unexpected number of contracts in set '%v', %v != 2", opts.ContractSet, len(contracts))
return fmt.Errorf("unexpected number of contracts in set '%v', %v != 2", cs, len(contracts))
}
for _, c := range contracts {
if c.HostKey == hk3 {
Expand All @@ -57,12 +80,21 @@ func TestBlocklist(t *testing.T) {
tt.OK(err)
tt.OK(b.UpdateHostBlocklist(ctx, []string{h1.NetAddress}, nil, false))

// assert the contract of h1 can't be used for downloads
contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
tt.OK(err)
if len(contracts) != 1 {
t.Fatalf("unexpected number of contracts, %v != 1", len(contracts))
} else if contracts[0].HostKey != hk2 {
t.Fatal("unexpected download contract")
}

// assert h1 is no longer in the contract set
tt.Retry(100, 100*time.Millisecond, func() error {
contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set})
tt.OK(err)
if len(contracts) != 1 {
return fmt.Errorf("unexpected number of contracts in set '%v', %v != 1", opts.ContractSet, len(contracts))
return fmt.Errorf("unexpected number of contracts in set '%v', %v != 1", cs, len(contracts))
}
for _, c := range contracts {
if c.HostKey == hk1 {
Expand All @@ -75,11 +107,19 @@ func TestBlocklist(t *testing.T) {
// clear the allowlist and blocklist and assert we have 3 contracts again
tt.OK(b.UpdateHostAllowlist(ctx, nil, []types.PublicKey{hk1, hk2}, false))
tt.OK(b.UpdateHostBlocklist(ctx, nil, []string{h1.NetAddress}, false))

// fetch again using filter mode
contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
tt.OK(err)
if len(contracts) != 3 {
t.Fatalf("unexpected number of contracts, %v != 3", len(contracts))
}

tt.Retry(100, 100*time.Millisecond, func() error {
contracts, err := b.Contracts(ctx, opts)
contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: cs})
tt.OK(err)
if len(contracts) != 3 {
return fmt.Errorf("unexpected number of contracts in set '%v', %v != 3", opts.ContractSet, len(contracts))
return fmt.Errorf("unexpected number of contracts in set '%v', %v != 3", cs, len(contracts))
}
return nil
})
Expand Down Expand Up @@ -156,3 +196,62 @@ func TestBlocklist(t *testing.T) {
t.Fatal("unexpected number of hosts", len(hosts))
}
}

func TestBlocklistUploadDownload(t *testing.T) {
if testing.Short() {
t.SkipNow()
}

// create a new test cluster
cluster := newTestCluster(t, testClusterOptions{
logger: zap.NewNop(),
hosts: test.RedundancySettings.TotalShards,
})
defer cluster.Shutdown()
b := cluster.Bus
w := cluster.Worker
tt := cluster.tt

// prepare a file
data := make([]byte, 128)
tt.OKAll(frand.Read(data))

// upload the data
tt.OKAll(w.UploadObject(context.Background(), bytes.NewReader(data), testBucket, "/foo", api.UploadObjectOptions{}))

// download data
var buffer bytes.Buffer
tt.OK(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{}))

// block two hosts
h1 := cluster.hosts[0]
h2 := cluster.hosts[1]
h1Addr := h1.settings.Settings().NetAddress
h2Addr := h2.settings.Settings().NetAddress
tt.OK(b.UpdateHostBlocklist(context.Background(), []string{h1Addr, h2Addr}, nil, false))

// download data again and expect it to fail
tt.Fail(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{}))

// unblock one of the hosts and expect it to succeed
buffer.Reset()
tt.OK(b.UpdateHostBlocklist(context.Background(), nil, []string{h1Addr}, false))
tt.OK(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{}))

// clear blocklist and set allowlist to allow one host
tt.OK(b.UpdateHostBlocklist(context.Background(), nil, nil, true))
tt.OK(b.UpdateHostAllowlist(context.Background(), []types.PublicKey{h1.PublicKey()}, nil, false))

c, err := b.Contracts(context.Background(), api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
tt.OK(err)
if len(c) != 1 {
t.Fatal("unexpected number of contracts", len(c))
}
// download data again and expect it to fail
tt.Fail(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{}))

// extend allowlist with one more host and expect download to succeed
buffer.Reset()
tt.OK(b.UpdateHostAllowlist(context.Background(), []types.PublicKey{h2.PublicKey()}, nil, false))
tt.OK(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{}))
}
6 changes: 5 additions & 1 deletion internal/test/e2e/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
func TestEvents(t *testing.T) {
// list all webhooks
allEvents := []func(string, map[string]string) webhooks.Webhook{
api.WebhookACLUpdate,
peterjan marked this conversation as resolved.
Show resolved Hide resolved
api.WebhookConsensusUpdate,
api.WebhookContractArchive,
api.WebhookContractRenew,
Expand Down Expand Up @@ -121,13 +122,16 @@ func TestEvents(t *testing.T) {
gp, err := b.GougingParams(context.Background())
tt.OK(err)

// update ACL
h := cluster.hosts[0]
tt.OK(b.UpdateHostAllowlist(context.Background(), []types.PublicKey{h.PublicKey()}, nil, false))

// update settings
gs := gp.GougingSettings
gs.HostBlockHeightLeeway = 100
tt.OK(b.UpdateGougingSettings(context.Background(), gs))

// update host setting
h := cluster.hosts[0]
settings := h.settings.Settings()
settings.NetAddress = "127.0.0.1:0"
tt.OK(h.UpdateSettings(settings))
Expand Down
8 changes: 8 additions & 0 deletions internal/test/tt.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type (

AssertContains(err error, target string)
AssertIs(err, target error)
Fail(err error)
FailAll(vs ...interface{})
OK(err error)
OKAll(vs ...interface{})
Expand Down Expand Up @@ -63,6 +64,13 @@ func (t impl) AssertIs(err, target error) {
t.AssertContains(err, target.Error())
}

func (t impl) Fail(err error) {
t.Helper()
if err == nil {
t.Fatal("should've failed")
}
}

func (t impl) FailAll(vs ...interface{}) {
t.Helper()
for _, v := range vs {
Expand Down
11 changes: 9 additions & 2 deletions internal/worker/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ func (c *cache) DownloadContracts(ctx context.Context) (contracts []api.Contract
// fetch directly from bus if the cache is not ready
if !c.isReady() {
c.logger.Warn(errCacheNotReady)
contracts, err = c.b.Contracts(ctx, api.ContractsOpts{})
contracts, err = c.b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
return
}

// fetch from bus if it's not cached or expired
value, found, expired := c.cache.Get(cacheKeyDownloadContracts)
if !found || expired {
contracts, err = c.b.Contracts(ctx, api.ContractsOpts{})
contracts, err = c.b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload})
if err == nil {
c.cache.Set(cacheKeyDownloadContracts, contracts)
}
Expand Down Expand Up @@ -187,6 +187,9 @@ func (c *cache) HandleEvent(event webhooks.Event) (err error) {
case api.EventContractRenew:
log = log.With("fcid", e.Renewal.ID, "renewedFrom", e.Renewal.RenewedFrom, "ts", e.Timestamp)
c.handleContractRenew(e)
case api.EventACLUpdate:
log = log.With("ts", e.Timestamp)
c.handleACLUpdate(e)
case api.EventHostUpdate:
log = log.With("hk", e.HostKey, "ts", e.Timestamp)
c.handleHostUpdate(e)
Expand Down Expand Up @@ -316,6 +319,10 @@ func (c *cache) handleHostUpdate(e api.EventHostUpdate) {
c.cache.Set(cacheKeyDownloadContracts, contracts)
}

func (c *cache) handleACLUpdate(_ api.EventACLUpdate) {
c.cache.Invalidate(cacheKeyDownloadContracts)
}

func (c *cache) handleSettingUpdate(e api.EventSettingUpdate) {
// return early if the cache doesn't have gouging params to update
value, found, _ := c.cache.Get(cacheKeyGougingParams)
Expand Down
1 change: 1 addition & 0 deletions internal/worker/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func (e *eventSubscriber) Register(ctx context.Context, eventsURL string, opts .

// prepare webhooks
webhooks := []webhooks.Webhook{
api.WebhookACLUpdate(eventsURL, headers),
api.WebhookConsensusUpdate(eventsURL, headers),
api.WebhookContractAdd(eventsURL, headers),
api.WebhookContractArchive(eventsURL, headers),
Expand Down
4 changes: 2 additions & 2 deletions internal/worker/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ func TestEventSubscriber(t *testing.T) {
time.Sleep(testRegisterInterval)

// assert webhook was registered
if webhooks := w.Webhooks(); len(webhooks) != 6 {
t.Fatal("expected 6 webhooks, got", len(webhooks))
if webhooks := w.Webhooks(); len(webhooks) != 7 {
t.Fatal("expected 7 webhooks, got", len(webhooks))
}

// send the same event again
Expand Down
Loading
Loading