diff --git a/agent/agent.go b/agent/agent.go index 9a23ec521057a..115735bc69407 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -71,6 +71,8 @@ const ( EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) +var ErrAgentClosing = xerrors.New("agent is closing") + type Options struct { Filesystem afero.Fs LogDir string @@ -401,6 +403,7 @@ func (a *agent) runLoop() { // need to keep retrying up to the hardCtx so that we can send graceful shutdown-related // messages. ctx := a.hardCtx + defer a.logger.Info(ctx, "agent main loop exited") for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") err := a.run() @@ -1348,7 +1351,7 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co a.closeMutex.Unlock() if closing { _ = network.Close() - return xerrors.New("agent is closing") + return xerrors.Errorf("agent closed while creating tailnet: %w", ErrAgentClosing) } } else { // Update the wireguard IPs if the agent ID changed. @@ -1471,7 +1474,7 @@ func (a *agent) trackGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.closing { - return xerrors.New("track conn goroutine: agent is closing") + return xerrors.Errorf("track conn goroutine: %w", ErrAgentClosing) } a.closeWaitGroup.Add(1) go func() { @@ -2152,16 +2155,7 @@ func (a *apiConnRoutineManager) startAgentAPI( a.eg.Go(func() error { logger.Debug(ctx, "starting agent routine") err := f(ctx, a.aAPI) - if xerrors.Is(err, context.Canceled) && ctx.Err() != nil { - logger.Debug(ctx, "swallowing context canceled") - // Don't propagate context canceled errors to the error group, because we don't want the - // graceful context being canceled to halt the work of routines with - // gracefulShutdownBehaviorRemain. Note that we check both that the error is - // context.Canceled and that *our* context is currently canceled, because when Coderd - // unilaterally closes the API connection (for example if the build is outdated), it can - // sometimes show up as context.Canceled in our RPC calls. - return nil - } + err = shouldPropagateError(ctx, logger, err) logger.Debug(ctx, "routine exited", slog.Error(err)) if err != nil { return xerrors.Errorf("error in routine %s: %w", name, err) @@ -2189,21 +2183,7 @@ func (a *apiConnRoutineManager) startTailnetAPI( a.eg.Go(func() error { logger.Debug(ctx, "starting tailnet routine") err := f(ctx, a.tAPI) - if (xerrors.Is(err, context.Canceled) || - xerrors.Is(err, io.EOF)) && - ctx.Err() != nil { - logger.Debug(ctx, "swallowing error because context is canceled", slog.Error(err)) - // Don't propagate context canceled errors to the error group, because we don't want the - // graceful context being canceled to halt the work of routines with - // gracefulShutdownBehaviorRemain. Unfortunately, the dRPC library closes the stream - // when context is canceled on an RPC, so canceling the context can also show up as - // io.EOF. Also, when Coderd unilaterally closes the API connection (for example if the - // build is outdated), it can sometimes show up as context.Canceled in our RPC calls. - // We can't reliably distinguish between a context cancelation and a legit EOF, so we - // also check that *our* context is currently canceled. If it is, we can safely ignore - // the error. - return nil - } + err = shouldPropagateError(ctx, logger, err) logger.Debug(ctx, "routine exited", slog.Error(err)) if err != nil { return xerrors.Errorf("error in routine %s: %w", name, err) @@ -2212,6 +2192,34 @@ func (a *apiConnRoutineManager) startTailnetAPI( }) } +// shouldPropagateError decides whether an error from an API connection routine should be propagated to the +// apiConnRoutineManager. Its purpose is to prevent errors related to shutting down from propagating to the manager's +// error group, which will tear down the API connection and potentially stop graceful shutdown from succeeding. +func shouldPropagateError(ctx context.Context, logger slog.Logger, err error) error { + if (xerrors.Is(err, context.Canceled) || + xerrors.Is(err, io.EOF)) && + ctx.Err() != nil { + logger.Debug(ctx, "swallowing error because context is canceled", slog.Error(err)) + // Don't propagate context canceled errors to the error group, because we don't want the + // graceful context being canceled to halt the work of routines with + // gracefulShutdownBehaviorRemain. Unfortunately, the dRPC library closes the stream + // when context is canceled on an RPC, so canceling the context can also show up as + // io.EOF. Also, when Coderd unilaterally closes the API connection (for example if the + // build is outdated), it can sometimes show up as context.Canceled in our RPC calls. + // We can't reliably distinguish between a context cancelation and a legit EOF, so we + // also check that *our* context is currently canceled. If it is, we can safely ignore + // the error. + return nil + } + if xerrors.Is(err, ErrAgentClosing) { + logger.Debug(ctx, "swallowing error because agent is closing") + // This can only be generated when the agent is closing, so we never want it to propagate to other routines. + // (They are signaled to exit via canceled contexts.) + return nil + } + return err +} + func (a *apiConnRoutineManager) wait() error { return a.eg.Wait() } diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 37d5e6d3b7a85..3046a22d89ea7 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -359,7 +359,16 @@ func (m *agentConnectionMonitor) start(ctx context.Context) { } func (m *agentConnectionMonitor) monitor(ctx context.Context) { + reason := "disconnect" defer func() { + m.logger.Debug(ctx, "agent connection monitor is closing connection", + slog.F("reason", reason)) + _ = m.conn.Close(websocket.StatusGoingAway, reason) + m.disconnectedAt = sql.NullTime{ + Time: dbtime.Now(), + Valid: true, + } + // If connection closed then context will be canceled, try to // ensure our final update is sent. By waiting at most the agent // inactive disconnect timeout we ensure that we don't block but @@ -372,13 +381,6 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { finalCtx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(m.apiCtx), m.disconnectTimeout) defer cancel() - // Only update timestamp if the disconnect is new. - if !m.disconnectedAt.Valid { - m.disconnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } err := m.updateConnectionTimes(finalCtx) if err != nil { // This is a bug with unit tests that cancel the app context and @@ -398,12 +400,6 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { AgentID: &m.workspaceAgent.ID, }) }() - reason := "disconnect" - defer func() { - m.logger.Debug(ctx, "agent connection monitor is closing connection", - slog.F("reason", reason)) - _ = m.conn.Close(websocket.StatusGoingAway, reason) - }() err := m.updateConnectionTimes(ctx) if err != nil { @@ -432,8 +428,7 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { m.logger.Warn(ctx, "connection to agent timed out") return } - connectionStatusChanged := m.disconnectedAt.Valid - m.disconnectedAt = sql.NullTime{} + m.lastConnectedAt = sql.NullTime{ Time: dbtime.Now(), Valid: true, @@ -447,13 +442,9 @@ func (m *agentConnectionMonitor) monitor(ctx context.Context) { } return } - if connectionStatusChanged { - m.updater.publishWorkspaceUpdate(ctx, m.workspace.OwnerID, wspubsub.WorkspaceEvent{ - Kind: wspubsub.WorkspaceEventKindAgentConnectionUpdate, - WorkspaceID: m.workspaceBuild.WorkspaceID, - AgentID: &m.workspaceAgent.ID, - }) - } + // we don't need to publish a workspace update here because we published an update when the workspace first + // connected. Since all we've done is updated lastConnectedAt, the workspace is still connected and hasn't + // changed status. We don't expect to get updates just for the times changing. ctx, err := dbauthz.WithWorkspaceRBAC(ctx, m.workspace.RBACObject()) if err != nil { diff --git a/coderd/workspaceagentsrpc_internal_test.go b/coderd/workspaceagentsrpc_internal_test.go index 5c254b41fe64c..88d08bc4e32fc 100644 --- a/coderd/workspaceagentsrpc_internal_test.go +++ b/coderd/workspaceagentsrpc_internal_test.go @@ -23,76 +23,107 @@ import ( func TestAgentConnectionMonitor_ContextCancel(t *testing.T) { t.Parallel() - ctx := testutil.Context(t, testutil.WaitShort) now := dbtime.Now() - fConn := &fakePingerCloser{} - ctrl := gomock.NewController(t) - mDB := dbmock.NewMockStore(ctrl) - fUpdater := &fakeUpdater{} - logger := testutil.Logger(t) - agent := database.WorkspaceAgent{ - ID: uuid.New(), - FirstConnectedAt: sql.NullTime{ - Time: now.Add(-time.Minute), - Valid: true, + agentID := uuid.UUID{1} + replicaID := uuid.UUID{2} + testCases := []struct { + name string + agent database.WorkspaceAgent + initialMatcher connectionUpdateMatcher + }{ + { + name: "no disconnected at", + agent: database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + }, + initialMatcher: connectionUpdate(agentID, replicaID), + }, + { + name: "disconnected at", + agent: database.WorkspaceAgent{ + ID: agentID, + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + DisconnectedAt: sql.NullTime{ + Time: now.Add(-2 * time.Minute), + Valid: true, + }, + }, + initialMatcher: connectionUpdate(agentID, replicaID, withDisconnectedAt(now.Add(-2*time.Minute))), }, } - build := database.WorkspaceBuild{ - ID: uuid.New(), - WorkspaceID: uuid.New(), + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + fConn := &fakePingerCloser{} + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + fUpdater := &fakeUpdater{} + logger := testutil.Logger(t) + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: uuid.New(), + } + + uut := &agentConnectionMonitor{ + apiCtx: ctx, + workspaceAgent: tc.agent, + workspaceBuild: build, + conn: fConn, + db: mDB, + replicaID: replicaID, + updater: fUpdater, + logger: logger, + pingPeriod: testutil.IntervalFast, + disconnectTimeout: testutil.WaitShort, + } + uut.init() + + connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + tc.initialMatcher, + ). + AnyTimes(). + Return(nil) + mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agentID, replicaID, withDisconnectedAfter(now)), + ). + After(connected). + Times(1). + Return(nil) + mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). + AnyTimes(). + Return(database.WorkspaceBuild{ID: build.ID}, nil) + + closeCtx, cancel := context.WithCancel(ctx) + defer cancel() + done := make(chan struct{}) + go func() { + uut.monitor(closeCtx) + close(done) + }() + // wait a couple intervals, but not long enough for a disconnect + time.Sleep(3 * testutil.IntervalFast) + fConn.requireNotClosed(t) + fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) + n := fUpdater.getUpdates() + cancel() + fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "canceled") + + // make sure we got at least one additional update on close + _ = testutil.TryReceive(ctx, t, done) + m := fUpdater.getUpdates() + require.Greater(t, m, n) + }) } - replicaID := uuid.New() - - uut := &agentConnectionMonitor{ - apiCtx: ctx, - workspaceAgent: agent, - workspaceBuild: build, - conn: fConn, - db: mDB, - replicaID: replicaID, - updater: fUpdater, - logger: logger, - pingPeriod: testutil.IntervalFast, - disconnectTimeout: testutil.WaitShort, - } - uut.init() - - connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( - gomock.Any(), - connectionUpdate(agent.ID, replicaID), - ). - AnyTimes(). - Return(nil) - mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( - gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), - ). - After(connected). - Times(1). - Return(nil) - mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). - AnyTimes(). - Return(database.WorkspaceBuild{ID: build.ID}, nil) - - closeCtx, cancel := context.WithCancel(ctx) - defer cancel() - done := make(chan struct{}) - go func() { - uut.monitor(closeCtx) - close(done) - }() - // wait a couple intervals, but not long enough for a disconnect - time.Sleep(3 * testutil.IntervalFast) - fConn.requireNotClosed(t) - fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) - n := fUpdater.getUpdates() - cancel() - fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "canceled") - - // make sure we got at least one additional update on close - _ = testutil.TryReceive(ctx, t, done) - m := fUpdater.getUpdates() - require.Greater(t, m, n) } func TestAgentConnectionMonitor_PingTimeout(t *testing.T) { @@ -141,7 +172,7 @@ func TestAgentConnectionMonitor_PingTimeout(t *testing.T) { Return(nil) mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), + connectionUpdate(agent.ID, replicaID, withDisconnectedAfter(now)), ). After(connected). Times(1). @@ -204,7 +235,7 @@ func TestAgentConnectionMonitor_BuildOutdated(t *testing.T) { Return(nil) mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), + connectionUpdate(agent.ID, replicaID, withDisconnectedAfter(now)), ). After(connected). Times(1). @@ -289,7 +320,7 @@ func TestAgentConnectionMonitor_StartClose(t *testing.T) { Return(nil) mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( gomock.Any(), - connectionUpdate(agent.ID, replicaID, withDisconnected()), + connectionUpdate(agent.ID, replicaID, withDisconnectedAfter(now)), ). After(connected). Times(1). @@ -392,9 +423,10 @@ func (f *fakeUpdater) getUpdates() int { } type connectionUpdateMatcher struct { - agentID uuid.UUID - replicaID uuid.UUID - disconnected bool + agentID uuid.UUID + replicaID uuid.UUID + disconnectedAt sql.NullTime + disconnectedAfter sql.NullTime } type connectionUpdateMatcherOption func(m connectionUpdateMatcher) connectionUpdateMatcher @@ -410,9 +442,22 @@ func connectionUpdate(id, replica uuid.UUID, opts ...connectionUpdateMatcherOpti return m } -func withDisconnected() connectionUpdateMatcherOption { +func withDisconnectedAfter(t time.Time) connectionUpdateMatcherOption { return func(m connectionUpdateMatcher) connectionUpdateMatcher { - m.disconnected = true + m.disconnectedAfter = sql.NullTime{ + Valid: true, + Time: t, + } + return m + } +} + +func withDisconnectedAt(t time.Time) connectionUpdateMatcherOption { + return func(m connectionUpdateMatcher) connectionUpdateMatcher { + m.disconnectedAt = sql.NullTime{ + Valid: true, + Time: t, + } return m } } @@ -431,15 +476,23 @@ func (m connectionUpdateMatcher) Matches(x interface{}) bool { if args.LastConnectedReplicaID.UUID != m.replicaID { return false } - if args.DisconnectedAt.Valid != m.disconnected { + if m.disconnectedAfter.Valid { + if !args.DisconnectedAt.Valid { + return false + } + if !args.DisconnectedAt.Time.After(m.disconnectedAfter.Time) { + return false + } + // disconnectedAfter takes precedence over disconnectedAt + } else if args.DisconnectedAt != m.disconnectedAt { return false } return true } func (m connectionUpdateMatcher) String() string { - return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}", - m.agentID.String(), m.replicaID.String(), m.disconnected) + return fmt.Sprintf("{agent=%s, replica=%s, disconnectedAt=%v, disconnectedAfter=%v}", + m.agentID.String(), m.replicaID.String(), m.disconnectedAt, m.disconnectedAfter) } func (connectionUpdateMatcher) Got(x interface{}) string { @@ -447,6 +500,6 @@ func (connectionUpdateMatcher) Got(x interface{}) string { if !ok { return fmt.Sprintf("type=%T", x) } - return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}", - args.ID, args.LastConnectedReplicaID.UUID, args.DisconnectedAt.Valid) + return fmt.Sprintf("{agent=%s, replica=%s, disconnectedAt=%v}", + args.ID, args.LastConnectedReplicaID.UUID, args.DisconnectedAt) }