Skip to content

Commit

Permalink
Fixed unit tests.
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Littley <cody@eigenlabs.org>
  • Loading branch information
cody-littley committed Sep 5, 2024
1 parent 988f1b3 commit 783a769
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 114 deletions.
54 changes: 29 additions & 25 deletions tools/traffic/workers/blob_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package workers
import (
"context"
"crypto/md5"
"github.com/Layr-Labs/eigenda/api/clients"
"github.com/Layr-Labs/eigenda/common"
tu "github.com/Layr-Labs/eigenda/common/testutils"
"github.com/Layr-Labs/eigenda/tools/traffic/config"
"github.com/Layr-Labs/eigenda/tools/traffic/metrics"
"github.com/Layr-Labs/eigenda/tools/traffic/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/exp/rand"
"sync"
"testing"
Expand All @@ -28,9 +30,15 @@ func TestBlobReader(t *testing.T) {

readerMetrics := metrics.NewMockMetrics()

lock := sync.Mutex{}
chainClient := newMockChainClient(&lock)
retrievalClient := newMockRetrievalClient(t, &lock)
chainClient := &mockChainClient{}
chainClient.mock.On(
"FetchBatchHeader",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything)
retrievalClient := &mockRetrievalClient{}

blobReader := NewBlobReader(
&ctx,
Expand Down Expand Up @@ -82,7 +90,16 @@ func TestBlobReader(t *testing.T) {
readPermits)
assert.Nil(t, err)

retrievalClient.AddBlob(blobMetadata, blobData)
// Simplify tracking by hijacking the BlobHeaderLength field to store the blob index,
// which is used as a unique identifier within this test.
chunks := &clients.BlobChunks{BlobHeaderLength: blobMetadata.BlobIndex}
retrievalClient.mock.On("RetrieveBlobChunks",
blobMetadata.BatchHeaderHash,
uint32(blobMetadata.BlobIndex),
mock.Anything,
mock.Anything,
mock.Anything).Return(chunks, nil)
retrievalClient.mock.On("CombineChunks", chunks).Return(blobData, nil)

blobTable.Add(blobMetadata)
}
Expand All @@ -92,36 +109,26 @@ func TestBlobReader(t *testing.T) {
for i := uint(0); i < expectedTotalReads; i++ {
blobReader.randomRead()

tu.AssertEventuallyTrue(t, func() bool {
return retrievalClient.RetrieveBlobChunksCount == i+1 &&
retrievalClient.CombineChunksCount == i+1 &&
chainClient.Count == i+1
}, time.Second)
chainClient.mock.AssertNumberOfCalls(t, "FetchBatchHeader", int(i+1))
retrievalClient.mock.AssertNumberOfCalls(t, "RetrieveBlobChunks", int(i+1))
retrievalClient.mock.AssertNumberOfCalls(t, "CombineChunks", int(i+1))

remainingPermits := uint(0)
for _, metadata := range blobTable.GetAll() {
remainingPermits += uint(metadata.RemainingReadPermits)
}
assert.Equal(t, remainingPermits, expectedTotalReads-i-1)

tu.AssertEventuallyTrue(t, func() bool {
return uint(readerMetrics.GetCount("read_success")) == i+1 &&
uint(readerMetrics.GetCount("fetch_batch_header_success")) == i+1 &&
uint(readerMetrics.GetCount("recombination_success")) == i+1
}, time.Second)
assert.Equal(t, i+1, uint(readerMetrics.GetCount("read_success")))
assert.Equal(t, i+1, uint(readerMetrics.GetCount("fetch_batch_header_success")))
assert.Equal(t, i+1, uint(readerMetrics.GetCount("recombination_success")))
}

expectedInvalidBlobs := uint(invalidBlobCount * readPermits)
expectedValidBlobs := expectedTotalReads - expectedInvalidBlobs
tu.AssertEventuallyEquals(t, expectedValidBlobs,
func() any {
return uint(readerMetrics.GetCount("valid_blob"))
}, time.Second)
tu.AssertEventuallyEquals(t, expectedInvalidBlobs,
func() any {
return uint(readerMetrics.GetCount("invalid_blob"))
}, time.Second)

assert.Equal(t, expectedValidBlobs, uint(readerMetrics.GetCount("valid_blob")))
assert.Equal(t, expectedInvalidBlobs, uint(readerMetrics.GetCount("invalid_blob")))
assert.Equal(t, uint(0), uint(readerMetrics.GetGaugeValue("required_read_pool_size")))
assert.Equal(t, uint(0), uint(readerMetrics.GetGaugeValue("optional_read_pool_size")))

Expand All @@ -137,7 +144,4 @@ func TestBlobReader(t *testing.T) {
assert.Equal(t, expectedInvalidBlobs, uint(readerMetrics.GetCount("invalid_blob")))

cancel()
tu.ExecuteWithTimeout(func() {
waitGroup.Wait()
}, time.Second)
}
18 changes: 3 additions & 15 deletions tools/traffic/workers/mock_chain_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@ import (
"context"
binding "github.com/Layr-Labs/eigenda/contracts/bindings/EigenDAServiceManager"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/mock"
"math/big"
"sync"
)

type mockChainClient struct {
lock *sync.Mutex
Count uint
}

func newMockChainClient(lock *sync.Mutex) *mockChainClient {
return &mockChainClient{
lock: lock,
}

mock mock.Mock
}

func (m *mockChainClient) FetchBatchHeader(
Expand All @@ -27,10 +19,6 @@ func (m *mockChainClient) FetchBatchHeader(
fromBlock *big.Int,
toBlock *big.Int) (*binding.IEigenDAServiceManagerBatchHeader, error) {

m.lock.Lock()
defer m.lock.Unlock()

m.Count++

m.mock.Called(serviceManagerAddress, batchHeaderHash, fromBlock, toBlock)
return &binding.IEigenDAServiceManagerBatchHeader{}, nil
}
82 changes: 8 additions & 74 deletions tools/traffic/workers/mock_retrieval_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,12 @@ import (
"context"
"github.com/Layr-Labs/eigenda/api/clients"
"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/tools/traffic/table"
"github.com/stretchr/testify/assert"
"sync"
"testing"
"github.com/stretchr/testify/mock"
)

// mockRetrievalClient is a mock implementation of the clients.RetrievalClient interface.
type mockRetrievalClient struct {
t *testing.T

lock *sync.Mutex

// Since it isn't being used during this test, blob index field is used
// as a convenient unique identifier for the blob.

// A map from blob index to the blob data.
blobData map[uint]*[]byte

// A map from blob index to the blob metadata.
blobMetadata map[uint]*table.BlobMetadata

// A map from blob index to the blob chunks corresponding to that blob.
blobChunks map[uint]*clients.BlobChunks

RetrieveBlobChunksCount uint
CombineChunksCount uint
}

func newMockRetrievalClient(t *testing.T, lock *sync.Mutex) *mockRetrievalClient {
return &mockRetrievalClient{
t: t,
lock: lock,
blobData: make(map[uint]*[]byte),
blobMetadata: make(map[uint]*table.BlobMetadata),
blobChunks: make(map[uint]*clients.BlobChunks),
}
}

// AddBlob adds a blob to the mock retrieval client. Once added, the retrieval client will act as if
// it is able to retrieve the blob.
func (m *mockRetrievalClient) AddBlob(metadata *table.BlobMetadata, data []byte) {
m.lock.Lock()
defer m.lock.Unlock()

m.blobData[metadata.BlobIndex] = &data
m.blobMetadata[metadata.BlobIndex] = metadata

// The blob index is used in this test as a convenient unique identifier for the blob.

m.blobChunks[metadata.BlobIndex] = &clients.BlobChunks{
// Since it isn't otherwise used in this field, we can use it to store the unique identifier for the blob.
BlobHeaderLength: metadata.BlobIndex,
}
mock mock.Mock
}

func (m *mockRetrievalClient) RetrieveBlob(
Expand All @@ -66,7 +19,8 @@ func (m *mockRetrievalClient) RetrieveBlob(
referenceBlockNumber uint,
batchRoot [32]byte,
quorumID core.QuorumID) ([]byte, error) {
panic("this method should not be called during this test")
args := m.mock.Called(batchHeaderHash, blobIndex, referenceBlockNumber, batchRoot, quorumID)
return args.Get(0).([]byte), args.Error(1)
}

func (m *mockRetrievalClient) RetrieveBlobChunks(
Expand All @@ -77,31 +31,11 @@ func (m *mockRetrievalClient) RetrieveBlobChunks(
batchRoot [32]byte,
quorumID core.QuorumID) (*clients.BlobChunks, error) {

m.lock.Lock()
defer m.lock.Unlock()

m.RetrieveBlobChunksCount++

chunks, ok := m.blobChunks[uint(blobIndex)]
assert.True(m.t, ok, "blob not found")

metadata := m.blobMetadata[uint(blobIndex)]
assert.Equal(m.t, metadata.BlobIndex, uint(blobIndex))
assert.Equal(m.t, metadata.BatchHeaderHash[:32], batchHeaderHash[:32])

return chunks, nil

args := m.mock.Called(batchHeaderHash, blobIndex, referenceBlockNumber, batchRoot, quorumID)
return args.Get(0).(*clients.BlobChunks), args.Error(1)
}

func (m *mockRetrievalClient) CombineChunks(chunks *clients.BlobChunks) ([]byte, error) {
m.lock.Lock()
defer m.lock.Unlock()

m.CombineChunksCount++

blobIndex := chunks.BlobHeaderLength
data, ok := m.blobData[blobIndex]
assert.True(m.t, ok, "blob not found")

return *data, nil
args := m.mock.Called(chunks)
return args.Get(0).([]byte), args.Error(1)
}

0 comments on commit 783a769

Please sign in to comment.