From d3bfdb8bc5711e4cb0e1f298d1b35d6b4c237c43 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Thu, 7 Nov 2024 17:37:32 -0600 Subject: [PATCH 1/2] Refactor storage domains to prevent import cycles Currently, various storage domains are defined in different packages, such as: - common - runtime - stdlib This commit moves storage domains from different packages to a single place: common/storagedomain.go. This change will help us avoid import cycles when using domains across different packages. --- common/storagedomain.go | 117 ++++++++++++++++++++++++ interpreter/interpreter.go | 4 +- interpreter/value_composite.go | 2 +- runtime/capabilitycontrollers_test.go | 2 +- runtime/contract_test.go | 2 +- runtime/environment.go | 2 +- runtime/ft_test.go | 2 +- runtime/runtime_memory_metering_test.go | 2 +- runtime/sharedstate_test.go | 4 +- runtime/storage.go | 4 +- stdlib/account.go | 50 ++++------ stdlib/account_test.go | 11 ++- 12 files changed, 150 insertions(+), 52 deletions(-) create mode 100644 common/storagedomain.go diff --git a/common/storagedomain.go b/common/storagedomain.go new file mode 100644 index 0000000000..741756c44e --- /dev/null +++ b/common/storagedomain.go @@ -0,0 +1,117 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package common + +import ( + "github.com/onflow/cadence/errors" +) + +type StorageDomain uint8 + +const ( + StorageDomainUnknown StorageDomain = iota + + StorageDomainStorage + + StorageDomainPrivate + + StorageDomainPublic + + StorageDomainContract + + StorageDomainInbox + + // StorageDomainCapabilityController is the storage domain which stores + // capability controllers by capability ID + StorageDomainCapabilityController + + // StorageDomainCapabilityControllerTag is the storage domain which stores + // capability controller tags by capability ID + StorageDomainCapabilityControllerTag + + // StorageDomainPathCapability is the storage domain which stores + // capability ID dictionaries (sets) by storage path identifier + StorageDomainPathCapability + + // StorageDomainAccountCapability is the storage domain which + // records active account capability controller IDs + StorageDomainAccountCapability +) + +var AllStorageDomains = []StorageDomain{ + StorageDomainStorage, + StorageDomainPrivate, + StorageDomainPublic, + StorageDomainContract, + StorageDomainInbox, + StorageDomainCapabilityController, + StorageDomainCapabilityControllerTag, + StorageDomainPathCapability, + StorageDomainAccountCapability, +} + +var AllStorageDomainsByIdentifier = map[string]StorageDomain{} + +func init() { + for _, domain := range AllStorageDomains { + identifier := domain.Identifier() + AllStorageDomainsByIdentifier[identifier] = domain + } +} + +func StorageDomainFromIdentifier(domain string) (StorageDomain, bool) { + result, ok := AllStorageDomainsByIdentifier[domain] + if !ok { + return StorageDomainUnknown, false + } + return result, true +} + +func (d StorageDomain) Identifier() string { + switch d { + case StorageDomainStorage: + return PathDomainStorage.Identifier() + + case StorageDomainPrivate: + return PathDomainPrivate.Identifier() + + case StorageDomainPublic: + return PathDomainPublic.Identifier() + + case StorageDomainContract: + return "contract" + + case StorageDomainInbox: + return "inbox" + + case StorageDomainCapabilityController: + return "cap_con" + + case StorageDomainCapabilityControllerTag: + return "cap_tag" + + case StorageDomainPathCapability: + return "path_cap" + + case StorageDomainAccountCapability: + return "acc_cap" + } + + panic(errors.NewUnreachableError()) +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index a163efb01f..1c6533ba76 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -967,7 +967,7 @@ func (interpreter *Interpreter) declareSelfVariable(value Value, locationRange L } func (interpreter *Interpreter) visitAssignment( - transferOperation ast.TransferOperation, + _ ast.TransferOperation, targetGetterSetter getterSetter, targetType sema.Type, valueExpression ast.Expression, valueType sema.Type, position ast.HasPosition, @@ -1271,7 +1271,7 @@ func (declarationInterpreter *Interpreter) declareNonEnumCompositeValue( functions.Set(resourceDefaultDestroyEventName(compositeType), destroyEventConstructor) } - applyDefaultFunctions := func(ty *sema.InterfaceType, code WrapperCode) { + applyDefaultFunctions := func(_ *sema.InterfaceType, code WrapperCode) { // Apply default functions, if conforming type does not provide the function diff --git a/interpreter/value_composite.go b/interpreter/value_composite.go index d11311d543..cf460d466d 100644 --- a/interpreter/value_composite.go +++ b/interpreter/value_composite.go @@ -1658,7 +1658,7 @@ func (v *CompositeValue) getBaseValue( return NewEphemeralReferenceValue(interpreter, functionAuthorization, v.base, baseType, locationRange) } -func (v *CompositeValue) setBaseValue(interpreter *Interpreter, base *CompositeValue) { +func (v *CompositeValue) setBaseValue(_ *Interpreter, base *CompositeValue) { v.base = base } diff --git a/runtime/capabilitycontrollers_test.go b/runtime/capabilitycontrollers_test.go index fc6f692fe4..0e3427d97c 100644 --- a/runtime/capabilitycontrollers_test.go +++ b/runtime/capabilitycontrollers_test.go @@ -3253,7 +3253,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { storageMap := storage.GetStorageMap( common.MustBytesToAddress([]byte{0x1}), - stdlib.PathCapabilityStorageDomain, + common.StorageDomainPathCapability.Identifier(), false, ) require.Zero(t, storageMap.Count()) diff --git a/runtime/contract_test.go b/runtime/contract_test.go index 2bb89a2cb5..7cd3e72d49 100644 --- a/runtime/contract_test.go +++ b/runtime/contract_test.go @@ -223,7 +223,7 @@ func TestRuntimeContract(t *testing.T) { getContractValueExists := func() bool { storageMap := NewStorage(storage, nil). - GetStorageMap(signerAddress, StorageDomainContract, false) + GetStorageMap(signerAddress, common.StorageDomainContract.Identifier(), false) if storageMap == nil { return false } diff --git a/runtime/environment.go b/runtime/environment.go index 7873d4e990..ca6d48f01a 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -1108,7 +1108,7 @@ func (e *interpreterEnvironment) loadContract( if addressLocation, ok := location.(common.AddressLocation); ok { storageMap := e.storage.GetStorageMap( addressLocation.Address, - StorageDomainContract, + common.StorageDomainContract.Identifier(), false, ) if storageMap != nil { diff --git a/runtime/ft_test.go b/runtime/ft_test.go index b9374d986e..44602a00f0 100644 --- a/runtime/ft_test.go +++ b/runtime/ft_test.go @@ -1085,7 +1085,7 @@ func TestRuntimeBrokenFungibleTokenRecovery(t *testing.T) { contractStorage := storage.GetStorageMap( contractsAddress, - StorageDomainContract, + common.StorageDomainContract.Identifier(), true, ) contractStorage.SetValue( diff --git a/runtime/runtime_memory_metering_test.go b/runtime/runtime_memory_metering_test.go index d6d389ff73..64512a1091 100644 --- a/runtime/runtime_memory_metering_test.go +++ b/runtime/runtime_memory_metering_test.go @@ -930,7 +930,7 @@ func TestRuntimeMemoryMeteringErrors(t *testing.T) { type memoryMeter map[common.MemoryKind]uint64 - runtimeInterface := func(meter memoryMeter) *TestRuntimeInterface { + runtimeInterface := func(memoryMeter) *TestRuntimeInterface { return &TestRuntimeInterface{ OnMeterMemory: func(usage common.MemoryUsage) error { if usage.Kind == common.MemoryKindStringValue || diff --git a/runtime/sharedstate_test.go b/runtime/sharedstate_test.go index 3008c85fff..4e3749a944 100644 --- a/runtime/sharedstate_test.go +++ b/runtime/sharedstate_test.go @@ -213,11 +213,11 @@ func TestRuntimeSharedState(t *testing.T) { []ownerKeyPair{ { owner: signerAddress[:], - key: []byte(StorageDomainContract), + key: []byte(common.StorageDomainContract.Identifier()), }, { owner: signerAddress[:], - key: []byte(StorageDomainContract), + key: []byte(common.StorageDomainContract.Identifier()), }, { owner: signerAddress[:], diff --git a/runtime/storage.go b/runtime/storage.go index 7b9a567285..94cb9431b9 100644 --- a/runtime/storage.go +++ b/runtime/storage.go @@ -32,8 +32,6 @@ import ( "github.com/onflow/cadence/interpreter" ) -const StorageDomainContract = "contract" - type Storage struct { *atree.PersistentSlabStorage NewStorageMaps *orderedmap.OrderedMap[interpreter.StorageKey, atree.SlabIndex] @@ -216,7 +214,7 @@ func (s *Storage) writeContractUpdate( key interpreter.StorageKey, contractValue *interpreter.CompositeValue, ) { - storageMap := s.GetStorageMap(key.Address, StorageDomainContract, true) + storageMap := s.GetStorageMap(key.Address, common.StorageDomainContract.Identifier(), true) // NOTE: pass nil instead of allocating a Value-typed interface that points to nil storageMapKey := interpreter.StringStorageMapKey(key.Key) if contractValue == nil { diff --git a/stdlib/account.go b/stdlib/account.go index 97a8e4d0f3..c1a303a2c5 100644 --- a/stdlib/account.go +++ b/stdlib/account.go @@ -938,8 +938,6 @@ func newAccountKeysRevokeFunction( } } -const InboxStorageDomain = "inbox" - func newAccountInboxPublishFunction( inter *interpreter.Interpreter, handler EventEmitter, @@ -996,7 +994,7 @@ func newAccountInboxPublishFunction( inter.WriteStored( provider, - InboxStorageDomain, + common.StorageDomainInbox.Identifier(), storageMapKey, publishedValue, ) @@ -1029,7 +1027,7 @@ func newAccountInboxUnpublishFunction( storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) - readValue := inter.ReadStored(provider, InboxStorageDomain, storageMapKey) + readValue := inter.ReadStored(provider, common.StorageDomainInbox.Identifier(), storageMapKey) if readValue == nil { return interpreter.Nil } @@ -1065,7 +1063,7 @@ func newAccountInboxUnpublishFunction( inter.WriteStored( provider, - InboxStorageDomain, + common.StorageDomainInbox.Identifier(), storageMapKey, nil, ) @@ -1114,7 +1112,7 @@ func newAccountInboxClaimFunction( storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) - readValue := inter.ReadStored(providerAddress, InboxStorageDomain, storageMapKey) + readValue := inter.ReadStored(providerAddress, common.StorageDomainInbox.Identifier(), storageMapKey) if readValue == nil { return interpreter.Nil } @@ -1155,7 +1153,7 @@ func newAccountInboxClaimFunction( inter.WriteStored( providerAddress, - InboxStorageDomain, + common.StorageDomainInbox.Identifier(), storageMapKey, nil, ) @@ -2983,10 +2981,6 @@ func IssueAccountCapabilityController( return capabilityIDValue } -// CapabilityControllerStorageDomain is the storage domain which stores -// capability controllers by capability ID -const CapabilityControllerStorageDomain = "cap_con" - // storeCapabilityController stores a capability controller in the account's capability ID to controller storage map func storeCapabilityController( inter *interpreter.Interpreter, @@ -2998,7 +2992,7 @@ func storeCapabilityController( existed := inter.WriteStored( address, - CapabilityControllerStorageDomain, + common.StorageDomainCapabilityController.Identifier(), storageMapKey, controller, ) @@ -3017,7 +3011,7 @@ func removeCapabilityController( existed := inter.WriteStored( address, - CapabilityControllerStorageDomain, + common.StorageDomainCapabilityController.Identifier(), storageMapKey, nil, ) @@ -3045,7 +3039,7 @@ func getCapabilityController( readValue := inter.ReadStored( address, - CapabilityControllerStorageDomain, + common.StorageDomainCapabilityController.Identifier(), storageMapKey, ) if readValue == nil { @@ -3225,10 +3219,6 @@ var capabilityIDSetStaticType = &interpreter.DictionaryStaticType{ ValueType: interpreter.NilStaticType, } -// PathCapabilityStorageDomain is the storage domain which stores -// capability ID dictionaries (sets) by storage path identifier -const PathCapabilityStorageDomain = "path_cap" - func recordStorageCapabilityController( inter *interpreter.Interpreter, locationRange interpreter.LocationRange, @@ -3254,7 +3244,7 @@ func recordStorageCapabilityController( storageMap := inter.Storage().GetStorageMap( address, - PathCapabilityStorageDomain, + common.StorageDomainPathCapability.Identifier(), true, ) @@ -3296,7 +3286,7 @@ func getPathCapabilityIDSet( storageMap := inter.Storage().GetStorageMap( address, - PathCapabilityStorageDomain, + common.StorageDomainPathCapability.Identifier(), false, ) if storageMap == nil { @@ -3346,7 +3336,7 @@ func unrecordStorageCapabilityController( if capabilityIDSet.Count() == 0 { storageMap := inter.Storage().GetStorageMap( address, - PathCapabilityStorageDomain, + common.StorageDomainPathCapability.Identifier(), true, ) if storageMap == nil { @@ -3397,10 +3387,6 @@ func getStorageCapabilityControllerIDsIterator( return } -// AccountCapabilityStorageDomain is the storage domain which -// records active account capability controller IDs -const AccountCapabilityStorageDomain = "acc_cap" - func recordAccountCapabilityController( inter *interpreter.Interpreter, locationRange interpreter.LocationRange, @@ -3418,7 +3404,7 @@ func recordAccountCapabilityController( storageMap := inter.Storage().GetStorageMap( address, - AccountCapabilityStorageDomain, + common.StorageDomainAccountCapability.Identifier(), true, ) @@ -3445,7 +3431,7 @@ func unrecordAccountCapabilityController( storageMap := inter.Storage().GetStorageMap( address, - AccountCapabilityStorageDomain, + common.StorageDomainAccountCapability.Identifier(), true, ) @@ -3464,7 +3450,7 @@ func getAccountCapabilityControllerIDsIterator( ) { storageMap := inter.Storage().GetStorageMap( address, - AccountCapabilityStorageDomain, + common.StorageDomainAccountCapability.Identifier(), false, ) if storageMap == nil { @@ -4427,10 +4413,6 @@ func newAccountCapabilityControllerDeleteFunction( } } -// CapabilityControllerTagStorageDomain is the storage domain which stores -// capability controller tags by capability ID -const CapabilityControllerTagStorageDomain = "cap_tag" - func getCapabilityControllerTag( inter *interpreter.Interpreter, address common.Address, @@ -4439,7 +4421,7 @@ func getCapabilityControllerTag( value := inter.ReadStored( address, - CapabilityControllerTagStorageDomain, + common.StorageDomainCapabilityControllerTag.Identifier(), interpreter.Uint64StorageMapKey(capabilityID), ) if value == nil { @@ -4501,7 +4483,7 @@ func setCapabilityControllerTag( inter.WriteStored( address, - CapabilityControllerTagStorageDomain, + common.StorageDomainCapabilityControllerTag.Identifier(), interpreter.Uint64StorageMapKey(capabilityID), value, ) diff --git a/stdlib/account_test.go b/stdlib/account_test.go index 2429663cc2..2218655c86 100644 --- a/stdlib/account_test.go +++ b/stdlib/account_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/sema" . "github.com/onflow/cadence/test_utils/common_utils" ) @@ -33,11 +34,11 @@ func TestSemaCheckPathLiteralForInternalStorageDomains(t *testing.T) { t.Parallel() internalStorageDomains := []string{ - InboxStorageDomain, - AccountCapabilityStorageDomain, - CapabilityControllerStorageDomain, - PathCapabilityStorageDomain, - CapabilityControllerTagStorageDomain, + common.StorageDomainInbox.Identifier(), + common.StorageDomainAccountCapability.Identifier(), + common.StorageDomainCapabilityController.Identifier(), + common.StorageDomainPathCapability.Identifier(), + common.StorageDomainCapabilityControllerTag.Identifier(), } test := func(domain string) { From 2f8c2980677c2def8ec2174a5ccdc2d3ac841346 Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Thu, 7 Nov 2024 20:27:30 -0600 Subject: [PATCH 2/2] Refactor to get storage map with common.StorageDomain Currently, we use domain string to get StorageMap. Even though we use pre-defined domain strings, this can still be error-prone. This commit modifies GetStorageMap() to use common.StorageDomain instead of string. --- cmd/decode-state-values/main.go | 2 +- common/pathdomain.go | 15 ++++++++++ common/storagedomain.go | 33 +++++++++++++++------ interpreter/account_test.go | 2 +- interpreter/interpreter.go | 20 ++++++------- interpreter/misc_test.go | 2 +- interpreter/storage.go | 25 +++++++++++++--- interpreter/storage_test.go | 2 +- interpreter/stringatreevalue_test.go | 2 +- interpreter/value_storage_reference.go | 2 +- interpreter/value_test.go | 2 +- runtime/capabilitycontrollers_test.go | 8 +++--- runtime/contract_test.go | 2 +- runtime/environment.go | 2 +- runtime/ft_test.go | 4 +-- runtime/runtime.go | 2 +- runtime/storage.go | 27 +++++++++-------- runtime/storage_test.go | 13 +++++---- stdlib/account.go | 40 +++++++++++++------------- 19 files changed, 128 insertions(+), 77 deletions(-) diff --git a/cmd/decode-state-values/main.go b/cmd/decode-state-values/main.go index 70ecd7e7b4..99eeda2497 100644 --- a/cmd/decode-state-values/main.go +++ b/cmd/decode-state-values/main.go @@ -234,7 +234,7 @@ type interpreterStorage struct { var _ interpreter.Storage = &interpreterStorage{} -func (i interpreterStorage) GetStorageMap(_ common.Address, _ string, _ bool) *interpreter.StorageMap { +func (i interpreterStorage) GetStorageMap(_ common.Address, _ common.StorageDomain, _ bool) *interpreter.StorageMap { panic("unexpected GetStorageMap call") } diff --git a/common/pathdomain.go b/common/pathdomain.go index 943301dc4c..ea9e5609c8 100644 --- a/common/pathdomain.go +++ b/common/pathdomain.go @@ -70,3 +70,18 @@ func (i PathDomain) Identifier() string { panic(errors.NewUnreachableError()) } + +func (i PathDomain) StorageDomain() StorageDomain { + switch i { + case PathDomainStorage: + return StorageDomainPathStorage + + case PathDomainPrivate: + return StorageDomainPathPrivate + + case PathDomainPublic: + return StorageDomainPathPublic + } + + panic(errors.NewUnreachableError()) +} diff --git a/common/storagedomain.go b/common/storagedomain.go index 741756c44e..108196cdac 100644 --- a/common/storagedomain.go +++ b/common/storagedomain.go @@ -19,6 +19,8 @@ package common import ( + "fmt" + "github.com/onflow/cadence/errors" ) @@ -27,11 +29,11 @@ type StorageDomain uint8 const ( StorageDomainUnknown StorageDomain = iota - StorageDomainStorage + StorageDomainPathStorage - StorageDomainPrivate + StorageDomainPathPrivate - StorageDomainPublic + StorageDomainPathPublic StorageDomainContract @@ -55,9 +57,9 @@ const ( ) var AllStorageDomains = []StorageDomain{ - StorageDomainStorage, - StorageDomainPrivate, - StorageDomainPublic, + StorageDomainPathStorage, + StorageDomainPathPrivate, + StorageDomainPathPublic, StorageDomainContract, StorageDomainInbox, StorageDomainCapabilityController, @@ -68,10 +70,14 @@ var AllStorageDomains = []StorageDomain{ var AllStorageDomainsByIdentifier = map[string]StorageDomain{} +var allStorageDomainsSet = map[StorageDomain]struct{}{} + func init() { for _, domain := range AllStorageDomains { identifier := domain.Identifier() AllStorageDomainsByIdentifier[identifier] = domain + + allStorageDomainsSet[domain] = struct{}{} } } @@ -83,15 +89,24 @@ func StorageDomainFromIdentifier(domain string) (StorageDomain, bool) { return result, true } +func StorageDomainFromUint64(i uint64) (StorageDomain, error) { + d := StorageDomain(i) + _, exists := allStorageDomainsSet[d] + if !exists { + return StorageDomainUnknown, fmt.Errorf("failed to convert %d to StorageDomain", i) + } + return d, nil +} + func (d StorageDomain) Identifier() string { switch d { - case StorageDomainStorage: + case StorageDomainPathStorage: return PathDomainStorage.Identifier() - case StorageDomainPrivate: + case StorageDomainPathPrivate: return PathDomainPrivate.Identifier() - case StorageDomainPublic: + case StorageDomainPathPublic: return PathDomainPublic.Identifier() case StorageDomainContract: diff --git a/interpreter/account_test.go b/interpreter/account_test.go index a003477a56..595d0d7352 100644 --- a/interpreter/account_test.go +++ b/interpreter/account_test.go @@ -492,7 +492,7 @@ func testAccountWithErrorHandler( } storageKey := storageKey{ address: storageMapKey.Address, - domain: storageMapKey.Key, + domain: storageMapKey.Domain.Identifier(), key: key, } accountValues[storageKey] = value diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 1c6533ba76..efaa58ba94 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -237,7 +237,7 @@ func (c TypeCodes) Merge(codes TypeCodes) { type Storage interface { atree.SlabStorage - GetStorageMap(address common.Address, domain string, createIfNotExists bool) *StorageMap + GetStorageMap(address common.Address, domain common.StorageDomain, createIfNotExists bool) *StorageMap CheckHealth() error } @@ -2678,7 +2678,7 @@ func (interpreter *Interpreter) NewSubInterpreter( func (interpreter *Interpreter) StoredValueExists( storageAddress common.Address, - domain string, + domain common.StorageDomain, identifier StorageMapKey, ) bool { accountStorage := interpreter.Storage().GetStorageMap(storageAddress, domain, false) @@ -2690,7 +2690,7 @@ func (interpreter *Interpreter) StoredValueExists( func (interpreter *Interpreter) ReadStored( storageAddress common.Address, - domain string, + domain common.StorageDomain, identifier StorageMapKey, ) Value { accountStorage := interpreter.Storage().GetStorageMap(storageAddress, domain, false) @@ -2702,7 +2702,7 @@ func (interpreter *Interpreter) ReadStored( func (interpreter *Interpreter) WriteStored( storageAddress common.Address, - domain string, + domain common.StorageDomain, key StorageMapKey, value Value, ) (existed bool) { @@ -4069,7 +4069,7 @@ func (interpreter *Interpreter) IsSubTypeOfSemaType(staticSubType StaticType, su } func (interpreter *Interpreter) domainPaths(address common.Address, domain common.PathDomain) []Value { - storageMap := interpreter.Storage().GetStorageMap(address, domain.Identifier(), false) + storageMap := interpreter.Storage().GetStorageMap(address, domain.StorageDomain(), false) if storageMap == nil { return []Value{} } @@ -4164,7 +4164,7 @@ func (interpreter *Interpreter) newStorageIterationFunction( parameterTypes := fnType.ParameterTypes() returnType := fnType.ReturnTypeAnnotation.Type - storageMap := config.Storage.GetStorageMap(address, domain.Identifier(), false) + storageMap := config.Storage.GetStorageMap(address, domain.StorageDomain(), false) if storageMap == nil { // if nothing is stored, no iteration is required return Void @@ -4327,7 +4327,7 @@ func (interpreter *Interpreter) authAccountSaveFunction( panic(errors.NewUnreachableError()) } - domain := path.Domain.Identifier() + domain := path.Domain.StorageDomain() identifier := path.Identifier // Prevent an overwrite @@ -4390,7 +4390,7 @@ func (interpreter *Interpreter) authAccountTypeFunction( panic(errors.NewUnreachableError()) } - domain := path.Domain.Identifier() + domain := path.Domain.StorageDomain() identifier := path.Identifier storageMapKey := StringStorageMapKey(identifier) @@ -4448,7 +4448,7 @@ func (interpreter *Interpreter) authAccountReadFunction( panic(errors.NewUnreachableError()) } - domain := path.Domain.Identifier() + domain := path.Domain.StorageDomain() identifier := path.Identifier storageMapKey := StringStorageMapKey(identifier) @@ -4589,7 +4589,7 @@ func (interpreter *Interpreter) authAccountCheckFunction( panic(errors.NewUnreachableError()) } - domain := path.Domain.Identifier() + domain := path.Domain.StorageDomain() identifier := path.Identifier storageMapKey := StringStorageMapKey(identifier) diff --git a/interpreter/misc_test.go b/interpreter/misc_test.go index 8ec9d08532..de93c0e6f8 100644 --- a/interpreter/misc_test.go +++ b/interpreter/misc_test.go @@ -5349,7 +5349,7 @@ func TestInterpretReferenceFailableDowncasting(t *testing.T) { true, // r is standalone. ) - domain := storagePath.Domain.Identifier() + domain := storagePath.Domain.StorageDomain() storageMap := storage.GetStorageMap(storageAddress, domain, true) storageMapKey := interpreter.StringStorageMapKey(storagePath.Identifier) storageMap.WriteValue(inter, storageMapKey, r) diff --git a/interpreter/storage.go b/interpreter/storage.go index c5bdcb0b87..d5e53c9a4e 100644 --- a/interpreter/storage.go +++ b/interpreter/storage.go @@ -101,6 +101,23 @@ func ConvertStoredValue(gauge common.MemoryGauge, value atree.Value) (Value, err } } +type StorageDomainKey struct { + Domain common.StorageDomain + Address common.Address +} + +func NewStorageDomainKey( + memoryGauge common.MemoryGauge, + address common.Address, + domain common.StorageDomain, +) StorageDomainKey { + common.UseMemory(memoryGauge, common.StorageKeyMemoryUsage) + return StorageDomainKey{ + Address: address, + Domain: domain, + } +} + type StorageKey struct { Key string Address common.Address @@ -130,7 +147,7 @@ func (k StorageKey) IsLess(o StorageKey) bool { // InMemoryStorage type InMemoryStorage struct { *atree.BasicSlabStorage - StorageMaps map[StorageKey]*StorageMap + StorageMaps map[StorageDomainKey]*StorageMap memoryGauge common.MemoryGauge } @@ -158,19 +175,19 @@ func NewInMemoryStorage(memoryGauge common.MemoryGauge) InMemoryStorage { return InMemoryStorage{ BasicSlabStorage: slabStorage, - StorageMaps: make(map[StorageKey]*StorageMap), + StorageMaps: make(map[StorageDomainKey]*StorageMap), memoryGauge: memoryGauge, } } func (i InMemoryStorage) GetStorageMap( address common.Address, - domain string, + domain common.StorageDomain, createIfNotExists bool, ) ( storageMap *StorageMap, ) { - key := NewStorageKey(i.memoryGauge, address, domain) + key := NewStorageDomainKey(i.memoryGauge, address, domain) storageMap = i.StorageMaps[key] if storageMap == nil && createIfNotExists { storageMap = NewStorageMap(i.memoryGauge, i, atree.Address(address)) diff --git a/interpreter/storage_test.go b/interpreter/storage_test.go index 0693df2d9b..21e0f298d3 100644 --- a/interpreter/storage_test.go +++ b/interpreter/storage_test.go @@ -524,7 +524,7 @@ func TestStorageOverwriteAndRemove(t *testing.T) { const storageMapKey = StringStorageMapKey("test") - storageMap := storage.GetStorageMap(address, "storage", true) + storageMap := storage.GetStorageMap(address, common.StorageDomainPathStorage, true) storageMap.WriteValue(inter, storageMapKey, array1) // Overwriting delete any existing child slabs diff --git a/interpreter/stringatreevalue_test.go b/interpreter/stringatreevalue_test.go index f2e622a8a9..00fa5e988c 100644 --- a/interpreter/stringatreevalue_test.go +++ b/interpreter/stringatreevalue_test.go @@ -38,7 +38,7 @@ func TestLargeStringAtreeValueInSeparateSlab(t *testing.T) { storageMap := storage.GetStorageMap( common.MustBytesToAddress([]byte{0x1}), - common.PathDomainStorage.Identifier(), + common.PathDomainStorage.StorageDomain(), true, ) diff --git a/interpreter/value_storage_reference.go b/interpreter/value_storage_reference.go index 741fd4ae18..906edc5819 100644 --- a/interpreter/value_storage_reference.go +++ b/interpreter/value_storage_reference.go @@ -123,7 +123,7 @@ func (*StorageReferenceValue) IsImportable(_ *Interpreter, _ LocationRange) bool func (v *StorageReferenceValue) dereference(interpreter *Interpreter, locationRange LocationRange) (*Value, error) { address := v.TargetStorageAddress - domain := v.TargetPath.Domain.Identifier() + domain := v.TargetPath.Domain.StorageDomain() identifier := v.TargetPath.Identifier storageMapKey := StringStorageMapKey(identifier) diff --git a/interpreter/value_test.go b/interpreter/value_test.go index 295f5a4346..c0bbe137d1 100644 --- a/interpreter/value_test.go +++ b/interpreter/value_test.go @@ -3806,7 +3806,7 @@ func TestValue_ConformsToStaticType(t *testing.T) { ) require.NoError(t, err) - storageMap := storage.GetStorageMap(testAddress, "storage", true) + storageMap := storage.GetStorageMap(testAddress, common.StorageDomainPathStorage, true) storageMap.WriteValue(inter, StringStorageMapKey("test"), TrueValue) value := valueFactory(inter) diff --git a/runtime/capabilitycontrollers_test.go b/runtime/capabilitycontrollers_test.go index 0e3427d97c..2e396121df 100644 --- a/runtime/capabilitycontrollers_test.go +++ b/runtime/capabilitycontrollers_test.go @@ -3253,7 +3253,7 @@ func TestRuntimeCapabilityControllers(t *testing.T) { storageMap := storage.GetStorageMap( common.MustBytesToAddress([]byte{0x1}), - common.StorageDomainPathCapability.Identifier(), + common.StorageDomainPathCapability, false, ) require.Zero(t, storageMap.Count()) @@ -3842,7 +3842,7 @@ func TestRuntimeCapabilitiesGetBackwardCompatibility(t *testing.T) { publicStorageMap := storage.GetStorageMap( testAddress, - common.PathDomainPublic.Identifier(), + common.PathDomainPublic.StorageDomain(), true, ) @@ -3949,7 +3949,7 @@ func TestRuntimeCapabilitiesPublishBackwardCompatibility(t *testing.T) { publicStorageMap := storage.GetStorageMap( testAddress, - common.PathDomainStorage.Identifier(), + common.PathDomainStorage.StorageDomain(), true, ) @@ -4039,7 +4039,7 @@ func TestRuntimeCapabilitiesUnpublishBackwardCompatibility(t *testing.T) { publicStorageMap := storage.GetStorageMap( testAddress, - common.PathDomainPublic.Identifier(), + common.PathDomainPublic.StorageDomain(), true, ) diff --git a/runtime/contract_test.go b/runtime/contract_test.go index 7cd3e72d49..32dc0b12cd 100644 --- a/runtime/contract_test.go +++ b/runtime/contract_test.go @@ -223,7 +223,7 @@ func TestRuntimeContract(t *testing.T) { getContractValueExists := func() bool { storageMap := NewStorage(storage, nil). - GetStorageMap(signerAddress, common.StorageDomainContract.Identifier(), false) + GetStorageMap(signerAddress, common.StorageDomainContract, false) if storageMap == nil { return false } diff --git a/runtime/environment.go b/runtime/environment.go index ca6d48f01a..4c4d13b69a 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -1108,7 +1108,7 @@ func (e *interpreterEnvironment) loadContract( if addressLocation, ok := location.(common.AddressLocation); ok { storageMap := e.storage.GetStorageMap( addressLocation.Address, - common.StorageDomainContract.Identifier(), + common.StorageDomainContract, false, ) if storageMap != nil { diff --git a/runtime/ft_test.go b/runtime/ft_test.go index 44602a00f0..739920488f 100644 --- a/runtime/ft_test.go +++ b/runtime/ft_test.go @@ -1085,7 +1085,7 @@ func TestRuntimeBrokenFungibleTokenRecovery(t *testing.T) { contractStorage := storage.GetStorageMap( contractsAddress, - common.StorageDomainContract.Identifier(), + common.StorageDomainContract, true, ) contractStorage.SetValue( @@ -1120,7 +1120,7 @@ func TestRuntimeBrokenFungibleTokenRecovery(t *testing.T) { userStorage := storage.GetStorageMap( userAddress, - common.PathDomainStorage.Identifier(), + common.PathDomainStorage.StorageDomain(), true, ) const storagePathIdentifier = "exampleTokenVault" diff --git a/runtime/runtime.go b/runtime/runtime.go index e385f2c0c7..c6277c55ae 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -612,7 +612,7 @@ func (r *interpreterRuntime) ReadStored( pathValue := valueImporter{inter: inter}.importPathValue(path) - domain := pathValue.Domain.Identifier() + domain := pathValue.Domain.StorageDomain() identifier := pathValue.Identifier storageMapKey := interpreter.StringStorageMapKey(identifier) diff --git a/runtime/storage.go b/runtime/storage.go index 94cb9431b9..d1c3cd2495 100644 --- a/runtime/storage.go +++ b/runtime/storage.go @@ -34,8 +34,8 @@ import ( type Storage struct { *atree.PersistentSlabStorage - NewStorageMaps *orderedmap.OrderedMap[interpreter.StorageKey, atree.SlabIndex] - storageMaps map[interpreter.StorageKey]*interpreter.StorageMap + NewStorageMaps *orderedmap.OrderedMap[interpreter.StorageDomainKey, atree.SlabIndex] + storageMaps map[interpreter.StorageDomainKey]*interpreter.StorageMap contractUpdates *orderedmap.OrderedMap[interpreter.StorageKey, *interpreter.CompositeValue] Ledger atree.Ledger memoryGauge common.MemoryGauge @@ -76,7 +76,7 @@ func NewStorage(ledger atree.Ledger, memoryGauge common.MemoryGauge) *Storage { return &Storage{ Ledger: ledger, PersistentSlabStorage: persistentSlabStorage, - storageMaps: map[interpreter.StorageKey]*interpreter.StorageMap{}, + storageMaps: map[interpreter.StorageDomainKey]*interpreter.StorageMap{}, memoryGauge: memoryGauge, } } @@ -85,12 +85,12 @@ const storageIndexLength = 8 func (s *Storage) GetStorageMap( address common.Address, - domain string, + domain common.StorageDomain, createIfNotExists bool, ) ( storageMap *interpreter.StorageMap, ) { - key := interpreter.NewStorageKey(s.memoryGauge, address, domain) + key := interpreter.NewStorageDomainKey(s.memoryGauge, address, domain) storageMap = s.storageMaps[key] if storageMap == nil { @@ -100,7 +100,10 @@ func (s *Storage) GetStorageMap( var data []byte var err error errors.WrapPanic(func() { - data, err = s.Ledger.GetValue(key.Address[:], []byte(key.Key)) + data, err = s.Ledger.GetValue( + key.Address[:], + []byte(key.Domain.Identifier()), + ) }) if err != nil { panic(interpreter.WrappedExternalError(err)) @@ -112,7 +115,7 @@ func (s *Storage) GetStorageMap( // TODO: add dedicated error type? panic(errors.NewUnexpectedError( "invalid storage index for storage map with domain '%s': expected length %d, got %d", - domain, storageIndexLength, dataLength, + domain.Identifier(), storageIndexLength, dataLength, )) } @@ -143,15 +146,15 @@ func (s *Storage) loadExistingStorageMap(address atree.Address, slabIndex atree. return interpreter.NewStorageMapWithRootID(s, slabID) } -func (s *Storage) StoreNewStorageMap(address atree.Address, domain string) *interpreter.StorageMap { +func (s *Storage) StoreNewStorageMap(address atree.Address, domain common.StorageDomain) *interpreter.StorageMap { storageMap := interpreter.NewStorageMap(s.memoryGauge, s, address) slabIndex := storageMap.SlabID().Index() - storageKey := interpreter.NewStorageKey(s.memoryGauge, common.Address(address), domain) + storageKey := interpreter.NewStorageDomainKey(s.memoryGauge, common.Address(address), domain) if s.NewStorageMaps == nil { - s.NewStorageMaps = &orderedmap.OrderedMap[interpreter.StorageKey, atree.SlabIndex]{} + s.NewStorageMaps = &orderedmap.OrderedMap[interpreter.StorageDomainKey, atree.SlabIndex]{} } s.NewStorageMaps.Set(storageKey, slabIndex) @@ -214,7 +217,7 @@ func (s *Storage) writeContractUpdate( key interpreter.StorageKey, contractValue *interpreter.CompositeValue, ) { - storageMap := s.GetStorageMap(key.Address, common.StorageDomainContract.Identifier(), true) + storageMap := s.GetStorageMap(key.Address, common.StorageDomainContract, true) // NOTE: pass nil instead of allocating a Value-typed interface that points to nil storageMapKey := interpreter.StringStorageMapKey(key.Key) if contractValue == nil { @@ -277,7 +280,7 @@ func (s *Storage) commitNewStorageMaps() error { errors.WrapPanic(func() { err = s.Ledger.SetValue( pair.Key.Address[:], - []byte(pair.Key.Key), + []byte(pair.Key.Domain.Identifier()), pair.Value[:], ) }) diff --git a/runtime/storage_test.go b/runtime/storage_test.go index 5a7bb60bc9..aedd492c11 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -53,22 +53,23 @@ func withWritesToStorage( inter := NewTestInterpreter(tb) - address := common.MustBytesToAddress([]byte{0x1}) - for i := 0; i < count; i++ { randomIndex := random.Uint32() - storageKey := interpreter.StorageKey{ + var address common.Address + random.Read(address[:]) + + storageKey := interpreter.StorageDomainKey{ Address: address, - Key: fmt.Sprintf("%d", randomIndex), + Domain: common.StorageDomainPathStorage, } var slabIndex atree.SlabIndex binary.BigEndian.PutUint32(slabIndex[:], randomIndex) if storage.NewStorageMaps == nil { - storage.NewStorageMaps = &orderedmap.OrderedMap[interpreter.StorageKey, atree.SlabIndex]{} + storage.NewStorageMaps = &orderedmap.OrderedMap[interpreter.StorageDomainKey, atree.SlabIndex]{} } storage.NewStorageMaps.Set(storageKey, slabIndex) } @@ -3100,7 +3101,7 @@ func TestRuntimeStorageInternalAccess(t *testing.T) { }) require.NoError(t, err) - storageMap := storage.GetStorageMap(address, common.PathDomainStorage.Identifier(), false) + storageMap := storage.GetStorageMap(address, common.PathDomainStorage.StorageDomain(), false) require.NotNil(t, storageMap) // Read first diff --git a/stdlib/account.go b/stdlib/account.go index c1a303a2c5..201affa202 100644 --- a/stdlib/account.go +++ b/stdlib/account.go @@ -994,7 +994,7 @@ func newAccountInboxPublishFunction( inter.WriteStored( provider, - common.StorageDomainInbox.Identifier(), + common.StorageDomainInbox, storageMapKey, publishedValue, ) @@ -1027,7 +1027,7 @@ func newAccountInboxUnpublishFunction( storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) - readValue := inter.ReadStored(provider, common.StorageDomainInbox.Identifier(), storageMapKey) + readValue := inter.ReadStored(provider, common.StorageDomainInbox, storageMapKey) if readValue == nil { return interpreter.Nil } @@ -1063,7 +1063,7 @@ func newAccountInboxUnpublishFunction( inter.WriteStored( provider, - common.StorageDomainInbox.Identifier(), + common.StorageDomainInbox, storageMapKey, nil, ) @@ -1112,7 +1112,7 @@ func newAccountInboxClaimFunction( storageMapKey := interpreter.StringStorageMapKey(nameValue.Str) - readValue := inter.ReadStored(providerAddress, common.StorageDomainInbox.Identifier(), storageMapKey) + readValue := inter.ReadStored(providerAddress, common.StorageDomainInbox, storageMapKey) if readValue == nil { return interpreter.Nil } @@ -1153,7 +1153,7 @@ func newAccountInboxClaimFunction( inter.WriteStored( providerAddress, - common.StorageDomainInbox.Identifier(), + common.StorageDomainInbox, storageMapKey, nil, ) @@ -2992,7 +2992,7 @@ func storeCapabilityController( existed := inter.WriteStored( address, - common.StorageDomainCapabilityController.Identifier(), + common.StorageDomainCapabilityController, storageMapKey, controller, ) @@ -3011,7 +3011,7 @@ func removeCapabilityController( existed := inter.WriteStored( address, - common.StorageDomainCapabilityController.Identifier(), + common.StorageDomainCapabilityController, storageMapKey, nil, ) @@ -3039,7 +3039,7 @@ func getCapabilityController( readValue := inter.ReadStored( address, - common.StorageDomainCapabilityController.Identifier(), + common.StorageDomainCapabilityController, storageMapKey, ) if readValue == nil { @@ -3244,7 +3244,7 @@ func recordStorageCapabilityController( storageMap := inter.Storage().GetStorageMap( address, - common.StorageDomainPathCapability.Identifier(), + common.StorageDomainPathCapability, true, ) @@ -3286,7 +3286,7 @@ func getPathCapabilityIDSet( storageMap := inter.Storage().GetStorageMap( address, - common.StorageDomainPathCapability.Identifier(), + common.StorageDomainPathCapability, false, ) if storageMap == nil { @@ -3336,7 +3336,7 @@ func unrecordStorageCapabilityController( if capabilityIDSet.Count() == 0 { storageMap := inter.Storage().GetStorageMap( address, - common.StorageDomainPathCapability.Identifier(), + common.StorageDomainPathCapability, true, ) if storageMap == nil { @@ -3404,7 +3404,7 @@ func recordAccountCapabilityController( storageMap := inter.Storage().GetStorageMap( address, - common.StorageDomainAccountCapability.Identifier(), + common.StorageDomainAccountCapability, true, ) @@ -3431,7 +3431,7 @@ func unrecordAccountCapabilityController( storageMap := inter.Storage().GetStorageMap( address, - common.StorageDomainAccountCapability.Identifier(), + common.StorageDomainAccountCapability, true, ) @@ -3450,7 +3450,7 @@ func getAccountCapabilityControllerIDsIterator( ) { storageMap := inter.Storage().GetStorageMap( address, - common.StorageDomainAccountCapability.Identifier(), + common.StorageDomainAccountCapability, false, ) if storageMap == nil { @@ -3517,7 +3517,7 @@ func newAccountCapabilitiesPublishFunction( panic(errors.NewUnreachableError()) } - domain := pathValue.Domain.Identifier() + domain := pathValue.Domain.StorageDomain() identifier := pathValue.Identifier capabilityType, ok := capabilityValue.StaticType(inter).(*interpreter.CapabilityStaticType) @@ -3636,7 +3636,7 @@ func newAccountCapabilitiesUnpublishFunction( panic(errors.NewUnreachableError()) } - domain := pathValue.Domain.Identifier() + domain := pathValue.Domain.StorageDomain() identifier := pathValue.Identifier // Read/remove capability @@ -3910,7 +3910,7 @@ func newAccountCapabilitiesGetFunction( panic(errors.NewUnreachableError()) } - domain := pathValue.Domain.Identifier() + domain := pathValue.Domain.StorageDomain() identifier := pathValue.Identifier // Get borrow type type argument @@ -4095,7 +4095,7 @@ func newAccountCapabilitiesExistsFunction( panic(errors.NewUnreachableError()) } - domain := pathValue.Domain.Identifier() + domain := pathValue.Domain.StorageDomain() identifier := pathValue.Identifier // Read stored capability, if any @@ -4421,7 +4421,7 @@ func getCapabilityControllerTag( value := inter.ReadStored( address, - common.StorageDomainCapabilityControllerTag.Identifier(), + common.StorageDomainCapabilityControllerTag, interpreter.Uint64StorageMapKey(capabilityID), ) if value == nil { @@ -4483,7 +4483,7 @@ func setCapabilityControllerTag( inter.WriteStored( address, - common.StorageDomainCapabilityControllerTag.Identifier(), + common.StorageDomainCapabilityControllerTag, interpreter.Uint64StorageMapKey(capabilityID), value, )