From 25b32a179fad196d5abea395f77a2cf27e0707e2 Mon Sep 17 00:00:00 2001 From: cszczepaniak Date: Sun, 30 Jul 2023 11:57:42 -0500 Subject: [PATCH] simplify the eventually implementations --- assert/assertions.go | 97 ++++++++++++++++++++------------------- assert/assertions_test.go | 4 +- 2 files changed, 53 insertions(+), 48 deletions(-) diff --git a/assert/assertions.go b/assert/assertions.go index 17ae12707..9ea64ee08 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -3,6 +3,7 @@ package assert import ( "bufio" "bytes" + "context" "encoding/json" "errors" "fmt" @@ -13,6 +14,7 @@ import ( "runtime" "runtime/debug" "strings" + "sync" "time" "unicode" "unicode/utf8" @@ -1836,32 +1838,32 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t } ch := make(chan bool, 1) - checkCond := func() { ch <- condition() } - timer := time.NewTimer(waitFor) - defer timer.Stop() - - ticker := time.NewTicker(tick) - defer ticker.Stop() - - var tickC <-chan time.Time + ctx, cancel := context.WithTimeout(context.Background(), waitFor) + defer cancel() - // Check the condition once first on the initial call. - go checkCond() + go func() { + ticker := time.NewTicker(tick) + defer ticker.Stop() + for { + if condition() { + ch <- true + return + } - for { - select { - case <-timer.C: - return Fail(t, "Condition never satisfied", msgAndArgs...) - case <-tickC: - tickC = nil - go checkCond() - case v := <-ch: - if v { - return true + select { + case <-ticker.C: + case <-ctx.Done(): + return } - tickC = ticker.C } + }() + + select { + case <-ch: + return true + case <-ctx.Done(): + return Fail(t, "Condition never satisfied", msgAndArgs...) } } @@ -1921,37 +1923,40 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time ch := make(chan bool, 1) collect := new(CollectT) - checkCond := func() { - condition(collect) - ch <- len(collect.errors) == 0 - } - - timer := time.NewTimer(waitFor) - defer timer.Stop() + ctx, cancel := context.WithTimeout(context.Background(), waitFor) + defer cancel() - ticker := time.NewTicker(tick) - defer ticker.Stop() + wg := sync.WaitGroup{} + wg.Add(1) - var tickC <-chan time.Time + go func() { + ticker := time.NewTicker(tick) + defer ticker.Stop() + defer wg.Done() + for { + collect.Reset() - // Check the condition once first on the initial call. - go checkCond() + condition(collect) + if len(collect.errors) == 0 { + ch <- true + return + } - for { - select { - case <-timer.C: - collect.Copy(t) - return Fail(t, "Condition never satisfied", msgAndArgs...) - case <-tickC: - tickC = nil - collect.Reset() - go checkCond() - case v := <-ch: - if v { - return true + select { + case <-ticker.C: + case <-ctx.Done(): + return } - tickC = ticker.C } + }() + + select { + case <-ch: + return true + case <-ctx.Done(): + wg.Wait() + collect.Copy(t) + return Fail(t, "Condition never satisfied", msgAndArgs...) } } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index a0ff9adf0..69e99cadf 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -2819,7 +2819,7 @@ func TestEventuallyIssue805(t *testing.T) { func TestEventuallySucceedQuickly(t *testing.T) { mockT := new(testing.T) - condition := func() bool { <-time.After(time.Millisecond); return true } + condition := func() bool { return true } done := make(chan struct{}) go func() { @@ -2837,7 +2837,7 @@ func TestEventuallySucceedQuickly(t *testing.T) { func TestEventuallyWithTSucceedQuickly(t *testing.T) { mockT := new(testing.T) - condition := func(t *CollectT) { <-time.After(time.Millisecond) } + condition := func(t *CollectT) {} done := make(chan struct{}) go func() {