Skip to content

Commit

Permalink
stores: add merkleproof type
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjan committed Feb 29, 2024
1 parent 7e519a9 commit fd12fd4
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 97 deletions.
97 changes: 0 additions & 97 deletions stores/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -4308,99 +4307,3 @@ func TestUpdateObjectReuseSlab(t *testing.T) {
}
}
}

func TestTypeCurrency(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert currencies in random order
if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil {
t.Fatal(err)
}

// fetch currencies and assert they're sorted
var currencies []bCurrency
if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(&currencies).Error; err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}

// convenience variables
c0 := currencies[0]
c1 := currencies[1]
cM := currencies[2]

tests := []struct {
a bCurrency
b bCurrency
cmp string
}{
{
a: c0,
b: c1,
cmp: "<",
},
{
a: c1,
b: c0,
cmp: ">",
},
{
a: c0,
b: c1,
cmp: "!=",
},
{
a: c1,
b: c1,
cmp: "=",
},
{
a: c0,
b: cM,
cmp: "<",
},
{
a: cM,
b: c0,
cmp: ">",
},
{
a: cM,
b: cM,
cmp: "=",
},
}
for i, test := range tests {
var result bool
query := fmt.Sprintf("SELECT ? %s ?", test.cmp)
if !isSQLite(ss.db) {
query = strings.Replace(query, "?", "HEX(?)", -1)
}
if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil {
t.Fatal(err)
} else if !result {
t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String())
} else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 {
t.Fatal("invalid result")
} else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 {
t.Fatal("invalid result")
} else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 {
t.Fatal("invalid result")
}
}
}
39 changes: 39 additions & 0 deletions stores/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
)

const (
proofHashSize = 32
secretKeySize = 32
)

Expand All @@ -35,6 +36,10 @@ type (
balance big.Int
unsigned64 uint64 // used for storing large uint64 values in sqlite
secretKey []byte

// NOTE: we have to wrap the proof here because Gorm can't scan bytes into
// multiple slices, all bytes are scanned into the first row
merkleProof struct{ proof []types.Hash256 }
)

// GormDataType implements gorm.GormDataTypeInterface.
Expand Down Expand Up @@ -341,6 +346,7 @@ func (u unsigned64) Value() (driver.Value, error) {
return int64(u), nil
}

// GormDataType implements gorm.GormDataTypeInterface.
func (bCurrency) GormDataType() string {
return "bytes"
}
Expand All @@ -366,3 +372,36 @@ func (sc bCurrency) Value() (driver.Value, error) {
binary.BigEndian.PutUint64(buf[8:], sc.Lo)
return buf, nil
}

// GormDataType implements gorm.GormDataTypeInterface.
func (mp *merkleProof) GormDataType() string {
return "bytes"
}

// Scan scans value into mp, implements sql.Scanner interface.
func (mp *merkleProof) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("failed to unmarshal merkleProof value:", value))
} else if len(bytes) == 0 || len(bytes)%proofHashSize != 0 {
return fmt.Errorf("failed to unmarshal merkleProof value due to invalid number of bytes %v", len(bytes))
}

n := len(bytes) / proofHashSize
mp.proof = make([]types.Hash256, n)
for i := 0; i < n; i++ {
copy(mp.proof[i][:], bytes[:proofHashSize])
bytes = bytes[proofHashSize:]
}
return nil
}

// Value returns a merkle proof value, implements driver.Valuer interface.
func (mp merkleProof) Value() (driver.Value, error) {
var i int
out := make([]byte, len(mp.proof)*proofHashSize)
for _, ph := range mp.proof {
i += copy(out[i:], ph[:])
}
return out, nil
}
154 changes: 154 additions & 0 deletions stores/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package stores

import (
"fmt"
"sort"
"strings"
"testing"

"go.sia.tech/core/types"
)

func TestTypeCurrency(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert currencies in random order
if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil {
t.Fatal(err)
}

// fetch currencies and assert they're sorted
var currencies []bCurrency
if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(&currencies).Error; err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}

// convenience variables
c0 := currencies[0]
c1 := currencies[1]
cM := currencies[2]

tests := []struct {
a bCurrency
b bCurrency
cmp string
}{
{
a: c0,
b: c1,
cmp: "<",
},
{
a: c1,
b: c0,
cmp: ">",
},
{
a: c0,
b: c1,
cmp: "!=",
},
{
a: c1,
b: c1,
cmp: "=",
},
{
a: c0,
b: cM,
cmp: "<",
},
{
a: cM,
b: c0,
cmp: ">",
},
{
a: cM,
b: cM,
cmp: "=",
},
}
for i, test := range tests {
var result bool
query := fmt.Sprintf("SELECT ? %s ?", test.cmp)
if !isSQLite(ss.db) {
query = strings.Replace(query, "?", "HEX(?)", -1)
}
if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil {
t.Fatal(err)
} else if !result {
t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String())
} else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 {
t.Fatal("invalid result")
} else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 {
t.Fatal("invalid result")
} else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 {
t.Fatal("invalid result")
}
}
}

func TestTypeMerkleProof(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INTEGER PRIMARY KEY AUTOINCREMENT,merkle_proof BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
ss.db.Exec("DROP TABLE IF EXISTS merkle_proofs;")
if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INT AUTO_INCREMENT PRIMARY KEY, merkle_proof BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert merkle proof
mp1 := merkleProof{proof: []types.Hash256{{3}, {1}, {2}}}
mp2 := merkleProof{proof: []types.Hash256{{4}}}
if err := ss.db.Exec("INSERT INTO merkle_proofs (merkle_proof) VALUES (?), (?);", mp1, mp2).Error; err != nil {
t.Fatal(err)
}

// fetch first proof
var first merkleProof
if err := ss.db.
Raw(`SELECT merkle_proof FROM merkle_proofs`).
Take(&first).
Error; err != nil {
t.Fatal(err)
} else if first.proof[0] != (types.Hash256{3}) || first.proof[1] != (types.Hash256{1}) || first.proof[2] != (types.Hash256{2}) {
t.Fatalf("unexpected proof %+v", first)
}

// fetch both proofs
var both []merkleProof
if err := ss.db.
Raw(`SELECT merkle_proof FROM merkle_proofs`).
Scan(&both).
Error; err != nil {
t.Fatal(err)
} else if len(both) != 2 {
t.Fatalf("unexpected number of proofs: %d", len(both))
} else if both[1].proof[0] != (types.Hash256{4}) {
t.Fatalf("unexpected proof %+v", both)
}
}

0 comments on commit fd12fd4

Please sign in to comment.