diff --git a/testutil/ctx.go b/testutil/ctx.go index acbf14e5bb6c8..ff974f1557cf0 100644 --- a/testutil/ctx.go +++ b/testutil/ctx.go @@ -6,7 +6,21 @@ import ( "time" ) -func Context(t testing.TB, dur time.Duration) context.Context { +// Context returns a context with a timeout that starts on first use and resets +// when accessed from new lines in test files. Each call to Done, Deadline, or +// Err from a new line resets the deadline. +// +// To prevent resets, store the Done channel or wrap with a child context: +// +// done := ctx.Done() +// <-done // Uses stored channel, no reset. +func Context(t testing.TB, timeout time.Duration) context.Context { + return newLazyTimeoutContext(t, timeout) +} + +// ContextFixed returns a context with a timeout that starts immediately and +// does not reset. +func ContextFixed(t testing.TB, dur time.Duration) context.Context { ctx, cancel := context.WithTimeout(context.Background(), dur) t.Cleanup(cancel) return ctx diff --git a/testutil/lazy_ctx.go b/testutil/lazy_ctx.go new file mode 100644 index 0000000000000..4289009323fc4 --- /dev/null +++ b/testutil/lazy_ctx.go @@ -0,0 +1,170 @@ +package testutil + +import ( + "context" + "fmt" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +var _ context.Context = (*lazyTimeoutContext)(nil) + +// lazyTimeoutContext implements context.Context with a timeout that starts on +// first use and resets when accessed from new locations in test files. +type lazyTimeoutContext struct { + t testing.TB + timeout time.Duration + + mu sync.Mutex // Protects following fields. + testDone bool // True after cancel, prevents post-test logging. + deadline time.Time + timer *time.Timer + done chan struct{} + err error + seenLocations map[string]struct{} +} + +func newLazyTimeoutContext(t testing.TB, timeout time.Duration) context.Context { + ctx := &lazyTimeoutContext{ + t: t, + timeout: timeout, + done: make(chan struct{}), + seenLocations: make(map[string]struct{}), + } + t.Cleanup(ctx.cancel) + return ctx +} + +// Deadline returns the current deadline. The deadline is set on first access +// and may be extended when accessed from new locations in test files. +func (c *lazyTimeoutContext) Deadline() (deadline time.Time, ok bool) { + c.maybeResetForLocation() + + c.mu.Lock() + defer c.mu.Unlock() + return c.deadline, true +} + +// Done returns a channel that's closed when the context is canceled. +func (c *lazyTimeoutContext) Done() <-chan struct{} { + c.maybeResetForLocation() + return c.done +} + +// Err returns the error indicating why this context was canceled. +func (c *lazyTimeoutContext) Err() error { + c.maybeResetForLocation() + + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +// Value returns nil. It does not trigger initialization or reset. +func (*lazyTimeoutContext) Value(any) any { + return nil +} + +// maybeResetForLocation starts the timer on first access and resets the +// deadline when called from a previously unseen location in a test file. +func (c *lazyTimeoutContext) maybeResetForLocation() { + loc := callerLocation() + + c.mu.Lock() + defer c.mu.Unlock() + + // Already canceled. + if c.err != nil { + return + } + + // First access, start timer. + if c.timer == nil { + c.startLocked() + if loc != "" { + c.seenLocations[loc] = struct{}{} + } + if testing.Verbose() && !c.testDone { + c.t.Logf("lazyTimeoutContext: started timeout for location: %s", loc) + } + return + } + + // Non-test location, ignore. + if loc == "" { + return + } + + if _, seen := c.seenLocations[loc]; seen { + return + } + c.seenLocations[loc] = struct{}{} + + // New location, reset deadline. + c.deadline = time.Now().Add(c.timeout) + if c.timer.Stop() { + c.timer.Reset(c.timeout) + } + + if testing.Verbose() && !c.testDone { + c.t.Logf("lazyTimeoutContext: reset timeout for new location: %s", loc) + } +} + +// startLocked initializes the deadline and timer. It must be called with mu held. +func (c *lazyTimeoutContext) startLocked() { + c.deadline = time.Now().Add(c.timeout) + c.timer = time.AfterFunc(c.timeout, func() { + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.err = context.DeadlineExceeded + close(c.done) + } + }) +} + +// cancel stops the timer and marks the context as canceled. It is called by +// t.Cleanup when the test ends. +func (c *lazyTimeoutContext) cancel() { + c.mu.Lock() + defer c.mu.Unlock() + c.testDone = true + if c.timer != nil { + c.timer.Stop() + } + if c.err == nil { + c.err = context.Canceled + close(c.done) + } +} + +// callerLocation returns the file:line of the first caller in a _test.go file, +// or the empty string if none is found. +func callerLocation() string { + // Skip runtime.Callers, callerLocation, maybeResetForLocation, and the + // context method (Done/Deadline/Err). + pc := make([]uintptr, 50) + n := runtime.Callers(4, pc) + if n == 0 { + return "" + } + + frames := runtime.CallersFrames(pc[:n]) + for { + frame, more := frames.Next() + + if strings.HasSuffix(frame.File, "_test.go") { + return fmt.Sprintf("%s:%d", frame.File, frame.Line) + } + + if !more { + break + } + } + + return "" +} diff --git a/testutil/lazy_ctx_test.go b/testutil/lazy_ctx_test.go new file mode 100644 index 0000000000000..c8d5d7e8538fb --- /dev/null +++ b/testutil/lazy_ctx_test.go @@ -0,0 +1,200 @@ +package testutil_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/testutil" +) + +func TestLazyTimeoutContext_LazyStart(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 10*time.Millisecond) + + time.Sleep(50 * time.Millisecond) // Longer than timeout. + + // Timer hasn't started, context should be valid. + select { + case <-ctx.Done(): + t.Fatal("context should not be done yet - timer should not have started") + default: + } + + // First select started the timer, wait for expiration. + select { + case <-ctx.Done(): + case <-time.After(50 * time.Millisecond): + t.Fatal("context should have expired") + } +} + +func TestLazyTimeoutContext_ValueDoesNotTriggerStart(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 10*time.Millisecond) + + _ = ctx.Value("key") // Must not start timer. + + time.Sleep(50 * time.Millisecond) + + select { + case <-ctx.Done(): + t.Fatal("Value() should not start timer") + default: + } +} + +func TestLazyTimeoutContext_Expiration(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 5*time.Millisecond) + + done := ctx.Done() // Store to avoid reset in select. + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fatal("context should have expired") + } + + require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded) +} + +func TestLazyTimeoutContext_ResetOnNewLocation(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 50*time.Millisecond) + + done := ctx.Done() // Store to check expiration. + time.Sleep(30 * time.Millisecond) // 60% of timeout. + _ = ctx.Done() // New line, resets timeout. + time.Sleep(30 * time.Millisecond) // 60% again, would be 120% without reset. + + select { + case <-done: + t.Fatal("timeout should have been reset") + default: + } + + select { + case <-done: + case <-time.After(50 * time.Millisecond): + t.Fatal("context should have expired") + } +} + +func TestLazyTimeoutContext_NoResetOnSameLocation(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 50*time.Millisecond) + + var done <-chan struct{} + // Same line, no reset. 5*15ms = 75ms > 50ms timeout. + for i := 0; i < 5; i++ { + done = ctx.Done() + time.Sleep(15 * time.Millisecond) + } + + select { + case <-done: + default: + t.Fatal("context should be done - same location should not reset") + } +} + +func TestLazyTimeoutContext_AlreadyExpiredNoResurrection(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 5*time.Millisecond) + + <-ctx.Done() + require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded) + + _ = ctx.Err() // New location, must not resurrect. + + select { + case <-ctx.Done(): + default: + t.Fatal("expired context should not be resurrected") + } + + require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded) +} + +func TestLazyTimeoutContext_ThreadSafety(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 100*time.Millisecond) + + var wg sync.WaitGroup + const numGoroutines = 10 + // Relies on -race to detect issues. + for i := range numGoroutines { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + _ = ctx.Done() + _, _ = ctx.Deadline() + _ = ctx.Err() + _ = ctx.Value("key") + } + }() + } + + wg.Wait() +} + +func TestLazyTimeoutContext_WithChildContext(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 50*time.Millisecond) + + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + + select { + case <-childCtx.Done(): + t.Fatal("child context should not be done yet") + default: + } + + cancel() + + select { + case <-childCtx.Done(): + case <-time.After(50 * time.Millisecond): + t.Fatal("child context should be done after cancel") + } +} + +func TestLazyTimeoutContext_ErrBeforeExpiration(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, 50*time.Millisecond) + + err := ctx.Err() + assert.NoError(t, err, "Err() should return nil before expiration") +} + +func TestLazyTimeoutContext_DeadlineReturnsCorrectValue(t *testing.T) { + t.Parallel() + + timeout := 50 * time.Millisecond + before := time.Now() + ctx := testutil.Context(t, timeout) + + deadline, ok := ctx.Deadline() + after := time.Now() + + require.True(t, ok, "deadline should be set after Deadline() call") + require.False(t, deadline.IsZero(), "deadline should not be zero") + require.True(t, deadline.After(before.Add(timeout-time.Millisecond))) + require.True(t, deadline.Before(after.Add(timeout+10*time.Millisecond))) +}