diff --git a/stores/accounts.go b/stores/accounts.go index 523ba2697..183582b8b 100644 --- a/stores/accounts.go +++ b/stores/accounts.go @@ -9,7 +9,7 @@ import ( // Accounts returns all accounts from the db. func (s *SQLStore) Accounts(ctx context.Context) (accounts []api.Account, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { accounts, err = tx.Accounts(ctx) return err }) @@ -21,7 +21,7 @@ func (s *SQLStore) Accounts(ctx context.Context) (accounts []api.Account, err er // sync all accounts after an unclean shutdown and the bus will know not to // apply drift. func (s *SQLStore) SetUncleanShutdown(ctx context.Context) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.SetUncleanShutdown(ctx) }) } @@ -29,7 +29,7 @@ func (s *SQLStore) SetUncleanShutdown(ctx context.Context) error { // SaveAccounts saves the given accounts in the db, overwriting any existing // ones. func (s *SQLStore) SaveAccounts(ctx context.Context, accounts []api.Account) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.SaveAccounts(ctx, accounts) }) } diff --git a/stores/autopilot.go b/stores/autopilot.go index 45b899576..9c557c3d7 100644 --- a/stores/autopilot.go +++ b/stores/autopilot.go @@ -9,7 +9,7 @@ import ( ) func (s *SQLStore) Autopilots(ctx context.Context) (aps []api.Autopilot, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { aps, err = tx.Autopilots(ctx) return }) @@ -17,7 +17,7 @@ func (s *SQLStore) Autopilots(ctx context.Context) (aps []api.Autopilot, _ error } func (s *SQLStore) Autopilot(ctx context.Context, id string) (ap api.Autopilot, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { ap, err = tx.Autopilot(ctx, id) return }) @@ -32,7 +32,7 @@ func (s *SQLStore) UpdateAutopilot(ctx context.Context, ap api.Autopilot) error if err := ap.Config.Validate(); err != nil { return err } - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateAutopilot(ctx, ap) }) } diff --git a/stores/chain.go b/stores/chain.go index 049d9843c..06710ae4e 100644 --- a/stores/chain.go +++ b/stores/chain.go @@ -14,25 +14,18 @@ var ( ) // ChainIndex returns the last stored chain index. -func (ss *SQLStore) ChainIndex(ctx context.Context) (types.ChainIndex, error) { - var ci dbConsensusInfo - if err := ss.db. - WithContext(ctx). - Where(&dbConsensusInfo{Model: Model{ID: consensusInfoID}}). - FirstOrCreate(&ci). - Error; err != nil { - return types.ChainIndex{}, err - } - return types.ChainIndex{ - Height: ci.Height, - ID: types.BlockID(ci.BlockID), - }, nil +func (s *SQLStore) ChainIndex(ctx context.Context) (ci types.ChainIndex, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + ci, err = tx.Tip(ctx) + return err + }) + return } // ProcessChainUpdate returns a callback function that process a chain update // inside a transaction. func (s *SQLStore) ProcessChainUpdate(ctx context.Context, applyFn chain.ApplyChainUpdateFn) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.ProcessChainUpdate(ctx, applyFn) }) } @@ -46,7 +39,7 @@ func (s *SQLStore) UpdateChainState(reverted []chain.RevertUpdate, applied []cha // ResetChainState deletes all chain data in the database. func (s *SQLStore) ResetChainState(ctx context.Context) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.ResetChainState(ctx) }) } diff --git a/stores/hostdb.go b/stores/hostdb.go index f13eb2534..452ee6acd 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -17,12 +17,6 @@ import ( "gorm.io/gorm/clause" ) -const ( - // consensusInfoID defines the primary key of the entry in the consensusInfo - // table. - consensusInfoID = 1 -) - var ( ErrNegativeMaxDowntime = errors.New("max downtime can not be negative") ) @@ -120,12 +114,6 @@ type ( Hosts []dbHost `gorm:"many2many:host_blocklist_entry_hosts;constraint:OnDelete:CASCADE"` } - dbConsensusInfo struct { - Model - Height uint64 - BlockID hash256 - } - // dbAnnouncement is a table used for storing all announcements. It // doesn't have any relations to dbHost which means it won't // automatically prune when a host is deleted. @@ -151,9 +139,6 @@ type ( // TableName implements the gorm.Tabler interface. func (dbAnnouncement) TableName() string { return "host_announcements" } -// TableName implements the gorm.Tabler interface. -func (dbConsensusInfo) TableName() string { return "consensus_infos" } - // TableName implements the gorm.Tabler interface. func (dbHost) TableName() string { return "hosts" } @@ -271,8 +256,8 @@ func (e *dbBlocklistEntry) blocks(h dbHost) bool { } // Host returns information about a host. -func (ss *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) { - hosts, err := ss.SearchHosts(ctx, "", api.HostFilterModeAll, api.UsabilityFilterModeAll, "", []types.PublicKey{hostKey}, 0, 1) +func (s *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) { + hosts, err := s.SearchHosts(ctx, "", api.HostFilterModeAll, api.UsabilityFilterModeAll, "", []types.PublicKey{hostKey}, 0, 1) if err != nil { return api.Host{}, err } else if len(hosts) == 0 { @@ -282,15 +267,15 @@ func (ss *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host } } -func (ss *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) { - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateHostCheck(ctx, autopilotID, hk, hc) }) } // HostsForScanning returns the address of hosts for scanning. -func (ss *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) (hosts []api.HostAddress, err error) { - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) (hosts []api.HostAddress, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { hosts, err = tx.HostsForScanning(ctx, maxLastScan, offset, limit) return err }) @@ -306,9 +291,9 @@ func (s *SQLStore) ResetLostSectors(ctx context.Context, hk types.PublicKey) err }) } -func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { +func (s *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { var hosts []api.Host - err := ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { hosts, err = tx.SearchHosts(ctx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) return }) @@ -316,16 +301,16 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us } // Hosts returns non-blocked hosts at given offset and limit. -func (ss *SQLStore) Hosts(ctx context.Context, offset, limit int) ([]api.Host, error) { - return ss.SearchHosts(ctx, "", api.HostFilterModeAllowed, api.UsabilityFilterModeAll, "", nil, offset, limit) +func (s *SQLStore) Hosts(ctx context.Context, offset, limit int) ([]api.Host, error) { + return s.SearchHosts(ctx, "", api.HostFilterModeAllowed, api.UsabilityFilterModeAll, "", nil, offset, limit) } -func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures uint64, maxDowntime time.Duration) (removed uint64, err error) { +func (s *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures uint64, maxDowntime time.Duration) (removed uint64, err error) { // sanity check 'maxDowntime' if maxDowntime < 0 { return 0, ErrNegativeMaxDowntime } - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { n, err := tx.RemoveOfflineHosts(ctx, minRecentFailures, maxDowntime) removed = uint64(n) return err @@ -333,50 +318,50 @@ func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures ui return } -func (ss *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) (err error) { +func (s *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) (err error) { // nothing to do if len(add)+len(remove) == 0 && !clear { return nil } - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateHostAllowlistEntries(ctx, add, remove, clear) }) } -func (ss *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) (err error) { +func (s *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) (err error) { // nothing to do if len(add)+len(remove) == 0 && !clear { return nil } - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateHostBlocklistEntries(ctx, add, remove, clear) }) } -func (ss *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) { - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { allowlist, err = tx.HostAllowlist(ctx) return err }) return } -func (ss *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) { - err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { blocklist, err = tx.HostBlocklist(ctx) return err }) return } -func (ss *SQLStore) RecordHostScans(ctx context.Context, scans []api.HostScan) error { - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) RecordHostScans(ctx context.Context, scans []api.HostScan) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.RecordHostScans(ctx, scans) }) } -func (ss *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []api.HostPriceTableUpdate) error { - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { +func (s *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []api.HostPriceTableUpdate) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.RecordPriceTables(ctx, priceTableUpdate) }) } diff --git a/stores/hostdb_test.go b/stores/hostdb_test.go index aa1b537e0..5947fdb6e 100644 --- a/stores/hostdb_test.go +++ b/stores/hostdb_test.go @@ -56,7 +56,7 @@ func TestSQLHostDB(t *testing.T) { // Fetch the host var h dbHost - tx := ss.db.Where("net_address = ?", "address").Find(&h) + tx := ss.gormDB.Where("net_address = ?", "address").Find(&h) if tx.Error != nil { t.Fatal(tx.Error) } else if types.PublicKey(h.PublicKey) != hk { @@ -321,7 +321,7 @@ func TestSearchHosts(t *testing.T) { // assert there are currently 3 checks var cnt int64 - err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error + err = ss.gormDB.Model(&dbHostCheck{}).Count(&cnt).Error if err != nil { t.Fatal(err) } else if cnt != 3 { @@ -397,11 +397,11 @@ func TestSearchHosts(t *testing.T) { } // assert cascade delete on host - err = ss.db.Exec("DELETE FROM hosts WHERE public_key = ?", publicKey(types.PublicKey{1})).Error + err = ss.gormDB.Exec("DELETE FROM hosts WHERE public_key = ?", publicKey(types.PublicKey{1})).Error if err != nil { t.Fatal(err) } - err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error + err = ss.gormDB.Model(&dbHostCheck{}).Count(&cnt).Error if err != nil { t.Fatal(err) } else if cnt != 2 { @@ -409,11 +409,11 @@ func TestSearchHosts(t *testing.T) { } // assert cascade delete on autopilot - err = ss.db.Exec("DELETE FROM autopilots WHERE identifier IN (?,?)", ap1, ap2).Error + err = ss.gormDB.Exec("DELETE FROM autopilots WHERE identifier IN (?,?)", ap1, ap2).Error if err != nil { t.Fatal(err) } - err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error + err = ss.gormDB.Model(&dbHostCheck{}).Count(&cnt).Error if err != nil { t.Fatal(err) } else if cnt != 0 { @@ -452,7 +452,7 @@ func TestRecordScan(t *testing.T) { } // Fetch the host directly to get the creation time. - h, err := hostByPubKey(ss.db, hk) + h, err := hostByPubKey(ss.gormDB, hk) if err != nil { t.Fatal(err) } @@ -589,11 +589,11 @@ func TestInsertAnnouncements(t *testing.T) { ann3 := newTestAnnouncement(types.GeneratePrivateKey().PublicKey(), "") // Insert the first one and check that all fields are set. - if err := insertAnnouncements(ss.db, []announcement{ann1}); err != nil { + if err := insertAnnouncements(ss.gormDB, []announcement{ann1}); err != nil { t.Fatal(err) } var ann dbAnnouncement - if err := ss.db.Find(&ann).Error; err != nil { + if err := ss.gormDB.Find(&ann).Error; err != nil { t.Fatal(err) } ann.Model = Model{} // ignore @@ -607,12 +607,12 @@ func TestInsertAnnouncements(t *testing.T) { t.Fatal("mismatch", cmp.Diff(ann, expectedAnn)) } // Insert the first and second one. - if err := insertAnnouncements(ss.db, []announcement{ann1, ann2}); err != nil { + if err := insertAnnouncements(ss.gormDB, []announcement{ann1, ann2}); err != nil { t.Fatal(err) } // Insert the first one twice. The second one again and the third one. - if err := insertAnnouncements(ss.db, []announcement{ann1, ann2, ann1, ann3}); err != nil { + if err := insertAnnouncements(ss.gormDB, []announcement{ann1, ann2, ann1, ann3}); err != nil { t.Fatal(err) } @@ -627,7 +627,7 @@ func TestInsertAnnouncements(t *testing.T) { // There should be 7 announcements total. var announcements []dbAnnouncement - if err := ss.db.Find(&announcements).Error; err != nil { + if err := ss.gormDB.Find(&announcements).Error; err != nil { t.Fatal(err) } if len(announcements) != 7 { @@ -644,7 +644,7 @@ func TestInsertAnnouncements(t *testing.T) { // Insert multiple announcements for host 1 - this asserts that the UNIQUE // constraint on the blocklist table isn't triggered when inserting multiple // announcements for a host that's on the blocklist - if err := insertAnnouncements(ss.db, []announcement{ann1, ann1}); err != nil { + if err := insertAnnouncements(ss.gormDB, []announcement{ann1, ann1}); err != nil { t.Fatal(err) } } @@ -661,7 +661,7 @@ func TestRemoveHosts(t *testing.T) { } // fetch the host and assert the recent downtime is zero - h, err := hostByPubKey(ss.db, hk) + h, err := hostByPubKey(ss.gormDB, hk) if err != nil { t.Fatal(err) } @@ -691,7 +691,7 @@ func TestRemoveHosts(t *testing.T) { } // fetch the host and assert the recent downtime is 30 minutes and he has 2 recent scan failures - h, err = hostByPubKey(ss.db, hk) + h, err = hostByPubKey(ss.gormDB, hk) if err != nil { t.Fatal(err) } @@ -746,7 +746,7 @@ func TestRemoveHosts(t *testing.T) { } // assert host is removed from the database - if _, err = hostByPubKey(ss.db, hk); err != gorm.ErrRecordNotFound { + if _, err = hostByPubKey(ss.gormDB, hk); err != gorm.ErrRecordNotFound { t.Fatal("expected record not found error") } } @@ -777,7 +777,7 @@ func TestSQLHostAllowlist(t *testing.T) { numRelations := func() (cnt int64) { t.Helper() - err := ss.db.Table("host_allowlist_entry_hosts").Count(&cnt).Error + err := ss.gormDB.Table("host_allowlist_entry_hosts").Count(&cnt).Error if err != nil { t.Fatal(err) } @@ -883,7 +883,7 @@ func TestSQLHostAllowlist(t *testing.T) { } // remove host 1 - if err = ss.db.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk1)}).Delete(&dbHost{}).Error; err != nil { + if err = ss.gormDB.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk1)}).Delete(&dbHost{}).Error; err != nil { t.Fatal(err) } if numHosts() != 0 { @@ -949,7 +949,7 @@ func TestSQLHostBlocklist(t *testing.T) { numAllowlistRelations := func() (cnt int64) { t.Helper() - err := ss.db.Table("host_allowlist_entry_hosts").Count(&cnt).Error + err := ss.gormDB.Table("host_allowlist_entry_hosts").Count(&cnt).Error if err != nil { t.Fatal(err) } @@ -958,7 +958,7 @@ func TestSQLHostBlocklist(t *testing.T) { numBlocklistRelations := func() (cnt int64) { t.Helper() - err := ss.db.Table("host_blocklist_entry_hosts").Count(&cnt).Error + err := ss.gormDB.Table("host_blocklist_entry_hosts").Count(&cnt).Error if err != nil { t.Fatal(err) } @@ -1067,7 +1067,7 @@ func TestSQLHostBlocklist(t *testing.T) { } // delete host 2 and assert the delete cascaded properly - if err = ss.db.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk2)}).Delete(&dbHost{}).Error; err != nil { + if err = ss.gormDB.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk2)}).Delete(&dbHost{}).Error; err != nil { t.Fatal(err) } if numHosts() != 2 { @@ -1234,7 +1234,7 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error { } // fetch blocklists - allowlist, blocklist, err := getBlocklists(s.db) + allowlist, blocklist, err := getBlocklists(s.gormDB) if err != nil { return err } @@ -1246,7 +1246,7 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error { dbAllowlist = append(dbAllowlist, entry) } } - if err := s.db.Model(&host).Association("Allowlist").Replace(&dbAllowlist); err != nil { + if err := s.gormDB.Model(&host).Association("Allowlist").Replace(&dbAllowlist); err != nil { return err } @@ -1257,21 +1257,21 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error { dbBlocklist = append(dbBlocklist, entry) } } - return s.db.Model(&host).Association("Blocklist").Replace(&dbBlocklist) + return s.gormDB.Model(&host).Association("Blocklist").Replace(&dbBlocklist) } // announceHost adds a host announcement to the database. func (s *SQLStore) announceHost(hk types.PublicKey, na string) (host dbHost, err error) { - err = s.db.Transaction(func(tx *gorm.DB) error { + err = s.gormDB.Transaction(func(tx *gorm.DB) error { host = dbHost{ PublicKey: publicKey(hk), LastAnnouncement: time.Now().UTC().Round(time.Second), NetAddress: na, } - if err := s.db.Create(&host).Error; err != nil { + if err := s.gormDB.Create(&host).Error; err != nil { return err } - return s.db.Create(&dbAnnouncement{ + return s.gormDB.Create(&dbAnnouncement{ HostKey: publicKey(hk), BlockHeight: 42, BlockID: types.BlockID{1, 2, 3}.String(), @@ -1285,7 +1285,7 @@ func (s *SQLStore) announceHost(hk types.PublicKey, na string) (host dbHost, err // interactions for all hosts is expensive in production. func (db *SQLStore) hosts() ([]dbHost, error) { var hosts []dbHost - tx := db.db.Find(&hosts) + tx := db.gormDB.Find(&hosts) if tx.Error != nil { return nil, tx.Error } diff --git a/stores/metadata.go b/stores/metadata.go index fbbbb0450..1f3cac17a 100644 --- a/stores/metadata.go +++ b/stores/metadata.go @@ -423,7 +423,7 @@ func (raw rawObject) toSlabSlice() (slice object.SlabSlice, _ error) { } func (s *SQLStore) Bucket(ctx context.Context, bucket string) (b api.Bucket, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { b, err = tx.Bucket(ctx, bucket) return }) @@ -431,25 +431,25 @@ func (s *SQLStore) Bucket(ctx context.Context, bucket string) (b api.Bucket, err } func (s *SQLStore) CreateBucket(ctx context.Context, bucket string, policy api.BucketPolicy) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.CreateBucket(ctx, bucket, policy) }) } func (s *SQLStore) UpdateBucketPolicy(ctx context.Context, bucket string, policy api.BucketPolicy) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateBucketPolicy(ctx, bucket, policy) }) } func (s *SQLStore) DeleteBucket(ctx context.Context, bucket string) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.DeleteBucket(ctx, bucket) }) } func (s *SQLStore) ListBuckets(ctx context.Context) (buckets []api.Bucket, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { buckets, err = tx.ListBuckets(ctx) return }) @@ -460,7 +460,7 @@ func (s *SQLStore) ListBuckets(ctx context.Context) (buckets []api.Bucket, err e // reduce locking and make sure all results are consistent, everything is done // within a single transaction. func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) (resp api.ObjectsStatsResponse, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.ObjectsStats(ctx, opts) return }) @@ -470,7 +470,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) func (s *SQLStore) SlabBuffers(ctx context.Context) ([]api.SlabBuffer, error) { var err error var fileNameToContractSet map[string]string - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { fileNameToContractSet, err = tx.SlabBuffers(ctx) return err }) @@ -488,7 +488,7 @@ func (s *SQLStore) SlabBuffers(ctx context.Context) ([]api.SlabBuffer, error) { func (s *SQLStore) AddContract(ctx context.Context, c rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, state string) (_ api.ContractMetadata, err error) { var contract api.ContractMetadata - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { contract, err = tx.InsertContract(ctx, c, contractPrice, totalCost, startHeight, types.FileContractID{}, state) return err }) @@ -501,7 +501,7 @@ func (s *SQLStore) AddContract(ctx context.Context, c rhpv2.ContractRevision, co func (s *SQLStore) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) { var contracts []api.ContractMetadata - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { contracts, err = tx.Contracts(ctx, opts) return }) @@ -513,7 +513,7 @@ func (s *SQLStore) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api // contracts and moved to the archive. Both new and old contract will be linked // to each other through the RenewedFrom and RenewedTo fields respectively. func (s *SQLStore) AddRenewedContract(ctx context.Context, c rhpv2.ContractRevision, contractPrice, totalCost types.Currency, startHeight uint64, renewedFrom types.FileContractID, state string) (renewed api.ContractMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { renewed, err = tx.RenewContract(ctx, c, contractPrice, totalCost, startHeight, renewedFrom, state) return err }) @@ -524,7 +524,7 @@ func (s *SQLStore) AddRenewedContract(ctx context.Context, c rhpv2.ContractRevis } func (s *SQLStore) AncestorContracts(ctx context.Context, id types.FileContractID, startHeight uint64) (ancestors []api.ArchivedContract, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { ancestors, err = tx.AncestorContracts(ctx, id, startHeight) return err }) @@ -549,7 +549,7 @@ func (s *SQLStore) ArchiveContracts(ctx context.Context, toArchive map[types.Fil // archive the contract but don't interrupt the process if one contract // fails - if err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + if err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.ArchiveContract(ctx, fcid, reason) }); err != nil { errs = append(errs, fmt.Sprintf("%v: %v", fcid, err)) @@ -575,7 +575,7 @@ func (s *SQLStore) ArchiveAllContracts(ctx context.Context, reason string) error } func (s *SQLStore) Contract(ctx context.Context, id types.FileContractID) (cm api.ContractMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { cm, err = tx.Contract(ctx, id) return err }) @@ -583,7 +583,7 @@ func (s *SQLStore) Contract(ctx context.Context, id types.FileContractID) (cm ap } func (s *SQLStore) ContractRoots(ctx context.Context, id types.FileContractID) (roots []types.Hash256, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { roots, err = tx.ContractRoots(ctx, id) return err }) @@ -591,7 +591,7 @@ func (s *SQLStore) ContractRoots(ctx context.Context, id types.FileContractID) ( } func (s *SQLStore) ContractSets(ctx context.Context) (sets []string, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { sets, err = tx.ContractSets(ctx) return err }) @@ -599,7 +599,7 @@ func (s *SQLStore) ContractSets(ctx context.Context) (sets []string, err error) } func (s *SQLStore) ContractSizes(ctx context.Context) (sizes map[types.FileContractID]api.ContractSize, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { sizes, err = tx.ContractSizes(ctx) return err }) @@ -607,7 +607,7 @@ func (s *SQLStore) ContractSizes(ctx context.Context) (sizes map[types.FileContr } func (s *SQLStore) ContractSize(ctx context.Context, id types.FileContractID) (cs api.ContractSize, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { cs, err = tx.ContractSize(ctx, id) return }) @@ -691,13 +691,13 @@ func (s *SQLStore) SetContractSet(ctx context.Context, name string, contractIds } func (s *SQLStore) RemoveContractSet(ctx context.Context, name string) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.RemoveContractSet(ctx, name) }) } func (s *SQLStore) RenewedContract(ctx context.Context, renewedFrom types.FileContractID) (cm api.ContractMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { cm, err = tx.RenewedContract(ctx, renewedFrom) return err }) @@ -705,7 +705,7 @@ func (s *SQLStore) RenewedContract(ctx context.Context, renewedFrom types.FileCo } func (s *SQLStore) SearchObjects(ctx context.Context, bucket, substring string, offset, limit int) (objects []api.ObjectMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { objects, err = tx.SearchObjects(ctx, bucket, substring, offset, limit) return err }) @@ -713,7 +713,7 @@ func (s *SQLStore) SearchObjects(ctx context.Context, bucket, substring string, } func (s *SQLStore) ObjectEntries(ctx context.Context, bucket, path, prefix, sortBy, sortDir, marker string, offset, limit int) (metadata []api.ObjectMetadata, hasMore bool, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { metadata, hasMore, err = tx.ObjectEntries(ctx, bucket, path, prefix, sortBy, sortDir, marker, offset, limit) return err }) @@ -850,7 +850,7 @@ func fetchUsedContracts(tx *gorm.DB, usedContractsByHost map[types.PublicKey]map } func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew string, force bool) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // create new dir dirID, err := tx.MakeDirsForPath(ctx, keyNew) if err != nil { @@ -868,7 +868,7 @@ func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew stri } func (s *SQLStore) RenameObjects(ctx context.Context, bucket, prefixOld, prefixNew string, force bool) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // create new dir dirID, err := tx.MakeDirsForPath(ctx, prefixNew) if err != nil { @@ -891,7 +891,7 @@ func (s *SQLStore) AddPartialSlab(ctx context.Context, data []byte, minShards, t } func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath, dstPath, mimeType string, metadata api.ObjectUserMetadata) (om api.ObjectMetadata, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { if srcBucket != dstBucket || srcPath != dstPath { _, err = tx.DeleteObject(ctx, dstBucket, dstPath) if err != nil { @@ -905,7 +905,7 @@ func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath } func (s *SQLStore) DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) (deletedSectors int, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { deletedSectors, err = tx.DeleteHostSector(ctx, hk, root) return err }) @@ -925,7 +925,7 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, // UpdateObject is ACID. var prune bool - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // Try to delete. We want to get rid of the object and its slices if it // exists. // @@ -966,7 +966,7 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, func (s *SQLStore) RemoveObject(ctx context.Context, bucket, path string) error { var prune bool - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { prune, err = tx.DeleteObject(ctx, bucket, path) return }) @@ -986,7 +986,7 @@ func (s *SQLStore) RemoveObjects(ctx context.Context, bucket, prefix string) err start := time.Now() var done bool var duration time.Duration - if err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + if err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { deleted, err := tx.DeleteObjects(ctx, bucket, prefix, objectDeleteBatchSizes[batchSizeIdx]) if err != nil { return err @@ -1014,22 +1014,22 @@ func (s *SQLStore) RemoveObjects(ctx context.Context, bucket, prefix string) err } func (s *SQLStore) Slab(ctx context.Context, key object.EncryptionKey) (slab object.Slab, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { slab, err = tx.Slab(ctx, key) return err }) return } -func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet string) error { +func (s *SQLStore) UpdateSlab(ctx context.Context, slab object.Slab, contractSet string) error { // sanity check the shards don't contain an empty root - for _, s := range s.Shards { - if s.Root == (types.Hash256{}) { + for _, shard := range slab.Shards { + if shard.Root == (types.Hash256{}) { return errors.New("shard root can never be the empty root") } } // Sanity check input. - for i, shard := range s.Shards { + for i, shard := range slab.Shards { // Verify that all hosts have a contract. if len(shard.Contracts) == 0 { return fmt.Errorf("missing hosts for slab %d", i) @@ -1037,8 +1037,8 @@ func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet s } // Update slab. - return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { - return tx.UpdateSlab(ctx, s, contractSet, s.Contracts()) + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.UpdateSlab(ctx, slab, contractSet, slab.Contracts()) }) } @@ -1046,7 +1046,7 @@ func (s *SQLStore) RefreshHealth(ctx context.Context) error { for { // update slabs var rowsAffected int64 - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { rowsAffected, err = tx.UpdateSlabHealth(ctx, refreshHealthBatchSize, refreshHealthMinHealthValidity, refreshHealthMaxHealthValidity) return }) @@ -1200,7 +1200,7 @@ func (s *SQLStore) objectHydrate(tx *gorm.DB, bucket, path string, obj rawObject // ObjectMetadata returns an object's metadata func (s *SQLStore) ObjectMetadata(ctx context.Context, bucket, path string) (obj api.Object, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { obj, err = tx.ObjectMetadata(ctx, bucket, path) return err }) @@ -1402,7 +1402,7 @@ func (s *SQLStore) pruneSlabsLoop() { pruneSuccess := true for { var deleted int64 - err := s.bMain.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { + err := s.db.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { var err error deleted, err = dt.PruneSlabs(s.shutdownCtx, slabPruningBatchSize) return err @@ -1430,7 +1430,7 @@ func (s *SQLStore) pruneSlabsLoop() { } // prune dirs - err := s.bMain.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { + err := s.db.Transaction(s.shutdownCtx, func(dt sql.DatabaseTx) error { return dt.PruneEmptydirs(s.shutdownCtx) }) if err != nil { @@ -1469,7 +1469,7 @@ func (s *SQLStore) triggerSlabPruning() { func (s *SQLStore) invalidateSlabHealthByFCID(ctx context.Context, fcids []types.FileContractID) error { for { var affected int64 - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { affected, err = tx.InvalidateSlabHealthByFCID(ctx, fcids, refreshHealthBatchSize) return }) @@ -1486,7 +1486,7 @@ func (s *SQLStore) invalidateSlabHealthByFCID(ctx context.Context, fcids []types // a delimiter for now (see backend.go) but it would be interesting to have // arbitrary 'delim' support in ListObjects. func (s *SQLStore) ListObjects(ctx context.Context, bucket, prefix, sortBy, sortDir, marker string, limit int) (resp api.ObjectsListResponse, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { resp, err = tx.ListObjects(ctx, bucket, prefix, sortBy, sortDir, marker, limit) return err }) diff --git a/stores/metadata_test.go b/stores/metadata_test.go index 4dd1d84ec..a5ff944df 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -168,11 +168,11 @@ func TestObjectBasic(t *testing.T) { // delete a sector var sectors []dbSector - if err := ss.db.Find(§ors).Error; err != nil { + if err := ss.gormDB.Find(§ors).Error; err != nil { t.Fatal(err) } else if len(sectors) != 2 { t.Fatal("unexpected number of sectors") - } else if tx := ss.db.Delete(sectors[0]); tx.Error != nil || tx.RowsAffected != 1 { + } else if tx := ss.gormDB.Delete(sectors[0]); tx.Error != nil || tx.RowsAffected != 1 { t.Fatal("unexpected number of sectors deleted", tx.Error, tx.RowsAffected) } @@ -261,7 +261,7 @@ func TestObjectMetadata(t *testing.T) { // assert metadata CASCADE on object delete var cnt int64 - if err := ss.db.Model(&dbObjectUserMetadata{}).Count(&cnt).Error; err != nil { + if err := ss.gormDB.Model(&dbObjectUserMetadata{}).Count(&cnt).Error; err != nil { t.Fatal(err) } else if cnt != 2 { t.Fatal("unexpected number of metadata entries", cnt) @@ -273,7 +273,7 @@ func TestObjectMetadata(t *testing.T) { } // assert records are gone - if err := ss.db.Model(&dbObjectUserMetadata{}).Count(&cnt).Error; err != nil { + if err := ss.gormDB.Model(&dbObjectUserMetadata{}).Count(&cnt).Error; err != nil { t.Fatal(err) } else if cnt != 0 { t.Fatal("unexpected number of metadata entries", cnt) @@ -462,7 +462,7 @@ func TestSQLContractStore(t *testing.T) { // Make sure the db was cleaned up properly through the CASCADE delete. tableCountCheck := func(table interface{}, tblCount int64) error { var count int64 - if err := ss.db.Model(table).Count(&count).Error; err != nil { + if err := ss.gormDB.Model(table).Count(&count).Error; err != nil { return err } if count != tblCount { @@ -476,7 +476,7 @@ func TestSQLContractStore(t *testing.T) { // Check join table count as well. var count int64 - if err := ss.db.Table("contract_sectors").Count(&count).Error; err != nil { + if err := ss.gormDB.Table("contract_sectors").Count(&count).Error; err != nil { t.Fatal(err) } if count != 0 { @@ -741,7 +741,7 @@ func TestRenewedContract(t *testing.T) { // Archived contract should exist. var ac dbArchivedContract - err = ss.db.Model(&dbArchivedContract{}). + err = ss.gormDB.Model(&dbArchivedContract{}). Where("fcid", fileContractID(fcid1)). Take(&ac). Error @@ -905,7 +905,7 @@ func TestArchiveContracts(t *testing.T) { ffcids[0] = fileContractID(fcids[1]) ffcids[1] = fileContractID(fcids[2]) var acs []dbArchivedContract - err = ss.db.Model(&dbArchivedContract{}). + err = ss.gormDB.Model(&dbArchivedContract{}). Where("fcid IN (?)", ffcids). Find(&acs). Error @@ -1243,7 +1243,7 @@ func TestSQLMetadataStore(t *testing.T) { countCheck := func(objCount, sliceCount, slabCount, sectorCount int64) error { tableCountCheck := func(table interface{}, tblCount int64) error { var count int64 - if err := ss.db.Model(table).Count(&count).Error; err != nil { + if err := ss.gormDB.Model(table).Count(&count).Error; err != nil { return err } if count != tblCount { @@ -1486,7 +1486,7 @@ func TestObjectEntries(t *testing.T) { } // update health of objects to match the overridden health of the slabs - if err := updateAllObjectsHealth(ss.db); err != nil { + if err := updateAllObjectsHealth(ss.gormDB); err != nil { t.Fatal() } @@ -1921,7 +1921,7 @@ func TestUnhealthySlabsNoContracts(t *testing.T) { if err != nil { t.Fatal(err) } - if err := ss.db.Table("contract_sectors").Where("TRUE").Delete(&dbContractSector{}).Error; err != nil { + if err := ss.gormDB.Table("contract_sectors").Where("TRUE").Delete(&dbContractSector{}).Error; err != nil { t.Fatal(err) } @@ -2060,7 +2060,7 @@ func TestContractSectors(t *testing.T) { // Check the join table. Should be empty. var css []dbContractSector - if err := ss.db.Find(&css).Error; err != nil { + if err := ss.gormDB.Find(&css).Error; err != nil { t.Fatal(err) } if len(css) != 0 { @@ -2084,10 +2084,10 @@ func TestContractSectors(t *testing.T) { } // Delete the sector. - if err := ss.db.Delete(&dbSector{Model: Model{ID: 1}}).Error; err != nil { + if err := ss.gormDB.Delete(&dbSector{Model: Model{ID: 1}}).Error; err != nil { t.Fatal(err) } - if err := ss.db.Find(&css).Error; err != nil { + if err := ss.gormDB.Find(&css).Error; err != nil { t.Fatal(err) } if len(css) != 0 { @@ -2144,7 +2144,7 @@ func TestUpdateSlab(t *testing.T) { // helper to fetch a slab from the database fetchSlab := func() (slab dbSlab) { t.Helper() - if err = ss.db. + if err = ss.gormDB. Where(&dbSlab{Key: key}). Preload("Shards.Contracts"). Take(&slab). @@ -2227,7 +2227,7 @@ func TestUpdateSlab(t *testing.T) { // assert there's still only one entry in the dbslab table var cnt int64 - if err := ss.db.Model(&dbSlab{}).Count(&cnt).Error; err != nil { + if err := ss.gormDB.Model(&dbSlab{}).Count(&cnt).Error; err != nil { t.Fatal(err) } else if cnt != 1 { t.Fatalf("unexpected number of entries in dbslab, %v != 1", cnt) @@ -2262,7 +2262,7 @@ func TestUpdateSlab(t *testing.T) { t.Fatal(err) } var s dbSlab - if err := ss.db.Where(&dbSlab{Key: key}). + if err := ss.gormDB.Where(&dbSlab{Key: key}). Joins("DBContractSet"). Preload("Shards"). Take(&s). @@ -2497,7 +2497,7 @@ func TestRenameObjects(t *testing.T) { } var directories []dbDirectory test.Retry(100, 100*time.Millisecond, func() error { - if err := ss.db.Find(&directories).Error; err != nil { + if err := ss.gormDB.Find(&directories).Error; err != nil { return err } else if len(directories) != len(expectedDirs) { return fmt.Errorf("unexpected number of directories, %v != %v", len(directories), len(expectedDirs)) @@ -2565,7 +2565,7 @@ func TestObjectsStats(t *testing.T) { // Get all entries in contract_sectors and store them again with a different // contract id. This should cause the uploaded size to double. var contractSectors []dbContractSector - err = ss.db.Find(&contractSectors).Error + err = ss.gormDB.Find(&contractSectors).Error if err != nil { t.Fatal(err) } @@ -2717,7 +2717,7 @@ func TestPartialSlab(t *testing.T) { var buffer bufferedSlab sk, _ := slabs[0].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { + if err := ss.gormDB.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { t.Fatal(err) } if buffer.Filename == "" { @@ -2779,7 +2779,7 @@ func TestPartialSlab(t *testing.T) { } buffer = bufferedSlab{} sk, _ = slabs[0].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { + if err := ss.gormDB.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { t.Fatal(err) } assertBuffer(buffer1Name, 4194303, false, false) @@ -2820,13 +2820,13 @@ func TestPartialSlab(t *testing.T) { } buffer = bufferedSlab{} sk, _ = slabs[0].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { + if err := ss.gormDB.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { t.Fatal(err) } assertBuffer(buffer1Name, rhpv2.SectorSize, true, false) buffer = bufferedSlab{} sk, _ = slabs[1].Key.MarshalBinary() - if err := ss.db.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { + if err := ss.gormDB.Joins("DBSlab").Take(&buffer, "DBSlab.key = ?", secretKey(sk)).Error; err != nil { t.Fatal(err) } buffer2Name := buffer.Filename @@ -2854,11 +2854,11 @@ func TestPartialSlab(t *testing.T) { assertBuffer(buffer2Name, 1, false, false) var foo []bufferedSlab - if err := ss.db.Find(&foo).Error; err != nil { + if err := ss.gormDB.Find(&foo).Error; err != nil { t.Fatal(err) } buffer = bufferedSlab{} - if err := ss.db.Take(&buffer, "id = ?", packedSlabs[0].BufferID).Error; err != nil { + if err := ss.gormDB.Take(&buffer, "id = ?", packedSlabs[0].BufferID).Error; err != nil { t.Fatal(err) } @@ -2877,7 +2877,7 @@ func TestPartialSlab(t *testing.T) { } buffer = bufferedSlab{} - if err := ss.db.Take(&buffer, "id = ?", packedSlabs[0].BufferID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { + if err := ss.gormDB.Take(&buffer, "id = ?", packedSlabs[0].BufferID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Fatal("shouldn't be able to find buffer", err) } assertBuffer(buffer2Name, 1, false, false) @@ -3084,7 +3084,7 @@ func TestContractSizes(t *testing.T) { // dbObject retrieves a dbObject from the store. func (s *SQLStore) dbObject(key string) (dbObject, error) { var obj dbObject - tx := s.db.Where(&dbObject{ObjectID: key}). + tx := s.gormDB.Where(&dbObject{ObjectID: key}). Preload("Slabs"). Take(&obj) if errors.Is(tx.Error, gorm.ErrRecordNotFound) { @@ -3096,7 +3096,7 @@ func (s *SQLStore) dbObject(key string) (dbObject, error) { // dbSlab retrieves a dbSlab from the store. func (s *SQLStore) dbSlab(key []byte) (dbSlab, error) { var slab dbSlab - tx := s.db.Where(&dbSlab{Key: key}). + tx := s.gormDB.Where(&dbSlab{Key: key}). Preload("Shards.Contracts.Host"). Take(&slab) if errors.Is(tx.Error, gorm.ErrRecordNotFound) { @@ -3364,7 +3364,7 @@ func TestBucketObjects(t *testing.T) { // See if we can fetch the object by slab. var ec object.EncryptionKey - if obj, err := ss.objectRaw(ss.db, b1, "/bar"); err != nil { + if obj, err := ss.objectRaw(ss.gormDB, b1, "/bar"); err != nil { t.Fatal(err) } else if err := ec.UnmarshalBinary(obj[0].SlabKey); err != nil { t.Fatal(err) @@ -3498,7 +3498,7 @@ func TestMarkSlabUploadedAfterRenew(t *testing.T) { } var count int64 - if err := ss.db.Model(&dbContractSector{}).Count(&count).Error; err != nil { + if err := ss.gormDB.Model(&dbContractSector{}).Count(&count).Error; err != nil { t.Fatal(err) } else if count != 1 { t.Fatal("expected 1 sector", count) @@ -3549,7 +3549,7 @@ func TestListObjects(t *testing.T) { } // update health of objects to match the overridden health of the slabs - if err := updateAllObjectsHealth(ss.db); err != nil { + if err := updateAllObjectsHealth(ss.gormDB); err != nil { t.Fatal() } @@ -3633,7 +3633,7 @@ func TestDeleteHostSector(t *testing.T) { // get all contracts var dbContracts []dbContract - if err := ss.db.Model(&dbContract{}).Preload("Host").Find(&dbContracts).Error; err != nil { + if err := ss.gormDB.Model(&dbContract{}).Preload("Host").Find(&dbContracts).Error; err != nil { t.Fatal(err) } @@ -3654,13 +3654,13 @@ func TestDeleteHostSector(t *testing.T) { }, }, } - if err := ss.db.Create(&slab).Error; err != nil { + if err := ss.gormDB.Create(&slab).Error; err != nil { t.Fatal(err) } // Make sure 4 contractSector entries exist. var n int64 - if err := ss.db.Model(&dbContractSector{}). + if err := ss.gormDB.Model(&dbContractSector{}). Count(&n). Error; err != nil { t.Fatal(err) @@ -3676,7 +3676,7 @@ func TestDeleteHostSector(t *testing.T) { } // Make sure 2 contractSector entries exist. - if err := ss.db.Model(&dbContractSector{}). + if err := ss.gormDB.Model(&dbContractSector{}). Count(&n). Error; err != nil { t.Fatal(err) @@ -3686,7 +3686,7 @@ func TestDeleteHostSector(t *testing.T) { // Find the slab. It should have an invalid health. var s dbSlab - if err := ss.db.Preload("Shards").Take(&s).Error; err != nil { + if err := ss.gormDB.Preload("Shards").Take(&s).Error; err != nil { t.Fatal(err) } else if s.HealthValid() { t.Fatal("expected health to be invalid") @@ -3696,7 +3696,7 @@ func TestDeleteHostSector(t *testing.T) { // Fetch the sector and assert the contracts association. var sectors []dbSector - if err := ss.db.Model(&dbSector{}).Preload("Contracts").Find(§ors).Preload("Contracts").Error; err != nil { + if err := ss.gormDB.Model(&dbSector{}).Preload("Contracts").Find(§ors).Preload("Contracts").Error; err != nil { t.Fatal(err) } else if len(sectors) != 1 { t.Fatal("expected 1 sector", len(sectors)) @@ -3738,7 +3738,7 @@ func TestDeleteHostSector(t *testing.T) { } // Fetch the sector and check the public key has the default value - if err := ss.db.Model(&dbSector{}).Find(§ors).Error; err != nil { + if err := ss.gormDB.Model(&dbSector{}).Find(§ors).Error; err != nil { t.Fatal(err) } else if len(sectors) != 1 { t.Fatal("expected 1 sector", len(sectors)) @@ -3839,7 +3839,7 @@ func TestSlabHealthInvalidation(t *testing.T) { var slab dbSlab if key, err := slabKey.MarshalBinary(); err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Where(&dbSlab{Key: key}).Take(&slab).Error; err != nil { + } else if err := ss.gormDB.Model(&dbSlab{}).Where(&dbSlab{Key: key}).Take(&slab).Error; err != nil { t.Fatal(err) } else if slab.HealthValid() != expected { t.Fatal("unexpected health valid", slab.HealthValid(), slab.HealthValidUntil, time.Now(), time.Unix(slab.HealthValidUntil, 0)) @@ -3961,7 +3961,7 @@ func TestSlabHealthInvalidation(t *testing.T) { // assert the health validity is always updated to a random time in the future that matches the boundaries for i := 0; i < 1e3; i++ { // reset health validity - if tx := ss.db.Exec("UPDATE slabs SET health_valid_until = 0;"); tx.Error != nil { + if tx := ss.gormDB.Exec("UPDATE slabs SET health_valid_until = 0;"); tx.Error != nil { t.Fatal(err) } @@ -3975,7 +3975,7 @@ func TestSlabHealthInvalidation(t *testing.T) { var slab dbSlab if key, err := s1.MarshalBinary(); err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Where(&dbSlab{Key: key}).Take(&slab).Error; err != nil { + } else if err := ss.gormDB.Model(&dbSlab{}).Where(&dbSlab{Key: key}).Take(&slab).Error; err != nil { t.Fatal(err) } @@ -4113,18 +4113,18 @@ func TestSlabCleanup(t *testing.T) { // create contract set cs := dbContractSet{} - if err := ss.db.Create(&cs).Error; err != nil { + if err := ss.gormDB.Create(&cs).Error; err != nil { t.Fatal(err) } // create buffered slab bsID := uint(1) - if err := ss.db.Exec("INSERT INTO buffered_slabs (filename) VALUES ('foo');").Error; err != nil { + if err := ss.gormDB.Exec("INSERT INTO buffered_slabs (filename) VALUES ('foo');").Error; err != nil { t.Fatal(err) } var dirID int64 - err := ss.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + err := ss.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { var err error dirID, err = tx.MakeDirsForPath(context.Background(), "1") return err @@ -4140,7 +4140,7 @@ func TestSlabCleanup(t *testing.T) { DBBucketID: ss.DefaultBucketID(), Health: 1, } - if err := ss.db.Create(&obj1).Error; err != nil { + if err := ss.gormDB.Create(&obj1).Error; err != nil { t.Fatal(err) } obj2 := dbObject{ @@ -4149,7 +4149,7 @@ func TestSlabCleanup(t *testing.T) { DBBucketID: ss.DefaultBucketID(), Health: 1, } - if err := ss.db.Create(&obj2).Error; err != nil { + if err := ss.gormDB.Create(&obj2).Error; err != nil { t.Fatal(err) } @@ -4161,7 +4161,7 @@ func TestSlabCleanup(t *testing.T) { Key: secretKey(ek), HealthValidUntil: 100, } - if err := ss.db.Create(&slab).Error; err != nil { + if err := ss.gormDB.Create(&slab).Error; err != nil { t.Fatal(err) } @@ -4170,14 +4170,14 @@ func TestSlabCleanup(t *testing.T) { DBObjectID: &obj1.ID, DBSlabID: slab.ID, } - if err := ss.db.Create(&slice1).Error; err != nil { + if err := ss.gormDB.Create(&slice1).Error; err != nil { t.Fatal(err) } slice2 := dbSlice{ DBObjectID: &obj2.ID, DBSlabID: slab.ID, } - if err := ss.db.Create(&slice2).Error; err != nil { + if err := ss.gormDB.Create(&slice2).Error; err != nil { t.Fatal(err) } @@ -4189,7 +4189,7 @@ func TestSlabCleanup(t *testing.T) { // check slice count var slabCntr int64 - if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { + if err := ss.gormDB.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { t.Fatal(err) } else if slabCntr != 1 { t.Fatalf("expected 1 slabs, got %v", slabCntr) @@ -4199,7 +4199,7 @@ func TestSlabCleanup(t *testing.T) { err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, obj2.ObjectID) if err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { + } else if err := ss.gormDB.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { t.Fatal(err) } else if slabCntr != 0 { t.Fatalf("expected 0 slabs, got %v", slabCntr) @@ -4214,7 +4214,7 @@ func TestSlabCleanup(t *testing.T) { Key: ek, HealthValidUntil: 100, } - if err := ss.db.Create(&bufferedSlab).Error; err != nil { + if err := ss.gormDB.Create(&bufferedSlab).Error; err != nil { t.Fatal(err) } obj3 := dbObject{ @@ -4223,17 +4223,17 @@ func TestSlabCleanup(t *testing.T) { DBBucketID: ss.DefaultBucketID(), Health: 1, } - if err := ss.db.Create(&obj3).Error; err != nil { + if err := ss.gormDB.Create(&obj3).Error; err != nil { t.Fatal(err) } slice := dbSlice{ DBObjectID: &obj3.ID, DBSlabID: bufferedSlab.ID, } - if err := ss.db.Create(&slice).Error; err != nil { + if err := ss.gormDB.Create(&slice).Error; err != nil { t.Fatal(err) } - if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { + if err := ss.gormDB.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { t.Fatal(err) } else if slabCntr != 1 { t.Fatalf("expected 1 slabs, got %v", slabCntr) @@ -4243,7 +4243,7 @@ func TestSlabCleanup(t *testing.T) { err = ss.RemoveObjectBlocking(context.Background(), api.DefaultBucketName, obj3.ObjectID) if err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { + } else if err := ss.gormDB.Model(&dbSlab{}).Count(&slabCntr).Error; err != nil { t.Fatal(err) } else if slabCntr != 1 { t.Fatalf("expected 1 slabs, got %v", slabCntr) @@ -4306,7 +4306,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch the object var dbObj dbObject - if err := ss.db.Where("db_bucket_id", ss.DefaultBucketID()).Take(&dbObj).Error; err != nil { + if err := ss.gormDB.Where("db_bucket_id", ss.DefaultBucketID()).Take(&dbObj).Error; err != nil { t.Fatal(err) } else if dbObj.ID != 1 { t.Fatal("unexpected id", dbObj.ID) @@ -4322,7 +4322,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch its slices var dbSlices []dbSlice - if err := ss.db.Where("db_object_id", dbObj.ID).Find(&dbSlices).Error; err != nil { + if err := ss.gormDB.Where("db_object_id", dbObj.ID).Find(&dbSlices).Error; err != nil { t.Fatal(err) } else if len(dbSlices) != 2 { t.Fatal("invalid number of slices", len(dbSlices)) @@ -4339,7 +4339,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch the slab var dbSlab dbSlab key, _ := obj.Slabs[i].Key.MarshalBinary() - if err := ss.db.Where("id", dbSlice.DBSlabID).Take(&dbSlab).Error; err != nil { + if err := ss.gormDB.Where("id", dbSlice.DBSlabID).Take(&dbSlab).Error; err != nil { t.Fatal(err) } else if dbSlab.ID != uint(i+1) { t.Fatal("unexpected id", dbSlab.ID) @@ -4359,7 +4359,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch the sectors var dbSectors []dbSector - if err := ss.db.Where("db_slab_id", dbSlab.ID).Find(&dbSectors).Error; err != nil { + if err := ss.gormDB.Where("db_slab_id", dbSlab.ID).Find(&dbSectors).Error; err != nil { t.Fatal(err) } else if len(dbSectors) != totalShards { t.Fatal("invalid number of sectors", len(dbSectors)) @@ -4412,7 +4412,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch the object var dbObj2 dbObject - if err := ss.db.Where("db_bucket_id", ss.DefaultBucketID()). + if err := ss.gormDB.Where("db_bucket_id", ss.DefaultBucketID()). Where("object_id", "2"). Take(&dbObj2).Error; err != nil { t.Fatal(err) @@ -4424,7 +4424,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch its slices var dbSlices2 []dbSlice - if err := ss.db.Where("db_object_id", dbObj2.ID).Find(&dbSlices2).Error; err != nil { + if err := ss.gormDB.Where("db_object_id", dbObj2.ID).Find(&dbSlices2).Error; err != nil { t.Fatal(err) } else if len(dbSlices2) != 2 { t.Fatal("invalid number of slices", len(dbSlices)) @@ -4443,7 +4443,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch the slab var dbSlab2 dbSlab key, _ := obj2.Slabs[0].Key.MarshalBinary() - if err := ss.db.Where("id", dbSlice2.DBSlabID).Take(&dbSlab2).Error; err != nil { + if err := ss.gormDB.Where("id", dbSlice2.DBSlabID).Take(&dbSlab2).Error; err != nil { t.Fatal(err) } else if dbSlab2.ID != uint(len(dbSlices)+1) { t.Fatal("unexpected id", dbSlab2.ID) @@ -4455,7 +4455,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { // fetch the sectors var dbSectors2 []dbSector - if err := ss.db.Where("db_slab_id", dbSlab2.ID).Find(&dbSectors2).Error; err != nil { + if err := ss.gormDB.Where("db_slab_id", dbSlab2.ID).Find(&dbSectors2).Error; err != nil { t.Fatal(err) } else if len(dbSectors2) != totalShards { t.Fatal("invalid number of sectors", len(dbSectors2)) @@ -4478,7 +4478,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) { } var contractSectors []dbContractSector - if err := ss.db.Find(&contractSectors).Error; err != nil { + if err := ss.gormDB.Find(&contractSectors).Error; err != nil { t.Fatal(err) } else if len(contractSectors) != 3*totalShards { t.Fatal("invalid number of contract sectors", len(contractSectors)) @@ -4613,7 +4613,7 @@ func TestFetchUsedContracts(t *testing.T) { // assert empty map returns no contracts usedContracts := make(map[types.PublicKey]map[types.FileContractID]struct{}) - contracts, err := fetchUsedContracts(ss.db, usedContracts) + contracts, err := fetchUsedContracts(ss.gormDB, usedContracts) if err != nil { t.Fatal(err) } else if len(contracts) != 0 { @@ -4625,7 +4625,7 @@ func TestFetchUsedContracts(t *testing.T) { usedContracts[hk1][types.FileContractID{1}] = struct{}{} // assert we get the used contract - contracts, err = fetchUsedContracts(ss.db, usedContracts) + contracts, err = fetchUsedContracts(ss.gormDB, usedContracts) if err != nil { t.Fatal(err) } else if len(contracts) != 1 { @@ -4642,7 +4642,7 @@ func TestFetchUsedContracts(t *testing.T) { } // assert used contracts contains one entry and it points to the renewal - contracts, err = fetchUsedContracts(ss.db, usedContracts) + contracts, err = fetchUsedContracts(ss.gormDB, usedContracts) if err != nil { t.Fatal(err) } else if len(contracts) != 1 { @@ -4658,7 +4658,7 @@ func TestFetchUsedContracts(t *testing.T) { // assert used contracts now contains an entry for both contracts and both // point to the renewed contract - contracts, err = fetchUsedContracts(ss.db, usedContracts) + contracts, err = fetchUsedContracts(ss.gormDB, usedContracts) if err != nil { t.Fatal(err) } else if len(contracts) != 2 { @@ -4685,7 +4685,7 @@ func TestDirectories(t *testing.T) { for _, o := range objects { var dirID int64 - err := ss.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + err := ss.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { var err error dirID, err = tx.MakeDirsForPath(context.Background(), o) return err @@ -4730,7 +4730,7 @@ func TestDirectories(t *testing.T) { } var dbDirs []dbDirectory - if err := ss.db.Find(&dbDirs).Error; err != nil { + if err := ss.gormDB.Find(&dbDirs).Error; err != nil { t.Fatal(err) } else if len(dbDirs) != len(expectedDirs) { t.Fatalf("expected %v dirs, got %v", len(expectedDirs), len(dbDirs)) @@ -4751,7 +4751,7 @@ func TestDirectories(t *testing.T) { }) var n int64 - if err := ss.db.Model(&dbDirectory{}).Count(&n).Error; err != nil { + if err := ss.gormDB.Model(&dbDirectory{}).Count(&n).Error; err != nil { t.Fatal(err) } else if n != 1 { t.Fatal("expected 1 dir, got", n) diff --git a/stores/metrics.go b/stores/metrics.go index 45003dfb9..62dbde8ce 100644 --- a/stores/metrics.go +++ b/stores/metrics.go @@ -9,7 +9,7 @@ import ( ) func (s *SQLStore) ContractMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractMetricsQueryOpts) (metrics []api.ContractMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractMetrics(ctx, start, n, interval, opts) return }) @@ -17,7 +17,7 @@ func (s *SQLStore) ContractMetrics(ctx context.Context, start time.Time, n uint6 } func (s *SQLStore) ContractPruneMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractPruneMetricsQueryOpts) (metrics []api.ContractPruneMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractPruneMetrics(ctx, start, n, interval, opts) return }) @@ -25,7 +25,7 @@ func (s *SQLStore) ContractPruneMetrics(ctx context.Context, start time.Time, n } func (s *SQLStore) ContractSetChurnMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractSetChurnMetricsQueryOpts) (metrics []api.ContractSetChurnMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractSetChurnMetrics(ctx, start, n, interval, opts) return }) @@ -33,7 +33,7 @@ func (s *SQLStore) ContractSetChurnMetrics(ctx context.Context, start time.Time, } func (s *SQLStore) ContractSetMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.ContractSetMetricsQueryOpts) (metrics []api.ContractSetMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.ContractSetMetrics(ctx, start, n, interval, opts) return }) @@ -41,7 +41,7 @@ func (s *SQLStore) ContractSetMetrics(ctx context.Context, start time.Time, n ui } func (s *SQLStore) PerformanceMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.PerformanceMetricsQueryOpts) (metrics []api.PerformanceMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.PerformanceMetrics(ctx, start, n, interval, opts) return }) @@ -49,43 +49,43 @@ func (s *SQLStore) PerformanceMetrics(ctx context.Context, start time.Time, n ui } func (s *SQLStore) RecordContractMetric(ctx context.Context, metrics ...api.ContractMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractMetric(ctx, metrics...) }) } func (s *SQLStore) RecordContractPruneMetric(ctx context.Context, metrics ...api.ContractPruneMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractPruneMetric(ctx, metrics...) }) } func (s *SQLStore) RecordContractSetChurnMetric(ctx context.Context, metrics ...api.ContractSetChurnMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractSetChurnMetric(ctx, metrics...) }) } func (s *SQLStore) RecordContractSetMetric(ctx context.Context, metrics ...api.ContractSetMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordContractSetMetric(ctx, metrics...) }) } func (s *SQLStore) RecordPerformanceMetric(ctx context.Context, metrics ...api.PerformanceMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordPerformanceMetric(ctx, metrics...) }) } func (s *SQLStore) RecordWalletMetric(ctx context.Context, metrics ...api.WalletMetric) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.RecordWalletMetric(ctx, metrics...) }) } func (s *SQLStore) WalletMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.WalletMetricsQueryOpts) (metrics []api.WalletMetric, err error) { - err = s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { + err = s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) (txErr error) { metrics, txErr = tx.WalletMetrics(ctx, start, n, interval, opts) return }) @@ -93,7 +93,7 @@ func (s *SQLStore) WalletMetrics(ctx context.Context, start time.Time, n uint64, } func (s *SQLStore) PruneMetrics(ctx context.Context, metric string, cutoff time.Time) error { - return s.bMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { + return s.dbMetrics.Transaction(ctx, func(tx sql.MetricsDatabaseTx) error { return tx.PruneMetrics(ctx, metric, cutoff) }) } diff --git a/stores/multipart.go b/stores/multipart.go index 95aad7104..ec987619f 100644 --- a/stores/multipart.go +++ b/stores/multipart.go @@ -12,7 +12,7 @@ import ( func (s *SQLStore) CreateMultipartUpload(ctx context.Context, bucket, path string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (api.MultipartCreateResponse, error) { var uploadID string - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { uploadID, err = tx.InsertMultipartUpload(ctx, bucket, path, ec, mimeType, metadata) return }) @@ -25,13 +25,13 @@ func (s *SQLStore) CreateMultipartUpload(ctx context.Context, bucket, path strin } func (s *SQLStore) AddMultipartPart(ctx context.Context, bucket, path, contractSet, eTag, uploadID string, partNumber int, slices []object.SlabSlice) (err error) { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.AddMultipartPart(ctx, bucket, path, contractSet, eTag, uploadID, partNumber, slices) }) } func (s *SQLStore) MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.MultipartUpload(ctx, uploadID) return }) @@ -39,7 +39,7 @@ func (s *SQLStore) MultipartUpload(ctx context.Context, uploadID string) (resp a } func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMarker, uploadIDMarker string, limit int) (resp api.MultipartListUploadsResponse, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.MultipartUploads(ctx, bucket, prefix, keyMarker, uploadIDMarker, limit) return }) @@ -47,7 +47,7 @@ func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMark } func (s *SQLStore) MultipartUploadParts(ctx context.Context, bucket, object string, uploadID string, marker int, limit int64) (resp api.MultipartListPartsResponse, _ error) { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { resp, err = tx.MultipartUploadParts(ctx, bucket, object, uploadID, marker, limit) return }) @@ -55,7 +55,7 @@ func (s *SQLStore) MultipartUploadParts(ctx context.Context, bucket, object stri } func (s *SQLStore) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.AbortMultipartUpload(ctx, bucket, path, uploadID) }) if err != nil { @@ -80,7 +80,7 @@ func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path str var eTag string var prune bool - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { // Delete potentially existing object. prune, err = tx.DeleteObject(ctx, bucket, path) if err != nil { diff --git a/stores/multipart_test.go b/stores/multipart_test.go index 762ea45a9..369ee7b03 100644 --- a/stores/multipart_test.go +++ b/stores/multipart_test.go @@ -73,7 +73,7 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { // Assert metadata was persisted and is linked to the multipart upload var metadatas []dbObjectUserMetadata - if err := ss.db.Model(&dbObjectUserMetadata{}).Find(&metadatas).Error; err != nil { + if err := ss.gormDB.Model(&dbObjectUserMetadata{}).Find(&metadatas).Error; err != nil { t.Fatal(err) } else if len(metadatas) != len(testMetadata) { t.Fatal("expected metadata to be persisted") @@ -87,13 +87,13 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { // Complete the upload. Check that the number of slices stays the same. var nSlicesBefore int64 var nSlicesAfter int64 - if err := ss.db.Model(&dbSlice{}).Count(&nSlicesBefore).Error; err != nil { + if err := ss.gormDB.Model(&dbSlice{}).Count(&nSlicesBefore).Error; err != nil { t.Fatal(err) } else if nSlicesBefore == 0 { t.Fatal("expected some slices") } else if _, err = ss.CompleteMultipartUpload(ctx, api.DefaultBucketName, objName, resp.UploadID, parts, api.CompleteMultipartOptions{}); err != nil { t.Fatal(err) - } else if err := ss.db.Model(&dbSlice{}).Count(&nSlicesAfter).Error; err != nil { + } else if err := ss.gormDB.Model(&dbSlice{}).Count(&nSlicesAfter).Error; err != nil { t.Fatal(err) } else if nSlicesBefore != nSlicesAfter { t.Fatalf("expected number of slices to stay the same, but got %v before and %v after", nSlicesBefore, nSlicesAfter) @@ -115,7 +115,7 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { } // Assert metadata was converted and the multipart upload id was nullified - if err := ss.db.Model(&dbObjectUserMetadata{}).Find(&metadatas).Error; err != nil { + if err := ss.gormDB.Model(&dbObjectUserMetadata{}).Find(&metadatas).Error; err != nil { t.Fatal(err) } else if len(metadatas) != len(testMetadata) { t.Fatal("expected metadata to be persisted") diff --git a/stores/peers.go b/stores/peers.go index 3c6f6036b..5937a9d68 100644 --- a/stores/peers.go +++ b/stores/peers.go @@ -15,14 +15,14 @@ var ( // AddPeer adds a peer to the store. If the peer already exists, nil should be // returned. func (s *SQLStore) AddPeer(addr string) error { - return s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { return tx.AddPeer(context.Background(), addr) }) } // Peers returns the set of known peers. func (s *SQLStore) Peers() (peers []syncer.PeerInfo, err error) { - err = s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { peers, txErr = tx.Peers(context.Background()) return }) @@ -32,7 +32,7 @@ func (s *SQLStore) Peers() (peers []syncer.PeerInfo, err error) { // PeerInfo returns the metadata for the specified peer or ErrPeerNotFound // if the peer wasn't found in the store. func (s *SQLStore) PeerInfo(addr string) (info syncer.PeerInfo, err error) { - err = s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { info, txErr = tx.PeerInfo(context.Background(), addr) return }) @@ -42,7 +42,7 @@ func (s *SQLStore) PeerInfo(addr string) (info syncer.PeerInfo, err error) { // UpdatePeerInfo updates the metadata for the specified peer. If the peer // is not found, the error should be ErrPeerNotFound. func (s *SQLStore) UpdatePeerInfo(addr string, fn func(*syncer.PeerInfo)) error { - return s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { return tx.UpdatePeerInfo(context.Background(), addr, fn) }) } @@ -50,14 +50,14 @@ func (s *SQLStore) UpdatePeerInfo(addr string, fn func(*syncer.PeerInfo)) error // Ban temporarily bans one or more IPs. The addr should either be a single // IP with port (e.g. 1.2.3.4:5678) or a CIDR subnet (e.g. 1.2.3.4/16). func (s *SQLStore) Ban(addr string, duration time.Duration, reason string) error { - return s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) error { + return s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) error { return tx.BanPeer(context.Background(), addr, duration, reason) }) } // Banned returns true, nil if the peer is banned. func (s *SQLStore) Banned(addr string) (banned bool, err error) { - err = s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (txErr error) { banned, txErr = tx.PeerBanned(context.Background(), addr) return }) diff --git a/stores/settingsdb.go b/stores/settingsdb.go index 08d0d3faf..7a895108c 100644 --- a/stores/settingsdb.go +++ b/stores/settingsdb.go @@ -13,7 +13,7 @@ func (s *SQLStore) DeleteSetting(ctx context.Context, key string) error { defer s.settingsMu.Unlock() // delete from database first - if err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + if err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.DeleteSettings(ctx, key) }); err != nil { return err @@ -36,7 +36,7 @@ func (s *SQLStore) Setting(ctx context.Context, key string) (string, error) { // Check database. var err error - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { value, err = tx.Setting(ctx, key) return err }) @@ -49,7 +49,7 @@ func (s *SQLStore) Setting(ctx context.Context, key string) (string, error) { // Settings implements the bus.SettingStore interface. func (s *SQLStore) Settings(ctx context.Context) (settings []string, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { settings, err = tx.Settings(ctx) return err }) @@ -62,7 +62,7 @@ func (s *SQLStore) UpdateSetting(ctx context.Context, key, value string) error { s.settingsMu.Lock() defer s.settingsMu.Unlock() - err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.UpdateSetting(ctx, key, value) }) if err != nil { diff --git a/stores/slabbuffer_test.go b/stores/slabbuffer_test.go index 145a1aae6..27098a350 100644 --- a/stores/slabbuffer_test.go +++ b/stores/slabbuffer_test.go @@ -13,7 +13,7 @@ func TestRecordAppendToCompletedBuffer(t *testing.T) { defer ss.Close() completionThreshold := int64(1000) - mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.bMain, ss.logger, completionThreshold, t.TempDir()) + mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.db, ss.logger, completionThreshold, t.TempDir()) if err != nil { t.Fatal(err) } @@ -21,7 +21,7 @@ func TestRecordAppendToCompletedBuffer(t *testing.T) { // get contract set for its id var set dbContractSet - if err := ss.db.Where("name", testContractSet).Take(&set).Error; err != nil { + if err := ss.gormDB.Where("name", testContractSet).Take(&set).Error; err != nil { t.Fatal(err) } @@ -66,7 +66,7 @@ func TestMarkBufferCompleteTwice(t *testing.T) { ss := newTestSQLStore(t, defaultTestSQLStoreConfig) defer ss.Close() - mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.bMain, ss.logger, 0, t.TempDir()) + mgr, err := newSlabBufferManager(context.Background(), ss.alerts, ss.db, ss.logger, 0, t.TempDir()) if err != nil { t.Fatal(err) } @@ -74,7 +74,7 @@ func TestMarkBufferCompleteTwice(t *testing.T) { // get contract set for its id var set dbContractSet - if err := ss.db.Where("name", testContractSet).Take(&set).Error; err != nil { + if err := ss.gormDB.Where("name", testContractSet).Take(&set).Error; err != nil { t.Fatal(err) } diff --git a/stores/sql.go b/stores/sql.go index 8465e671b..7315bf96a 100644 --- a/stores/sql.go +++ b/stores/sql.go @@ -50,11 +50,10 @@ type ( // SQLStore is a helper type for interacting with a SQL-based backend. SQLStore struct { - alerts alerts.Alerter - db *gorm.DB - bMain sql.Database - bMetrics sql.MetricsDatabase - logger *zap.SugaredLogger + alerts alerts.Alerter + db sql.Database + dbMetrics sql.MetricsDatabase + logger *zap.SugaredLogger walletAddress types.Address @@ -76,6 +75,8 @@ type ( mu sync.Mutex lastPrunedAt time.Time closed bool + + gormDB *gorm.DB // deprecated: don't use } ) @@ -168,11 +169,11 @@ func NewSQLStore(cfg Config) (*SQLStore, error) { shutdownCtx, shutdownCtxCancel := context.WithCancel(context.Background()) ss := &SQLStore{ - alerts: cfg.Alerts, - db: db, - bMain: dbMain, - bMetrics: dbMetrics, - logger: l, + alerts: cfg.Alerts, + gormDB: db, + db: dbMain, + dbMetrics: dbMetrics, + logger: l, settings: make(map[string]string), walletAddress: cfg.WalletAddress, @@ -215,7 +216,7 @@ func (s *SQLStore) initSlabPruning() error { }() // prune once to guarantee consistency on startup - return s.bMain.Transaction(s.shutdownCtx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(s.shutdownCtx, func(tx sql.DatabaseTx) error { _, err := tx.PruneSlabs(s.shutdownCtx, math.MaxInt64) return err }) @@ -230,11 +231,11 @@ func (s *SQLStore) Close() error { return err } - err = s.bMain.Close() + err = s.db.Close() if err != nil { return err } - err = s.bMetrics.Close() + err = s.dbMetrics.Close() if err != nil { return err } @@ -246,7 +247,7 @@ func (s *SQLStore) Close() error { } func (s *SQLStore) retryTransaction(ctx context.Context, fc func(tx *gorm.DB) error) error { - return retryTransaction(ctx, s.db, s.logger, s.retryTransactionIntervals, fc, s.retryAbortFn) + return retryTransaction(ctx, s.gormDB, s.logger, s.retryTransactionIntervals, fc, s.retryAbortFn) } func (s *SQLStore) retryAbortFn(err error) bool { diff --git a/stores/sql/main.go b/stores/sql/main.go index 28bc11af0..7238c5439 100644 --- a/stores/sql/main.go +++ b/stores/sql/main.go @@ -2340,7 +2340,11 @@ func Tip(ctx context.Context, tx sql.Tx) (types.ChainIndex, error) { var id Hash256 var height uint64 if err := tx.QueryRow(ctx, "SELECT height, block_id FROM consensus_infos WHERE id = ?", sql.ConsensusInfoID). - Scan(&id, &height); err != nil { + Scan(&height, &id); errors.Is(err, dsql.ErrNoRows) { + // init + _, err = tx.Exec(ctx, "INSERT INTO consensus_infos (id, height, block_id) VALUES (?, ?, ?)", sql.ConsensusInfoID, 0, Hash256{}) + return types.ChainIndex{}, err + } else if err != nil { return types.ChainIndex{}, err } return types.ChainIndex{ diff --git a/stores/sql_test.go b/stores/sql_test.go index 624536c7b..8deb485cf 100644 --- a/stores/sql_test.go +++ b/stores/sql_test.go @@ -180,7 +180,7 @@ func newTestSQLStore(t *testing.T, cfg testSQLStoreConfig) *testSQLStore { } func (s *testSQLStore) DB() *isql.DB { - switch db := s.bMain.(type) { + switch db := s.db.(type) { case *sqlite.MainDatabase: return db.DB() case *mysql.MainDatabase: @@ -192,7 +192,7 @@ func (s *testSQLStore) DB() *isql.DB { } func (s *testSQLStore) DBMetrics() *isql.DB { - switch db := s.bMetrics.(type) { + switch db := s.dbMetrics.(type) { case *sqlite.MetricsDatabase: return db.DB() case *mysql.MetricsDatabase: @@ -212,7 +212,7 @@ func (s *testSQLStore) Close() error { func (s *testSQLStore) DefaultBucketID() uint { var b dbBucket - if err := s.db. + if err := s.gormDB. Model(&dbBucket{}). Where("name = ?", api.DefaultBucketName). Take(&b). @@ -297,7 +297,7 @@ func (s *SQLStore) addTestRenewedContract(fcid, renewedFrom types.FileContractID } func (s *SQLStore) contractsCount() (cnt int64, err error) { - err = s.db. + err = s.gormDB. Model(&dbContract{}). Count(&cnt). Error @@ -305,7 +305,7 @@ func (s *SQLStore) contractsCount() (cnt int64, err error) { } func (s *SQLStore) overrideSlabHealth(objectID string, health float64) (err error) { - err = s.db.Exec(fmt.Sprintf(` + err = s.gormDB.Exec(fmt.Sprintf(` UPDATE slabs SET health = %v WHERE id IN ( SELECT * FROM ( SELECT sla.id @@ -371,16 +371,16 @@ func TestQueryPlan(t *testing.T) { } for _, query := range queries { - if isSQLite(ss.db) { + if isSQLite(ss.gormDB) { var explain sqliteQueryPlan - if err := ss.db.Raw(fmt.Sprintf("EXPLAIN QUERY PLAN %s;", query)).Scan(&explain).Error; err != nil { + if err := ss.gormDB.Raw(fmt.Sprintf("EXPLAIN QUERY PLAN %s;", query)).Scan(&explain).Error; err != nil { t.Fatal(err) } else if !explain.usesIndex() { t.Fatalf("query '%s' should use an index, instead the plan was %+v", query, explain) } } else { var explain mysqlQueryPlan - if err := ss.db.Raw(fmt.Sprintf("EXPLAIN %s;", query)).Scan(&explain).Error; err != nil { + if err := ss.gormDB.Raw(fmt.Sprintf("EXPLAIN %s;", query)).Scan(&explain).Error; err != nil { t.Fatal(err) } else if !explain.usesIndex() { t.Fatalf("query '%s' should use an index, instead the plan was %+v", query, explain) diff --git a/stores/types.go b/stores/types.go index 7740287fd..fd1950b27 100644 --- a/stores/types.go +++ b/stores/types.go @@ -28,7 +28,6 @@ type ( currency types.Currency bCurrency types.Currency fileContractID types.FileContractID - hash256 types.Hash256 publicKey types.PublicKey hostSettings rhpv2.HostSettings hostPriceTable rhpv3.HostPriceTable @@ -100,29 +99,6 @@ func (k secretKey) Value() (driver.Value, error) { return []byte(k), nil } -// GormDataType implements gorm.GormDataTypeInterface. -func (hash256) GormDataType() string { - return "bytes" -} - -// Scan scan value into address, implements sql.Scanner interface. -func (h *hash256) Scan(value interface{}) error { - bytes, ok := value.([]byte) - if !ok { - return errors.New(fmt.Sprint("failed to unmarshal hash256 value:", value)) - } - if len(bytes) != len(hash256{}) { - return fmt.Errorf("failed to unmarshal hash256 value due to invalid number of bytes %v != %v: %v", len(bytes), len(fileContractID{}), value) - } - *h = *(*hash256)(bytes) - return nil -} - -// Value returns an addr value, implements driver.Valuer interface. -func (h hash256) Value() (driver.Value, error) { - return h[:], nil -} - // GormDataType implements gorm.GormDataTypeInterface. func (fileContractID) GormDataType() string { return "bytes" diff --git a/stores/types_test.go b/stores/types_test.go index 9e03078a0..7bd4c75a1 100644 --- a/stores/types_test.go +++ b/stores/types_test.go @@ -25,24 +25,24 @@ func TestTypeCurrency(t *testing.T) { defer ss.Close() // prepare the table - if isSQLite(ss.db) { - if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil { + if isSQLite(ss.gormDB) { + if err := ss.gormDB.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil { t.Fatal(err) } } else { - if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil { + if err := ss.gormDB.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil { t.Fatal(err) } } // insert currencies in random order - if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil { + if err := ss.gormDB.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil { t.Fatal(err) } // fetch currencies and assert they're sorted var currencies []bCurrency - if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(¤cies).Error; err != nil { + if err := ss.gormDB.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(¤cies).Error; err != nil { t.Fatal(err) } else if !sort.SliceIsSorted(currencies, func(i, j int) bool { return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0 @@ -99,10 +99,10 @@ func TestTypeCurrency(t *testing.T) { for i, test := range tests { var result bool query := fmt.Sprintf("SELECT ? %s ?", test.cmp) - if !isSQLite(ss.db) { + if !isSQLite(ss.gormDB) { query = strings.ReplaceAll(query, "?", "HEX(?)") } - if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil { + if err := ss.gormDB.Raw(query, test.a, test.b).Scan(&result).Error; err != nil { t.Fatal(err) } else if !result { t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String()) @@ -121,13 +121,13 @@ func TestTypeMerkleProof(t *testing.T) { defer ss.Close() // prepare the table - if isSQLite(ss.db) { - if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INTEGER PRIMARY KEY AUTOINCREMENT,merkle_proof BLOB);").Error; err != nil { + if isSQLite(ss.gormDB) { + if err := ss.gormDB.Exec("CREATE TABLE merkle_proofs (id INTEGER PRIMARY KEY AUTOINCREMENT,merkle_proof BLOB);").Error; err != nil { t.Fatal(err) } } else { - ss.db.Exec("DROP TABLE IF EXISTS merkle_proofs;") - if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INT AUTO_INCREMENT PRIMARY KEY, merkle_proof BLOB);").Error; err != nil { + ss.gormDB.Exec("DROP TABLE IF EXISTS merkle_proofs;") + if err := ss.gormDB.Exec("CREATE TABLE merkle_proofs (id INT AUTO_INCREMENT PRIMARY KEY, merkle_proof BLOB);").Error; err != nil { t.Fatal(err) } } @@ -135,13 +135,13 @@ func TestTypeMerkleProof(t *testing.T) { // insert merkle proof mp1 := merkleProof{proof: []types.Hash256{{3}, {1}, {2}}} mp2 := merkleProof{proof: []types.Hash256{{4}}} - if err := ss.db.Exec("INSERT INTO merkle_proofs (merkle_proof) VALUES (?), (?);", mp1, mp2).Error; err != nil { + if err := ss.gormDB.Exec("INSERT INTO merkle_proofs (merkle_proof) VALUES (?), (?);", mp1, mp2).Error; err != nil { t.Fatal(err) } // fetch first proof var first merkleProof - if err := ss.db. + if err := ss.gormDB. Raw(`SELECT merkle_proof FROM merkle_proofs`). Take(&first). Error; err != nil { @@ -152,7 +152,7 @@ func TestTypeMerkleProof(t *testing.T) { // fetch both proofs var both []merkleProof - if err := ss.db. + if err := ss.gormDB. Raw(`SELECT merkle_proof FROM merkle_proofs`). Scan(&both). Error; err != nil { diff --git a/stores/wallet.go b/stores/wallet.go index d5c3e9a84..ff8053358 100644 --- a/stores/wallet.go +++ b/stores/wallet.go @@ -2,12 +2,10 @@ package stores import ( "context" - "errors" "go.sia.tech/core/types" "go.sia.tech/coreutils/wallet" "go.sia.tech/renterd/stores/sql" - "gorm.io/gorm" ) var ( @@ -16,24 +14,17 @@ var ( // Tip returns the consensus change ID and block height of the last wallet // change. -func (s *SQLStore) Tip() (types.ChainIndex, error) { - var cs dbConsensusInfo - if err := s.db. - Model(&dbConsensusInfo{}). - First(&cs).Error; errors.Is(err, gorm.ErrRecordNotFound) { - return types.ChainIndex{}, nil - } else if err != nil { - return types.ChainIndex{}, err - } - return types.ChainIndex{ - Height: cs.Height, - ID: types.BlockID(cs.BlockID), - }, nil +func (s *SQLStore) Tip() (ci types.ChainIndex, err error) { + err = s.db.Transaction(s.shutdownCtx, func(tx sql.DatabaseTx) error { + ci, err = tx.Tip(s.shutdownCtx) + return err + }) + return } // UnspentSiacoinElements returns a list of all unspent siacoin outputs func (s *SQLStore) UnspentSiacoinElements() (elements []types.SiacoinElement, err error) { - err = s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { elements, err = tx.UnspentSiacoinElements(context.Background()) return }) @@ -43,7 +34,7 @@ func (s *SQLStore) UnspentSiacoinElements() (elements []types.SiacoinElement, er // WalletEvents returns a paginated list of events, ordered by maturity height, // descending. If no more events are available, (nil, nil) is returned. func (s *SQLStore) WalletEvents(offset, limit int) (events []wallet.Event, err error) { - err = s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { events, err = tx.WalletEvents(context.Background(), offset, limit) return }) @@ -52,7 +43,7 @@ func (s *SQLStore) WalletEvents(offset, limit int) (events []wallet.Event, err e // WalletEventCount returns the number of events relevant to the wallet. func (s *SQLStore) WalletEventCount() (count uint64, err error) { - err = s.bMain.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { + err = s.db.Transaction(context.Background(), func(tx sql.DatabaseTx) (err error) { count, err = tx.WalletEventCount(context.Background()) return }) diff --git a/stores/webhooks.go b/stores/webhooks.go index 02516c419..e7e3782ea 100644 --- a/stores/webhooks.go +++ b/stores/webhooks.go @@ -8,19 +8,19 @@ import ( ) func (s *SQLStore) AddWebhook(ctx context.Context, wh webhooks.Webhook) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.AddWebhook(ctx, wh) }) } func (s *SQLStore) DeleteWebhook(ctx context.Context, wh webhooks.Webhook) error { - return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { return tx.DeleteWebhook(ctx, wh) }) } func (s *SQLStore) Webhooks(ctx context.Context) (whs []webhooks.Webhook, err error) { - err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error { whs, err = tx.Webhooks(ctx) return err })