From 1084d9986d6379cb0a0532b31b57985471bb20b3 Mon Sep 17 00:00:00 2001 From: Dusan Maksimovic <94966669+dusan-maksimovic@users.noreply.github.com> Date: Mon, 3 Jul 2023 13:28:14 +0200 Subject: [PATCH] EVM-716 Check if the error is produced for duplicate TXs Co-authored-by: Igor Crevar --- txpool/account.go | 99 +++- txpool/event_manager.go | 2 +- txpool/event_subscription.go | 4 +- txpool/queue_account.go | 11 +- txpool/slot_gauge.go | 5 + txpool/txpool.go | 152 +++-- txpool/txpool_test.go | 1067 +++++++++++++++++++++++++--------- 7 files changed, 975 insertions(+), 365 deletions(-) diff --git a/txpool/account.go b/txpool/account.go index fc65791bbf..a18aadc68c 100644 --- a/txpool/account.go +++ b/txpool/account.go @@ -12,8 +12,7 @@ import ( type accountsMap struct { sync.Map - count uint64 - + count uint64 maxEnqueuedLimit uint64 } @@ -22,6 +21,7 @@ func (m *accountsMap) initOnce(addr types.Address, nonce uint64) *account { a, loaded := m.LoadOrStore(addr, &account{ enqueued: newAccountQueue(), promoted: newAccountQueue(), + nonceToTx: newNonceToTxLookup(), maxEnqueued: m.maxEnqueuedLimit, nextNonce: nonce, }) @@ -136,6 +136,43 @@ func (m *accountsMap) allTxs(includeEnqueued bool) ( return } +type nonceToTxLookup struct { + mapping map[uint64]*types.Transaction + mutex sync.Mutex +} + +func newNonceToTxLookup() *nonceToTxLookup { + return &nonceToTxLookup{ + mapping: make(map[uint64]*types.Transaction), + } +} + +func (m *nonceToTxLookup) lock() { + m.mutex.Lock() +} + +func (m *nonceToTxLookup) unlock() { + m.mutex.Unlock() +} + +func (m *nonceToTxLookup) get(nonce uint64) *types.Transaction { + return m.mapping[nonce] +} + +func (m *nonceToTxLookup) set(tx *types.Transaction) { + m.mapping[tx.Nonce] = tx +} + +func (m *nonceToTxLookup) reset() { + m.mapping = make(map[uint64]*types.Transaction) +} + +func (m *nonceToTxLookup) remove(txs ...*types.Transaction) { + for _, tx := range txs { + delete(m.mapping, tx.Nonce) + } +} + // An account is the core structure for processing // transactions from a specific address. The nextNonce // field is what separates the enqueued from promoted transactions: @@ -147,10 +184,13 @@ func (m *accountsMap) allTxs(includeEnqueued bool) ( // a promoteRequest is signaled for this account // indicating the account's enqueued transaction(s) // are ready to be moved to the promoted queue. +// lock order is important! promoted.lock(true), enqueued.lock(true), nonceToTx.lock() type account struct { enqueued, promoted *accountQueue - nextNonce uint64 - demotions uint64 + nonceToTx *nonceToTxLookup + + nextNonce uint64 + demotions uint64 // the number of consecutive blocks that don't contain account's transaction skips uint64 @@ -192,21 +232,27 @@ func (a *account) reset(nonce uint64, promoteCh chan<- promoteRequest) ( prunedEnqueued []*types.Transaction, ) { a.promoted.lock(true) - defer a.promoted.unlock() + a.enqueued.lock(true) + a.nonceToTx.lock() + + defer func() { + a.nonceToTx.unlock() + a.enqueued.unlock() + a.promoted.unlock() + }() // prune the promoted txs prunedPromoted = a.promoted.prune(nonce) + a.nonceToTx.remove(prunedPromoted...) if nonce <= a.getNonce() { // only the promoted queue needed pruning return } - a.enqueued.lock(true) - defer a.enqueued.unlock() - // prune the enqueued txs prunedEnqueued = a.enqueued.prune(nonce) + a.nonceToTx.remove(prunedEnqueued...) // update nonce expected for this account a.setNonce(nonce) @@ -222,24 +268,31 @@ func (a *account) reset(nonce uint64, promoteCh chan<- promoteRequest) ( return } -// enqueue attempts tp push the transaction onto the enqueued queue. -func (a *account) enqueue(tx *types.Transaction) error { - a.enqueued.lock(true) - defer a.enqueued.unlock() +// enqueue push the transaction onto the enqueued queue or replace it +func (a *account) enqueue(tx *types.Transaction, replace bool) { + replaceInQueue := func(queue minNonceQueue) bool { + for i, x := range queue { + if x.Nonce == tx.Nonce { + queue[i] = tx // replace - if a.enqueued.length() == a.maxEnqueued { - return ErrMaxEnqueuedLimitReached - } + return true + } + } - // reject low nonce tx - if tx.Nonce < a.getNonce() { - return ErrNonceTooLow + return false } - // enqueue tx - a.enqueued.push(tx) + a.nonceToTx.set(tx) - return nil + if !replace { + a.enqueued.push(tx) + } else { + // first -> try to replace in enqueued + if !replaceInQueue(a.enqueued.queue) { + // .. then try to replace in promoted + replaceInQueue(a.promoted.queue) + } + } } // Promote moves eligible transactions from enqueued to promoted. @@ -295,6 +348,10 @@ func (a *account) promote() (promoted []*types.Transaction, pruned []*types.Tran a.setNonce(nextNonce) } + a.nonceToTx.lock() + a.nonceToTx.remove(pruned...) + a.nonceToTx.unlock() + return } diff --git a/txpool/event_manager.go b/txpool/event_manager.go index c2e4dad907..f4c8c72cda 100644 --- a/txpool/event_manager.go +++ b/txpool/event_manager.go @@ -41,7 +41,7 @@ func (em *eventManager) subscribe(eventTypes []proto.EventType) *subscribeResult eventTypes: eventTypes, outputCh: make(chan *proto.TxPoolEvent), doneCh: make(chan struct{}), - notifyCh: make(chan struct{}, 1), + notifyCh: make(chan struct{}, 10), eventStore: &eventQueue{ events: make([]*proto.TxPoolEvent, 0), }, diff --git a/txpool/event_subscription.go b/txpool/event_subscription.go index 4eafe40840..97332a19f5 100644 --- a/txpool/event_subscription.go +++ b/txpool/event_subscription.go @@ -38,12 +38,12 @@ func (es *eventSubscription) eventSupported(eventType proto.EventType) bool { // close stops the event subscription func (es *eventSubscription) close() { close(es.doneCh) - close(es.outputCh) - close(es.notifyCh) } // runLoop is the main loop that listens for notifications and handles the event / close signals func (es *eventSubscription) runLoop() { + defer close(es.outputCh) + for { select { case <-es.doneCh: // Break if a close signal has been received diff --git a/txpool/queue_account.go b/txpool/queue_account.go index 63340b0b2d..78b938199f 100644 --- a/txpool/queue_account.go +++ b/txpool/queue_account.go @@ -144,10 +144,11 @@ func (q *minNonceQueue) Push(x interface{}) { } func (q *minNonceQueue) Pop() interface{} { - old := q - n := len(*old) - x := (*old)[n-1] - *q = (*old)[0 : n-1] + old := *q + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + *q = old[0 : n-1] - return x + return item } diff --git a/txpool/slot_gauge.go b/txpool/slot_gauge.go index 2df18c6df0..ee7550d781 100644 --- a/txpool/slot_gauge.go +++ b/txpool/slot_gauge.go @@ -40,6 +40,11 @@ func (g *slotGauge) highPressure() bool { return g.read() > (highPressureMark*g.max)/100 } +// free slots returns how many slots are currently available +func (g *slotGauge) freeSlots() uint64 { + return g.max - g.read() +} + // slotsRequired calculates the number of slots required for given transaction(s). func slotsRequired(txs ...*types.Transaction) uint64 { slots := uint64(0) diff --git a/txpool/txpool.go b/txpool/txpool.go index 57f2187edc..b5a2aca714 100644 --- a/txpool/txpool.go +++ b/txpool/txpool.go @@ -59,6 +59,8 @@ var ( ErrTipAboveFeeCap = errors.New("max priority fee per gas higher than max fee per gas") ErrTipVeryHigh = errors.New("max priority fee per gas higher than 2^256-1") ErrFeeCapVeryHigh = errors.New("max fee per gas higher than 2^256-1") + ErrNonceExistsInPool = errors.New("tx with the same nonce is already present") + ErrReplacementUnderpriced = errors.New("replacement tx underpriced") ) // indicates origin of a transaction @@ -162,7 +164,6 @@ type TxPool struct { // channels on which the pool's event loop // does dispatching/handling requests. - enqueueReqCh chan enqueueRequest promoteReqCh chan promoteRequest pruneCh chan struct{} @@ -208,7 +209,6 @@ func NewTxPool( priceLimit: config.PriceLimit, // main loop channels - enqueueReqCh: make(chan enqueueRequest), promoteReqCh: make(chan promoteRequest), pruneCh: make(chan struct{}), shutdownCh: make(chan struct{}), @@ -272,8 +272,6 @@ func (p *TxPool) Start() { select { case <-p.shutdownCh: return - case req := <-p.enqueueReqCh: - go p.handleEnqueueRequest(req) case req := <-p.promoteReqCh: go p.handlePromoteRequest(req) } @@ -284,7 +282,7 @@ func (p *TxPool) Start() { // Close shuts down the pool's main loop. func (p *TxPool) Close() { p.eventManager.Close() - p.shutdownCh <- struct{}{} + close(p.shutdownCh) } // SetSigner sets the signer the pool will use @@ -365,11 +363,19 @@ func (p *TxPool) Pop(tx *types.Transaction) { account := p.accounts.get(tx.From) account.promoted.lock(true) - defer account.promoted.unlock() + account.nonceToTx.lock() + + defer func() { + account.nonceToTx.unlock() + account.promoted.unlock() + }() // pop the top most promoted tx account.promoted.pop() + // update the account nonce -> *tx map + account.nonceToTx.remove(tx) + // successfully popping an account resets its demotions count to 0 account.resetDemotions() @@ -393,6 +399,13 @@ func (p *TxPool) Drop(tx *types.Transaction) { account.promoted.lock(true) account.enqueued.lock(true) + account.nonceToTx.lock() + + defer func() { + account.nonceToTx.unlock() + account.enqueued.unlock() + account.promoted.unlock() + }() // num of all txs dropped droppedCount := 0 @@ -406,15 +419,13 @@ func (p *TxPool) Drop(tx *types.Transaction) { droppedCount += len(txs) } - defer func() { - account.enqueued.unlock() - account.promoted.unlock() - }() - // rollback nonce nextNonce := tx.Nonce account.setNonce(nextNonce) + // reset accounts nonce map + account.nonceToTx.reset() + // drop promoted dropped := account.promoted.clear() clearAccountQueue(dropped) @@ -698,6 +709,9 @@ func (p *TxPool) pruneAccountsWithNonceHoles() { account.enqueued.lock(true) defer account.enqueued.unlock() + account.nonceToTx.lock() + defer account.nonceToTx.unlock() + firstTx := account.enqueued.peek() if firstTx == nil { @@ -710,6 +724,7 @@ func (p *TxPool) pruneAccountsWithNonceHoles() { removed := account.enqueued.clear() + account.nonceToTx.remove(removed...) p.index.remove(removed...) p.gauge.decrease(slotsRequired(removed...)) @@ -724,10 +739,7 @@ func (p *TxPool) pruneAccountsWithNonceHoles() { // (only once) and an enqueueRequest is signaled. func (p *TxPool) addTx(origin txOrigin, tx *types.Transaction) error { if p.logger.IsDebug() { - p.logger.Debug("add tx", - "origin", origin.String(), - "hash", tx.Hash.String(), - ) + p.logger.Debug("add tx", "origin", origin.String(), "hash", tx.Hash.String()) } // validate incoming tx @@ -735,25 +747,66 @@ func (p *TxPool) addTx(origin txOrigin, tx *types.Transaction) error { return err } + // calculate tx hash + tx.ComputeHash() + + // initialize account for this address once or retrieve existing one + account := p.getOrCreateAccount(tx.From) + // populate currently free slots + slotsFree := p.gauge.freeSlots() + + account.promoted.lock(true) + account.enqueued.lock(true) + account.nonceToTx.lock() + + defer func() { + account.nonceToTx.unlock() + account.enqueued.unlock() + account.promoted.unlock() + }() + + accountNonce := account.getNonce() + + // only accept transactions with expected nonce if p.gauge.highPressure() { p.signalPruning() - // only accept transactions with expected nonce - if account := p.accounts.get(tx.From); account != nil && - tx.Nonce > account.getNonce() { + if tx.Nonce > accountNonce { metrics.IncrCounter([]string{txPoolMetrics, "rejected_future_tx"}, 1) return ErrRejectFutureTx } } + // try to find if there is transaction with same nonce for this account + oldTxWithSameNonce := account.nonceToTx.get(tx.Nonce) + if oldTxWithSameNonce != nil { + if oldTxWithSameNonce.Hash == tx.Hash { + metrics.IncrCounter([]string{txPoolMetrics, "already_known_tx"}, 1) + + return ErrAlreadyKnown + } else if oldTxWithSameNonce.GasPrice.Cmp(tx.GasPrice) >= 0 { + // if tx with same nonce does exist and has same or better gas price -> return error + return ErrUnderpriced + } + + slotsFree += slotsRequired(oldTxWithSameNonce) // add old tx slots + } else { + if account.enqueued.length() == account.maxEnqueued { + return ErrMaxEnqueuedLimitReached + } + + // reject low nonce tx + if tx.Nonce < accountNonce { + return ErrNonceTooLow + } + } + // check for overflow - if slotsRequired(tx) > p.gauge.max-p.gauge.read() { + if slotsRequired(tx) > slotsFree { return ErrTxPoolOverflow } - tx.ComputeHash() - // add to index if ok := p.index.add(tx); !ok { metrics.IncrCounter([]string{txPoolMetrics, "already_known_tx"}, 1) @@ -761,53 +814,36 @@ func (p *TxPool) addTx(origin txOrigin, tx *types.Transaction) error { return ErrAlreadyKnown } - // initialize account for this address once - p.createAccountOnce(tx.From) + if oldTxWithSameNonce != nil { + p.index.remove(oldTxWithSameNonce) + p.gauge.decrease(slotsRequired(oldTxWithSameNonce)) + } else { + metrics.SetGauge([]string{txPoolMetrics, "added_tx"}, 1) + } - // send request [BLOCKING] - p.enqueueReqCh <- enqueueRequest{tx: tx} - p.eventManager.signalEvent(proto.EventType_ADDED, tx.Hash) + account.enqueue(tx, oldTxWithSameNonce != nil) // add or replace tx into account + p.gauge.increase(slotsRequired(tx)) - metrics.SetGauge([]string{txPoolMetrics, "added_tx"}, 1) + go p.invokePromotion(tx, tx.Nonce <= accountNonce) // don't signal promotion for higher nonce txs return nil } -// handleEnqueueRequest attempts to enqueue the transaction -// contained in the given request to the associated account. -// If, afterwards, the account is eligible for promotion, -// a promoteRequest is signaled. -func (p *TxPool) handleEnqueueRequest(req enqueueRequest) { - tx := req.tx - addr := req.tx.From - - // fetch account - account := p.accounts.get(addr) - - // enqueue tx - if err := account.enqueue(tx); err != nil { - p.logger.Error("enqueue request", "err", err) - - p.index.remove(tx) - - return - } +func (p *TxPool) invokePromotion(tx *types.Transaction, callPromote bool) { + p.eventManager.signalEvent(proto.EventType_ADDED, tx.Hash) if p.logger.IsDebug() { p.logger.Debug("enqueue request", "hash", tx.Hash.String()) } - p.gauge.increase(slotsRequired(tx)) - p.eventManager.signalEvent(proto.EventType_ENQUEUED, tx.Hash) - if tx.Nonce > account.getNonce() { - // don't signal promotion for - // higher nonce txs - return + if callPromote { + select { + case <-p.shutdownCh: + case p.promoteReqCh <- promoteRequest{account: tx.From}: // BLOCKING + } } - - p.promoteReqCh <- promoteRequest{account: addr} // BLOCKING } // handlePromoteRequest handles moving promotable transactions @@ -971,11 +1007,11 @@ func (p *TxPool) updateAccountSkipsCounts(latestActiveAccounts map[types.Address ) } -// createAccountOnce creates an account and +// getOrCreateAccount creates an account and // ensures it is only initialized once. -func (p *TxPool) createAccountOnce(newAddr types.Address) *account { - if p.accounts.exists(newAddr) { - return nil +func (p *TxPool) getOrCreateAccount(newAddr types.Address) *account { + if account := p.accounts.get(newAddr); account != nil { + return account } // fetch nonce from state diff --git a/txpool/txpool_test.go b/txpool/txpool_test.go index c19647bc23..8e80a83d04 100644 --- a/txpool/txpool_test.go +++ b/txpool/txpool_test.go @@ -4,9 +4,11 @@ import ( "context" "crypto/ecdsa" "crypto/rand" + "errors" "fmt" "math/big" "sync" + "sync/atomic" "testing" "time" @@ -268,12 +270,7 @@ func TestAddTxErrors(t *testing.T) { tx = signTx(tx) // enqueue tx - go func() { - assert.NoError(t, - pool.addTx(local, tx), - ) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + assert.NoError(t, pool.addTx(local, tx)) <-pool.promoteReqCh }) @@ -299,12 +296,7 @@ func TestAddTxErrors(t *testing.T) { tx = signTx(tx) // send the tx beforehand - go func() { - err := pool.addTx(local, tx) - assert.NoError(t, err) - }() - - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + assert.NoError(t, pool.addTx(local, tx)) <-pool.promoteReqCh assert.ErrorIs(t, @@ -313,6 +305,28 @@ func TestAddTxErrors(t *testing.T) { ) }) + t.Run("ErrAlreadyKnown", func(t *testing.T) { + t.Parallel() + pool := setupPool() + + tx := newTx(defaultAddr, 0, 1) + tx.GasPrice = big.NewInt(200) + tx = signTx(tx) + + // send the tx beforehand + assert.NoError(t, pool.addTx(local, tx)) + <-pool.promoteReqCh + + tx = newTx(defaultAddr, 0, 1) + tx.GasPrice = big.NewInt(100) + tx = signTx(tx) + + assert.ErrorIs(t, + pool.addTx(local, tx), + ErrUnderpriced, + ) + }) + t.Run("ErrOversizedData", func(t *testing.T) { t.Parallel() pool := setupPool() @@ -375,7 +389,7 @@ func TestPruneAccountsWithNonceHoles(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) - pool.createAccountOnce(addr1) + pool.getOrCreateAccount(addr1) assert.Equal(t, uint64(0), pool.gauge.read()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) @@ -398,13 +412,11 @@ func TestPruneAccountsWithNonceHoles(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) + tx := newTx(addr1, 0, 1) + // enqueue tx - go func() { - assert.NoError(t, - pool.addTx(local, newTx(addr1, 0, 1)), - ) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + assert.NoError(t, pool.addTx(local, tx)) + acc := pool.accounts.get(addr1) <-pool.promoteReqCh assert.Equal(t, uint64(1), pool.gauge.read()) @@ -420,6 +432,8 @@ func TestPruneAccountsWithNonceHoles(t *testing.T) { assert.Equal(t, uint64(1), pool.gauge.read()) assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) + + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) }, ) @@ -432,14 +446,10 @@ func TestPruneAccountsWithNonceHoles(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) - // enqueue tx - go func() { - assert.NoError(t, - pool.addTx(local, newTx(addr1, 5, 1)), - ) - }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + tx := newTx(addr1, 5, 1) + // enqueue tx + assert.NoError(t, pool.addTx(local, tx)) assert.Equal(t, uint64(1), pool.gauge.read()) assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) @@ -453,6 +463,10 @@ func TestPruneAccountsWithNonceHoles(t *testing.T) { assert.Equal(t, uint64(0), pool.gauge.read()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) + + acc := pool.accounts.get(addr1) + + assert.Equal(t, int(0), len(acc.nonceToTx.mapping)) }, ) } @@ -483,9 +497,6 @@ func TestAddTxHighPressure(t *testing.T) { // pick up signal _, ok := <-pool.pruneCh assert.True(t, ok) - - // unblock the handler (handler would block entire test run) - <-pool.enqueueReqCh }, ) @@ -498,7 +509,7 @@ func TestAddTxHighPressure(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) - pool.createAccountOnce(addr1) + pool.getOrCreateAccount(addr1) pool.accounts.get(addr1).nextNonce = 5 // mock high pressure @@ -509,6 +520,10 @@ func TestAddTxHighPressure(t *testing.T) { ErrRejectFutureTx, pool.addTx(local, newTx(addr1, 8, 1)), ) + + acc := pool.accounts.get(addr1) + + assert.Equal(t, int(0), len(acc.nonceToTx.mapping)) }, ) @@ -521,7 +536,7 @@ func TestAddTxHighPressure(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) - pool.createAccountOnce(addr1) + pool.getOrCreateAccount(addr1) pool.accounts.get(addr1).nextNonce = 5 // mock high pressure @@ -529,15 +544,15 @@ func TestAddTxHighPressure(t *testing.T) { println("slots", slots, "max", pool.gauge.max) pool.gauge.increase(slots) - go func() { - assert.NoError(t, - pool.addTx(local, newTx(addr1, 5, 1)), - ) - }() - enq := <-pool.enqueueReqCh + tx := newTx(addr1, 5, 1) + assert.NoError(t, pool.addTx(local, tx)) - _, exists := pool.index.get(enq.tx.Hash) + _, exists := pool.index.get(tx.Hash) assert.True(t, exists) + + acc := pool.accounts.get(addr1) + + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) }, ) } @@ -564,15 +579,12 @@ func TestAddGossipTx(t *testing.T) { } // send tx - go func() { - protoTx := &proto.Txn{ - Raw: &any.Any{ - Value: signedTx.MarshalRLP(), - }, - } - pool.addGossipTx(protoTx, "") - }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + protoTx := &proto.Txn{ + Raw: &any.Any{ + Value: signedTx.MarshalRLP(), + }, + } + pool.addGossipTx(protoTx, "") assert.Equal(t, uint64(1), pool.accounts.get(sender).enqueued.length()) }) @@ -586,7 +598,7 @@ func TestAddGossipTx(t *testing.T) { pool.SetSealing(false) - pool.createAccountOnce(sender) + pool.getOrCreateAccount(sender) signedTx, err := signer.SignTx(tx, key) if err != nil { @@ -615,10 +627,7 @@ func TestDropKnownGossipTx(t *testing.T) { tx := newTx(addr1, 1, 1) // send tx as local - go func() { - assert.NoError(t, pool.addTx(local, tx)) - }() - <-pool.enqueueReqCh + assert.NoError(t, pool.addTx(local, tx)) _, exists := pool.index.get(tx.Hash) assert.True(t, exists) @@ -628,6 +637,10 @@ func TestDropKnownGossipTx(t *testing.T) { pool.addTx(gossip, tx), ErrAlreadyKnown, ) + + acc := pool.accounts.get(addr1) + + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) } func TestEnqueueHandler(t *testing.T) { @@ -643,11 +656,8 @@ func TestEnqueueHandler(t *testing.T) { pool.SetSigner(&mockSigner{}) // send higher nonce tx - go func() { - err := pool.addTx(local, newTx(addr1, 10, 1)) // 10 > 0 - assert.NoError(t, err) - }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, newTx(addr1, 10, 1)) // 10 > 0 + assert.NoError(t, err) assert.Equal(t, uint64(1), pool.gauge.read()) assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) @@ -664,18 +674,21 @@ func TestEnqueueHandler(t *testing.T) { pool.SetSigner(&mockSigner{}) // setup prestate - acc := pool.createAccountOnce(addr1) + acc := pool.getOrCreateAccount(addr1) acc.setNonce(20) // send tx go func() { err := pool.addTx(local, newTx(addr1, 10, 1)) // 10 < 20 - assert.NoError(t, err) + assert.EqualError(t, err, "nonce too low") }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) assert.Equal(t, uint64(0), pool.gauge.read()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) + + acc.nonceToTx.lock() + assert.Equal(t, int(0), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() }, ) @@ -689,11 +702,8 @@ func TestEnqueueHandler(t *testing.T) { pool.SetSigner(&mockSigner{}) // send tx - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) // 0 == 0 - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, newTx(addr1, 0, 1)) // 0 == 0 + assert.NoError(t, err) // catch pending promotion <-pool.promoteReqCh @@ -701,6 +711,11 @@ func TestEnqueueHandler(t *testing.T) { assert.Equal(t, uint64(1), pool.gauge.read()) assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + + acc := pool.accounts.get(addr1) + acc.nonceToTx.lock() + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() }, ) @@ -712,23 +727,15 @@ func TestEnqueueHandler(t *testing.T) { fillEnqueued := func(pool *TxPool, num uint64) { // first tx will signal promotion, grab the signal // but don't execute the handler - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) - assert.NoError(t, err) - }() - - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err := pool.addTx(local, newTx(addr1, 0, 1)) + assert.NoError(t, err) // catch pending promotion <-pool.promoteReqCh for i := uint64(1); i < num; i++ { - go func() { - err := pool.addTx(local, newTx(addr1, i, 1)) - assert.NoError(t, err) - }() - - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err := pool.addTx(local, newTx(addr1, i, 1)) + assert.NoError(t, err) } } @@ -746,17 +753,381 @@ func TestEnqueueHandler(t *testing.T) { // send next expected tx go func() { - assert.NoError(t, - pool.addTx(local, newTx(addr1, 1, 1)), - ) + err := pool.addTx(local, newTx(addr1, 1, 1)) + assert.True(t, errors.Is(err, ErrMaxEnqueuedLimitReached)) }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + // assert the transaction was rejected + assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(1), pool.gauge.read()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + + acc := pool.accounts.get(addr1) + acc.nonceToTx.lock() + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() + }, + ) +} + +func TestAddTx(t *testing.T) { + t.Parallel() + + t.Run( + "return underpriced for cheaper tx when pricier tx exists in enqueued", + func(t *testing.T) { + t.Parallel() + + // helper + newPricedTx := func( + addr types.Address, + nonce, + gasPrice, + slots uint64, + ) *types.Transaction { + tx := newTx(addr, nonce, slots) + tx.GasPrice.SetUint64(gasPrice) + + return tx + } + + pool, err := newTestPool() + assert.NoError(t, err) + pool.SetSigner(&mockSigner{}) + + tx1 := newPricedTx(addr1, 0, 10, 2) + tx2 := newPricedTx(addr1, 0, 20, 3) + + // add the transactions + assert.NoError(t, pool.addTx(local, tx2)) + + _, exists := pool.index.get(tx2.Hash) + assert.True(t, exists) + + // check the account nonce before promoting + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + + assert.ErrorIs(t, pool.addTx(local, tx1), ErrUnderpriced) + + // execute the enqueue handlers + <-pool.promoteReqCh + + // at this point the pointer of the first tx should be overwritten by the second pricier tx + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + assert.Len(t, pool.index.all, int(1)) + + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) + + acc := pool.accounts.get(addr1) + acc.nonceToTx.lock() + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() + }, + ) + + t.Run( + "return underpriced for cheaper tx when pricier tx exists in promoted", + func(t *testing.T) { + t.Parallel() + + // helper + newPricedTx := func( + addr types.Address, + nonce, + gasPrice, + slots uint64, + ) *types.Transaction { + tx := newTx(addr, nonce, slots) + tx.GasPrice.SetUint64(gasPrice) + + return tx + } + + pool, err := newTestPool() + assert.NoError(t, err) + pool.SetSigner(&mockSigner{}) + + tx1 := newPricedTx(addr1, 0, 10, 2) + tx2 := newPricedTx(addr1, 0, 20, 3) + + // add the transactions + assert.NoError(t, pool.addTx(local, tx2)) + + _, exists := pool.index.get(tx2.Hash) + assert.True(t, exists) + + // check the account nonce before promoting + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + + // execute the enqueue handlers + promoteReq := <-pool.promoteReqCh + pool.handlePromoteRequest(promoteReq) + + assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) + + assert.ErrorIs(t, pool.addTx(local, tx1), ErrUnderpriced) + + acc := pool.accounts.get(addr1) + acc.nonceToTx.lock() + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() + }, + ) + + t.Run( + "addTx handler discards cheaper tx", + func(t *testing.T) { + t.Parallel() + + // helper + newPricedTx := func( + addr types.Address, + nonce, + gasPrice, + slots uint64, + ) *types.Transaction { + tx := newTx(addr, nonce, slots) + tx.GasPrice.SetUint64(gasPrice) + + return tx + } + + pool, err := newTestPool() + assert.NoError(t, err) + pool.SetSigner(&mockSigner{}) + + tx1 := newPricedTx(addr1, 0, 10, 2) + tx2 := newPricedTx(addr1, 0, 20, 3) + + // add the transactions + assert.NoError(t, pool.addTx(local, tx1)) + assert.NoError(t, pool.addTx(local, tx2)) + + _, exists := pool.index.get(tx1.Hash) + assert.False(t, exists) + + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + // check the account nonce before promoting + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + + // execute the enqueue handlers + promReq1 := <-pool.promoteReqCh + promReq2 := <-pool.promoteReqCh + + // at this point the pointer of the first tx should be overwritten by the second pricier tx + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + assert.Len(t, pool.index.all, int(1)) + + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) + + // promote the second Tx and remove the first Tx + pool.handlePromoteRequest(promReq1) + + assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) // should be empty + assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) + + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + assert.Equal(t, len(pool.index.all), int(1)) + + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) + + // should do nothing in the 2nd promotion + pool.handlePromoteRequest(promReq2) + + assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) + + // because the *tx1 and *tx2 now contain the same hash we only need to check for *tx2 existence + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + assert.Equal(t, len(pool.index.all), int(1)) + + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) + + acc := pool.accounts.get(addr1) + acc.nonceToTx.lock() + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() + }, + ) + + t.Run( + "addTx discards cheaper tx from enqueued", + func(t *testing.T) { + t.Parallel() + + // helper + newPricedTx := func( + addr types.Address, + nonce, + gasPrice, + slots uint64, + ) *types.Transaction { + tx := newTx(addr, nonce, slots) + tx.GasPrice.SetUint64(gasPrice) + + return tx + } + + pool, err := newTestPool() + assert.NoError(t, err) + pool.SetSigner(&mockSigner{}) + + tx1 := newPricedTx(addr1, 0, 10, 2) + tx2 := newPricedTx(addr1, 0, 20, 3) + + // add the transactions + assert.NoError(t, pool.addTx(local, tx1)) + assert.NoError(t, pool.addTx(local, tx2)) + + acc := pool.accounts.get(addr1) + assert.NotNil(t, acc) + + _, exists := pool.index.get(tx1.Hash) + assert.False(t, exists) + + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + maptx2 := acc.nonceToTx.get(tx2.Nonce) + nonceMapLength := len(acc.nonceToTx.mapping) + + assert.NotNil(t, maptx2) + assert.Equal(t, tx2, maptx2) + assert.Equal(t, int(1), nonceMapLength) + + assert.Equal(t, len(pool.index.all), int(1)) + + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) + + assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + }, + ) + + t.Run( + "addTx discards cheaper tx from promoted", + func(t *testing.T) { + t.Parallel() + + // helper + newPricedTx := func( + addr types.Address, + nonce, + gasPrice, + slots uint64, + ) *types.Transaction { + tx := newTx(addr, nonce, slots) + tx.GasPrice.SetUint64(gasPrice) + + return tx + } + + pool, err := newTestPool() + assert.NoError(t, err) + pool.SetSigner(&mockSigner{}) + + tx1 := newPricedTx(addr1, 0, 10, 2) + tx2 := newPricedTx(addr1, 0, 20, 3) + + // add the transactions + assert.NoError(t, pool.addTx(local, tx1)) + + acc := pool.accounts.get(addr1) + assert.NotNil(t, acc) + + promReq1 := <-pool.promoteReqCh + pool.handlePromoteRequest(promReq1) + + assert.Equal(t, len(pool.index.all), int(1)) + + assert.Equal( + t, + slotsRequired(tx1), + pool.gauge.read(), + ) + + assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) + + _, exists := pool.index.get(tx1.Hash) + assert.True(t, exists) + + _, exists = pool.index.get(tx2.Hash) + assert.False(t, exists) + + assert.NoError(t, pool.addTx(local, tx2)) + + maptx2 := acc.nonceToTx.get(tx2.Nonce) + nonceMapLength := len(acc.nonceToTx.mapping) + + assert.NotNil(t, maptx2) + assert.Equal(t, tx2, maptx2) + assert.Equal(t, int(1), nonceMapLength) + + _, exists = pool.index.get(tx1.Hash) + assert.False(t, exists) + + _, exists = pool.index.get(tx2.Hash) + assert.True(t, exists) + + assert.Equal(t, len(pool.index.all), int(1)) + + assert.Equal( + t, + slotsRequired(tx2), + pool.gauge.read(), + ) - // assert the transaction was rejected - assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) - assert.Equal(t, uint64(1), pool.gauge.read()) - assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) + assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) }, ) } @@ -780,7 +1151,7 @@ func TestPromoteHandler(t *testing.T) { } // fresh account (queues are empty) - acc := pool.createAccountOnce(addr1) + acc := pool.getOrCreateAccount(addr1) acc.setNonce(7) // fake a promotion @@ -790,19 +1161,30 @@ func TestPromoteHandler(t *testing.T) { assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) // enqueue higher nonce tx - go func() { - err := pool.addTx(local, newTx(addr1, 10, 1)) - assert.NoError(t, err) - }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + tx := newTx(addr1, 10, 1) + err = pool.addTx(local, tx) + assert.NoError(t, err) + assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + mapLen := len(acc.nonceToTx.mapping) + maptx := acc.nonceToTx.get(tx.Nonce) + + assert.Equal(t, int(1), mapLen) + assert.Equal(t, tx, maptx) + // fake a promotion go signalPromotion() pool.handlePromoteRequest(<-pool.promoteReqCh) assert.Equal(t, uint64(1), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + + mapLen = len(acc.nonceToTx.mapping) + maptx = acc.nonceToTx.get(tx.Nonce) + + assert.Equal(t, int(1), mapLen) + assert.Equal(t, tx, maptx) }) t.Run("promote one tx", func(t *testing.T) { @@ -812,11 +1194,18 @@ func TestPromoteHandler(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + tx := newTx(addr1, 0, 1) + err = pool.addTx(local, tx) + assert.NoError(t, err) + + acc := pool.accounts.get(addr1) + assert.NotNil(t, acc) + + mapLen := len(acc.nonceToTx.mapping) + maptx := acc.nonceToTx.get(tx.Nonce) + + assert.Equal(t, int(1), mapLen) + assert.Equal(t, tx, maptx) // tx enqueued -> promotion signaled pool.handlePromoteRequest(<-pool.promoteReqCh) @@ -826,6 +1215,12 @@ func TestPromoteHandler(t *testing.T) { assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) + + mapLen = len(acc.nonceToTx.mapping) + maptx = acc.nonceToTx.get(tx.Nonce) + + assert.Equal(t, int(1), mapLen) + assert.Equal(t, tx, maptx) }) t.Run("promote several txs", func(t *testing.T) { @@ -839,25 +1234,34 @@ func TestPromoteHandler(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) + txs := make([]*types.Transaction, 10) + + for i := 0; i < 10; i++ { + txs[i] = newTx(addr1, uint64(i), 1) + } + // send the first (expected) tx -> signals promotion - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) // 0 == 0 - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, txs[0]) // 0 == 0 + assert.NoError(t, err) // save the promotion handler req := <-pool.promoteReqCh // send the remaining txs (all will be enqueued) - for nonce := uint64(1); nonce < 10; nonce++ { - go func() { - err := pool.addTx(local, newTx(addr1, nonce, 1)) - assert.NoError(t, err) - }() - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + for i := 1; i < 10; i++ { + err := pool.addTx(local, txs[i]) + assert.NoError(t, err) + } + + acc := pool.accounts.get(addr1) + + for _, tx := range txs { + maptx := acc.nonceToTx.get(tx.Nonce) + assert.Equal(t, tx, maptx) } + assert.Equal(t, int(10), len(acc.nonceToTx.mapping)) + // verify all 10 are enqueued assert.Equal(t, uint64(10), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) @@ -870,6 +1274,13 @@ func TestPromoteHandler(t *testing.T) { assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(10), pool.accounts.get(addr1).promoted.length()) + + for _, tx := range txs { + maptx := acc.nonceToTx.get(tx.Nonce) + assert.Equal(t, tx, maptx) + } + + assert.Equal(t, int(10), len(acc.nonceToTx.mapping)) }) t.Run("one tx -> one promotion", func(t *testing.T) { @@ -881,127 +1292,33 @@ func TestPromoteHandler(t *testing.T) { assert.NoError(t, err) pool.SetSigner(&mockSigner{}) - for nonce := uint64(0); nonce < 20; nonce++ { - go func(nonce uint64) { - err := pool.addTx(local, newTx(addr1, nonce, 1)) - assert.NoError(t, err) - }(nonce) - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + txs := make([]*types.Transaction, 20) + + for i := 0; i < 20; i++ { + txs[i] = newTx(addr1, uint64(i), 1) + } + + for _, tx := range txs { + err := pool.addTx(local, tx) + assert.NoError(t, err) pool.handlePromoteRequest(<-pool.promoteReqCh) } + acc := pool.accounts.get(addr1) + + for _, tx := range txs { + maptx := acc.nonceToTx.get(tx.Nonce) + assert.Equal(t, tx, maptx) + } + + assert.Equal(t, int(20), len(acc.nonceToTx.mapping)) + assert.Equal(t, uint64(20), pool.gauge.read()) assert.Equal(t, uint64(20), pool.accounts.get(addr1).getNonce()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(20), pool.accounts.get(addr1).promoted.length()) }) - - t.Run( - "promote handler discards cheaper tx", - func(t *testing.T) { - t.Parallel() - - // helper - newPricedTx := func( - addr types.Address, - nonce, - gasPrice, - slots uint64, - ) *types.Transaction { - tx := newTx(addr, nonce, slots) - tx.GasPrice.SetUint64(gasPrice) - - return tx - } - - pool, err := newTestPool() - assert.NoError(t, err) - pool.SetSigner(&mockSigner{}) - - addTx := func(tx *types.Transaction) enqueueRequest { - tx.ComputeHash() - - go func() { - assert.NoError(t, - pool.addTx(local, tx), - ) - }() - - // grab the enqueue signal - return <-pool.enqueueReqCh - } - - handleEnqueueRequest := func(req enqueueRequest) promoteRequest { - go func() { - pool.handleEnqueueRequest(req) - }() - - return <-pool.promoteReqCh - } - - assertTxExists := func(t *testing.T, tx *types.Transaction, shouldExists bool) { - t.Helper() - - _, exists := pool.index.get(tx.Hash) - assert.Equal(t, shouldExists, exists) - } - - tx1 := newPricedTx(addr1, 0, 10, 2) - tx2 := newPricedTx(addr1, 0, 20, 3) - - // add the transactions - enqTx1 := addTx(tx1) - enqTx2 := addTx(tx2) - - assertTxExists(t, tx1, true) - assertTxExists(t, tx2, true) - - // check the account nonce before promoting - assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) - - // execute the enqueue handlers - promReq1 := handleEnqueueRequest(enqTx1) - promReq2 := handleEnqueueRequest(enqTx2) - - assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) - assert.Equal(t, uint64(2), pool.accounts.get(addr1).enqueued.length()) - assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) - assert.Equal( - t, - slotsRequired(tx1)+slotsRequired(tx2), - pool.gauge.read(), - ) - - // promote the second Tx and remove the first Tx - pool.handlePromoteRequest(promReq1) - - assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) - assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) // should be empty - assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) - assertTxExists(t, tx1, false) - assertTxExists(t, tx2, true) - assert.Equal( - t, - slotsRequired(tx2), - pool.gauge.read(), - ) - - // should do nothing in the 2nd promotion - pool.handlePromoteRequest(promReq2) - - assert.Equal(t, uint64(1), pool.accounts.get(addr1).getNonce()) - assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) - assert.Equal(t, uint64(1), pool.accounts.get(addr1).promoted.length()) - assertTxExists(t, tx1, false) - assertTxExists(t, tx2, true) - assert.Equal( - t, - slotsRequired(tx2), - pool.gauge.read(), - ) - }, - ) } func TestResetAccount(t *testing.T) { @@ -1080,14 +1397,11 @@ func TestResetAccount(t *testing.T) { pool.SetSigner(&mockSigner{}) // setup prestate - acc := pool.createAccountOnce(addr1) + acc := pool.getOrCreateAccount(addr1) acc.setNonce(test.txs[0].Nonce) - go func() { - err := pool.addTx(local, test.txs[0]) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, test.txs[0]) + assert.NoError(t, err) // save the promotion req := <-pool.promoteReqCh @@ -1098,21 +1412,25 @@ func TestResetAccount(t *testing.T) { // first was handled continue } - go func(tx *types.Transaction) { - err := pool.addTx(local, tx) - assert.NoError(t, err) - }(tx) - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + + err := pool.addTx(local, tx) + assert.NoError(t, err) } pool.handlePromoteRequest(req) - assert.Equal(t, uint64(0), pool.accounts.get(addr1).enqueued.length()) - assert.Equal(t, uint64(len(test.txs)), pool.accounts.get(addr1).promoted.length()) + + assert.Equal(t, uint64(0), acc.enqueued.length()) + assert.Equal(t, uint64(len(test.txs)), acc.promoted.length()) + assert.Equal(t, len(test.txs), len(acc.nonceToTx.mapping)) pool.resetAccounts(map[types.Address]uint64{ addr1: test.newNonce, }) + assert.Equal(t, + int(test.expected.accounts[addr1].promoted+test.expected.accounts[addr1].enqueued), + len(acc.nonceToTx.mapping)) + assert.Equal(t, test.expected.slots, pool.gauge.read()) assert.Equal(t, // enqueued test.expected.accounts[addr1].enqueued, @@ -1219,13 +1537,13 @@ func TestResetAccount(t *testing.T) { // setup prestate for _, tx := range test.txs { - go func(tx *types.Transaction) { - err := pool.addTx(local, tx) - assert.NoError(t, err) - }(tx) - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err := pool.addTx(local, tx) + assert.NoError(t, err) } + acc := pool.accounts.get(addr1) + assert.Equal(t, len(test.txs), len(acc.nonceToTx.mapping)) + assert.Equal(t, uint64(len(test.txs)), pool.accounts.get(addr1).enqueued.length()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) @@ -1240,6 +1558,15 @@ func TestResetAccount(t *testing.T) { }) } + // not sure if this lock is needed + // but because resetAccounts() runs in a separate goroutine if it is triggered + // it is safer to have a lock on the nonce map + acc.nonceToTx.lock() + + assert.Equal(t, int(test.expected.accounts[addr1].enqueued+test.expected.accounts[addr1].promoted), len(acc.nonceToTx.mapping)) + + acc.nonceToTx.unlock() + assert.Equal(t, test.expected.slots, pool.gauge.read()) assert.Equal(t, // enqueued test.expected.accounts[addr1].enqueued, @@ -1366,14 +1693,11 @@ func TestResetAccount(t *testing.T) { pool.SetSigner(&mockSigner{}) // setup prestate - acc := pool.createAccountOnce(addr1) + acc := pool.getOrCreateAccount(addr1) acc.setNonce(test.txs[0].Nonce) - go func() { - err := pool.addTx(local, test.txs[0]) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, test.txs[0]) + assert.NoError(t, err) // save the promotion req := <-pool.promoteReqCh @@ -1384,13 +1708,13 @@ func TestResetAccount(t *testing.T) { // first was handled continue } - go func(tx *types.Transaction) { - err := pool.addTx(local, tx) - assert.NoError(t, err) - }(tx) - pool.handleEnqueueRequest(<-pool.enqueueReqCh) + + err := pool.addTx(local, tx) + assert.NoError(t, err) } + assert.Equal(t, len(test.txs), len(acc.nonceToTx.mapping)) + pool.handlePromoteRequest(req) if test.signal { @@ -1404,6 +1728,12 @@ func TestResetAccount(t *testing.T) { }) } + acc.nonceToTx.lock() + + assert.Equal(t, int(test.expected.accounts[addr1].enqueued+test.expected.accounts[addr1].promoted), len(acc.nonceToTx.mapping)) + + acc.nonceToTx.unlock() + assert.Equal(t, test.expected.slots, pool.gauge.read()) assert.Equal(t, // enqueued test.expected.accounts[addr1].enqueued, @@ -1424,11 +1754,15 @@ func TestPop(t *testing.T) { pool.SetSigner(&mockSigner{}) // send 1 tx and promote it - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + tx1 := newTx(addr1, 0, 1) + err = pool.addTx(local, tx1) + assert.NoError(t, err) + + acc := pool.accounts.get(addr1) + + assert.Equal(t, int(1), len(acc.nonceToTx.mapping)) + assert.Equal(t, tx1, acc.nonceToTx.get(tx1.Nonce)) + pool.handlePromoteRequest(<-pool.promoteReqCh) assert.Equal(t, uint64(1), pool.gauge.read()) @@ -1439,8 +1773,13 @@ func TestPop(t *testing.T) { tx := pool.Peek() pool.Pop(tx) + assert.Equal(t, tx1, tx) + assert.Equal(t, uint64(0), pool.gauge.read()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + + assert.Equal(t, int(0), len(acc.nonceToTx.mapping)) + assert.Equal(t, (*types.Transaction)(nil), acc.nonceToTx.get(tx1.Nonce)) } func TestDrop(t *testing.T) { @@ -1451,11 +1790,12 @@ func TestDrop(t *testing.T) { pool.SetSigner(&mockSigner{}) // send 1 tx and promote it - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + tx1 := newTx(addr1, 0, 1) + err = pool.addTx(local, tx1) + assert.NoError(t, err) + + acc := pool.accounts.get(addr1) + pool.handlePromoteRequest(<-pool.promoteReqCh) assert.Equal(t, uint64(1), pool.gauge.read()) @@ -1466,10 +1806,14 @@ func TestDrop(t *testing.T) { pool.Prepare(0) tx := pool.Peek() pool.Drop(tx) + assert.Equal(t, tx1, tx) assert.Equal(t, uint64(0), pool.gauge.read()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).getNonce()) assert.Equal(t, uint64(0), pool.accounts.get(addr1).promoted.length()) + + assert.Equal(t, int(0), len(acc.nonceToTx.mapping)) + assert.Equal(t, (*types.Transaction)(nil), acc.nonceToTx.get(tx1.Nonce)) } func TestDemote(t *testing.T) { @@ -1483,11 +1827,9 @@ func TestDemote(t *testing.T) { pool.SetSigner(&mockSigner{}) // send tx - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, newTx(addr1, 0, 1)) + assert.NoError(t, err) + pool.handlePromoteRequest(<-pool.promoteReqCh) assert.Equal(t, uint64(1), pool.gauge.read()) @@ -1517,11 +1859,9 @@ func TestDemote(t *testing.T) { pool.SetSigner(&mockSigner{}) // send tx - go func() { - err := pool.addTx(local, newTx(addr1, 0, 1)) - assert.NoError(t, err) - }() - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) + err = pool.addTx(local, newTx(addr1, 0, 1)) + assert.NoError(t, err) + pool.handlePromoteRequest(<-pool.promoteReqCh) assert.Equal(t, uint64(1), pool.gauge.read()) @@ -1557,16 +1897,11 @@ func Test_updateAccountSkipsCounts(t *testing.T) { ) { t.Helper() - go func() { - err := pool.addTx(local, tx) - assert.NoError(t, err) - }() + err := pool.addTx(local, tx) + assert.NoError(t, err) if shouldPromote { - go pool.handleEnqueueRequest(<-pool.enqueueReqCh) pool.handlePromoteRequest(<-pool.promoteReqCh) - } else { - pool.handleEnqueueRequest(<-pool.enqueueReqCh) } } @@ -2046,6 +2381,13 @@ func TestResetAccounts_Promoted(t *testing.T) { assert.Equal(t, // promoted expected.accounts[addr].promoted, pool.accounts.get(addr).promoted.length()) + + acc := pool.accounts.get(addr) + + acc.nonceToTx.lock() + + assert.Equal(t, int(expected.accounts[addr].enqueued+expected.accounts[addr].promoted), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() } } @@ -2166,6 +2508,14 @@ func TestResetAccounts_Enqueued(t *testing.T) { assert.Equal(t, expected.slots, pool.gauge.read()) commonAssert(expected.accounts, pool) + + for addr := range allTxs { + acc := pool.accounts.get(addr) + + acc.nonceToTx.lock() + assert.Equal(t, int(expected.accounts[addr].enqueued+expected.accounts[addr].promoted), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() + } }) t.Run("reset will not promote", func(t *testing.T) { @@ -2248,6 +2598,14 @@ func TestResetAccounts_Enqueued(t *testing.T) { assert.Equal(t, expected.slots, pool.gauge.read()) commonAssert(expected.accounts, pool) + + for addr := range allTxs { + acc := pool.accounts.get(addr) + + acc.nonceToTx.lock() + assert.Equal(t, int(expected.accounts[addr].enqueued+expected.accounts[addr].promoted), len(acc.nonceToTx.mapping)) + acc.nonceToTx.unlock() + } }) } @@ -2560,7 +2918,7 @@ func TestRecovery(t *testing.T) { expectedEnqueued := uint64(0) for addr, txs := range test.allTxs { // preset nonce so promotions can happen - acc := pool.createAccountOnce(addr) + acc := pool.getOrCreateAccount(addr) acc.setNonce(txs[0].tx.Nonce) expectedEnqueued += test.expected.accounts[addr].enqueued @@ -2908,11 +3266,11 @@ func TestBatchTx_SingleAccount(t *testing.T) { // subscribe to enqueue and promote events subscription := pool.eventManager.subscribe([]proto.EventType{proto.EventType_ENQUEUED, proto.EventType_PROMOTED}) - defer pool.eventManager.cancelSubscription(subscription.subscriptionID) txHashMap := map[types.Hash]struct{}{} // mutex for txHashMap mux := &sync.RWMutex{} + counter := uint64(0) // run max number of addTx concurrently for i := 0; i < int(defaultMaxAccountEnqueued); i++ { @@ -2928,15 +3286,24 @@ func TestBatchTx_SingleAccount(t *testing.T) { // submit transaction to pool assert.NoError(t, pool.addTx(local, tx)) + + atomic.AddUint64(&counter, 1) }(uint64(i)) } enqueuedCount := 0 promotedCount := 0 + ev := (*proto.TxPoolEvent)(nil) // wait for all the submitted transactions to be promoted for { - ev := <-subscription.subscriptionChannel + select { + case ev = <-subscription.subscriptionChannel: + case <-time.After(time.Second * 3): + t.Fatal(fmt.Sprintf("timeout. processed: %d/%d and %d/%d. Added: %d", + enqueuedCount, defaultMaxAccountEnqueued, promotedCount, defaultMaxAccountEnqueued, + atomic.LoadUint64(&counter))) + } // check if valid transaction hash mux.Lock() @@ -2960,4 +3327,148 @@ func TestBatchTx_SingleAccount(t *testing.T) { break } } + + acc := pool.accounts.get(addr) + + acc.nonceToTx.lock() + + assert.Equal(t, int(defaultMaxAccountEnqueued), len(acc.nonceToTx.mapping)) + + acc.nonceToTx.unlock() +} + +func TestAddTxsInOrder(t *testing.T) { + t.Parallel() + + const accountCount = 10 + + type container struct { + key *ecdsa.PrivateKey + addr types.Address + txs []*types.Transaction + } + + addrsTxs := [accountCount]container{} + + for i := 0; i < accountCount; i++ { + key, err := crypto.GenerateECDSAKey() + require.NoError(t, err) + + addrsTxs[i] = container{ + key: key, + addr: crypto.PubKeyToAddress(&key.PublicKey), + txs: make([]*types.Transaction, defaultMaxAccountEnqueued), + } + + for j := uint64(0); j < defaultMaxAccountEnqueued; j++ { + addrsTxs[i].txs[j] = newTx(addrsTxs[i].addr, j, uint64(1)) + } + } + + pool, err := newTestPool() + require.NoError(t, err) + + signer := crypto.NewEIP155Signer(100, true) + + pool.SetSigner(signer) + pool.Start() + + wg := new(sync.WaitGroup) + wg.Add(len(addrsTxs) * int(defaultMaxAccountEnqueued)) + + for _, atx := range addrsTxs { + for i, tx := range atx.txs { + go func(i int, tx *types.Transaction, key *ecdsa.PrivateKey) { + if i%2 == 1 { + time.Sleep(time.Millisecond * 50) + } + + signedTx, err := signer.SignTx(tx, key) + require.NoError(t, err) + + require.NoError(t, pool.addTx(local, signedTx)) + + wg.Done() + }(i, tx, atx.key) + } + } + + wg.Wait() + + time.Sleep(time.Second * 2) + + pool.Close() + + for _, addrtx := range addrsTxs { + acc := pool.accounts.get(addrtx.addr) + require.NotNil(t, acc) + + assert.Equal(t, uint64(0), acc.enqueued.length()) + assert.Equal(t, len(acc.nonceToTx.mapping), int(acc.promoted.length())) + } +} + +func BenchmarkAddTxTime(b *testing.B) { + b.Run("benchmark add one tx", func(b *testing.B) { + signer := crypto.NewEIP155Signer(100, true) + + key, err := crypto.GenerateECDSAKey() + if err != nil { + b.Fatal(err) + } + + signedTx, err := signer.SignTx(newTx(crypto.PubKeyToAddress(&key.PublicKey), 0, 1), key) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + pool, err := newTestPool() + if err != nil { + b.Fatal("fail to create pool", "err", err) + } + + pool.SetSigner(signer) + + err = pool.addTx(local, signedTx) + if err != nil { + b.Fatal(err) + } + } + }) + + b.Run("benchmark fill account", func(b *testing.B) { + signer := crypto.NewEIP155Signer(100, true) + + key, err := crypto.GenerateECDSAKey() + if err != nil { + b.Fatal(err) + } + + addr := crypto.PubKeyToAddress(&key.PublicKey) + txs := make([]*types.Transaction, defaultMaxAccountEnqueued) + + for i := range txs { + txs[i], err = signer.SignTx(newTx(addr, uint64(i), uint64(1)), key) + if err != nil { + b.Fatal(err) + } + } + + for i := 0; i < b.N; i++ { + pool, err := newTestPool() + if err != nil { + b.Fatal("fail to create pool", "err", err) + } + + pool.SetSigner(signer) + + for i := 0; i < len(txs); i++ { + err = pool.addTx(local, txs[i]) + if err != nil { + b.Fatal(err) + } + } + } + }) }