Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add WaitForWithContext #480

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ Concurrency helpers:
- [Async](#async)
- [Transaction](#transaction)
- [WaitFor](#waitfor)
- [WaitForWithContext](#waitforwithcontext)

Error handling:

Expand Down Expand Up @@ -3068,9 +3069,9 @@ laterTrue := func(i int) bool {
return i > 5
}

iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 0ms
// 1ms
// true

iterations, duration, ok := lo.WaitFor(alwaysFalse, 10*time.Millisecond, time.Millisecond)
Expand All @@ -3089,6 +3090,49 @@ iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Mi
// false
```


### WaitForWithContext

Runs periodically until a condition is validated or context is invalid.

The condition receives also the context, so it can invalidate the process in the condition checker

```go
ctx := context.Background()

alwaysTrue := func(_ context.Context, i int) bool { return true }
alwaysFalse := func(_ context.Context, i int) bool { return false }
laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}

iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 1ms
// true

iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond)
// 10
// 10ms
// false

iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond)
// 5
// 5ms
// true

iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond)
// 2
// 10ms
// false

expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond)
// 5
// 5.1ms
// false
```

### Validate

Helper function that creates an error when a condition is not met.
Expand Down
34 changes: 20 additions & 14 deletions concurrency.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lo

import (
"context"
"sync"
"time"
)
Expand Down Expand Up @@ -98,33 +99,38 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,
}

// WaitFor runs periodically until a condition is validated.
func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
if condition(0) {
return 1, 0, true
func WaitFor(condition func(i int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
conditionWithContext := func(_ context.Context, currentIteration int) bool {
return condition(currentIteration)
}
return WaitForWithContext(context.Background(), conditionWithContext, timeout, heartbeatDelay)
}

// WaitForWithContext runs periodically until a condition is validated or context is canceled.
func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, currentIteration int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
start := time.Now()

timer := time.NewTimer(maxDuration)
ticker := time.NewTicker(tick)
if ctx.Err() != nil {
return totalIterations, time.Since(start), false
}

ctx, cleanCtx := context.WithTimeout(ctx, timeout)
ticker := time.NewTicker(heartbeatDelay)

defer func() {
timer.Stop()
cleanCtx()
ticker.Stop()
}()

i := 1

for {
select {
case <-timer.C:
return i, time.Since(start), false
case <-ctx.Done():
return totalIterations, time.Since(start), false
case <-ticker.C:
if condition(i) {
return i + 1, time.Since(start), true
totalIterations++
if condition(ctx, totalIterations-1) {
return totalIterations, time.Since(start), true
}

i++
}
}
}
225 changes: 190 additions & 35 deletions concurrency_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lo

import (
"context"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -215,44 +216,198 @@ func TestAsyncX(t *testing.T) {

func TestWaitFor(t *testing.T) {
t.Parallel()
testWithTimeout(t, 100*time.Millisecond)
is := assert.New(t)

alwaysTrue := func(i int) bool { return true }
alwaysFalse := func(i int) bool { return false }
testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond

iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter)
is.Equal(time.Duration(0), duration)
is.True(ok)
iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 4*time.Millisecond)
is.Equal(3, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
t.Run("exist condition works", func(t *testing.T) {
t.Parallel()

laterTrue := func(i int) bool {
return i >= 5
}
testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(6, iter)
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, 5*time.Millisecond)
is.Equal(2, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)

counter := 0

alwaysFalse = func(i int) bool {
is.Equal(counter, i)
counter++
return false
}
laterTrue := func(i int) bool {
return i >= 5
}

iter, duration, ok := WaitFor(laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})

t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 1050*time.Microsecond)
is.Equal(10, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
counter := 0
alwaysFalse := func(i int) bool {
is.Equal(counter, i)
counter++
return false
}

iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

alwaysTrue := func(_ int) bool { return true }
alwaysFalse := func(_ int) bool { return false }

t.Run("short timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})
}

func TestWaitForWithContext(t *testing.T) {
t.Parallel()

testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond

t.Run("exist condition works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}

iter, duration, ok := WaitForWithContext(context.Background(), laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})

t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

counter := 0
alwaysFalse := func(_ context.Context, i int) bool {
is.Equal(counter, i)
counter++
return false
}

iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

alwaysTrue := func(_ context.Context, _ int) bool { return true }
alwaysFalse := func(_ context.Context, _ int) bool { return false }

t.Run("short timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitForWithContext(context.Background(), alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})

t.Run("context cancellation stops everything", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

expiringCtx, clean := context.WithTimeout(context.Background(), 8*time.Millisecond)
t.Cleanup(func() {
clean()
})

iter, duration, ok := WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, 3*time.Millisecond)
is.Equal(2, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("canceled context stops everything", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

canceledCtx, cancel := context.WithCancel(context.Background())
cancel()

iter, duration, ok := WaitForWithContext(canceledCtx, alwaysFalse, 100*time.Millisecond, 1050*time.Microsecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(1*time.Millisecond, duration, float64(5*time.Microsecond))
is.False(ok)
})
}
3 changes: 1 addition & 2 deletions lo_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package lo

import (
"os"
"testing"
"time"
)
Expand All @@ -13,12 +12,12 @@
testFinished := make(chan struct{})
t.Cleanup(func() { close(testFinished) })

go func() {

Check failure on line 15 in lo_test.go

View workflow job for this annotation

GitHub Actions / lint

SA2002: the goroutine calls T.FailNow, which must be called in the same goroutine as the test (staticcheck)
select {
case <-testFinished:
case <-time.After(timeout):
t.Errorf("test timed out after %s", timeout)
os.Exit(1)
t.FailNow()

Check failure on line 20 in lo_test.go

View workflow job for this annotation

GitHub Actions / lint

testinggoroutine: call to (*testing.T).FailNow from a non-test goroutine (govet)
}
}()
}
Expand Down
Loading