From 99710ce200548e4a94be57cda10ec4ffb7e63525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Wed, 13 Nov 2024 17:18:21 -0800 Subject: [PATCH] ensure account is not tested for V1 registers if V2 is disabled --- runtime/sharedstate_test.go | 201 +++++++++++++++++++++++++++++++++++- runtime/storage.go | 17 +-- 2 files changed, 210 insertions(+), 8 deletions(-) diff --git a/runtime/sharedstate_test.go b/runtime/sharedstate_test.go index 3bdbce758..b24051fd2 100644 --- a/runtime/sharedstate_test.go +++ b/runtime/sharedstate_test.go @@ -31,7 +31,206 @@ import ( . "github.com/onflow/cadence/test_utils/runtime_utils" ) -func TestRuntimeSharedState(t *testing.T) { +func TestRuntimeSharedStateV1(t *testing.T) { + + t.Parallel() + + config := DefaultTestInterpreterConfig + config.StorageFormatV2Enabled = false + runtime := NewTestInterpreterRuntimeWithConfig(config) + + signerAddress := common.MustBytesToAddress([]byte{0x1}) + + deploy1 := DeploymentTransaction("C1", []byte(` + access(all) contract C1 { + access(all) fun hello() { + log("Hello from C1!") + } + } + `)) + + deploy2 := DeploymentTransaction("C2", []byte(` + access(all) contract C2 { + access(all) fun hello() { + log("Hello from C2!") + } + } + `)) + + accountCodes := map[common.Location][]byte{} + + var events []cadence.Event + var loggedMessages []string + + var interpreterState *interpreter.SharedState + + var ledgerReads []ownerKeyPair + + ledger := NewTestLedger( + func(owner, key, value []byte) { + ledgerReads = append( + ledgerReads, + ownerKeyPair{ + owner: owner, + key: key, + }, + ) + }, + nil, + ) + + runtimeInterface := &TestRuntimeInterface{ + Storage: ledger, + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{signerAddress}, nil + }, + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnRemoveAccountContractCode: func(location common.AddressLocation) error { + delete(accountCodes, location) + return nil + }, + OnResolveLocation: MultipleIdentifierLocationResolver, + OnProgramLog: func(message string) { + loggedMessages = append(loggedMessages, message) + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnSetInterpreterSharedState: func(state *interpreter.SharedState) { + interpreterState = state + }, + OnGetInterpreterSharedState: func() *interpreter.SharedState { + return interpreterState + }, + } + + environment := NewBaseInterpreterEnvironment(config) + + nextTransactionLocation := NewTransactionLocationGenerator() + + // Deploy contracts + + for _, source := range [][]byte{ + deploy1, + deploy2, + } { + err := runtime.ExecuteTransaction( + Script{ + Source: source, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + Environment: environment, + }, + ) + require.NoError(t, err) + } + + assert.NotEmpty(t, accountCodes) + + // Call C1.hello using transaction + + loggedMessages = nil + + err := runtime.ExecuteTransaction( + Script{ + Source: []byte(` + import C1 from 0x1 + + transaction { + prepare(signer: &Account) { + C1.hello() + } + } + `), + Arguments: nil, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + Environment: environment, + }, + ) + require.NoError(t, err) + + assert.Equal(t, []string{`"Hello from C1!"`}, loggedMessages) + + // Call C1.hello manually + + loggedMessages = nil + + _, err = runtime.InvokeContractFunction( + common.AddressLocation{ + Address: signerAddress, + Name: "C1", + }, + "hello", + nil, + nil, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + Environment: environment, + }, + ) + require.NoError(t, err) + + assert.Equal(t, []string{`"Hello from C1!"`}, loggedMessages) + + // Call C2.hello manually + + loggedMessages = nil + + _, err = runtime.InvokeContractFunction( + common.AddressLocation{ + Address: signerAddress, + Name: "C2", + }, + "hello", + nil, + nil, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + Environment: environment, + }, + ) + require.NoError(t, err) + + assert.Equal(t, []string{`"Hello from C2!"`}, loggedMessages) + + // Assert shared state was used, + // i.e. data was not re-read + + require.Equal(t, + []ownerKeyPair{ + { + owner: signerAddress[:], + key: []byte(common.StorageDomainContract.Identifier()), + }, + { + owner: signerAddress[:], + key: []byte(common.StorageDomainContract.Identifier()), + }, + { + owner: signerAddress[:], + key: []byte{'$', 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2}, + }, + }, + ledgerReads, + ) +} + +func TestRuntimeSharedStateV2(t *testing.T) { t.Parallel() diff --git a/runtime/storage.go b/runtime/storage.go index d836fba47..94401c519 100644 --- a/runtime/storage.go +++ b/runtime/storage.go @@ -155,8 +155,7 @@ func (s *Storage) GetDomainStorageMap( } }() - if s.IsV1Account(address) || - !s.Config.StorageFormatV2Enabled { + if !s.Config.StorageFormatV2Enabled || s.IsV1Account(address) { domainStorageMap = s.AccountStorageV1.GetDomainStorageMap( address, @@ -169,7 +168,7 @@ func (s *Storage) GetDomainStorageMap( // mark the account as in storage format v1. if domainStorageMap != nil { - s.v1Accounts.Set(address, true) + s.setIsV1Account(address, true) } } else { @@ -214,10 +213,7 @@ func (s *Storage) IsV1Account(address common.Address) (isV1 bool) { // Cache result defer func() { - if s.v1Accounts == nil { - s.v1Accounts = &orderedmap.OrderedMap[common.Address, bool]{} - } - s.v1Accounts.Set(address, isV1) + s.setIsV1Account(address, isV1) }() // Check if a storage map register exists for any of the domains. @@ -239,6 +235,13 @@ func (s *Storage) IsV1Account(address common.Address) (isV1 bool) { return false } +func (s *Storage) setIsV1Account(address common.Address, isV1 bool) { + if s.v1Accounts == nil { + s.v1Accounts = &orderedmap.OrderedMap[common.Address, bool]{} + } + s.v1Accounts.Set(address, isV1) +} + // getSlabIndexFromRegisterValue returns register value as atree.SlabIndex. // This function returns error if // - underlying ledger panics, or