Skip to content

Commit

Permalink
feat: add WaitForWithContext (#480)
Browse files Browse the repository at this point in the history
* chore: fix test timeout helper

using os.Exit(1) kills everything, tests statuses are not always displayed

* chore: refactor WaitFor unit tests

zero-code changes

* fix: WaitFor on first condition

duration must be non-zero if first conditions is true

* feat: add WaitForWithContext

* chore: provide meaningful returned values for WaitFor and WaitForWithContext
  • Loading branch information
ccoVeille authored Jul 15, 2024
1 parent 9e34397 commit 0f4679b
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 53 deletions.
48 changes: 46 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ Concurrency helpers:
- [Async](#async)
- [Transaction](#transaction)
- [WaitFor](#waitfor)
- [WaitForWithContext](#waitforwithcontext)

Error handling:

Expand Down Expand Up @@ -3104,9 +3105,9 @@ laterTrue := func(i int) bool {
return i > 5
}

iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 0ms
// 1ms
// true

iterations, duration, ok := lo.WaitFor(alwaysFalse, 10*time.Millisecond, time.Millisecond)
Expand All @@ -3125,6 +3126,49 @@ iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Mi
// false
```


### WaitForWithContext

Runs periodically until a condition is validated or context is invalid.

The condition receives also the context, so it can invalidate the process in the condition checker

```go
ctx := context.Background()

alwaysTrue := func(_ context.Context, i int) bool { return true }
alwaysFalse := func(_ context.Context, i int) bool { return false }
laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}

iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 1ms
// true

iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond)
// 10
// 10ms
// false

iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond)
// 5
// 5ms
// true

iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond)
// 2
// 10ms
// false

expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond)
// 5
// 5.1ms
// false
```

### Validate

Helper function that creates an error when a condition is not met.
Expand Down
34 changes: 20 additions & 14 deletions concurrency.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lo

import (
"context"
"sync"
"time"
)
Expand Down Expand Up @@ -98,33 +99,38 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,
}

// WaitFor runs periodically until a condition is validated.
func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
if condition(0) {
return 1, 0, true
func WaitFor(condition func(i int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
conditionWithContext := func(_ context.Context, currentIteration int) bool {
return condition(currentIteration)
}
return WaitForWithContext(context.Background(), conditionWithContext, timeout, heartbeatDelay)
}

// WaitForWithContext runs periodically until a condition is validated or context is canceled.
func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, currentIteration int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) {
start := time.Now()

timer := time.NewTimer(maxDuration)
ticker := time.NewTicker(tick)
if ctx.Err() != nil {
return totalIterations, time.Since(start), false
}

ctx, cleanCtx := context.WithTimeout(ctx, timeout)
ticker := time.NewTicker(heartbeatDelay)

defer func() {
timer.Stop()
cleanCtx()
ticker.Stop()
}()

i := 1

for {
select {
case <-timer.C:
return i, time.Since(start), false
case <-ctx.Done():
return totalIterations, time.Since(start), false
case <-ticker.C:
if condition(i) {
return i + 1, time.Since(start), true
totalIterations++
if condition(ctx, totalIterations-1) {
return totalIterations, time.Since(start), true
}

i++
}
}
}
225 changes: 190 additions & 35 deletions concurrency_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lo

import (
"context"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -215,44 +216,198 @@ func TestAsyncX(t *testing.T) {

func TestWaitFor(t *testing.T) {
t.Parallel()
testWithTimeout(t, 100*time.Millisecond)
is := assert.New(t)

alwaysTrue := func(i int) bool { return true }
alwaysFalse := func(i int) bool { return false }
testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond

iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter)
is.Equal(time.Duration(0), duration)
is.True(ok)
iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 4*time.Millisecond)
is.Equal(3, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
t.Run("exist condition works", func(t *testing.T) {
t.Parallel()

laterTrue := func(i int) bool {
return i >= 5
}
testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(6, iter)
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
iter, duration, ok = WaitFor(laterTrue, 10*time.Millisecond, 5*time.Millisecond)
is.Equal(2, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)

counter := 0

alwaysFalse = func(i int) bool {
is.Equal(counter, i)
counter++
return false
}
laterTrue := func(i int) bool {
return i >= 5
}

iter, duration, ok := WaitFor(laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})

t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok = WaitFor(alwaysFalse, 10*time.Millisecond, 1050*time.Microsecond)
is.Equal(10, iter)
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
counter := 0
alwaysFalse := func(i int) bool {
is.Equal(counter, i)
counter++
return false
}

iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

alwaysTrue := func(_ int) bool { return true }
alwaysFalse := func(_ int) bool { return false }

t.Run("short timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitFor(alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitFor(alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})
}

func TestWaitForWithContext(t *testing.T) {
t.Parallel()

testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond

t.Run("exist condition works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}

iter, duration, ok := WaitForWithContext(context.Background(), laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})

t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

counter := 0
alwaysFalse := func(_ context.Context, i int) bool {
is.Equal(counter, i)
counter++
return false
}

iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

alwaysTrue := func(_ context.Context, _ int) bool { return true }
alwaysFalse := func(_ context.Context, _ int) bool { return false }

t.Run("short timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitForWithContext(context.Background(), alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})

t.Run("context cancellation stops everything", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

expiringCtx, clean := context.WithTimeout(context.Background(), 8*time.Millisecond)
t.Cleanup(func() {
clean()
})

iter, duration, ok := WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, 3*time.Millisecond)
is.Equal(2, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("canceled context stops everything", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

canceledCtx, cancel := context.WithCancel(context.Background())
cancel()

iter, duration, ok := WaitForWithContext(canceledCtx, alwaysFalse, 100*time.Millisecond, 1050*time.Microsecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(1*time.Millisecond, duration, float64(5*time.Microsecond))
is.False(ok)
})
}
3 changes: 1 addition & 2 deletions lo_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package lo

import (
"os"
"testing"
"time"
)
Expand All @@ -18,7 +17,7 @@ func testWithTimeout(t *testing.T, timeout time.Duration) {
case <-testFinished:
case <-time.After(timeout):
t.Errorf("test timed out after %s", timeout)
os.Exit(1)
t.FailNow()

Check failure on line 20 in lo_test.go

View workflow job for this annotation

GitHub Actions / lint

testinggoroutine: call to (*testing.T).FailNow from a non-test goroutine (govet)
}
}()
}
Expand Down

0 comments on commit 0f4679b

Please sign in to comment.