diff --git a/retry.go b/retry.go index e14d9f3..8e40cf2 100644 --- a/retry.go +++ b/retry.go @@ -97,12 +97,15 @@ func Do(retryableFunc RetryableFunc, opts ...Option) error { // Setting attempts to 0 means we'll retry until we succeed if config.attempts == 0 { for err := retryableFunc(); err != nil; err = retryableFunc() { - n++ - if !IsRecoverable(err) { return err } + if !config.retryIf(err) { + return err + } + + n++ config.onRetry(n, err) select { case <-config.timer.After(delay(config, n, err)): diff --git a/retry_test.go b/retry_test.go index 6f7844a..5b03462 100644 --- a/retry_test.go +++ b/retry_test.go @@ -72,7 +72,29 @@ func TestRetryIf(t *testing.T) { assert.Len(t, err, 3) assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format") assert.Equal(t, uint(2), retryCount, "right count of retry") +} + +func TestRetryIf_ZeroAttempts(t *testing.T) { + var retryCount uint + err := Do( + func() error { + if retryCount >= 2 { + return errors.New("special") + } else { + return errors.New("test") + } + }, + OnRetry(func(n uint, err error) { retryCount++ }), + RetryIf(func(err error) bool { + return err.Error() != "special" + }), + Delay(time.Nanosecond), + Attempts(0), + ) + assert.Error(t, err) + assert.Equal(t, "special", err.Error(), "retry error format") + assert.Equal(t, uint(2), retryCount, "right count of retry") } func TestZeroAttemptsWithError(t *testing.T) {