diff --git a/.github/.golangci.yml b/.github/.golangci.yml index dec1fe6911c..46e9021bddf 100644 --- a/.github/.golangci.yml +++ b/.github/.golangci.yml @@ -33,7 +33,7 @@ linters: forbidigo: forbid: - pattern: time.Sleep - msg: "Please use require.Eventually or assert.Eventually instead unless you've no other option" + msg: "Please use await.Require / s.Await unless there's no better option" - pattern: "^panic$" msg: "Please avoid using panic in application code" - pattern: time\.Now @@ -48,8 +48,10 @@ linters: msg: "FunctionalTestBase is deprecated. Use testcore.NewEnv(t) instead. See docs/development/testing.md for details." - pattern: context\.Background\(\) msg: "Avoid context.Background() in tests; use t.Context() to respect test timeouts and cancellation" + - pattern: '(^|\.)(Eventually|Eventuallyf|EventuallyWithT|EventuallyWithTf)(\(|$)' + msg: "Use await.Require / s.Await for assertion conditions, or await.RequireTrue / s.AwaitTrue for bool predicates, instead of testify Eventually helpers" - pattern: 'assert\.\w+' - msg: "Use require.X / protorequire.X instead of assert.X / protoassert.X — assert doesn't stop the test on failure. assert.CollectT is still allowed for EventuallyWithT callbacks." + msg: "Use require.X / protorequire.X instead of assert.X / protoassert.X — assert doesn't stop the test on failure." depguard: rules: main: @@ -202,7 +204,14 @@ linters: text: "context.Background" linters: - forbidigo - - text: "use of `assert\\.CollectT`" # allowed for EventuallyWithT callbacks + # Existing legacy call sites are tracked separately; keep this PR scoped + # to preventing new usage while migrating touched tests. + - path: tests/(nexus_standalone|nexus_workflow|schedule|schedule_migration)_test\.go$ + text: "Eventually" + linters: + - forbidigo + - path: tests/(nexus_standalone|nexus_workflow)_test\.go$ + text: "assert\\.CollectT" linters: - forbidigo - text: "use of `softassert\\.\\w+`" diff --git a/common/testing/await/doc.go b/common/testing/await/doc.go new file mode 100644 index 00000000000..fa86bad7d5f --- /dev/null +++ b/common/testing/await/doc.go @@ -0,0 +1,39 @@ +// Package await provides polling-based test assertions as a replacement +// for testify's Eventually, EventuallyWithT, and their formatted variants. +// +// Improvements over testify's eventually functions: +// +// - Misuse detection: accidentally using the real *testing.T (e.g. s.T() or +// suite assertion methods) instead of the callback's collect T is a +// common mistake. This package detects it and fails with a clear message. +// +// - Safer bool predicates: unlike testify's Eventually, [RequireTrue] only +// accepts func() bool, so returning false is the sole retry signal. If the +// predicate accidentally marks the real test failed, it reports that +// immediately instead of polling until timeout. +// +// - Timeout-aware callbacks: callbacks receive a context derived from the +// parent context and canceled when the await timeout or test deadline is +// reached, so RPCs and blocking waits can exit instead of continuing after +// the retry window has expired. +// +// - Panic propagation: if the condition panics (e.g. nil dereference), the +// panic is propagated immediately rather than being silently swallowed +// or retried until timeout. +// See https://github.com/stretchr/testify/issues/1810 +// +// - Bounded goroutine lifetime: each attempt completes before the next +// starts, avoiding the overlapping-attempt data races and "panic: Fail +// in goroutine after Test has completed" crashes seen with testify's +// Eventually. +// See https://github.com/stretchr/testify/issues/1611 +// +// - Deadlock detection: a condition that ignores t.Context() is abandoned +// after a grace period, producing a clear "does it honor t.Context()?" +// failure instead of hanging until go test -timeout. +// +// - Condition always runs: testify's Eventually can fail without ever +// running the condition due to a timer/ticker race with short timeouts. +// This package runs the condition immediately on the first iteration. +// See https://github.com/stretchr/testify/issues/1652 +package await diff --git a/common/testing/await/report.go b/common/testing/await/report.go new file mode 100644 index 00000000000..094ebe73fd7 --- /dev/null +++ b/common/testing/await/report.go @@ -0,0 +1,65 @@ +package await + +import ( + "fmt" + "strings" + "testing" + "time" +) + +// reportAttemptErrors emits the collected attempt failures. When there are +// many, only the first and the last few are shown — long polls would +// otherwise produce hundreds of duplicate lines. +const ( + reportHeadAttempts = 1 + reportTailAttempts = 3 +) + +type attemptFailure struct { + attempt int + errors []string +} + +// reportTimeout reports the timeout failure plus collected attempt errors. +func reportTimeout(tb testing.TB, failures []attemptFailure, funcName, timeoutMsg string, effectiveTimeout time.Duration, polls int) { + reportAttemptErrors(tb, failures) + if timeoutMsg != "" { + tb.Fatalf("%s: %s (not satisfied after %v, %d polls)", funcName, timeoutMsg, effectiveTimeout, polls) + } else { + tb.Fatalf("%s: condition not satisfied after %v (%d polls)", funcName, effectiveTimeout, polls) + } +} + +func reportAttemptErrors(tb testing.TB, failures []attemptFailure) { + if len(failures) == 0 { + return + } + + var b strings.Builder + b.WriteString("attempt errors:") + if len(failures) <= reportHeadAttempts+reportTailAttempts { + for _, f := range failures { + writeAttemptFailure(&b, f) + } + } else { + for _, f := range failures[:reportHeadAttempts] { + writeAttemptFailure(&b, f) + } + omitted := len(failures) - reportHeadAttempts - reportTailAttempts + fmt.Fprintf(&b, "\n ... %d attempts omitted ...", omitted) + for _, f := range failures[len(failures)-reportTailAttempts:] { + writeAttemptFailure(&b, f) + } + } + tb.Errorf("%s", b.String()) +} + +func writeAttemptFailure(b *strings.Builder, f attemptFailure) { + fmt.Fprintf(b, "\n attempt %d:", f.attempt) + for _, e := range f.errors { + for line := range strings.SplitSeq(e, "\n") { + b.WriteString("\n ") + b.WriteString(line) + } + } +} diff --git a/common/testing/await/require_ctx.go b/common/testing/await/require_ctx.go new file mode 100644 index 00000000000..a830d551d27 --- /dev/null +++ b/common/testing/await/require_ctx.go @@ -0,0 +1,282 @@ +package await + +import ( + "context" + "fmt" + "os" + "testing" + "time" +) + +const requireMisuseHint = "use the *await.T passed to the callback, not s.T() or suite assertion methods" + +// softDeadlockTimeoutEnvVar overrides the default soft-deadlock timeout. +// Parsed as a Go duration, e.g. "10s". +const softDeadlockTimeoutEnvVar = "TEMPORAL_AWAIT_SOFT_DEADLOCK_TIMEOUT" + +// defaultSoftDeadlockTimeout caps how long a single attempt can run before its +// context is cancelled (soft deadlock). Capped further by the overall await +// deadline. Each new attempt gets a fresh context with this same cap. +const defaultSoftDeadlockTimeout = 30 * time.Second + +func softDeadlockTimeout() time.Duration { + if s := os.Getenv(softDeadlockTimeoutEnvVar); s != "" { + if d, err := time.ParseDuration(s); err == nil { + return d + } + } + return defaultSoftDeadlockTimeout +} + +// hardDeadlockTimeoutEnvVar overrides the default hard-deadlock timeout. +// Parsed as a Go duration, e.g. "100ms". +const hardDeadlockTimeoutEnvVar = "TEMPORAL_AWAIT_HARD_DEADLOCK_TIMEOUT" + +// defaultHardDeadlockTimeout is how long runAttempt waits AFTER cancelling the +// attempt context (soft deadlock) for the condition goroutine to honor the +// cancellation. If it doesn't terminate by then, the goroutine is declared +// hard-deadlocked and abandoned. Without it, a condition that ignores +// t.Context() would hang the test until go test -timeout fires. +const defaultHardDeadlockTimeout = 10 * time.Second + +func hardDeadlockTimeout() time.Duration { + if s := os.Getenv(hardDeadlockTimeoutEnvVar); s != "" { + if d, err := time.ParseDuration(s); err == nil { + return d + } + } + return defaultHardDeadlockTimeout +} + +// Require polls condition until it returns without assertion failures, or +// until ctx is canceled or timeout expires (whichever is earliest). +// +// Pass the *await.T to require.*/assert.* — failures cause a retry, not a +// test failure. Use t.Context() inside the callback to honor the timeout. +func Require(ctx context.Context, tb testing.TB, condition func(*T), timeout, pollInterval time.Duration) { + tb.Helper() + run(ctx, tb, condition, timeout, pollInterval, "", "Require", requireMisuseHint, true) +} + +// Requiref is like [Require] but adds a formatted message to the timeout +// failure. +func Requiref(ctx context.Context, tb testing.TB, condition func(*T), timeout, pollInterval time.Duration, msg string, args ...any) { + tb.Helper() + run(ctx, tb, condition, timeout, pollInterval, fmt.Sprintf(msg, args...), "Requiref", requireMisuseHint, true) +} + +func run( + parentCtx context.Context, + tb testing.TB, + condition func(*T), + timeout, + pollInterval time.Duration, + timeoutMsg string, + funcName string, + misuseHint string, + cancellable bool, +) { + tb.Helper() + + // Skip if the test already failed — no point polling. + if tb.Failed() { + tb.Logf("%s: skipping (test already failed)", funcName) + return + } + // Guard: context.WithDeadline panics on a nil parent. + if parentCtx == nil { + tb.Fatalf("%s: nil context", funcName) + return + } + + deadline := time.Now().Add(timeout) + + // Cap at the parent context's deadline if it's earlier than our timeout. + if parentDeadline, hasDeadline := parentCtx.Deadline(); hasDeadline && parentDeadline.Before(deadline) { + deadline = parentDeadline + } + + // Cap at the test's deadline if it's earlier than our deadline. + // Ideally, the parent context already accounts for the test's deadline - but we are being defensive. + if d, ok := tb.(interface{ Deadline() (time.Time, bool) }); ok { + if testDeadline, hasDeadline := d.Deadline(); hasDeadline && testDeadline.Before(deadline) { + deadline = testDeadline + } + } + + effectiveTimeout := max(0, time.Until(deadline)) + awaitCtx, awaitCancel := context.WithDeadline(parentCtx, deadline) + defer awaitCancel() + + var failures []attemptFailure + polls := 0 + + for { + // Parent context was canceled while we were sleeping (not our deadline). + if err := awaitCtx.Err(); err != nil && !deadlineReached(deadline) { + reportAttemptErrors(tb, failures) + tb.Fatalf("%s: context canceled before condition was satisfied: %v", funcName, err) + return + } + + polls++ + + // Fresh context per attempt, scoped to the run-level ctx. runAttempt + // owns the soft timeout and the corresponding cancel. + attemptCtx, attemptCancel := context.WithCancel(awaitCtx) + t := &T{tb: tb, ctx: attemptCtx} + + // Run attempt. + res := runAttempt(t, condition, attemptCancel, funcName, cancellable) + attemptCancel() + if res.panicVal != nil { + panic(res.panicVal) // propagate to caller + } + if res.deadlocked { + reportAttemptErrors(tb, failures) + if cancellable { + tb.Fatalf("%s: condition still running %v past context cancellation — does it honor t.Context()? (%d polls)", + funcName, hardDeadlockTimeout(), polls) + } else { + tb.Fatalf("%s: condition still running %v past deadline (%d polls)", + funcName, hardDeadlockTimeout(), polls) + } + return + } + if len(t.errors) > 0 { + failures = append(failures, attemptFailure{attempt: polls, errors: t.errors}) + } + + // Check misuse where the real test failed instead of just the attempt. + if tb.Failed() { + tb.Fatalf("%s: the test was marked failed directly — %s", funcName, misuseHint) + return + } + + // Parent context was canceled during the attempt (not our deadline). + if err := awaitCtx.Err(); err != nil && !deadlineReached(deadline) { + reportAttemptErrors(tb, failures) + tb.Fatalf("%s: context canceled before condition was satisfied: %v", funcName, err) + return + } + + // Our deadline expired. + if deadlineReached(deadline) { + reportTimeout(tb, failures, funcName, timeoutMsg, effectiveTimeout, polls) + return + } + + // Success: attempt completed without failures. + if !res.stopped && !t.Failed() { + return + } + + // Wait for pollInterval, or context is canceled or deadline is reached. + sleep(awaitCtx, deadline, pollInterval) + } +} + +// attemptResult describes how an attempt terminated. Exactly one of the +// following fields is set: +// - panicVal != nil: condition panicked with a non-attemptFailed value; +// caller should re-panic with panicVal. +// - deadlocked: condition did not honor context cancellation within +// [hardDeadlockTimeout]; the goroutine is abandoned and leaks until the +// process exits. +// - stopped: condition stopped via attemptFailed (FailNow on *T) or +// runtime.Goexit (real-test FailNow misuse). +// - none: condition returned normally. +type attemptResult struct { + panicVal any + stopped bool + deadlocked bool +} + +// runAttempt runs condition in a goroutine so that an accidental call to the +// real test's FailNow (runtime.Goexit) terminates only this goroutine. +// +// Termination is detected in two phases: +// - Soft (cancellable only): if condition hasn't returned within +// [softDeadlockTimeout], log a warning and cancel ctx. Skipped if the +// parent ctx was already cancelled. +// - Hard: if condition still hasn't returned within [hardDeadlockTimeout], +// declare it deadlocked and abandon the goroutine. +func runAttempt( + t *T, + condition func(*T), + cancel context.CancelFunc, + funcName string, + cancellable bool, +) attemptResult { + done := make(chan attemptResult, 1) + + go func() { + completed := false + defer func() { + if r := recover(); r != nil { + if _, ok := r.(attemptFailed); ok { + done <- attemptResult{stopped: true} + return + } + done <- attemptResult{panicVal: r} + return + } + // recover returned nil: either normal return (completed=true) or + // runtime.Goexit (completed=false; Goexit is not a panic). + done <- attemptResult{stopped: !completed} + }() + condition(t) + completed = true + }() + + if cancellable { + // Soft phase: wait for the condition, our soft timer, or parent cancel. + softTimer := time.NewTimer(softDeadlockTimeout()) + defer softTimer.Stop() + + select { + case r := <-done: + return r + case <-softTimer.C: + // Soft deadlock: log a warning. + t.tb.Logf("%s: soft deadlock — condition still running after %v; waiting %v before declaring hard deadlock", + funcName, softDeadlockTimeout(), hardDeadlockTimeout()) + + // Cancel so the condition can observe ctx.Done(). + cancel() + case <-t.ctx.Done(): + // Parent cancelled (await deadline reached or upstream cancel). + // Proceed to hard phase quietly. + } + } + + // Hard phase: wait for the condition or the hard timer. + hardTimer := time.NewTimer(hardDeadlockTimeout()) + defer hardTimer.Stop() + + select { + case r := <-done: + return r + case <-hardTimer.C: + return attemptResult{deadlocked: true} + } +} + +func sleep(ctx context.Context, deadline time.Time, pollInterval time.Duration) { + remaining := time.Until(deadline) + if remaining < pollInterval { + pollInterval = remaining + } + + timer := time.NewTimer(pollInterval) + defer timer.Stop() + + select { + case <-ctx.Done(): + case <-timer.C: + } +} + +func deadlineReached(deadline time.Time) bool { + return !time.Now().Before(deadline) +} diff --git a/common/testing/await/require_ctx_test.go b/common/testing/await/require_ctx_test.go new file mode 100644 index 00000000000..7b6a62a7e6f --- /dev/null +++ b/common/testing/await/require_ctx_test.go @@ -0,0 +1,536 @@ +package await_test + +import ( + "context" + "fmt" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/testing/await" + "go.temporal.io/server/common/testing/testcontext" +) + +func TestRequire_ImmediateSuccess(t *testing.T) { + t.Parallel() + + attempts := 0 + + await.Require(t.Context(), t, func(t *await.T) { + attempts++ + }, time.Second, 100*time.Millisecond) + + require.Equal(t, 1, attempts, "condition should be called exactly once") +} + +func TestRequire_RetriesUntilAttemptPasses(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + fail func(*await.T, int32) + stops bool + }{ + { + name: "Errorf", + fail: func(t *await.T, attempt int32) { + t.Errorf("not ready: %d", attempt) + }, + }, + { + name: "FailNow", + stops: true, + fail: func(t *await.T, _ int32) { + t.FailNow() + }, + }, + { + name: "Fatal", + stops: true, + fail: func(t *await.T, _ int32) { + t.Fatal("not ready") + }, + }, + { + name: "Fatalf", + stops: true, + fail: func(t *await.T, attempt int32) { + t.Fatalf("not ready: %d", attempt) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var attempts atomic.Int32 + var continuedAfterFailure atomic.Bool + await.Require(t.Context(), t, func(t *await.T) { + attempt := attempts.Add(1) + if attempt < 3 { + tc.fail(t, attempt) + continuedAfterFailure.Store(true) + } + }, time.Second, 100*time.Millisecond) + + require.Equal(t, int32(3), attempts.Load()) + require.Equal(t, !tc.stops, continuedAfterFailure.Load()) + }) + } +} + +func TestRequire_PropagatesParentContextValues(t *testing.T) { + t.Parallel() + + type contextKey struct{} + ctx := context.WithValue(t.Context(), contextKey{}, "value") + + var got any + await.Require(ctx, t, func(t *await.T) { + got = t.Context().Value(contextKey{}) + }, time.Second, 100*time.Millisecond) + + require.Equal(t, "value", got) +} + +func TestRequire_SetsTimeoutContextDeadline(t *testing.T) { + t.Parallel() + + longCtx, cancel := context.WithTimeout(testcontext.New(t), time.Minute) + defer cancel() + longDeadline, ok := longCtx.Deadline() + require.True(t, ok) + + shortTimeout := 1 * time.Second + + var shortCtx context.Context + await.Require(longCtx, t, func(t *await.T) { + shortCtx = t.Context() + }, shortTimeout, 100*time.Millisecond) + + require.NotNil(t, shortCtx) + require.NotSame(t, longCtx, shortCtx) + + shortDeadline, ok := shortCtx.Deadline() + require.True(t, ok) + require.True(t, shortDeadline.Before(longDeadline)) + require.LessOrEqual(t, time.Until(shortDeadline), shortTimeout) + require.Greater(t, time.Until(shortDeadline), shortTimeout-200*time.Millisecond) +} + +func TestRequire_PollIntervalStartsAfterAttemptFinishes(t *testing.T) { + t.Parallel() + + var attempts atomic.Int32 + var attemptStarts []time.Time + var attemptEnds []time.Time + attemptDuration := 60 * time.Millisecond + pollInterval := 100 * time.Millisecond + + await.Require(t.Context(), t, func(t *await.T) { + attemptStarts = append(attemptStarts, time.Now()) + defer func() { attemptEnds = append(attemptEnds, time.Now()) }() + + time.Sleep(attemptDuration) //nolint:forbidigo // simulate attempt work to distinguish poll-after-start vs poll-after-end + + if attempts.Add(1) < 3 { + t.Error("not ready") + } + }, time.Second, pollInterval) + + require.Equal(t, int32(3), attempts.Load()) + require.Len(t, attemptStarts, 3) + require.Len(t, attemptEnds, 3) + for i := 1; i < len(attemptStarts); i++ { + gap := attemptStarts[i].Sub(attemptEnds[i-1]) + require.GreaterOrEqual(t, gap, pollInterval, + "poll interval should run after attempt finishes (gap=%v < %v)", gap, pollInterval) + } +} + +func TestRequire_FailureScenarios(t *testing.T) { + t.Parallel() + + t.Run("reports timeout", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + t.Error("not ready") + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + }) + + t.Run("cancels attempt context on timeout", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + <-t.Context().Done() + if t.Context().Err() != context.DeadlineExceeded { + t.Errorf("context error = %v", t.Context().Err()) + } + }, 2*time.Second, time.Second) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + }) + + t.Run("does not poll again after attempt consumes timeout", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + var attempts atomic.Int32 + + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + attempts.Add(1) + <-t.Context().Done() // block until timeout + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + require.Equal(t, int32(1), attempts.Load()) + }) + + t.Run("caps attempt context with parent deadline", func(t *testing.T) { + t.Parallel() + + parentCtx, cancel := context.WithTimeout(testcontext.New(t), time.Second) + defer cancel() + + tb := newRecordingTB() + tb.run(func() { + await.Require(parentCtx, tb, func(t *await.T) { + deadline, ok := t.Context().Deadline() + if !ok { + t.Error("missing deadline") + } + if time.Until(deadline) > time.Second { + t.Errorf("deadline = %v", deadline) + } + <-t.Context().Done() + if t.Context().Err() != context.DeadlineExceeded { + t.Errorf("context error = %v", t.Context().Err()) + } + }, 2*time.Second, time.Second) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + }) + + t.Run("parent context cancellation stops polling", func(t *testing.T) { + t.Parallel() + + parentCtx, cancel := context.WithCancel(testcontext.New(t)) + defer cancel() + var attempts atomic.Int32 + + tb := newRecordingTB() + tb.run(func() { + await.Require(parentCtx, tb, func(t *await.T) { + attempts.Add(1) + t.Error("not ready") + cancel() + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "context canceled before condition was satisfied") + + require.Equal(t, int32(1), attempts.Load(), "expected cancellation to stop polling") + }) + + t.Run("reports all attempt errors on timeout", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + var attempts atomic.Int32 + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + if attempts.Add(1) == 1 { + t.Error("first attempt error") + return + } + <-t.Context().Done() + t.Error("last attempt error") + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + require.Equal(t, "attempt errors:\n attempt 1:\n first attempt error\n attempt 2:\n last attempt error", tb.errors()) + require.Equal(t, int32(2), attempts.Load()) + }) + + t.Run("truncates middle attempts when many fail", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + var attempts atomic.Int32 + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + n := attempts.Add(1) + t.Errorf("attempt %d failed", n) + }, 400*time.Millisecond, 50*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + + n := attempts.Load() + require.Greater(t, n, int32(4), "need >4 attempts to exercise truncation") + + errs := tb.errors() + require.Contains(t, errs, "attempt errors:\n attempt 1:\n attempt 1 failed\n") + require.Contains(t, errs, fmt.Sprintf("... %d attempts omitted ...", n-4)) + // Last three attempts present in order. + for i := n - 2; i <= n; i++ { + require.Contains(t, errs, fmt.Sprintf("attempt %d:\n attempt %d failed", i, i)) + } + }) + + t.Run("Requiref includes message on timeout", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + tb := newRecordingTB() + tb.run(func() { + await.Requiref(ctx, tb, func(t *await.T) { + t.Error("not ready") + }, time.Second, 100*time.Millisecond, "workflow %s not ready", "wf-123") + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "workflow wf-123 not ready") + }) + + t.Run("panic propagates", func(t *testing.T) { + t.Parallel() + + require.PanicsWithValue(t, "unexpected nil pointer", func() { + await.Require(t.Context(), t, func(_ *await.T) { + panic("unexpected nil pointer") + }, time.Second, 100*time.Millisecond) + }) + }) + + t.Run("reports real TB misuse", func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + misuse func(*recordingTB) + }{ + {"Fatal stops real TB", func(tb *recordingTB) { tb.Fatal("wrong t used") }}, + {"Errorf marks real TB failed", func(tb *recordingTB) { tb.Errorf("assert-style misuse") }}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(_ *await.T) { + tc.misuse(tb) + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "use the *await.T") + }) + } + }) + + t.Run("does not poll after prior failure", func(t *testing.T) { + t.Parallel() + + ctx := testcontext.New(t) + conditionCalled := false + tb := newRecordingTB() + tb.run(func() { + tb.Errorf("previous failure") + await.Require(ctx, tb, func(_ *await.T) { + conditionCalled = true + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Empty(t, tb.fatals()) + require.False(t, conditionCalled, "condition should not run when test already failed") + }) +} + +func TestRequire_SoftDeadlockLogsAndCancels(t *testing.T) { + // not using T.Parallel() so it can use t.Setenv to override the deadlock timeouts + t.Setenv("TEMPORAL_AWAIT_SOFT_DEADLOCK_TIMEOUT", "50ms") + t.Setenv("TEMPORAL_AWAIT_HARD_DEADLOCK_TIMEOUT", "5s") + + const awaitTimeout = 10 * time.Second + + ctx := testcontext.New(t) + tb := newRecordingTB() + start := time.Now() + tb.run(func() { + // Await timeout is long so parent-cancel doesn't beat the soft timer. + await.Require(ctx, tb, func(t *await.T) { + <-t.Context().Done() // exits as soon as soft cancel fires + }, awaitTimeout, 100*time.Millisecond) + }) + elapsed := time.Since(start) + require.False(t, tb.Failed(), "soft deadlock + clean exit should succeed") + require.Contains(t, tb.logs(), "soft deadlock") + require.NotContains(t, tb.fatals(), "still running") + require.Less(t, elapsed, awaitTimeout, + "should return shortly after soft cancel, not wait the full await timeout (elapsed=%v)", elapsed) +} + +func TestRequire_DeadlockDetected(t *testing.T) { + // not using T.Parallel() so it can use t.Setenv to override the deadlock timeouts. + // Await timeout is long enough that the soft timer fires before parent cancellation, + // so the path is: soft fires → log + cancel → condition still running → hard fires. + t.Setenv("TEMPORAL_AWAIT_SOFT_DEADLOCK_TIMEOUT", "50ms") + t.Setenv("TEMPORAL_AWAIT_HARD_DEADLOCK_TIMEOUT", "100ms") + + const awaitTimeout = 10 * time.Second + + ctx := testcontext.New(t) + tb := newRecordingTB() + start := time.Now() + tb.run(func() { + await.Require(ctx, tb, func(*await.T) { + select {} // ignores t.Context() + }, awaitTimeout, 50*time.Millisecond) + }) + elapsed := time.Since(start) + require.True(t, tb.Failed()) + require.Contains(t, tb.logs(), "soft deadlock") + require.Contains(t, tb.fatals(), "still running") + require.Contains(t, tb.fatals(), "does it honor t.Context()") + require.Less(t, elapsed, awaitTimeout, + "should fail at hard deadlock, not wait the full await timeout (elapsed=%v)", elapsed) +} + +func TestRequire_WaitsForInFlightAttemptOnTimeout(t *testing.T) { + t.Parallel() + + var finished atomic.Bool + ctx := testcontext.New(t) + tb := newRecordingTB() + tb.run(func() { + await.Require(ctx, tb, func(t *await.T) { + <-t.Context().Done() + finished.Store(true) + }, time.Second, time.Second) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + require.True(t, finished.Load(), "Require returned before the running attempt exited") +} + +// recordingTB is a minimal testing.TB implementation for testing failure scenarios. +type recordingTB struct { + testing.TB // embed for interface satisfaction + mu sync.Mutex + failed atomic.Bool + errorMessages []string + fatalMessages []string + logMessages []string + cleanups []func() +} + +func newRecordingTB() *recordingTB { + return &recordingTB{} +} + +func (r *recordingTB) Helper() {} +func (r *recordingTB) Failed() bool { return r.failed.Load() } +func (r *recordingTB) Logf(format string, args ...any) { + r.mu.Lock() + defer r.mu.Unlock() + r.logMessages = append(r.logMessages, fmt.Sprintf(format, args...)) +} +func (r *recordingTB) Context() context.Context { + return context.Background() +} + +func (r *recordingTB) Cleanup(fn func()) { + r.mu.Lock() + defer r.mu.Unlock() + r.cleanups = append(r.cleanups, fn) +} + +func (r *recordingTB) Errorf(format string, args ...any) { + r.mu.Lock() + defer r.mu.Unlock() + r.failed.Store(true) + r.errorMessages = append(r.errorMessages, fmt.Sprintf(format, args...)) +} + +func (r *recordingTB) Fatalf(format string, args ...any) { + r.mu.Lock() + r.failed.Store(true) + r.fatalMessages = append(r.fatalMessages, fmt.Sprintf(format, args...)) + r.mu.Unlock() + runtime.Goexit() +} + +func (r *recordingTB) Fatal(args ...any) { + r.mu.Lock() + r.failed.Store(true) + r.fatalMessages = append(r.fatalMessages, strings.TrimSuffix(fmt.Sprintln(args...), "\n")) + r.mu.Unlock() + runtime.Goexit() +} + +func (r *recordingTB) FailNow() { + r.failed.Store(true) + runtime.Goexit() +} + +func (r *recordingTB) run(fn func()) { + done := make(chan struct{}) + go func() { + defer func() { + r.runCleanups() + close(done) + }() + fn() + }() + <-done +} + +func (r *recordingTB) runCleanups() { + r.mu.Lock() + cleanups := r.cleanups + r.cleanups = nil + r.mu.Unlock() + + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() + } +} + +func (r *recordingTB) fatals() string { + r.mu.Lock() + defer r.mu.Unlock() + return strings.Join(r.fatalMessages, "\n") +} + +func (r *recordingTB) errors() string { + r.mu.Lock() + defer r.mu.Unlock() + return strings.Join(r.errorMessages, "\n") +} + +func (r *recordingTB) logs() string { + r.mu.Lock() + defer r.mu.Unlock() + return strings.Join(r.logMessages, "\n") +} diff --git a/common/testing/await/require_true.go b/common/testing/await/require_true.go new file mode 100644 index 00000000000..48be79674f9 --- /dev/null +++ b/common/testing/await/require_true.go @@ -0,0 +1,36 @@ +package await + +import ( + "fmt" + "testing" + "time" + + "go.temporal.io/server/common/testing/testcontext" +) + +const requireTrueMisuseHint = "do not use test assertions inside the predicate - return false to retry or use await.Require for assertions" + +// RequireTrue runs `condition` repeatedly until it returns true, or until the +// timeout expires. The timeout is capped at the test's deadline, if one is set. +// +// Use [RequireTrue] for simple local predicates only. Do not use assertions or +// side effects in the predicate - use [Require] for these. +func RequireTrue(tb testing.TB, condition func() bool, timeout, pollInterval time.Duration) { + tb.Helper() + run(testcontext.New(tb), tb, func(t *T) { + if !condition() { + t.Fail() + } + }, timeout, pollInterval, "", "RequireTrue", requireTrueMisuseHint, false) +} + +// RequireTruef is like [RequireTrue] but accepts a format string that is included +// in the failure message when the condition is not satisfied before the timeout. +func RequireTruef(tb testing.TB, condition func() bool, timeout, pollInterval time.Duration, msg string, args ...any) { + tb.Helper() + run(testcontext.New(tb), tb, func(t *T) { + if !condition() { + t.Fail() + } + }, timeout, pollInterval, fmt.Sprintf(msg, args...), "RequireTruef", requireTrueMisuseHint, false) +} diff --git a/common/testing/await/require_true_test.go b/common/testing/await/require_true_test.go new file mode 100644 index 00000000000..0a749d549b6 --- /dev/null +++ b/common/testing/await/require_true_test.go @@ -0,0 +1,133 @@ +package await_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/testing/await" +) + +// RequireTrue is a thin bool-predicate adapter over the same polling runner +// covered by require_ctx_test.go, so these tests focus on adapter behavior. + +func TestRequireTrue_ImmediateSuccess(t *testing.T) { + t.Parallel() + + attempts := 0 + + await.RequireTrue(t, func() bool { + attempts++ + return true + }, time.Second, 100*time.Millisecond) + + require.Equal(t, 1, attempts, "condition should be called exactly once") +} + +func TestRequireTrue_RetriesFalseUntilTrue(t *testing.T) { + t.Parallel() + + var attempts atomic.Int32 + + await.RequireTrue(t, func() bool { + return attempts.Add(1) >= 3 + }, time.Second, 100*time.Millisecond) + + require.Equal(t, int32(3), attempts.Load()) +} + +func TestRequireTrue_FailureScenarios(t *testing.T) { + t.Parallel() + + t.Run("reports timeout", func(t *testing.T) { + t.Parallel() + + tb := newRecordingTB() + tb.run(func() { + await.RequireTrue(tb, func() bool { + return false + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "not satisfied after") + }) + + t.Run("RequireTruef includes message on timeout", func(t *testing.T) { + t.Parallel() + + tb := newRecordingTB() + tb.run(func() { + await.RequireTruef(tb, func() bool { + return false + }, time.Second, 100*time.Millisecond, "workflow %s not ready", "wf-123") + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "workflow wf-123 not ready") + }) + + t.Run("reports real TB misuse", func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + misuse func(*recordingTB) + }{ + {"Fatal stops real TB", func(tb *recordingTB) { tb.Fatal("wrong t used") }}, + {"Errorf marks real TB failed", func(tb *recordingTB) { tb.Errorf("assert-style misuse") }}, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + tb := newRecordingTB() + tb.run(func() { + await.RequireTrue(tb, func() bool { + tc.misuse(tb) + return true + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "do not use test assertions") + }) + } + }) + + t.Run("does not poll after prior failure", func(t *testing.T) { + t.Parallel() + + conditionCalled := false + + tb := newRecordingTB() + tb.run(func() { + tb.Errorf("previous failure") + await.RequireTrue(tb, func() bool { + conditionCalled = true + return true + }, time.Second, 100*time.Millisecond) + }) + require.True(t, tb.Failed()) + require.Empty(t, tb.fatals()) + require.False(t, conditionCalled, "condition should not run when test already failed") + }) +} + +func TestRequireTrue_DeadlockDetected(t *testing.T) { + // not using T.Parallel() so it can use t.Setenv to override the deadlock timeouts + t.Setenv("TEMPORAL_AWAIT_HARD_DEADLOCK_TIMEOUT", "100ms") + + const awaitTimeout = 10 * time.Second + + tb := newRecordingTB() + start := time.Now() + tb.run(func() { + await.RequireTrue(tb, func() bool { + select {} // never returns; predicate has no way to honor cancellation + }, awaitTimeout, 50*time.Millisecond) + }) + elapsed := time.Since(start) + require.True(t, tb.Failed()) + require.Contains(t, tb.fatals(), "still running") + require.Contains(t, tb.fatals(), "past deadline") + require.Less(t, elapsed, awaitTimeout, + "should fail at hard deadlock, not wait the full await timeout (elapsed=%v)", elapsed) +} diff --git a/common/testing/await/t.go b/common/testing/await/t.go new file mode 100644 index 00000000000..c8f66196e7a --- /dev/null +++ b/common/testing/await/t.go @@ -0,0 +1,77 @@ +package await + +import ( + "context" + "fmt" + "strings" + "testing" +) + +type attemptFailed struct{} + +// T is passed to the condition callback. It intercepts assertion failures +// so the polling loop can retry. +// +// Only use T for assertions (require.*, assert.*, t.Errorf, t.Fatal, t.FailNow). +type T struct { + tb testing.TB + ctx context.Context + errors []string + failed bool +} + +// Context returns the await-scoped context for the current attempt. +func (t *T) Context() context.Context { + if t.ctx != nil { + return t.ctx + } + return t.tb.Context() +} + +// Fail marks the current attempt as failed without stopping it. +func (t *T) Fail() { + t.failed = true +} + +// Error records an error message for reporting on timeout. +func (t *T) Error(args ...any) { + t.Fail() + t.errors = append(t.errors, strings.TrimSuffix(fmt.Sprintln(args...), "\n")) +} + +// Errorf records an error message for reporting on timeout. +func (t *T) Errorf(format string, args ...any) { + t.Fail() + t.errors = append(t.errors, fmt.Sprintf(format, args...)) +} + +// FailNow is called by require.* on failure. It stops the current attempt. +// Unlike testing.TB.FailNow(), this does NOT mark the test as failed. +func (t *T) FailNow() { + t.Fail() + panic(attemptFailed{}) +} + +// Fatal records an error message and stops this attempt. +func (t *T) Fatal(args ...any) { + t.errors = append(t.errors, strings.TrimSuffix(fmt.Sprintln(args...), "\n")) + t.FailNow() +} + +// Fatalf records an error message and stops this attempt. +func (t *T) Fatalf(format string, args ...any) { + t.Errorf(format, args...) + t.FailNow() +} + +// Failed reports whether this attempt has failed. +func (t *T) Failed() bool { + return t.failed +} + +// Helper marks the calling function as a test helper. +func (t *T) Helper() { + if t.tb != nil { + t.tb.Helper() + } +} diff --git a/common/testing/await/t_test.go b/common/testing/await/t_test.go new file mode 100644 index 00000000000..31b51ae5767 --- /dev/null +++ b/common/testing/await/t_test.go @@ -0,0 +1,80 @@ +package await_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/testing/await" +) + +func TestT_CollectsAssertionFailures(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + fail func(*await.T) + stops bool + }{ + { + name: "assert", + fail: func(t *await.T) { + assert.Equal(t, "expected", "actual") //nolint:forbidigo // intentionally testing that assert.* works with *await.T + }, + }, + { + name: "Errorf", + fail: func(t *await.T) { + t.Error("not ready") + }, + }, + { + name: "FailNow", + fail: func(t *await.T) { + t.FailNow() + }, + stops: true, + }, + { + name: "Fatal", + fail: func(t *await.T) { + t.Fatal("not ready") + }, + stops: true, + }, + { + name: "Fatalf", + fail: func(t *await.T) { + t.Fatalf("not ready: %d", 1) + }, + stops: true, + }, + { + name: "require", + fail: func(t *await.T) { + require.Equal(t, "expected", "actual") + }, + stops: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + at := &await.T{} + continuedAfterFailure := false + run := func() { + tc.fail(at) + continuedAfterFailure = true + } + + if tc.stops { + require.Panics(t, run) + require.False(t, continuedAfterFailure) + } else { + require.NotPanics(t, run) + require.True(t, continuedAfterFailure) + } + require.True(t, at.Failed()) + }) + } +} diff --git a/common/testing/parallelsuite/suite.go b/common/testing/parallelsuite/suite.go index 91937617dab..cf84dcb2141 100644 --- a/common/testing/parallelsuite/suite.go +++ b/common/testing/parallelsuite/suite.go @@ -1,24 +1,31 @@ package parallelsuite import ( + "context" "flag" "fmt" "reflect" "regexp" "strings" + "sync" "testing" + "time" "github.com/stretchr/testify/require" testifysuite "github.com/stretchr/testify/suite" + "go.temporal.io/server/common/testing/await" "go.temporal.io/server/common/testing/historyrequire" "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/testcontext" ) // testingSuite is the constraint for suite types. type testingSuite interface { testifysuite.TestingSuite - copySuite(t *testing.T) testingSuite - initSuite(t *testing.T) + //nolint:revive // ctx is last so callers can pass nil to mean "no override"; SA1012 forbids passing nil as the first ctx arg. + copySuite(t *testing.T, assertT require.TestingT, ctx context.Context) testingSuite + //nolint:revive // see copySuite above. + initSuite(t *testing.T, assertT require.TestingT, ctx context.Context) } // Suite provides parallel test execution with require-style (fail-fast) assertions. @@ -31,24 +38,34 @@ type Suite[T testingSuite] struct { protorequire.ProtoAssertions historyrequire.HistoryRequire - guardT guardT + guardT guardT + ctx context.Context // override set in initSuite; lazy-filled by Context() under ctxOnce when nil + ctxOnce sync.Once } // copySuite creates a fresh suite instance initialized for the given *testing.T. -func (s *Suite[T]) copySuite(t *testing.T) testingSuite { +// assertT overrides which TestingT assertions are bound to; nil means use the copy's own guardT. +// ctx overrides the suite's context; nil means use the default (lazy testcontext.New). +// +//nolint:revive // ctx is last so callers can pass nil to mean "no override"; SA1012 forbids passing nil as the first ctx arg. +func (s *Suite[T]) copySuite(t *testing.T, assertT require.TestingT, ctx context.Context) testingSuite { cp := reflect.New(reflect.TypeFor[T]().Elem()).Interface().(T) - cp.initSuite(t) + cp.initSuite(t, assertT, ctx) return cp } -func (s *Suite[T]) initSuite(t *testing.T) { +//nolint:revive // see copySuite above. +func (s *Suite[T]) initSuite(t *testing.T, assertT require.TestingT, ctx context.Context) { g := &s.guardT g.name = t.Name() g.T = t - g.hasSubtests.Store(false) - s.Assertions = require.New(g) - s.ProtoAssertions = protorequire.New(g) - s.HistoryRequire = historyrequire.New(g) + s.ctx = ctx + if assertT == nil { + assertT = g + } + s.Assertions = require.New(assertT) + s.ProtoAssertions = protorequire.New(assertT) + s.HistoryRequire = historyrequire.New(assertT) } // T returns the *testing.T, panicking if the guard has been sealed. @@ -59,6 +76,17 @@ func (s *Suite[T]) T() *testing.T { return s.guardT.T } +// Context returns the test-scoped context (created from [testcontext]). +// Inside an [Await] callback, it returns the await-scoped context. +func (s *Suite[T]) Context() context.Context { + s.ctxOnce.Do(func() { + if s.ctx == nil { + s.ctx = testcontext.New(s.T()) + } + }) + return s.ctx +} + // Run creates a parallel subtest. The callback receives a fresh copy of the // concrete suite type, initialized for the subtest's *testing.T. func (s *Suite[T]) Run(name string, fn func(T)) bool { @@ -66,10 +94,35 @@ func (s *Suite[T]) Run(name string, fn func(T)) bool { s.guardT.markHasSubtests() return pt.Run(name, func(t *testing.T) { t.Parallel() //nolint:testifylint // parallelsuite intentionally supports parallel subtests - fn(s.copySuite(t).(T)) + fn(s.copySuite(t, nil, nil).(T)) }) } +// Await calls fn repeatedly until all assertions pass or timeout is reached. +func (s *Suite[T]) Await(fn func(T), timeout, interval time.Duration) { + s.Awaitf(fn, timeout, interval, "") +} + +// Awaitf is like [Await] but includes a format string appended to the failure message. +func (s *Suite[T]) Awaitf(fn func(T), timeout, interval time.Duration, msg string, args ...any) { + t := s.T() + await.Requiref(s.Context(), t, func(at *await.T) { + fn(s.copySuite(t, at, at.Context()).(T)) + }, timeout, interval, msg, args...) +} + +// AwaitTrue calls fn repeatedly until it returns true or timeout is reached. +// +// Use it for simple local predicates only. Do not use assertions or side effects; use [Await] instead. +func (s *Suite[T]) AwaitTrue(fn func() bool, timeout, interval time.Duration) { + s.AwaitTruef(fn, timeout, interval, "") +} + +// AwaitTruef is like [AwaitTrue] but includes a format string appended to the failure message. +func (s *Suite[T]) AwaitTruef(fn func() bool, timeout, interval time.Duration, msg string, args ...any) { + await.RequireTruef(s.T(), fn, timeout, interval, msg, args...) +} + // Run discovers and runs all exported Test* methods on the given suite in parallel. // // Each method gets its own fresh suite instance initialized for the subtest's @@ -81,7 +134,7 @@ func Run[T testingSuite](t *testing.T, s T, args ...any) { t.Helper() typ := reflect.TypeFor[T]() - if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct { + if typ.Kind() != reflect.Pointer || typ.Elem().Kind() != reflect.Struct { panic(fmt.Sprintf("parallelsuite.Run: suite must be a pointer to a struct, got %v", typ)) } structType := typ.Elem() @@ -109,7 +162,7 @@ func Run[T testingSuite](t *testing.T, s T, args ...any) { t.Run(method.Name, func(t *testing.T) { t.Parallel() - cpS := s.copySuite(t) + cpS := s.copySuite(t, nil, nil) callArgs := append([]reflect.Value{reflect.ValueOf(cpS)}, argVals...) method.Func.Call(callArgs) }) @@ -122,8 +175,8 @@ func init() { type ds struct{ Suite[*ds] } ptrType := reflect.TypeFor[*ds]() inheritedMethods = make(map[string]bool, ptrType.NumMethod()) - for i := range ptrType.NumMethod() { - inheritedMethods[ptrType.Method(i).Name] = true + for method := range ptrType.Methods() { + inheritedMethods[method.Name] = true } } @@ -178,8 +231,8 @@ func applyTestifyMFilter(methods []reflect.Method) []reflect.Method { func discoverTestMethods(ptrType, structType reflect.Type, args []any) []reflect.Method { expectedNumIn := 1 + len(args) - for i := range ptrType.NumMethod() { - name := ptrType.Method(i).Name + for method := range ptrType.Methods() { + name := method.Name if !strings.HasPrefix(name, "Test") && !inheritedMethods[name] { panic(fmt.Sprintf( "parallelsuite.Run: suite %s has exported method %s that does not start with Test; "+ @@ -190,8 +243,8 @@ func discoverTestMethods(ptrType, structType reflect.Type, args []any) []reflect } var methods []reflect.Method - for i := range ptrType.NumMethod() { - method := ptrType.Method(i) + for method := range ptrType.Methods() { + method := method if !strings.HasPrefix(method.Name, "Test") { continue } diff --git a/common/testing/parallelsuite/suite_test.go b/common/testing/parallelsuite/suite_test.go index f13a6ce8b8a..ebf49e66a52 100644 --- a/common/testing/parallelsuite/suite_test.go +++ b/common/testing/parallelsuite/suite_test.go @@ -1,11 +1,15 @@ package parallelsuite import ( + "context" "flag" "reflect" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" + "go.temporal.io/server/common/testing/testcontext" ) type validSuite struct{ Suite[*validSuite] } @@ -48,6 +52,54 @@ type setupTestSuite struct{ Suite[*setupTestSuite] } func (s *setupTestSuite) TestA() {} func (s *setupTestSuite) SetupTest() {} //nolint:unused +type awaitTrueSuite struct{ Suite[*awaitTrueSuite] } + +func (s *awaitTrueSuite) TestAwaitTrue() { + var attempts atomic.Int32 + s.AwaitTrue(func() bool { + attempts.Add(1) + return true + }, time.Second, time.Millisecond) + s.Equal(int32(1), attempts.Load()) +} + +func (s *awaitTrueSuite) TestAwaitTrueFalseRetry() { + var attempts atomic.Int32 + s.AwaitTrue(func() bool { + return attempts.Add(1) == 2 + }, time.Second, time.Millisecond) + s.Equal(int32(2), attempts.Load()) +} + +func (s *awaitTrueSuite) TestAwaitTruef() { + s.AwaitTruef(func() bool { + return true + }, time.Second, time.Millisecond, "condition should pass") +} + +type contextSuite struct{ Suite[*contextSuite] } + +func (s *contextSuite) TestContextHasDeadline() { + deadline, ok := s.Context().Deadline() + s.True(ok) + s.Positive(time.Until(deadline)) +} + +func (s *contextSuite) TestAwaitUsesSuiteContext() { + type key struct{} + + testcontext.New(s.T(), testcontext.WithContextDecorator(key{}, func(ctx context.Context) context.Context { + return context.WithValue(ctx, key{}, "decorated") + })) + + s.Await(func(s *contextSuite) { + s.Equal("decorated", s.Context().Value(key{})) + deadline, ok := s.Context().Deadline() + s.True(ok) + s.Less(time.Until(deadline), 200*time.Millisecond) + }, 100*time.Millisecond, time.Millisecond) +} + type sealAfterRunSuite struct{ Suite[*sealAfterRunSuite] } func (s *sealAfterRunSuite) TestAssertionAfterRun() { @@ -72,6 +124,12 @@ func TestRun_AcceptsSuite(t *testing.T) { t.Run("with args", func(t *testing.T) { require.NotPanics(t, func() { Run(t, &validWithArgsSuite{}, "hello", 42) }) }) + t.Run("await true", func(t *testing.T) { + require.NotPanics(t, func() { Run(t, &awaitTrueSuite{}) }) + }) + t.Run("context", func(t *testing.T) { + require.NotPanics(t, func() { Run(t, &contextSuite{}) }) + }) } func TestRun_RejectsSuite(t *testing.T) { diff --git a/common/testing/testcontext/context.go b/common/testing/testcontext/context.go index 4b881fd7b22..96d3d30e176 100644 --- a/common/testing/testcontext/context.go +++ b/common/testing/testcontext/context.go @@ -14,13 +14,13 @@ const defaultTimeout = 90 * time.Second type contextStore struct { sync.Mutex - byTest map[*testing.T]*contextState + byTest map[testing.TB]*contextState } // testContexts is process-global so repeated helpers in the same test share // one context and one cleanup. var testContexts = contextStore{ - byTest: make(map[*testing.T]*contextState), + byTest: make(map[testing.TB]*contextState), } type config struct { @@ -34,22 +34,22 @@ type contextDecorator struct { decorate func(context.Context) context.Context } -// New returns the test-scoped context for t. The context is canceled when the +// New returns the test-scoped context for tb. The context is canceled when the // test ends or when the configured test timeout expires. // // The first call creates the per-test context and fixes its timeout. Later calls // may add decorators, but an explicit different timeout fails instead of being // silently ignored. -func New(t *testing.T, opts ...Option) context.Context { - t.Helper() +func New(tb testing.TB, opts ...Option) context.Context { + tb.Helper() cfg := config{timeout: effectiveTimeout(0)} for _, opt := range opts { opt(&cfg) } - st := getContextState(t, cfg.timeout) - st.configure(t, cfg) + st := getContextState(tb, cfg.timeout) + st.configure(tb, cfg) return st.context() } @@ -86,55 +86,55 @@ type contextState struct { decorators map[any]struct{} } -func getContextState(t *testing.T, timeout time.Duration) *contextState { - t.Helper() +func getContextState(tb testing.TB, timeout time.Duration) *contextState { + tb.Helper() testContexts.Lock() defer testContexts.Unlock() - if st, ok := testContexts.byTest[t]; ok { + if st, ok := testContexts.byTest[tb]; ok { return st } - ctx, cancel := context.WithTimeout(t.Context(), timeout) + ctx, cancel := context.WithTimeout(tb.Context(), timeout) st := &contextState{ ctx: ctx, cancel: cancel, timeout: timeout, decorators: make(map[any]struct{}), } - testContexts.byTest[t] = st + testContexts.byTest[tb] = st - t.Cleanup(func() { + tb.Cleanup(func() { st.cancel() testContexts.Lock() - delete(testContexts.byTest, t) + delete(testContexts.byTest, tb) testContexts.Unlock() if st.err() == context.DeadlineExceeded { - t.Errorf("Test exceeded timeout of %v", st.timeout) + tb.Errorf("Test exceeded timeout of %v", st.timeout) } }) return st } -func (s *contextState) configure(t *testing.T, cfg config) { - t.Helper() +func (s *contextState) configure(tb testing.TB, cfg config) { + tb.Helper() s.mu.Lock() defer s.mu.Unlock() if cfg.timeoutSet && cfg.timeout != s.timeout { - t.Fatalf("testcontext: test context already exists with timeout %v; cannot change it to %v", s.timeout, cfg.timeout) + tb.Fatalf("testcontext: test context already exists with timeout %v; cannot change it to %v", s.timeout, cfg.timeout) } // Decorators may be registered by independent helpers, so apply each keyed // decorator at most once while preserving call order. for _, decorator := range cfg.decorators { if decorator.key == nil { - t.Fatal("testcontext: context decorator key must not be nil") + tb.Fatal("testcontext: context decorator key must not be nil") } if decorator.decorate == nil { - t.Fatal("testcontext: context decorator must not be nil") + tb.Fatal("testcontext: context decorator must not be nil") } if _, ok := s.decorators[decorator.key]; ok { continue diff --git a/docs/development/testing.md b/docs/development/testing.md index ca74abbd93c..a918ee3ed0f 100644 --- a/docs/development/testing.md +++ b/docs/development/testing.md @@ -41,6 +41,33 @@ Always use `require.X` (and `protorequire.X`) instead of `assert.X` (and `protoa `assert` records a failure but lets the test continue, which often leads to confusing cascading errors. +### Polling with await.Require + +For polling/retry loops in tests, use `await.Require` (or `await.Requiref`) +from `common/testing/await` instead of testify's `EventuallyWithT`. + +Use `t.Context()` inside the callback for a context derived from the parent +context and canceled when the parent context is canceled or the await timeout +expires. + +```go +await.Require(ctx, t, func(t *await.T) { + resp, err := client.GetStatus(t.Context()) + require.NoError(t, err) + require.Equal(t, "ready", resp.Status) +}, 5*time.Second, 200*time.Millisecond) +``` + +Use `RequireTrue` instead of testify's `Eventually` for simple local bool-returning predicates. + +```go +await.RequireTrue(t, func() bool { + return cache.Ready() +}, 5*time.Second, 200*time.Millisecond) +``` + +`RequireTrue` is the wrong tool when dealing with errors or assertions; use `Require` instead. + ### Parallelization All tests (and subtests!) should use `t.Parallel()` to be run concurrently; @@ -60,6 +87,18 @@ and provides assertion helpers and safety mechanisms. It replaces all use of `testify`'s `Suite`. +#### Await shorthand + +```go +s.Await(func(s *MySuite) { + resp, err := client.GetStatus(s.Context()) + s.NoError(err) + s.Equal("ready", resp.Status) +}, 5*time.Second, 200*time.Millisecond) +``` + +Inside an `s.Await` callback, `s.Context()` is capped to that await's timeout. + ### testvars package Instead of creating identifiers like task queue name, namespace or worker identity by hand, diff --git a/tests/premature_eos_test.go b/tests/premature_eos_test.go index ee8258c9294..b58f290273d 100644 --- a/tests/premature_eos_test.go +++ b/tests/premature_eos_test.go @@ -92,13 +92,15 @@ func (s *PrematureEosTestSuite) Test_SpeculativeWFTEventsLostAfterSignalMidHisto // Without this wait there is a race: if the update hasn't been processed yet, the signal // would only add event 8 (SignalReceived) with freshNextEventId=9, producing 8 events // instead of the expected 9 and causing a false test failure. - s.Eventually(func() bool { - desc, descErr := env.FrontendClient().DescribeWorkflowExecution(testcore.NewContext(), + s.Awaitf(func(s *PrematureEosTestSuite) { + desc, descErr := env.FrontendClient().DescribeWorkflowExecution(s.Context(), &workflowservice.DescribeWorkflowExecutionRequest{ Namespace: env.Namespace().String(), Execution: wfExecution, }) - return descErr == nil && desc.GetPendingWorkflowTask() != nil + s.NoError(descErr) + s.NotNil(desc) + s.NotNil(desc.GetPendingWorkflowTask()) }, 5*time.Second, 250*time.Millisecond, "speculative WFT should be scheduled after sending update") // Fetch page 1 via GetWorkflowExecutionHistory — mimicking what the SDK does when a diff --git a/tests/query_workflow_test.go b/tests/query_workflow_test.go index 76373b52b94..35f28662ee8 100644 --- a/tests/query_workflow_test.go +++ b/tests/query_workflow_test.go @@ -305,7 +305,7 @@ func (s *QueryWorkflowSuite) TestQueryWorkflow_QueryFailedWorkflowTask() { s.NotNil(workflowRun) s.NotEmpty(workflowRun.GetRunID()) - s.Eventually(func() bool { + s.AwaitTrue(func() bool { // wait for workflow task to fail 3 times return atomic.LoadInt32(&failures) >= 3 }, 10*time.Second, 50*time.Millisecond)