diff --git a/api/events.go b/api/events.go index 85fe857d9..ea237cade 100644 --- a/api/events.go +++ b/api/events.go @@ -11,6 +11,7 @@ import ( ) const ( + ModuleACL = "acl" ModuleConsensus = "consensus" ModuleContract = "contract" ModuleContractSet = "contract_set" @@ -28,6 +29,12 @@ var ( ) type ( + EventACLUpdate struct { + Allowlist []types.PublicKey `json:"allowlist"` + Blocklist []string `json:"blocklist"` + Timestamp time.Time `json:"timestamp"` + } + EventConsensusUpdate struct { ConsensusState TransactionFee types.Currency `json:"transactionFee"` @@ -73,6 +80,15 @@ type ( ) var ( + WebhookACLUpdate = func(url string, headers map[string]string) webhooks.Webhook { + return webhooks.Webhook{ + Event: EventUpdate, + Headers: headers, + Module: ModuleACL, + URL: url, + } + } + WebhookConsensusUpdate = func(url string, headers map[string]string) webhooks.Webhook { return webhooks.Webhook{ Event: EventUpdate, @@ -143,6 +159,14 @@ func ParseEventWebhook(event webhooks.Event) (interface{}, error) { return nil, err } switch event.Module { + case ModuleACL: + if event.Event == EventUpdate { + var e EventACLUpdate + if err := json.Unmarshal(bytes, &e); err != nil { + return nil, err + } + return e, nil + } case ModuleContract: switch event.Event { case EventAdd: diff --git a/api/host.go b/api/host.go index 66e974a8b..128b75555 100644 --- a/api/host.go +++ b/api/host.go @@ -15,6 +15,7 @@ import ( const ( ContractFilterModeAll = "all" ContractFilterModeActive = "active" + ContractFilterModeDownload = "download" ContractFilterModeArchived = "archived" HostFilterModeAll = "all" diff --git a/bus/routes.go b/bus/routes.go index 029b61ab6..fbe83034e 100644 --- a/bus/routes.go +++ b/bus/routes.go @@ -638,6 +638,20 @@ func (b *Bus) hostsAllowlistHandlerPUT(jc jape.Context) { return } else if jc.Check("couldn't update allowlist entries", b.store.UpdateHostAllowlistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil { return + } else { + if allowlist, err := b.store.HostAllowlist(ctx); jc.Check("couldn't fetch allowlist", err) == nil { + if blocklist, err := b.store.HostBlocklist(ctx); jc.Check("couldn't fetch blocklist", err) == nil { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleACL, + Event: api.EventUpdate, + Payload: api.EventACLUpdate{ + Allowlist: allowlist, + Blocklist: blocklist, + Timestamp: time.Now().UTC(), + }, + }) + } + } } } } @@ -658,6 +672,20 @@ func (b *Bus) hostsBlocklistHandlerPUT(jc jape.Context) { return } else if jc.Check("couldn't update blocklist entries", b.store.UpdateHostBlocklistEntries(ctx, req.Add, req.Remove, req.Clear)) != nil { return + } else { + if allowlist, err := b.store.HostAllowlist(ctx); jc.Check("couldn't fetch allowlist", err) == nil { + if blocklist, err := b.store.HostBlocklist(ctx); jc.Check("couldn't fetch blocklist", err) == nil { + b.broadcastAction(webhooks.Event{ + Module: api.ModuleACL, + Event: api.EventUpdate, + Payload: api.EventACLUpdate{ + Allowlist: allowlist, + Blocklist: blocklist, + Timestamp: time.Now().UTC(), + }, + }) + } + } } } } @@ -676,6 +704,7 @@ func (b *Bus) contractsHandlerGET(jc jape.Context) { case api.ContractFilterModeAll: case api.ContractFilterModeActive: case api.ContractFilterModeArchived: + case api.ContractFilterModeDownload: default: jc.Error(fmt.Errorf("invalid filter mode: '%v'", filterMode), http.StatusBadRequest) return diff --git a/internal/test/e2e/blocklist_test.go b/internal/test/e2e/blocklist_test.go index b2c55fdea..0fbf026e6 100644 --- a/internal/test/e2e/blocklist_test.go +++ b/internal/test/e2e/blocklist_test.go @@ -1,6 +1,7 @@ package e2e import ( + "bytes" "context" "fmt" "testing" @@ -9,9 +10,15 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/test" + "go.uber.org/zap" + "lukechampine.com/frand" ) func TestBlocklist(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + ctx := context.Background() // create a new test cluster @@ -21,10 +28,17 @@ func TestBlocklist(t *testing.T) { defer cluster.Shutdown() b := cluster.Bus tt := cluster.tt + cs := test.AutopilotConfig.Contracts.Set // fetch contracts - opts := api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set} - contracts, err := b.Contracts(ctx, opts) + contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: cs}) + tt.OK(err) + if len(contracts) != 3 { + t.Fatalf("unexpected number of contracts, %v != 3", len(contracts)) + } + + // fetch again using filter mode + contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) tt.OK(err) if len(contracts) != 3 { t.Fatalf("unexpected number of contracts, %v != 3", len(contracts)) @@ -37,12 +51,21 @@ func TestBlocklist(t *testing.T) { err = b.UpdateHostAllowlist(ctx, []types.PublicKey{hk1, hk2}, nil, false) tt.OK(err) + // assert the contract of h3 can't be used for downloads + contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) + tt.OK(err) + if len(contracts) != 2 { + t.Fatalf("unexpected number of contracts, %v != 2", len(contracts)) + } else if contracts[0].HostKey == hk3 || contracts[1].HostKey == hk3 { + t.Fatal("unexpected download contract") + } + // assert h3 is no longer in the contract set tt.Retry(100, 100*time.Millisecond, func() error { - contracts, err := b.Contracts(ctx, opts) + contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: cs}) tt.OK(err) if len(contracts) != 2 { - return fmt.Errorf("unexpected number of contracts in set '%v', %v != 2", opts.ContractSet, len(contracts)) + return fmt.Errorf("unexpected number of contracts in set '%v', %v != 2", cs, len(contracts)) } for _, c := range contracts { if c.HostKey == hk3 { @@ -57,12 +80,21 @@ func TestBlocklist(t *testing.T) { tt.OK(err) tt.OK(b.UpdateHostBlocklist(ctx, []string{h1.NetAddress}, nil, false)) + // assert the contract of h1 can't be used for downloads + contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) + tt.OK(err) + if len(contracts) != 1 { + t.Fatalf("unexpected number of contracts, %v != 1", len(contracts)) + } else if contracts[0].HostKey != hk2 { + t.Fatal("unexpected download contract") + } + // assert h1 is no longer in the contract set tt.Retry(100, 100*time.Millisecond, func() error { contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: test.AutopilotConfig.Contracts.Set}) tt.OK(err) if len(contracts) != 1 { - return fmt.Errorf("unexpected number of contracts in set '%v', %v != 1", opts.ContractSet, len(contracts)) + return fmt.Errorf("unexpected number of contracts in set '%v', %v != 1", cs, len(contracts)) } for _, c := range contracts { if c.HostKey == hk1 { @@ -75,11 +107,19 @@ func TestBlocklist(t *testing.T) { // clear the allowlist and blocklist and assert we have 3 contracts again tt.OK(b.UpdateHostAllowlist(ctx, nil, []types.PublicKey{hk1, hk2}, false)) tt.OK(b.UpdateHostBlocklist(ctx, nil, []string{h1.NetAddress}, false)) + + // fetch again using filter mode + contracts, err = b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) + tt.OK(err) + if len(contracts) != 3 { + t.Fatalf("unexpected number of contracts, %v != 3", len(contracts)) + } + tt.Retry(100, 100*time.Millisecond, func() error { - contracts, err := b.Contracts(ctx, opts) + contracts, err := b.Contracts(ctx, api.ContractsOpts{ContractSet: cs}) tt.OK(err) if len(contracts) != 3 { - return fmt.Errorf("unexpected number of contracts in set '%v', %v != 3", opts.ContractSet, len(contracts)) + return fmt.Errorf("unexpected number of contracts in set '%v', %v != 3", cs, len(contracts)) } return nil }) @@ -156,3 +196,62 @@ func TestBlocklist(t *testing.T) { t.Fatal("unexpected number of hosts", len(hosts)) } } + +func TestBlocklistUploadDownload(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + + // create a new test cluster + cluster := newTestCluster(t, testClusterOptions{ + logger: zap.NewNop(), + hosts: test.RedundancySettings.TotalShards, + }) + defer cluster.Shutdown() + b := cluster.Bus + w := cluster.Worker + tt := cluster.tt + + // prepare a file + data := make([]byte, 128) + tt.OKAll(frand.Read(data)) + + // upload the data + tt.OKAll(w.UploadObject(context.Background(), bytes.NewReader(data), testBucket, "/foo", api.UploadObjectOptions{})) + + // download data + var buffer bytes.Buffer + tt.OK(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{})) + + // block two hosts + h1 := cluster.hosts[0] + h2 := cluster.hosts[1] + h1Addr := h1.settings.Settings().NetAddress + h2Addr := h2.settings.Settings().NetAddress + tt.OK(b.UpdateHostBlocklist(context.Background(), []string{h1Addr, h2Addr}, nil, false)) + + // download data again and expect it to fail + tt.Fail(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{})) + + // unblock one of the hosts and expect it to succeed + buffer.Reset() + tt.OK(b.UpdateHostBlocklist(context.Background(), nil, []string{h1Addr}, false)) + tt.OK(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{})) + + // clear blocklist and set allowlist to allow one host + tt.OK(b.UpdateHostBlocklist(context.Background(), nil, nil, true)) + tt.OK(b.UpdateHostAllowlist(context.Background(), []types.PublicKey{h1.PublicKey()}, nil, false)) + + c, err := b.Contracts(context.Background(), api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) + tt.OK(err) + if len(c) != 1 { + t.Fatal("unexpected number of contracts", len(c)) + } + // download data again and expect it to fail + tt.Fail(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{})) + + // extend allowlist with one more host and expect download to succeed + buffer.Reset() + tt.OK(b.UpdateHostAllowlist(context.Background(), []types.PublicKey{h2.PublicKey()}, nil, false)) + tt.OK(w.DownloadObject(context.Background(), &buffer, testBucket, "/foo", api.DownloadObjectOptions{})) +} diff --git a/internal/test/e2e/events_test.go b/internal/test/e2e/events_test.go index 5fa7dd768..b4f808e5e 100644 --- a/internal/test/e2e/events_test.go +++ b/internal/test/e2e/events_test.go @@ -21,6 +21,7 @@ import ( func TestEvents(t *testing.T) { // list all webhooks allEvents := []func(string, map[string]string) webhooks.Webhook{ + api.WebhookACLUpdate, api.WebhookConsensusUpdate, api.WebhookContractArchive, api.WebhookContractRenew, @@ -121,13 +122,16 @@ func TestEvents(t *testing.T) { gp, err := b.GougingParams(context.Background()) tt.OK(err) + // update ACL + h := cluster.hosts[0] + tt.OK(b.UpdateHostAllowlist(context.Background(), []types.PublicKey{h.PublicKey()}, nil, false)) + // update settings gs := gp.GougingSettings gs.HostBlockHeightLeeway = 100 tt.OK(b.UpdateGougingSettings(context.Background(), gs)) // update host setting - h := cluster.hosts[0] settings := h.settings.Settings() settings.NetAddress = "127.0.0.1:0" tt.OK(h.UpdateSettings(settings)) diff --git a/internal/test/tt.go b/internal/test/tt.go index d44152eda..4a1b6c44c 100644 --- a/internal/test/tt.go +++ b/internal/test/tt.go @@ -11,6 +11,7 @@ type ( AssertContains(err error, target string) AssertIs(err, target error) + Fail(err error) FailAll(vs ...interface{}) OK(err error) OKAll(vs ...interface{}) @@ -63,6 +64,13 @@ func (t impl) AssertIs(err, target error) { t.AssertContains(err, target.Error()) } +func (t impl) Fail(err error) { + t.Helper() + if err == nil { + t.Fatal("should've failed") + } +} + func (t impl) FailAll(vs ...interface{}) { t.Helper() for _, v := range vs { diff --git a/internal/worker/cache.go b/internal/worker/cache.go index d357293ff..be80c2794 100644 --- a/internal/worker/cache.go +++ b/internal/worker/cache.go @@ -119,14 +119,14 @@ func (c *cache) DownloadContracts(ctx context.Context) (contracts []api.Contract // fetch directly from bus if the cache is not ready if !c.isReady() { c.logger.Warn(errCacheNotReady) - contracts, err = c.b.Contracts(ctx, api.ContractsOpts{}) + contracts, err = c.b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) return } // fetch from bus if it's not cached or expired value, found, expired := c.cache.Get(cacheKeyDownloadContracts) if !found || expired { - contracts, err = c.b.Contracts(ctx, api.ContractsOpts{}) + contracts, err = c.b.Contracts(ctx, api.ContractsOpts{FilterMode: api.ContractFilterModeDownload}) if err == nil { c.cache.Set(cacheKeyDownloadContracts, contracts) } @@ -187,6 +187,9 @@ func (c *cache) HandleEvent(event webhooks.Event) (err error) { case api.EventContractRenew: log = log.With("fcid", e.Renewal.ID, "renewedFrom", e.Renewal.RenewedFrom, "ts", e.Timestamp) c.handleContractRenew(e) + case api.EventACLUpdate: + log = log.With("ts", e.Timestamp) + c.handleACLUpdate(e) case api.EventHostUpdate: log = log.With("hk", e.HostKey, "ts", e.Timestamp) c.handleHostUpdate(e) @@ -316,6 +319,10 @@ func (c *cache) handleHostUpdate(e api.EventHostUpdate) { c.cache.Set(cacheKeyDownloadContracts, contracts) } +func (c *cache) handleACLUpdate(_ api.EventACLUpdate) { + c.cache.Invalidate(cacheKeyDownloadContracts) +} + func (c *cache) handleSettingUpdate(e api.EventSettingUpdate) { // return early if the cache doesn't have gouging params to update value, found, _ := c.cache.Get(cacheKeyGougingParams) diff --git a/internal/worker/events.go b/internal/worker/events.go index e0960fd5c..a0aa2debc 100644 --- a/internal/worker/events.go +++ b/internal/worker/events.go @@ -111,6 +111,7 @@ func (e *eventSubscriber) Register(ctx context.Context, eventsURL string, opts . // prepare webhooks webhooks := []webhooks.Webhook{ + api.WebhookACLUpdate(eventsURL, headers), api.WebhookConsensusUpdate(eventsURL, headers), api.WebhookContractAdd(eventsURL, headers), api.WebhookContractArchive(eventsURL, headers), diff --git a/internal/worker/events_test.go b/internal/worker/events_test.go index 95a74da91..2ade704c0 100644 --- a/internal/worker/events_test.go +++ b/internal/worker/events_test.go @@ -156,8 +156,8 @@ func TestEventSubscriber(t *testing.T) { time.Sleep(testRegisterInterval) // assert webhook was registered - if webhooks := w.Webhooks(); len(webhooks) != 6 { - t.Fatal("expected 6 webhooks, got", len(webhooks)) + if webhooks := w.Webhooks(); len(webhooks) != 7 { + t.Fatal("expected 7 webhooks, got", len(webhooks)) } // send the same event again diff --git a/stores/sql/main.go b/stores/sql/main.go index a5db9f4d4..df8942a63 100644 --- a/stores/sql/main.go +++ b/stores/sql/main.go @@ -292,11 +292,26 @@ func Contracts(ctx context.Context, tx sql.Tx, opts api.ContractsOpts) ([]api.Co whereArgs = append(whereArgs, contractSetID) } + var hasAllowlist, hasBlocklist bool + if err := tx.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM host_allowlist_entries)").Scan(&hasAllowlist); err != nil { + return nil, fmt.Errorf("failed to check for allowlist: %w", err) + } else if err := tx.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM host_blocklist_entries)").Scan(&hasBlocklist); err != nil { + return nil, fmt.Errorf("failed to check for blocklist: %w", err) + } + if opts.FilterMode != "" { // validate filter mode switch opts.FilterMode { case api.ContractFilterModeActive: whereExprs = append(whereExprs, "c.archival_reason IS NULL") + case api.ContractFilterModeDownload: + whereExprs = append(whereExprs, "c.archival_reason IS NULL") + if hasAllowlist { + whereExprs = append(whereExprs, "EXISTS (SELECT 1 FROM host_allowlist_entry_hosts hbeh WHERE hbeh.db_host_id = h.id)") + } + if hasBlocklist { + whereExprs = append(whereExprs, "NOT EXISTS (SELECT 1 FROM host_blocklist_entry_hosts hbeh WHERE hbeh.db_host_id = h.id)") + } case api.ContractFilterModeArchived: whereExprs = append(whereExprs, "c.archival_reason IS NOT NULL") case api.ContractFilterModeAll: diff --git a/stores/sql/mysql/chain.go b/stores/sql/mysql/chain.go index 35f1dc90f..f96a4e3eb 100644 --- a/stores/sql/mysql/chain.go +++ b/stores/sql/mysql/chain.go @@ -253,42 +253,49 @@ func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, return fmt.Errorf("failed to fetch host id: %w", err) } - // update allow list - rows, err := c.tx.Query(c.ctx, "SELECT id, entry FROM host_allowlist_entries") + // prepare statements + insertAllowlistLinkStmt, err := c.tx.Prepare(c.ctx, "INSERT IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) VALUES (?,?)") + if err != nil { + return fmt.Errorf("failed to prepare statement, %w", err) + } + defer insertAllowlistLinkStmt.Close() + + insertBlocklistLinkStmt, err := c.tx.Prepare(c.ctx, "INSERT IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) VALUES (?,?)") + if err != nil { + return fmt.Errorf("failed to prepare statement, %w", err) + } + defer insertBlocklistLinkStmt.Close() + + deleteBlocklistLinkStmt, err := c.tx.Prepare(c.ctx, "DELETE FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = ? AND db_host_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement, %w", err) + } + defer deleteBlocklistLinkStmt.Close() + + // fetch allowlist entries + rows, err := c.tx.Query(c.ctx, "SELECT id FROM host_allowlist_entries WHERE entry = ?", ssql.PublicKey(hk)) if err != nil { return fmt.Errorf("failed to fetch allow list: %w", err) } defer rows.Close() - allowlistEntries := make(map[types.PublicKey]int64) + var allowlistEntryIDs []int64 for rows.Next() { var id int64 - var pk ssql.PublicKey - if err := rows.Scan(&id, &pk); err != nil { + if err := rows.Scan(&id); err != nil { return fmt.Errorf("failed to scan row: %w", err) } - allowlistEntries[types.PublicKey(pk)] = id + allowlistEntryIDs = append(allowlistEntryIDs, id) } - for pk, id := range allowlistEntries { - if hk == types.PublicKey(pk) { - if _, err := c.tx.Exec(c.ctx, - "INSERT IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) VALUES (?,?)", - id, - hostID, - ); err != nil { - return fmt.Errorf("failed to insert host into allowlist: %w", err) - } + // insert allowlist links + for _, entryID := range allowlistEntryIDs { + if _, err := insertAllowlistLinkStmt.Exec(c.ctx, entryID, hostID); err != nil { + return fmt.Errorf("failed to insert host into allowlist: %w", err) } } - // update blocklist - values := []string{ha.NetAddress} - host, _, err := net.SplitHostPort(ha.NetAddress) - if err == nil { - values = append(values, host) - } - + // fetch blocklist entries rows, err = c.tx.Query(c.ctx, "SELECT id, entry FROM host_blocklist_entries") if err != nil { return fmt.Errorf("failed to fetch block list: %w", err) @@ -308,6 +315,14 @@ func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, entries = append(entries, r) } + // prepare blocklist values + values := []string{ha.NetAddress} + host, _, err := net.SplitHostPort(ha.NetAddress) + if err == nil { + values = append(values, host) + } + + // insert blocklist links for _, row := range entries { var blocked bool for _, value := range values { @@ -317,19 +332,11 @@ func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, } } if blocked { - if _, err := c.tx.Exec(c.ctx, - "INSERT IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) VALUES (?,?)", - row.id, - hostID, - ); err != nil { + if _, err := insertBlocklistLinkStmt.Exec(c.ctx, row.id, hostID); err != nil { return fmt.Errorf("failed to insert host into blocklist: %w", err) } } else { - if _, err := c.tx.Exec(c.ctx, - "DELETE FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = ? AND db_host_id = ?", - row.id, - hostID, - ); err != nil { + if _, err := deleteBlocklistLinkStmt.Exec(c.ctx, row.id, hostID); err != nil { return fmt.Errorf("failed to remove host from blocklist: %w", err) } } diff --git a/stores/sql/sqlite/chain.go b/stores/sql/sqlite/chain.go index 00ad9f405..3eb19a161 100644 --- a/stores/sql/sqlite/chain.go +++ b/stores/sql/sqlite/chain.go @@ -265,42 +265,49 @@ func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, } } - // update allow list - rows, err := c.tx.Query(c.ctx, "SELECT id, entry FROM host_allowlist_entries") + // prepare statements + insertAllowlistLinkStmt, err := c.tx.Prepare(c.ctx, "INSERT OR IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) VALUES (?,?)") + if err != nil { + return fmt.Errorf("failed to prepare statement, %w", err) + } + defer insertAllowlistLinkStmt.Close() + + insertBlocklistLinkStmt, err := c.tx.Prepare(c.ctx, "INSERT OR IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) VALUES (?,?)") + if err != nil { + return fmt.Errorf("failed to prepare statement, %w", err) + } + defer insertBlocklistLinkStmt.Close() + + deleteBlocklistLinkStmt, err := c.tx.Prepare(c.ctx, "DELETE FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = ? AND db_host_id = ?") + if err != nil { + return fmt.Errorf("failed to prepare statement, %w", err) + } + defer deleteBlocklistLinkStmt.Close() + + // fetch allowlist entries + rows, err := c.tx.Query(c.ctx, "SELECT id FROM host_allowlist_entries WHERE entry = ?", ssql.PublicKey(hk)) if err != nil { return fmt.Errorf("failed to fetch allow list: %w", err) } defer rows.Close() - allowlistEntries := make(map[types.PublicKey]int64) + var allowlistEntryIDs []int64 for rows.Next() { var id int64 - var pk ssql.PublicKey - if err := rows.Scan(&id, &pk); err != nil { + if err := rows.Scan(&id); err != nil { return fmt.Errorf("failed to scan row: %w", err) } - allowlistEntries[types.PublicKey(pk)] = id + allowlistEntryIDs = append(allowlistEntryIDs, id) } - for pk, id := range allowlistEntries { - if hk == types.PublicKey(pk) { - if _, err := c.tx.Exec(c.ctx, - "INSERT OR IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) VALUES (?,?)", - id, - hostID, - ); err != nil { - return fmt.Errorf("failed to insert host into allowlist: %w", err) - } + // insert allowlist links + for _, entryID := range allowlistEntryIDs { + if _, err := insertAllowlistLinkStmt.Exec(c.ctx, entryID, hostID); err != nil { + return fmt.Errorf("failed to insert host into allowlist: %w", err) } } - // update blocklist - values := []string{ha.NetAddress} - host, _, err := net.SplitHostPort(ha.NetAddress) - if err == nil { - values = append(values, host) - } - + // fetch blocklist entries rows, err = c.tx.Query(c.ctx, "SELECT id, entry FROM host_blocklist_entries") if err != nil { return fmt.Errorf("failed to fetch block list: %w", err) @@ -320,6 +327,14 @@ func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, entries = append(entries, r) } + // prepare blocklist values + values := []string{ha.NetAddress} + host, _, err := net.SplitHostPort(ha.NetAddress) + if err == nil { + values = append(values, host) + } + + // insert blocklist links for _, row := range entries { var blocked bool for _, value := range values { @@ -329,19 +344,11 @@ func (c chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, } } if blocked { - if _, err := c.tx.Exec(c.ctx, - "INSERT OR IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) VALUES (?,?)", - row.id, - hostID, - ); err != nil { + if _, err := insertBlocklistLinkStmt.Exec(c.ctx, row.id, hostID); err != nil { return fmt.Errorf("failed to insert host into blocklist: %w", err) } } else { - if _, err := c.tx.Exec(c.ctx, - "DELETE FROM host_blocklist_entry_hosts WHERE db_blocklist_entry_id = ? AND db_host_id = ?", - row.id, - hostID, - ); err != nil { + if _, err := deleteBlocklistLinkStmt.Exec(c.ctx, row.id, hostID); err != nil { return fmt.Errorf("failed to remove host from blocklist: %w", err) } }