Skip to content

Commit

Permalink
Refactor to get storage map with common.StorageDomain
Browse files Browse the repository at this point in the history
Currently, we use domain string to get StorageMap.  Even though
we use pre-defined domain strings, this can still be error-prone.

This commit modifies GetStorageMap() to use common.StorageDomain
instead of string.
  • Loading branch information
fxamacker committed Nov 8, 2024
1 parent d3bfdb8 commit 2f8c298
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 77 deletions.
2 changes: 1 addition & 1 deletion cmd/decode-state-values/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ type interpreterStorage struct {

var _ interpreter.Storage = &interpreterStorage{}

func (i interpreterStorage) GetStorageMap(_ common.Address, _ string, _ bool) *interpreter.StorageMap {
func (i interpreterStorage) GetStorageMap(_ common.Address, _ common.StorageDomain, _ bool) *interpreter.StorageMap {
panic("unexpected GetStorageMap call")
}

Expand Down
15 changes: 15 additions & 0 deletions common/pathdomain.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,18 @@ func (i PathDomain) Identifier() string {

panic(errors.NewUnreachableError())
}

func (i PathDomain) StorageDomain() StorageDomain {
switch i {
case PathDomainStorage:
return StorageDomainPathStorage

case PathDomainPrivate:
return StorageDomainPathPrivate

case PathDomainPublic:
return StorageDomainPathPublic
}

panic(errors.NewUnreachableError())
}
33 changes: 24 additions & 9 deletions common/storagedomain.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package common

import (
"fmt"

"github.com/onflow/cadence/errors"
)

Expand All @@ -27,11 +29,11 @@ type StorageDomain uint8
const (
StorageDomainUnknown StorageDomain = iota

StorageDomainStorage
StorageDomainPathStorage

StorageDomainPrivate
StorageDomainPathPrivate

StorageDomainPublic
StorageDomainPathPublic

StorageDomainContract

Expand All @@ -55,9 +57,9 @@ const (
)

var AllStorageDomains = []StorageDomain{
StorageDomainStorage,
StorageDomainPrivate,
StorageDomainPublic,
StorageDomainPathStorage,
StorageDomainPathPrivate,
StorageDomainPathPublic,
StorageDomainContract,
StorageDomainInbox,
StorageDomainCapabilityController,
Expand All @@ -68,10 +70,14 @@ var AllStorageDomains = []StorageDomain{

var AllStorageDomainsByIdentifier = map[string]StorageDomain{}

var allStorageDomainsSet = map[StorageDomain]struct{}{}

func init() {
for _, domain := range AllStorageDomains {
identifier := domain.Identifier()
AllStorageDomainsByIdentifier[identifier] = domain

allStorageDomainsSet[domain] = struct{}{}
}
}

Expand All @@ -83,15 +89,24 @@ func StorageDomainFromIdentifier(domain string) (StorageDomain, bool) {
return result, true
}

func StorageDomainFromUint64(i uint64) (StorageDomain, error) {
d := StorageDomain(i)
_, exists := allStorageDomainsSet[d]
if !exists {
return StorageDomainUnknown, fmt.Errorf("failed to convert %d to StorageDomain", i)
}
return d, nil
}

func (d StorageDomain) Identifier() string {
switch d {
case StorageDomainStorage:
case StorageDomainPathStorage:
return PathDomainStorage.Identifier()

case StorageDomainPrivate:
case StorageDomainPathPrivate:
return PathDomainPrivate.Identifier()

case StorageDomainPublic:
case StorageDomainPathPublic:
return PathDomainPublic.Identifier()

case StorageDomainContract:
Expand Down
2 changes: 1 addition & 1 deletion interpreter/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func testAccountWithErrorHandler(
}
storageKey := storageKey{
address: storageMapKey.Address,
domain: storageMapKey.Key,
domain: storageMapKey.Domain.Identifier(),
key: key,
}
accountValues[storageKey] = value
Expand Down
20 changes: 10 additions & 10 deletions interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (c TypeCodes) Merge(codes TypeCodes) {

type Storage interface {
atree.SlabStorage
GetStorageMap(address common.Address, domain string, createIfNotExists bool) *StorageMap
GetStorageMap(address common.Address, domain common.StorageDomain, createIfNotExists bool) *StorageMap
CheckHealth() error
}

Expand Down Expand Up @@ -2678,7 +2678,7 @@ func (interpreter *Interpreter) NewSubInterpreter(

func (interpreter *Interpreter) StoredValueExists(
storageAddress common.Address,
domain string,
domain common.StorageDomain,
identifier StorageMapKey,
) bool {
accountStorage := interpreter.Storage().GetStorageMap(storageAddress, domain, false)
Expand All @@ -2690,7 +2690,7 @@ func (interpreter *Interpreter) StoredValueExists(

func (interpreter *Interpreter) ReadStored(
storageAddress common.Address,
domain string,
domain common.StorageDomain,
identifier StorageMapKey,
) Value {
accountStorage := interpreter.Storage().GetStorageMap(storageAddress, domain, false)
Expand All @@ -2702,7 +2702,7 @@ func (interpreter *Interpreter) ReadStored(

func (interpreter *Interpreter) WriteStored(
storageAddress common.Address,
domain string,
domain common.StorageDomain,
key StorageMapKey,
value Value,
) (existed bool) {
Expand Down Expand Up @@ -4069,7 +4069,7 @@ func (interpreter *Interpreter) IsSubTypeOfSemaType(staticSubType StaticType, su
}

func (interpreter *Interpreter) domainPaths(address common.Address, domain common.PathDomain) []Value {
storageMap := interpreter.Storage().GetStorageMap(address, domain.Identifier(), false)
storageMap := interpreter.Storage().GetStorageMap(address, domain.StorageDomain(), false)
if storageMap == nil {
return []Value{}
}
Expand Down Expand Up @@ -4164,7 +4164,7 @@ func (interpreter *Interpreter) newStorageIterationFunction(
parameterTypes := fnType.ParameterTypes()
returnType := fnType.ReturnTypeAnnotation.Type

storageMap := config.Storage.GetStorageMap(address, domain.Identifier(), false)
storageMap := config.Storage.GetStorageMap(address, domain.StorageDomain(), false)
if storageMap == nil {
// if nothing is stored, no iteration is required
return Void
Expand Down Expand Up @@ -4327,7 +4327,7 @@ func (interpreter *Interpreter) authAccountSaveFunction(
panic(errors.NewUnreachableError())
}

domain := path.Domain.Identifier()
domain := path.Domain.StorageDomain()
identifier := path.Identifier

// Prevent an overwrite
Expand Down Expand Up @@ -4390,7 +4390,7 @@ func (interpreter *Interpreter) authAccountTypeFunction(
panic(errors.NewUnreachableError())
}

domain := path.Domain.Identifier()
domain := path.Domain.StorageDomain()
identifier := path.Identifier

storageMapKey := StringStorageMapKey(identifier)
Expand Down Expand Up @@ -4448,7 +4448,7 @@ func (interpreter *Interpreter) authAccountReadFunction(
panic(errors.NewUnreachableError())
}

domain := path.Domain.Identifier()
domain := path.Domain.StorageDomain()
identifier := path.Identifier

storageMapKey := StringStorageMapKey(identifier)
Expand Down Expand Up @@ -4589,7 +4589,7 @@ func (interpreter *Interpreter) authAccountCheckFunction(
panic(errors.NewUnreachableError())
}

domain := path.Domain.Identifier()
domain := path.Domain.StorageDomain()
identifier := path.Identifier

storageMapKey := StringStorageMapKey(identifier)
Expand Down
2 changes: 1 addition & 1 deletion interpreter/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5349,7 +5349,7 @@ func TestInterpretReferenceFailableDowncasting(t *testing.T) {
true, // r is standalone.
)

domain := storagePath.Domain.Identifier()
domain := storagePath.Domain.StorageDomain()
storageMap := storage.GetStorageMap(storageAddress, domain, true)
storageMapKey := interpreter.StringStorageMapKey(storagePath.Identifier)
storageMap.WriteValue(inter, storageMapKey, r)
Expand Down
25 changes: 21 additions & 4 deletions interpreter/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ func ConvertStoredValue(gauge common.MemoryGauge, value atree.Value) (Value, err
}
}

type StorageDomainKey struct {
Domain common.StorageDomain
Address common.Address
}

func NewStorageDomainKey(
memoryGauge common.MemoryGauge,
address common.Address,
domain common.StorageDomain,
) StorageDomainKey {
common.UseMemory(memoryGauge, common.StorageKeyMemoryUsage)
return StorageDomainKey{
Address: address,
Domain: domain,
}
}

type StorageKey struct {
Key string
Address common.Address
Expand Down Expand Up @@ -130,7 +147,7 @@ func (k StorageKey) IsLess(o StorageKey) bool {
// InMemoryStorage
type InMemoryStorage struct {
*atree.BasicSlabStorage
StorageMaps map[StorageKey]*StorageMap
StorageMaps map[StorageDomainKey]*StorageMap
memoryGauge common.MemoryGauge
}

Expand Down Expand Up @@ -158,19 +175,19 @@ func NewInMemoryStorage(memoryGauge common.MemoryGauge) InMemoryStorage {

return InMemoryStorage{
BasicSlabStorage: slabStorage,
StorageMaps: make(map[StorageKey]*StorageMap),
StorageMaps: make(map[StorageDomainKey]*StorageMap),
memoryGauge: memoryGauge,
}
}

func (i InMemoryStorage) GetStorageMap(
address common.Address,
domain string,
domain common.StorageDomain,
createIfNotExists bool,
) (
storageMap *StorageMap,
) {
key := NewStorageKey(i.memoryGauge, address, domain)
key := NewStorageDomainKey(i.memoryGauge, address, domain)
storageMap = i.StorageMaps[key]
if storageMap == nil && createIfNotExists {
storageMap = NewStorageMap(i.memoryGauge, i, atree.Address(address))
Expand Down
2 changes: 1 addition & 1 deletion interpreter/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ func TestStorageOverwriteAndRemove(t *testing.T) {

const storageMapKey = StringStorageMapKey("test")

storageMap := storage.GetStorageMap(address, "storage", true)
storageMap := storage.GetStorageMap(address, common.StorageDomainPathStorage, true)
storageMap.WriteValue(inter, storageMapKey, array1)

// Overwriting delete any existing child slabs
Expand Down
2 changes: 1 addition & 1 deletion interpreter/stringatreevalue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestLargeStringAtreeValueInSeparateSlab(t *testing.T) {

storageMap := storage.GetStorageMap(
common.MustBytesToAddress([]byte{0x1}),
common.PathDomainStorage.Identifier(),
common.PathDomainStorage.StorageDomain(),
true,
)

Expand Down
2 changes: 1 addition & 1 deletion interpreter/value_storage_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (*StorageReferenceValue) IsImportable(_ *Interpreter, _ LocationRange) bool

func (v *StorageReferenceValue) dereference(interpreter *Interpreter, locationRange LocationRange) (*Value, error) {
address := v.TargetStorageAddress
domain := v.TargetPath.Domain.Identifier()
domain := v.TargetPath.Domain.StorageDomain()
identifier := v.TargetPath.Identifier

storageMapKey := StringStorageMapKey(identifier)
Expand Down
2 changes: 1 addition & 1 deletion interpreter/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3806,7 +3806,7 @@ func TestValue_ConformsToStaticType(t *testing.T) {
)
require.NoError(t, err)

storageMap := storage.GetStorageMap(testAddress, "storage", true)
storageMap := storage.GetStorageMap(testAddress, common.StorageDomainPathStorage, true)
storageMap.WriteValue(inter, StringStorageMapKey("test"), TrueValue)

value := valueFactory(inter)
Expand Down
8 changes: 4 additions & 4 deletions runtime/capabilitycontrollers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3253,7 +3253,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) {

storageMap := storage.GetStorageMap(
common.MustBytesToAddress([]byte{0x1}),
common.StorageDomainPathCapability.Identifier(),
common.StorageDomainPathCapability,
false,
)
require.Zero(t, storageMap.Count())
Expand Down Expand Up @@ -3842,7 +3842,7 @@ func TestRuntimeCapabilitiesGetBackwardCompatibility(t *testing.T) {

publicStorageMap := storage.GetStorageMap(
testAddress,
common.PathDomainPublic.Identifier(),
common.PathDomainPublic.StorageDomain(),
true,
)

Expand Down Expand Up @@ -3949,7 +3949,7 @@ func TestRuntimeCapabilitiesPublishBackwardCompatibility(t *testing.T) {

publicStorageMap := storage.GetStorageMap(
testAddress,
common.PathDomainStorage.Identifier(),
common.PathDomainStorage.StorageDomain(),
true,
)

Expand Down Expand Up @@ -4039,7 +4039,7 @@ func TestRuntimeCapabilitiesUnpublishBackwardCompatibility(t *testing.T) {

publicStorageMap := storage.GetStorageMap(
testAddress,
common.PathDomainPublic.Identifier(),
common.PathDomainPublic.StorageDomain(),
true,
)

Expand Down
2 changes: 1 addition & 1 deletion runtime/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func TestRuntimeContract(t *testing.T) {

getContractValueExists := func() bool {
storageMap := NewStorage(storage, nil).
GetStorageMap(signerAddress, common.StorageDomainContract.Identifier(), false)
GetStorageMap(signerAddress, common.StorageDomainContract, false)
if storageMap == nil {
return false
}
Expand Down
2 changes: 1 addition & 1 deletion runtime/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ func (e *interpreterEnvironment) loadContract(
if addressLocation, ok := location.(common.AddressLocation); ok {
storageMap := e.storage.GetStorageMap(
addressLocation.Address,
common.StorageDomainContract.Identifier(),
common.StorageDomainContract,
false,
)
if storageMap != nil {
Expand Down
4 changes: 2 additions & 2 deletions runtime/ft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ func TestRuntimeBrokenFungibleTokenRecovery(t *testing.T) {

contractStorage := storage.GetStorageMap(
contractsAddress,
common.StorageDomainContract.Identifier(),
common.StorageDomainContract,
true,
)
contractStorage.SetValue(
Expand Down Expand Up @@ -1120,7 +1120,7 @@ func TestRuntimeBrokenFungibleTokenRecovery(t *testing.T) {

userStorage := storage.GetStorageMap(
userAddress,
common.PathDomainStorage.Identifier(),
common.PathDomainStorage.StorageDomain(),
true,
)
const storagePathIdentifier = "exampleTokenVault"
Expand Down
2 changes: 1 addition & 1 deletion runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ func (r *interpreterRuntime) ReadStored(

pathValue := valueImporter{inter: inter}.importPathValue(path)

domain := pathValue.Domain.Identifier()
domain := pathValue.Domain.StorageDomain()
identifier := pathValue.Identifier

storageMapKey := interpreter.StringStorageMapKey(identifier)
Expand Down
Loading

0 comments on commit 2f8c298

Please sign in to comment.