From 2cce8d8c365e7553e11a36d26bdda8aaeb75b5f7 Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Wed, 3 Jul 2024 00:03:13 +0200 Subject: [PATCH 1/5] chore: fix test timeout helper using os.Exit(1) kills everything, tests statuses are not always displayed --- lo_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lo_test.go b/lo_test.go index 26db4c25..bc92dc35 100644 --- a/lo_test.go +++ b/lo_test.go @@ -1,7 +1,6 @@ package lo import ( - "os" "testing" "time" ) @@ -18,7 +17,7 @@ func testWithTimeout(t *testing.T, timeout time.Duration) { case <-testFinished: case <-time.After(timeout): t.Errorf("test timed out after %s", timeout) - os.Exit(1) + t.FailNow() } }() } From e26f1b43796eee77974b4b9e8d20d7636c3bfab2 Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Wed, 3 Jul 2024 00:07:14 +0200 Subject: [PATCH 2/5] chore: refactor WaitFor unit tests zero-code changes --- concurrency_test.go | 109 ++++++++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 35 deletions(-) diff --git a/concurrency_test.go b/concurrency_test.go index 0ee70dbf..ee339a90 100644 --- a/concurrency_test.go +++ b/concurrency_test.go @@ -215,44 +215,83 @@ func TestAsyncX(t *testing.T) { func TestWaitFor(t *testing.T) { t.Parallel() - testWithTimeout(t, 100*time.Millisecond) - is := assert.New(t) - alwaysTrue := func(i int) bool { return true } - alwaysFalse := func(i int) bool { return false } + testTimeout := 100 * time.Millisecond + longTimeout := 2 * testTimeout + shortTimeout := 4 * time.Millisecond - iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond) - is.Equal(1, iter) - is.Equal(time.Duration(0), duration) - is.True(ok) - iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 4*time.Millisecond) - is.Equal(3, iter) - is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) - is.False(ok) + t.Run("exist condition works", func(t *testing.T) { + t.Parallel() - laterTrue := func(i int) bool { - return i >= 5 - } + testWithTimeout(t, testTimeout) + is := assert.New(t) - iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, time.Millisecond) - is.Equal(6, iter) - is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond)) - is.True(ok) - iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, 5*time.Millisecond) - is.Equal(2, iter) - is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) - is.False(ok) - - counter := 0 - - alwaysFalse = func(i int) bool { - is.Equal(counter, i) - counter++ - return false - } + laterTrue := func(i int) bool { + return i >= 5 + } + + iter, duration, ok := WaitFor(laterTrue, longTimeout, time.Millisecond) + is.Equal(6, iter, "unexpected iteration count") + is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond)) + is.True(ok) + }) + + t.Run("counter is incremented", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + counter := 0 + alwaysFalse := func(i int) bool { + is.Equal(counter, i) + counter++ + return false + } + + iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 1050*time.Microsecond) + is.Equal(counter, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) - iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 1050*time.Microsecond) - is.Equal(10, iter) - is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) - is.False(ok) + alwaysTrue := func(_ int) bool { return true } + alwaysFalse := func(_ int) bool { return false } + + t.Run("short timeout works", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond) + is.Equal(1, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) + + t.Run("timeout works", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + shortTimeout := 4 * time.Millisecond + iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond) + is.Equal(1, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) + + t.Run("exist on first condition", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond) + is.Equal(1, iter, "unexpected iteration count") + is.Zero(duration) + is.True(ok) + }) } From 071a746f6331fe338a895f47b98173e9445292f8 Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Wed, 3 Jul 2024 00:15:15 +0200 Subject: [PATCH 3/5] fix: WaitFor on first condition duration must be non-zero if first conditions is true --- README.md | 4 ++-- concurrency.go | 6 +----- concurrency_test.go | 6 +++--- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index efa83fb2..bdd73086 100644 --- a/README.md +++ b/README.md @@ -3068,9 +3068,9 @@ laterTrue := func(i int) bool { return i > 5 } -iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond) +iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond) // 1 -// 0ms +// 1ms // true iterations, duration, ok := lo.WaitFor(alwaysFalse, 10*time.Millisecond, time.Millisecond) diff --git a/concurrency.go b/concurrency.go index 95580661..d907a74a 100644 --- a/concurrency.go +++ b/concurrency.go @@ -99,10 +99,6 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A, // WaitFor runs periodically until a condition is validated. func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) { - if condition(0) { - return 1, 0, true - } - start := time.Now() timer := time.NewTimer(maxDuration) @@ -113,7 +109,7 @@ func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Du ticker.Stop() }() - i := 1 + i := 0 for { select { diff --git a/concurrency_test.go b/concurrency_test.go index ee339a90..8a8c9e5d 100644 --- a/concurrency_test.go +++ b/concurrency_test.go @@ -265,7 +265,7 @@ func TestWaitFor(t *testing.T) { is := assert.New(t) iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond) - is.Equal(1, iter, "unexpected iteration count") + is.Equal(0, iter, "unexpected iteration count") is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) is.False(ok) }) @@ -278,7 +278,7 @@ func TestWaitFor(t *testing.T) { shortTimeout := 4 * time.Millisecond iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond) - is.Equal(1, iter, "unexpected iteration count") + is.Equal(0, iter, "unexpected iteration count") is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) is.False(ok) }) @@ -291,7 +291,7 @@ func TestWaitFor(t *testing.T) { iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond) is.Equal(1, iter, "unexpected iteration count") - is.Zero(duration) + is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond)) is.True(ok) }) } From e1d8c98673fa352825f3fb5d9fd4de31e54053fc Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Sun, 30 Jun 2024 01:42:12 +0200 Subject: [PATCH 4/5] feat: add WaitForWithContext --- README.md | 44 +++++++++++++++++ concurrency.go | 29 +++++++---- concurrency_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index bdd73086..a4cd9478 100644 --- a/README.md +++ b/README.md @@ -276,6 +276,7 @@ Concurrency helpers: - [Async](#async) - [Transaction](#transaction) - [WaitFor](#waitfor) +- [WaitForWithContext](#waitforwithcontext) Error handling: @@ -3089,6 +3090,49 @@ iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Mi // false ``` + +### WaitForWithContext + +Runs periodically until a condition is validated or context is invalid. + +The condition receives also the context, so it can invalidate the process in the condition checker + +```go +ctx := context.Background() + +alwaysTrue := func(_ context.Context, i int) bool { return true } +alwaysFalse := func(_ context.Context, i int) bool { return false } +laterTrue := func(_ context.Context, i int) bool { + return i >= 5 +} + +iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond) +// 1 +// 1ms +// true + +iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond) +// 10 +// 10ms +// false + +iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond) +// 5 +// 5ms +// true + +iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond) +// 2 +// 10ms +// false + +expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) +iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond) +// 5 +// 5.1ms +// false +``` + ### Validate Helper function that creates an error when a condition is not met. diff --git a/concurrency.go b/concurrency.go index d907a74a..dc16f8df 100644 --- a/concurrency.go +++ b/concurrency.go @@ -1,6 +1,7 @@ package lo import ( + "context" "sync" "time" ) @@ -99,28 +100,38 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A, // WaitFor runs periodically until a condition is validated. func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) { + conditionWithContext := func(_ context.Context, i int) bool { + return condition(i) + } + return WaitForWithContext(context.Background(), conditionWithContext, maxDuration, tick) +} + +// WaitForWithContext runs periodically until a condition is validated or context is canceled. +func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) { start := time.Now() - timer := time.NewTimer(maxDuration) + i := 0 + if ctx.Err() != nil { + return i, time.Since(start), false + } + + ctx, cleanCtx := context.WithTimeout(ctx, maxDuration) ticker := time.NewTicker(tick) defer func() { - timer.Stop() + cleanCtx() ticker.Stop() }() - i := 0 - for { select { - case <-timer.C: + case <-ctx.Done(): return i, time.Since(start), false case <-ticker.C: - if condition(i) { - return i + 1, time.Since(start), true - } - i++ + if condition(ctx, i-1) { + return i, time.Since(start), true + } } } } diff --git a/concurrency_test.go b/concurrency_test.go index 8a8c9e5d..61f3dd61 100644 --- a/concurrency_test.go +++ b/concurrency_test.go @@ -1,6 +1,7 @@ package lo import ( + "context" "sync" "testing" "time" @@ -295,3 +296,118 @@ func TestWaitFor(t *testing.T) { is.True(ok) }) } + +func TestWaitForWithContext(t *testing.T) { + t.Parallel() + + testTimeout := 100 * time.Millisecond + longTimeout := 2 * testTimeout + shortTimeout := 4 * time.Millisecond + + t.Run("exist condition works", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + laterTrue := func(_ context.Context, i int) bool { + return i >= 5 + } + + iter, duration, ok := WaitForWithContext(context.Background(), laterTrue, longTimeout, time.Millisecond) + is.Equal(6, iter, "unexpected iteration count") + is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond)) + is.True(ok) + }) + + t.Run("counter is incremented", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + counter := 0 + alwaysFalse := func(_ context.Context, i int) bool { + is.Equal(counter, i) + counter++ + return false + } + + iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 1050*time.Microsecond) + is.Equal(counter, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) + + alwaysTrue := func(_ context.Context, _ int) bool { return true } + alwaysFalse := func(_ context.Context, _ int) bool { return false } + + t.Run("short timeout works", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond) + is.Equal(0, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) + + t.Run("timeout works", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + shortTimeout := 4 * time.Millisecond + iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond) + is.Equal(0, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) + + t.Run("exist on first condition", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + iter, duration, ok := WaitForWithContext(context.Background(), alwaysTrue, 10*time.Millisecond, time.Millisecond) + is.Equal(1, iter, "unexpected iteration count") + is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond)) + is.True(ok) + }) + + t.Run("context cancellation stops everything", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + expiringCtx, clean := context.WithTimeout(context.Background(), 8*time.Millisecond) + t.Cleanup(func() { + clean() + }) + + iter, duration, ok := WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, 3*time.Millisecond) + is.Equal(2, iter, "unexpected iteration count") + is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond)) + is.False(ok) + }) + + t.Run("canceled context stops everything", func(t *testing.T) { + t.Parallel() + + testWithTimeout(t, testTimeout) + is := assert.New(t) + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + + iter, duration, ok := WaitForWithContext(canceledCtx, alwaysFalse, 100*time.Millisecond, 1050*time.Microsecond) + is.Equal(0, iter, "unexpected iteration count") + is.InEpsilon(1*time.Millisecond, duration, float64(5*time.Microsecond)) + is.False(ok) + }) +} From 74eb4206786d25ec1df911b29394d60ecbeec287 Mon Sep 17 00:00:00 2001 From: ccoVeille <3875889+ccoVeille@users.noreply.github.com> Date: Wed, 3 Jul 2024 00:35:19 +0200 Subject: [PATCH 5/5] chore: provide meaningful returned values for WaitFor and WaitForWithContext --- concurrency.go | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/concurrency.go b/concurrency.go index dc16f8df..a2ebbce2 100644 --- a/concurrency.go +++ b/concurrency.go @@ -99,24 +99,23 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A, } // WaitFor runs periodically until a condition is validated. -func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) { - conditionWithContext := func(_ context.Context, i int) bool { - return condition(i) +func WaitFor(condition func(i int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) { + conditionWithContext := func(_ context.Context, currentIteration int) bool { + return condition(currentIteration) } - return WaitForWithContext(context.Background(), conditionWithContext, maxDuration, tick) + return WaitForWithContext(context.Background(), conditionWithContext, timeout, heartbeatDelay) } // WaitForWithContext runs periodically until a condition is validated or context is canceled. -func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) { +func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, currentIteration int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) { start := time.Now() - i := 0 if ctx.Err() != nil { - return i, time.Since(start), false + return totalIterations, time.Since(start), false } - ctx, cleanCtx := context.WithTimeout(ctx, maxDuration) - ticker := time.NewTicker(tick) + ctx, cleanCtx := context.WithTimeout(ctx, timeout) + ticker := time.NewTicker(heartbeatDelay) defer func() { cleanCtx() @@ -126,11 +125,11 @@ func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, for { select { case <-ctx.Done(): - return i, time.Since(start), false + return totalIterations, time.Since(start), false case <-ticker.C: - i++ - if condition(ctx, i-1) { - return i, time.Since(start), true + totalIterations++ + if condition(ctx, totalIterations-1) { + return totalIterations, time.Since(start), true } } }