diff --git a/tools/traffic/workers/blob_reader_test.go b/tools/traffic/workers/blob_reader_test.go index c93b7e03b..13554cd9b 100644 --- a/tools/traffic/workers/blob_reader_test.go +++ b/tools/traffic/workers/blob_reader_test.go @@ -3,12 +3,14 @@ package workers import ( "context" "crypto/md5" + "github.com/Layr-Labs/eigenda/api/clients" "github.com/Layr-Labs/eigenda/common" tu "github.com/Layr-Labs/eigenda/common/testutils" "github.com/Layr-Labs/eigenda/tools/traffic/config" "github.com/Layr-Labs/eigenda/tools/traffic/metrics" "github.com/Layr-Labs/eigenda/tools/traffic/table" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "golang.org/x/exp/rand" "sync" "testing" @@ -28,9 +30,15 @@ func TestBlobReader(t *testing.T) { readerMetrics := metrics.NewMockMetrics() - lock := sync.Mutex{} - chainClient := newMockChainClient(&lock) - retrievalClient := newMockRetrievalClient(t, &lock) + chainClient := &mockChainClient{} + chainClient.mock.On( + "FetchBatchHeader", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything) + retrievalClient := &mockRetrievalClient{} blobReader := NewBlobReader( &ctx, @@ -82,7 +90,16 @@ func TestBlobReader(t *testing.T) { readPermits) assert.Nil(t, err) - retrievalClient.AddBlob(blobMetadata, blobData) + // Simplify tracking by hijacking the BlobHeaderLength field to store the blob index, + // which is used as a unique identifier within this test. + chunks := &clients.BlobChunks{BlobHeaderLength: blobMetadata.BlobIndex} + retrievalClient.mock.On("RetrieveBlobChunks", + blobMetadata.BatchHeaderHash, + uint32(blobMetadata.BlobIndex), + mock.Anything, + mock.Anything, + mock.Anything).Return(chunks, nil) + retrievalClient.mock.On("CombineChunks", chunks).Return(blobData, nil) blobTable.Add(blobMetadata) } @@ -92,11 +109,9 @@ func TestBlobReader(t *testing.T) { for i := uint(0); i < expectedTotalReads; i++ { blobReader.randomRead() - tu.AssertEventuallyTrue(t, func() bool { - return retrievalClient.RetrieveBlobChunksCount == i+1 && - retrievalClient.CombineChunksCount == i+1 && - chainClient.Count == i+1 - }, time.Second) + chainClient.mock.AssertNumberOfCalls(t, "FetchBatchHeader", int(i+1)) + retrievalClient.mock.AssertNumberOfCalls(t, "RetrieveBlobChunks", int(i+1)) + retrievalClient.mock.AssertNumberOfCalls(t, "CombineChunks", int(i+1)) remainingPermits := uint(0) for _, metadata := range blobTable.GetAll() { @@ -104,24 +119,16 @@ func TestBlobReader(t *testing.T) { } assert.Equal(t, remainingPermits, expectedTotalReads-i-1) - tu.AssertEventuallyTrue(t, func() bool { - return uint(readerMetrics.GetCount("read_success")) == i+1 && - uint(readerMetrics.GetCount("fetch_batch_header_success")) == i+1 && - uint(readerMetrics.GetCount("recombination_success")) == i+1 - }, time.Second) + assert.Equal(t, i+1, uint(readerMetrics.GetCount("read_success"))) + assert.Equal(t, i+1, uint(readerMetrics.GetCount("fetch_batch_header_success"))) + assert.Equal(t, i+1, uint(readerMetrics.GetCount("recombination_success"))) } expectedInvalidBlobs := uint(invalidBlobCount * readPermits) expectedValidBlobs := expectedTotalReads - expectedInvalidBlobs - tu.AssertEventuallyEquals(t, expectedValidBlobs, - func() any { - return uint(readerMetrics.GetCount("valid_blob")) - }, time.Second) - tu.AssertEventuallyEquals(t, expectedInvalidBlobs, - func() any { - return uint(readerMetrics.GetCount("invalid_blob")) - }, time.Second) + assert.Equal(t, expectedValidBlobs, uint(readerMetrics.GetCount("valid_blob"))) + assert.Equal(t, expectedInvalidBlobs, uint(readerMetrics.GetCount("invalid_blob"))) assert.Equal(t, uint(0), uint(readerMetrics.GetGaugeValue("required_read_pool_size"))) assert.Equal(t, uint(0), uint(readerMetrics.GetGaugeValue("optional_read_pool_size"))) @@ -137,7 +144,4 @@ func TestBlobReader(t *testing.T) { assert.Equal(t, expectedInvalidBlobs, uint(readerMetrics.GetCount("invalid_blob"))) cancel() - tu.ExecuteWithTimeout(func() { - waitGroup.Wait() - }, time.Second) } diff --git a/tools/traffic/workers/mock_chain_client.go b/tools/traffic/workers/mock_chain_client.go index 05830c324..4733897f5 100644 --- a/tools/traffic/workers/mock_chain_client.go +++ b/tools/traffic/workers/mock_chain_client.go @@ -4,20 +4,12 @@ import ( "context" binding "github.com/Layr-Labs/eigenda/contracts/bindings/EigenDAServiceManager" "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/mock" "math/big" - "sync" ) type mockChainClient struct { - lock *sync.Mutex - Count uint -} - -func newMockChainClient(lock *sync.Mutex) *mockChainClient { - return &mockChainClient{ - lock: lock, - } - + mock mock.Mock } func (m *mockChainClient) FetchBatchHeader( @@ -27,10 +19,6 @@ func (m *mockChainClient) FetchBatchHeader( fromBlock *big.Int, toBlock *big.Int) (*binding.IEigenDAServiceManagerBatchHeader, error) { - m.lock.Lock() - defer m.lock.Unlock() - - m.Count++ - + m.mock.Called(serviceManagerAddress, batchHeaderHash, fromBlock, toBlock) return &binding.IEigenDAServiceManagerBatchHeader{}, nil } diff --git a/tools/traffic/workers/mock_retrieval_client.go b/tools/traffic/workers/mock_retrieval_client.go index bd2ce6e2c..4afb183ee 100644 --- a/tools/traffic/workers/mock_retrieval_client.go +++ b/tools/traffic/workers/mock_retrieval_client.go @@ -4,59 +4,12 @@ import ( "context" "github.com/Layr-Labs/eigenda/api/clients" "github.com/Layr-Labs/eigenda/core" - "github.com/Layr-Labs/eigenda/tools/traffic/table" - "github.com/stretchr/testify/assert" - "sync" - "testing" + "github.com/stretchr/testify/mock" ) // mockRetrievalClient is a mock implementation of the clients.RetrievalClient interface. type mockRetrievalClient struct { - t *testing.T - - lock *sync.Mutex - - // Since it isn't being used during this test, blob index field is used - // as a convenient unique identifier for the blob. - - // A map from blob index to the blob data. - blobData map[uint]*[]byte - - // A map from blob index to the blob metadata. - blobMetadata map[uint]*table.BlobMetadata - - // A map from blob index to the blob chunks corresponding to that blob. - blobChunks map[uint]*clients.BlobChunks - - RetrieveBlobChunksCount uint - CombineChunksCount uint -} - -func newMockRetrievalClient(t *testing.T, lock *sync.Mutex) *mockRetrievalClient { - return &mockRetrievalClient{ - t: t, - lock: lock, - blobData: make(map[uint]*[]byte), - blobMetadata: make(map[uint]*table.BlobMetadata), - blobChunks: make(map[uint]*clients.BlobChunks), - } -} - -// AddBlob adds a blob to the mock retrieval client. Once added, the retrieval client will act as if -// it is able to retrieve the blob. -func (m *mockRetrievalClient) AddBlob(metadata *table.BlobMetadata, data []byte) { - m.lock.Lock() - defer m.lock.Unlock() - - m.blobData[metadata.BlobIndex] = &data - m.blobMetadata[metadata.BlobIndex] = metadata - - // The blob index is used in this test as a convenient unique identifier for the blob. - - m.blobChunks[metadata.BlobIndex] = &clients.BlobChunks{ - // Since it isn't otherwise used in this field, we can use it to store the unique identifier for the blob. - BlobHeaderLength: metadata.BlobIndex, - } + mock mock.Mock } func (m *mockRetrievalClient) RetrieveBlob( @@ -66,7 +19,8 @@ func (m *mockRetrievalClient) RetrieveBlob( referenceBlockNumber uint, batchRoot [32]byte, quorumID core.QuorumID) ([]byte, error) { - panic("this method should not be called during this test") + args := m.mock.Called(batchHeaderHash, blobIndex, referenceBlockNumber, batchRoot, quorumID) + return args.Get(0).([]byte), args.Error(1) } func (m *mockRetrievalClient) RetrieveBlobChunks( @@ -77,31 +31,11 @@ func (m *mockRetrievalClient) RetrieveBlobChunks( batchRoot [32]byte, quorumID core.QuorumID) (*clients.BlobChunks, error) { - m.lock.Lock() - defer m.lock.Unlock() - - m.RetrieveBlobChunksCount++ - - chunks, ok := m.blobChunks[uint(blobIndex)] - assert.True(m.t, ok, "blob not found") - - metadata := m.blobMetadata[uint(blobIndex)] - assert.Equal(m.t, metadata.BlobIndex, uint(blobIndex)) - assert.Equal(m.t, metadata.BatchHeaderHash[:32], batchHeaderHash[:32]) - - return chunks, nil - + args := m.mock.Called(batchHeaderHash, blobIndex, referenceBlockNumber, batchRoot, quorumID) + return args.Get(0).(*clients.BlobChunks), args.Error(1) } func (m *mockRetrievalClient) CombineChunks(chunks *clients.BlobChunks) ([]byte, error) { - m.lock.Lock() - defer m.lock.Unlock() - - m.CombineChunksCount++ - - blobIndex := chunks.BlobHeaderLength - data, ok := m.blobData[blobIndex] - assert.True(m.t, ok, "blob not found") - - return *data, nil + args := m.mock.Called(chunks) + return args.Get(0).([]byte), args.Error(1) }