diff --git a/jsonrpc/throttling_test.go b/jsonrpc/throttling_test.go index e6b65b2f06..75e4cfef48 100644 --- a/jsonrpc/throttling_test.go +++ b/jsonrpc/throttling_test.go @@ -3,6 +3,7 @@ package jsonrpc import ( "context" "sync" + "sync/atomic" "testing" "time" @@ -13,26 +14,28 @@ import ( func TestThrottling(t *testing.T) { t.Parallel() + const maxRequests = 5 + + var requests atomic.Int32 var attempts = []struct { duration time.Duration delay time.Duration - isError bool }{ // 1st 5 starts immediately, no error, order of execution is irrelevant - {200 * time.Millisecond, 0, false}, - {1000 * time.Millisecond, 0, false}, - {1000 * time.Millisecond, 0, false}, - {1000 * time.Millisecond, 0, false}, - {1000 * time.Millisecond, 0, false}, + {200 * time.Millisecond, 0}, + {1000 * time.Millisecond, 0}, + {1000 * time.Millisecond, 0}, + {1000 * time.Millisecond, 0}, + {1000 * time.Millisecond, 0}, // 6th & 8th attempt should fail, from now on order of execution is relevant, hence delay > 0 - {20 * time.Millisecond, 100 * time.Millisecond, true}, - {200 * time.Millisecond, 300 * time.Millisecond, false}, - {20 * time.Millisecond, 400 * time.Millisecond, true}, - {200 * time.Millisecond, 600 * time.Millisecond, false}, + {20 * time.Millisecond, 100 * time.Millisecond}, + {200 * time.Millisecond, 300 * time.Millisecond}, + {20 * time.Millisecond, 400 * time.Millisecond}, + {200 * time.Millisecond, 600 * time.Millisecond}, } - th := NewThrottling(5, 20*time.Millisecond) + th := NewThrottling(maxRequests, 20*time.Millisecond) sfn := func(value int, sleep time.Duration) func() (interface{}, error) { return func() (interface{}, error) { time.Sleep(sleep) @@ -42,15 +45,21 @@ func TestThrottling(t *testing.T) { } wg := sync.WaitGroup{} - wg.Add(9) + wg.Add(len(attempts)) for i := 0; i < len(attempts); i++ { - go func(value int, duration time.Duration, delay time.Duration, isError bool) { + go func(value int, duration time.Duration, delay time.Duration) { defer wg.Done() time.Sleep(delay) + var isError bool + if requests.Add(1) > maxRequests { + isError = true + } res, err := th.AttemptRequest(context.Background(), sfn(value, duration)) + requests.Add(-1) + if isError { require.ErrorIs(t, err, errRequestLimitExceeded) assert.Nil(t, res) @@ -58,7 +67,7 @@ func TestThrottling(t *testing.T) { require.NoError(t, err) assert.Equal(t, value, res.(int)) } - }(i, attempts[i].duration, attempts[i].delay, attempts[i].isError) + }(i, attempts[i].duration, attempts[i].delay) } wg.Wait()