Skip to content

Commit

Permalink
ensure account is not tested for V1 registers if V2 is disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
turbolent committed Nov 14, 2024
1 parent d579591 commit 99710ce
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 8 deletions.
201 changes: 200 additions & 1 deletion runtime/sharedstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,206 @@ import (
. "github.com/onflow/cadence/test_utils/runtime_utils"
)

func TestRuntimeSharedState(t *testing.T) {
func TestRuntimeSharedStateV1(t *testing.T) {

t.Parallel()

config := DefaultTestInterpreterConfig
config.StorageFormatV2Enabled = false
runtime := NewTestInterpreterRuntimeWithConfig(config)

signerAddress := common.MustBytesToAddress([]byte{0x1})

deploy1 := DeploymentTransaction("C1", []byte(`
access(all) contract C1 {
access(all) fun hello() {
log("Hello from C1!")
}
}
`))

deploy2 := DeploymentTransaction("C2", []byte(`
access(all) contract C2 {
access(all) fun hello() {
log("Hello from C2!")
}
}
`))

accountCodes := map[common.Location][]byte{}

var events []cadence.Event
var loggedMessages []string

var interpreterState *interpreter.SharedState

var ledgerReads []ownerKeyPair

ledger := NewTestLedger(
func(owner, key, value []byte) {
ledgerReads = append(
ledgerReads,
ownerKeyPair{
owner: owner,
key: key,
},
)
},
nil,
)

runtimeInterface := &TestRuntimeInterface{
Storage: ledger,
OnGetSigningAccounts: func() ([]Address, error) {
return []Address{signerAddress}, nil
},
OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error {
accountCodes[location] = code
return nil
},
OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) {
code = accountCodes[location]
return code, nil
},
OnRemoveAccountContractCode: func(location common.AddressLocation) error {
delete(accountCodes, location)
return nil
},
OnResolveLocation: MultipleIdentifierLocationResolver,
OnProgramLog: func(message string) {
loggedMessages = append(loggedMessages, message)
},
OnEmitEvent: func(event cadence.Event) error {
events = append(events, event)
return nil
},
OnSetInterpreterSharedState: func(state *interpreter.SharedState) {
interpreterState = state
},
OnGetInterpreterSharedState: func() *interpreter.SharedState {
return interpreterState
},
}

environment := NewBaseInterpreterEnvironment(config)

nextTransactionLocation := NewTransactionLocationGenerator()

// Deploy contracts

for _, source := range [][]byte{
deploy1,
deploy2,
} {
err := runtime.ExecuteTransaction(
Script{
Source: source,
},
Context{
Interface: runtimeInterface,
Location: nextTransactionLocation(),
Environment: environment,
},
)
require.NoError(t, err)
}

assert.NotEmpty(t, accountCodes)

// Call C1.hello using transaction

loggedMessages = nil

err := runtime.ExecuteTransaction(
Script{
Source: []byte(`
import C1 from 0x1
transaction {
prepare(signer: &Account) {
C1.hello()
}
}
`),
Arguments: nil,
},
Context{
Interface: runtimeInterface,
Location: nextTransactionLocation(),
Environment: environment,
},
)
require.NoError(t, err)

assert.Equal(t, []string{`"Hello from C1!"`}, loggedMessages)

// Call C1.hello manually

loggedMessages = nil

_, err = runtime.InvokeContractFunction(
common.AddressLocation{
Address: signerAddress,
Name: "C1",
},
"hello",
nil,
nil,
Context{
Interface: runtimeInterface,
Location: nextTransactionLocation(),
Environment: environment,
},
)
require.NoError(t, err)

assert.Equal(t, []string{`"Hello from C1!"`}, loggedMessages)

// Call C2.hello manually

loggedMessages = nil

_, err = runtime.InvokeContractFunction(
common.AddressLocation{
Address: signerAddress,
Name: "C2",
},
"hello",
nil,
nil,
Context{
Interface: runtimeInterface,
Location: nextTransactionLocation(),
Environment: environment,
},
)
require.NoError(t, err)

assert.Equal(t, []string{`"Hello from C2!"`}, loggedMessages)

// Assert shared state was used,
// i.e. data was not re-read

require.Equal(t,
[]ownerKeyPair{
{
owner: signerAddress[:],
key: []byte(common.StorageDomainContract.Identifier()),
},
{
owner: signerAddress[:],
key: []byte(common.StorageDomainContract.Identifier()),
},
{
owner: signerAddress[:],
key: []byte{'$', 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2},
},
},
ledgerReads,
)
}

func TestRuntimeSharedStateV2(t *testing.T) {

t.Parallel()

Expand Down
17 changes: 10 additions & 7 deletions runtime/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ func (s *Storage) GetDomainStorageMap(
}
}()

if s.IsV1Account(address) ||
!s.Config.StorageFormatV2Enabled {
if !s.Config.StorageFormatV2Enabled || s.IsV1Account(address) {

domainStorageMap = s.AccountStorageV1.GetDomainStorageMap(
address,
Expand All @@ -169,7 +168,7 @@ func (s *Storage) GetDomainStorageMap(
// mark the account as in storage format v1.

if domainStorageMap != nil {
s.v1Accounts.Set(address, true)
s.setIsV1Account(address, true)
}

} else {
Expand Down Expand Up @@ -214,10 +213,7 @@ func (s *Storage) IsV1Account(address common.Address) (isV1 bool) {
// Cache result

defer func() {
if s.v1Accounts == nil {
s.v1Accounts = &orderedmap.OrderedMap[common.Address, bool]{}
}
s.v1Accounts.Set(address, isV1)
s.setIsV1Account(address, isV1)
}()

// Check if a storage map register exists for any of the domains.
Expand All @@ -239,6 +235,13 @@ func (s *Storage) IsV1Account(address common.Address) (isV1 bool) {
return false
}

func (s *Storage) setIsV1Account(address common.Address, isV1 bool) {
if s.v1Accounts == nil {
s.v1Accounts = &orderedmap.OrderedMap[common.Address, bool]{}
}
s.v1Accounts.Set(address, isV1)
}

// getSlabIndexFromRegisterValue returns register value as atree.SlabIndex.
// This function returns error if
// - underlying ledger panics, or
Expand Down

0 comments on commit 99710ce

Please sign in to comment.