Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mempool): data race in mempool prepare proposal handler (backport #21413) #21541

Merged
merged 7 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* (x/consensus) [#21493](https://github.com/cosmos/cosmos-sdk/pull/21493) Fix regression that prevented to upgrade to > v0.50.7 without consensus version params.
* (baseapp) [#21256](https://github.com/cosmos/cosmos-sdk/pull/21256) Halt height will not commit the block indicated, meaning that if halt-height is set to 10, only blocks until 9 (included) will be committed. This is to go back to the original behavior before a change was introduced in v0.50.0.
* (baseapp) [#21444](https://github.com/cosmos/cosmos-sdk/pull/21444) Follow-up, Return PreBlocker events in FinalizeBlockResponse.
* (baseapp) [#21413](https://github.com/cosmos/cosmos-sdk/pull/21413) Fix data race in sdk mempool.

## [v0.50.9](https://github.com/cosmos/cosmos-sdk/releases/tag/v0.50.9) - 2024-08-07

Expand Down
41 changes: 26 additions & 15 deletions baseapp/abci_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ type (
// DefaultProposalHandler defines the default ABCI PrepareProposal and
// ProcessProposal handlers.
DefaultProposalHandler struct {
mempool mempool.Mempool
mempool mempool.ExtMempool
txVerifier ProposalTxVerifier
txSelector TxSelector
signerExtAdapter mempool.SignerExtractionAdapter
}
)

func NewDefaultProposalHandler(mp mempool.Mempool, txVerifier ProposalTxVerifier) *DefaultProposalHandler {
func NewDefaultProposalHandler(mp mempool.ExtMempool, txVerifier ProposalTxVerifier) *DefaultProposalHandler {
return &DefaultProposalHandler{
mempool: mp,
txVerifier: txVerifier,
Expand Down Expand Up @@ -279,14 +279,18 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
}

iterator := h.mempool.Select(ctx, req.Txs)
selectedTxsSignersSeqs := make(map[string]uint64)
var selectedTxsNums int
for iterator != nil {
memTx := iterator.Tx()
var (
resError error
selectedTxsNums int
invalidTxs []sdk.Tx // invalid txs to be removed out of the loop to avoid dead lock
)
h.mempool.SelectBy(ctx, req.Txs, func(memTx sdk.Tx) bool {
signerData, err := h.signerExtAdapter.GetSigners(memTx)
if err != nil {
return nil, err
// propagate the error to the caller
resError = err
return false
}

// If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before
Expand All @@ -310,8 +314,7 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
txSignersSeqs[signer.Signer.String()] = signer.Sequence
}
if !shouldAdd {
iterator = iterator.Next()
continue
return true
}

// NOTE: Since transaction verification was already executed in CheckTx,
Expand All @@ -320,14 +323,11 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
// check again.
txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx)
if err != nil {
err := h.mempool.Remove(memTx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
}
invalidTxs = append(invalidTxs, memTx)
} else {
stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx, txBz)
if stop {
break
return false
}

txsLen := len(h.txSelector.SelectedTxs(ctx))
Expand All @@ -348,7 +348,18 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
selectedTxsNums = txsLen
}

iterator = iterator.Next()
return true
})

if resError != nil {
return nil, resError
}

for _, tx := range invalidTxs {
err := h.mempool.Remove(tx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
}
}

return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
Expand Down
6 changes: 3 additions & 3 deletions baseapp/baseapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ type BaseApp struct {
txDecoder sdk.TxDecoder // unmarshal []byte into sdk.Tx
txEncoder sdk.TxEncoder // marshal sdk.Tx into []byte

mempool mempool.Mempool // application side mempool
anteHandler sdk.AnteHandler // ante handler for fee and auth
postHandler sdk.PostHandler // post handler, optional
mempool mempool.ExtMempool // application side mempool
anteHandler sdk.AnteHandler // ante handler for fee and auth
postHandler sdk.PostHandler // post handler, optional

initChainer sdk.InitChainer // ABCI InitChain handler
preBlocker sdk.PreBlocker // logic to run before BeginBlocker
Expand Down
4 changes: 2 additions & 2 deletions baseapp/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func SetSnapshot(snapshotStore *snapshots.Store, opts snapshottypes.SnapshotOpti
}

// SetMempool sets the mempool on BaseApp.
func SetMempool(mempool mempool.Mempool) func(*BaseApp) {
func SetMempool(mempool mempool.ExtMempool) func(*BaseApp) {
return func(app *BaseApp) { app.SetMempool(mempool) }
}

Expand Down Expand Up @@ -319,7 +319,7 @@ func (app *BaseApp) SetQueryMultiStore(ms storetypes.MultiStore) {
}

// SetMempool sets the mempool for the BaseApp and is required for the app to start up.
func (app *BaseApp) SetMempool(mempool mempool.Mempool) {
func (app *BaseApp) SetMempool(mempool mempool.ExtMempool) {
if app.sealed {
panic("SetMempool() on sealed BaseApp")
}
Expand Down
10 changes: 8 additions & 2 deletions types/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ type Mempool interface {
Insert(context.Context, sdk.Tx) error

// Select returns an Iterator over the app-side mempool. If txs are specified,
// then they shall be incorporated into the Iterator. The Iterator must
// closed by the caller.
// then they shall be incorporated into the Iterator. The Iterator is not thread-safe to use.
Select(context.Context, [][]byte) Iterator

// CountTx returns the number of transactions currently in the mempool.
Expand All @@ -25,6 +24,13 @@ type Mempool interface {
Remove(sdk.Tx) error
}

type ExtMempool interface {
Mempool

// SelectBy use callback to iterate over the mempool, it's thread-safe to use.
SelectBy(context.Context, [][]byte, func(sdk.Tx) bool)
}

// Iterator defines an app-side mempool iterator interface that is as minimal as
// possible. The order of iteration is determined by the app-side mempool
// implementation.
Expand Down
11 changes: 6 additions & 5 deletions types/mempool/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"
)

var _ Mempool = (*NoOpMempool)(nil)
var _ ExtMempool = (*NoOpMempool)(nil)

// NoOpMempool defines a no-op mempool. Transactions are completely discarded and
// ignored when BaseApp interacts with the mempool.
Expand All @@ -16,7 +16,8 @@ var _ Mempool = (*NoOpMempool)(nil)
// is FIFO-ordered by default.
type NoOpMempool struct{}

func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) SelectBy(context.Context, [][]byte, func(sdk.Tx) bool) {}
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
21 changes: 18 additions & 3 deletions types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
)

var (
_ Mempool = (*PriorityNonceMempool[int64])(nil)
_ Iterator = (*PriorityNonceIterator[int64])(nil)
_ ExtMempool = (*PriorityNonceMempool[int64])(nil)
_ Iterator = (*PriorityNonceIterator[int64])(nil)
)

type (
Expand Down Expand Up @@ -350,9 +350,13 @@ func (i *PriorityNonceIterator[C]) Tx() sdk.Tx {
//
// NOTE: It is not safe to use this iterator while removing transactions from
// the underlying mempool.
func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterator {
func (mp *PriorityNonceMempool[C]) Select(ctx context.Context, txs [][]byte) Iterator {
mp.mtx.Lock()
defer mp.mtx.Unlock()
return mp.doSelect(ctx, txs)
}

func (mp *PriorityNonceMempool[C]) doSelect(_ context.Context, _ [][]byte) Iterator {
if mp.priorityIndex.Len() == 0 {
return nil
}
Expand All @@ -367,6 +371,17 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato
return iterator.iteratePriority()
}

// SelectBy will hold the mutex during the iteration, callback returns if continue.
func (mp *PriorityNonceMempool[C]) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) {
mp.mtx.Lock()
defer mp.mtx.Unlock()

iter := mp.doSelect(ctx, txs)
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

type reorderKey[C comparable] struct {
deleteKey txMeta[C]
insertKey txMeta[C]
Expand Down
85 changes: 85 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mempool_test

import (
"context"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -396,6 +398,89 @@ func (s *MempoolTestSuite) TestIterator() {
}
}

func (s *MempoolTestSuite) TestIteratorConcurrency() {
t := s.T()
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

tests := []struct {
txs []txSpec
fail bool
}{
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: 8, n: 2, a: sb},
},
},
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: math.MinInt64, n: 2, a: sb},
},
},
}

for i, tt := range tests {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
pool := mempool.DefaultPriorityMempool()

// create test txs and insert into mempool
for i, ts := range tt.txs {
tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}

// iterate through txs
stdCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

id := len(tt.txs)
for {
select {
case <-stdCtx.Done():
return
default:
id++
tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}
}
}()

var i int
pool.SelectBy(ctx, nil, func(memTx sdk.Tx) bool {
tx := memTx.(testTx)
if tx.id < len(tt.txs) {
require.Equal(t, tt.txs[tx.id].p, int(tx.priority))
require.Equal(t, tt.txs[tx.id].n, int(tx.nonce))
require.Equal(t, tt.txs[tx.id].a, tx.address)
i++
}
return i < len(tt.txs)
})
require.Equal(t, i, len(tt.txs))
cancel()
wg.Wait()
})
}
}

func (s *MempoolTestSuite) TestPriorityTies() {
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3)
Expand Down
21 changes: 18 additions & 3 deletions types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
)

var (
_ Mempool = (*SenderNonceMempool)(nil)
_ Iterator = (*senderNonceMempoolIterator)(nil)
_ ExtMempool = (*SenderNonceMempool)(nil)
_ Iterator = (*senderNonceMempoolIterator)(nil)
)

var DefaultMaxTx = -1
Expand Down Expand Up @@ -158,9 +158,13 @@ func (snm *SenderNonceMempool) Insert(_ context.Context, tx sdk.Tx) error {
//
// NOTE: It is not safe to use this iterator while removing transactions from
// the underlying mempool.
func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
func (snm *SenderNonceMempool) Select(ctx context.Context, txs [][]byte) Iterator {
snm.mtx.Lock()
defer snm.mtx.Unlock()
return snm.doSelect(ctx, txs)
}

func (snm *SenderNonceMempool) doSelect(_ context.Context, _ [][]byte) Iterator {
var senders []string

senderCursors := make(map[string]*skiplist.Element)
Expand Down Expand Up @@ -188,6 +192,17 @@ func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
return iter.Next()
}

// SelectBy will hold the mutex during the iteration, callback returns if continue.
func (snm *SenderNonceMempool) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) {
snm.mtx.Lock()
defer snm.mtx.Unlock()

iter := snm.doSelect(ctx, txs)
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

// CountTx returns the total count of txs in the mempool.
func (snm *SenderNonceMempool) CountTx() int {
snm.mtx.Lock()
Expand Down
Loading