From 22bd9020675b29b727955da53da5fc5d1e1b8c53 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 21 Oct 2025 03:53:14 +0000 Subject: [PATCH 1/2] chore: add offset-based pagination support to aibridge list endpoint --- coderd/apidoc/docs.go | 11 ++- coderd/apidoc/swagger.json | 11 ++- coderd/database/dbauthz/dbauthz.go | 15 ++++ coderd/database/dbauthz/dbauthz_test.go | 14 ++++ coderd/database/dbmetrics/querymetrics.go | 14 ++++ coderd/database/dbmock/dbmock.go | 30 ++++++++ coderd/database/modelqueries.go | 41 +++++++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 60 ++++++++++++++- coderd/database/queries/aibridge.sql | 35 +++++++++ coderd/searchquery/search.go | 4 +- codersdk/aibridge.go | 1 + docs/reference/api/aibridge.md | 6 +- docs/reference/api/schemas.md | 4 +- enterprise/coderd/aibridge.go | 28 +++++-- enterprise/coderd/aibridge_test.go | 89 ++++++++++++++--------- site/src/api/typesGenerated.ts | 1 + 17 files changed, 317 insertions(+), 48 deletions(-) diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index f6a15f4c44d51..bf205b3cfd11c 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -115,9 +115,15 @@ const docTemplate = `{ }, { "type": "string", - "description": "Cursor pagination after ID", + "description": "Cursor pagination after ID (cannot be used with offset)", "name": "after_id", "in": "query" + }, + { + "type": "integer", + "description": "Offset pagination (cannot be used with after_id)", + "name": "offset", + "in": "query" } ], "responses": { @@ -11595,6 +11601,9 @@ const docTemplate = `{ "items": { "$ref": "#/definitions/codersdk.AIBridgeInterception" } + }, + "total": { + "type": "integer" } } }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 780f89d876f28..ba388b16cd9fd 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -91,9 +91,15 @@ }, { "type": "string", - "description": "Cursor pagination after ID", + "description": "Cursor pagination after ID (cannot be used with offset)", "name": "after_id", "in": "query" + }, + { + "type": "integer", + "description": "Offset pagination (cannot be used with after_id)", + "name": "offset", + "in": "query" } ], "responses": { @@ -10307,6 +10313,9 @@ "items": { "$ref": "#/definitions/codersdk.AIBridgeInterception" } + }, + "total": { + "type": "integer" } } }, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index e7899eee30a7c..ccd6151e88bbd 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1436,6 +1436,14 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error { return q.db.CleanTailnetTunnels(ctx) } +func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) + if err != nil { + return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep) +} + func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { // Shortcut if the user is an owner. The SQL filter is noticeable, // and this is an easy win for owners. Which is the common case. @@ -5868,3 +5876,10 @@ func (q *querier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg d // database.Store interface, so dbauthz needs to implement it. return q.ListAIBridgeInterceptions(ctx, arg) } + +func (q *querier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, _ rbac.PreparedAuthorized) (int64, error) { + // TODO: Delete this function, all CountAIBridgeInterceptions should be authorized. For now just call CountAIBridgeInterceptions on the authz querier. + // This cannot be deleted for now because it's included in the + // database.Store interface, so dbauthz needs to implement it. + return q.CountAIBridgeInterceptions(ctx, arg) +} diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 01e19237295ee..1137aa250d789 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -4549,6 +4549,20 @@ func (s *MethodTestSuite) TestAIBridge() { check.Args(params, emptyPreparedAuthorized{}).Asserts() })) + s.Run("CountAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.CountAIBridgeInterceptionsParams{} + db.EXPECT().CountAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params).Asserts() + })) + + s.Run("CountAuthorizedAIBridgeInterceptions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.CountAIBridgeInterceptionsParams{} + db.EXPECT().CountAuthorizedAIBridgeInterceptions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params, emptyPreparedAuthorized{}).Asserts() + })) + s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { ids := []uuid.UUID{{1}} db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index d5995222c3d4d..cdcc675c918ae 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -187,6 +187,13 @@ func (m queryMetricsStore) CleanTailnetTunnels(ctx context.Context) error { return r0 } +func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg) + m.queryLatencies.WithLabelValues("CountAIBridgeInterceptions").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { start := time.Now() r0, r1 := m.s.CountAuditLogs(ctx, arg) @@ -3721,3 +3728,10 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeInterceptions(ctx context.Conte m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeInterceptions").Observe(time.Since(start).Seconds()) return r0, r1 } + +func (m queryMetricsStore) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeInterceptions").Observe(time.Since(start).Seconds()) + return r0, r1 +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 51a06fa35e00c..89099b6e7142f 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -248,6 +248,21 @@ func (mr *MockStoreMockRecorder) CleanTailnetTunnels(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanTailnetTunnels", reflect.TypeOf((*MockStore)(nil).CleanTailnetTunnels), ctx) } +// CountAIBridgeInterceptions mocks base method. +func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAIBridgeInterceptions", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAIBridgeInterceptions indicates an expected call of CountAIBridgeInterceptions. +func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg) +} + // CountAuditLogs mocks base method. func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { m.ctrl.T.Helper() @@ -263,6 +278,21 @@ func (mr *MockStoreMockRecorder) CountAuditLogs(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuditLogs", reflect.TypeOf((*MockStore)(nil).CountAuditLogs), ctx, arg) } +// CountAuthorizedAIBridgeInterceptions mocks base method. +func (m *MockStore) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeInterceptions", ctx, arg, prepared) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAuthorizedAIBridgeInterceptions indicates an expected call of CountAuthorizedAIBridgeInterceptions. +func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared) +} + // CountAuthorizedAuditLogs mocks base method. func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 8e2d74fb8ffec..654dfb8b37b1d 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -764,6 +764,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun type aibridgeQuerier interface { ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]AIBridgeInterception, error) + CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) } func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]AIBridgeInterception, error) { @@ -786,6 +787,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar arg.Provider, arg.Model, arg.AfterID, + arg.Offset, arg.Limit, ) if err != nil { @@ -816,6 +818,45 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar return items, nil } +func (q *sqlQuerier) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.AIBridgeInterceptionConverter(), + }) + if err != nil { + return 0, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(countAIBridgeInterceptions, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return 0, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeInterceptions :one\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + ) + if err != nil { + return 0, err + } + defer rows.Close() + var count int64 + for rows.Next() { + if err := rows.Scan(&count); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if err := rows.Err(); err != nil { + return 0, err + } + return count, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3ea3c912f8ccf..3e2d655f44495 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -65,6 +65,7 @@ type sqlcQuerier interface { CleanTailnetCoordinators(ctx context.Context) error CleanTailnetLostPeers(ctx context.Context) error CleanTailnetTunnels(ctx context.Context) error + CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) // CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f62066954daf8..c0fd6eb64dc6d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -111,6 +111,61 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump return err } +const countAIBridgeInterceptions = `-- name: CountAIBridgeInterceptions :one +SELECT + COUNT(*) +FROM + aibridge_interceptions +WHERE + -- Filter by time frame + CASE + WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz + ELSE true + END + AND CASE + WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $2::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $3::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text + ELSE true + END + -- Filter model + AND CASE + WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text + ELSE true + END + -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeInterceptions + -- @authorize_filter +` + +type CountAIBridgeInterceptionsParams struct { + StartedAfter time.Time `db:"started_after" json:"started_after"` + StartedBefore time.Time `db:"started_before" json:"started_before"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` +} + +func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countAIBridgeInterceptions, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one SELECT id, initiator_id, provider, model, started_at, metadata @@ -522,7 +577,8 @@ WHERE ORDER BY aibridge_interceptions.started_at DESC, aibridge_interceptions.id DESC -LIMIT COALESCE(NULLIF($7::integer, 0), 100) +LIMIT COALESCE(NULLIF($8::integer, 0), 100) +OFFSET $7 ` type ListAIBridgeInterceptionsParams struct { @@ -532,6 +588,7 @@ type ListAIBridgeInterceptionsParams struct { Provider string `db:"provider" json:"provider"` Model string `db:"model" json:"model"` AfterID uuid.UUID `db:"after_id" json:"after_id"` + Offset int32 `db:"offset_" json:"offset_"` Limit int32 `db:"limit_" json:"limit_"` } @@ -543,6 +600,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr arg.Provider, arg.Model, arg.AfterID, + arg.Offset, arg.Limit, ) if err != nil { diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 79d41defd54af..32e0246447ce7 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -75,6 +75,40 @@ ORDER BY created_at ASC, id ASC; +-- name: CountAIBridgeInterceptions :one +SELECT + COUNT(*) +FROM + aibridge_interceptions +WHERE + -- Filter by time frame + CASE + WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz + ELSE true + END + AND CASE + WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text + ELSE true + END + -- Filter model + AND CASE + WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text + ELSE true + END + -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeInterceptions + -- @authorize_filter +; + -- name: ListAIBridgeInterceptions :many SELECT * @@ -127,6 +161,7 @@ ORDER BY aibridge_interceptions.started_at DESC, aibridge_interceptions.id DESC LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100) +OFFSET @offset_ ; -- name: ListAIBridgeTokenUsagesByInterceptionIDs :many diff --git a/coderd/searchquery/search.go b/coderd/searchquery/search.go index d07203a0293c0..3b34edacef28f 100644 --- a/coderd/searchquery/search.go +++ b/coderd/searchquery/search.go @@ -355,6 +355,8 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string, AfterID: page.AfterID, // #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range Limit: int32(page.Limit), + // #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range + Offset: int32(page.Offset), } if query == "" { @@ -363,7 +365,7 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string, values, errors := searchTerms(query, func(term string, values url.Values) error { // Default to the initiating user - values.Add("user", term) + values.Add("initiator", term) return nil }) if len(errors) > 0 { diff --git a/codersdk/aibridge.go b/codersdk/aibridge.go index 3101dab383ad1..05f1e1a6fc936 100644 --- a/codersdk/aibridge.go +++ b/codersdk/aibridge.go @@ -56,6 +56,7 @@ type AIBridgeToolUsage struct { } type AIBridgeListInterceptionsResponse struct { + Total int64 `json:"total"` Results []AIBridgeInterception `json:"results"` } diff --git a/docs/reference/api/aibridge.md b/docs/reference/api/aibridge.md index 6c929a3fa9383..35af663993f66 100644 --- a/docs/reference/api/aibridge.md +++ b/docs/reference/api/aibridge.md @@ -19,7 +19,8 @@ curl -X GET http://coder-server:8080/api/v2/api/experimental/aibridge/intercepti |------------|-------|---------|----------|------------------------------------------------------------------------------------------------------------------------| | `q` | query | string | false | Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before. | | `limit` | query | integer | false | Page limit | -| `after_id` | query | string | false | Cursor pagination after ID | +| `after_id` | query | string | false | Cursor pagination after ID (cannot be used with offset) | +| `offset` | query | integer | false | Offset pagination (cannot be used with after_id) | ### Example responses @@ -83,7 +84,8 @@ curl -X GET http://coder-server:8080/api/v2/api/experimental/aibridge/intercepti } ] } - ] + ], + "total": 0 } ``` diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 440f2dca0a59a..91a894528ea02 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -566,7 +566,8 @@ } ] } - ] + ], + "total": 0 } ``` @@ -575,6 +576,7 @@ | Name | Type | Required | Restrictions | Description | |-----------|-------------------------------------------------------------------------|----------|--------------|-------------| | `results` | array of [codersdk.AIBridgeInterception](#codersdkaibridgeinterception) | false | | | +| `total` | integer | false | | | ## codersdk.AIBridgeOpenAIConfig diff --git a/enterprise/coderd/aibridge.go b/enterprise/coderd/aibridge.go index 2917603c235d6..878e8ac74163e 100644 --- a/enterprise/coderd/aibridge.go +++ b/enterprise/coderd/aibridge.go @@ -33,7 +33,8 @@ const ( // @Tags AIBridge // @Param q query string false "Search query in the format `key:value`. Available keys are: initiator, provider, model, started_after, started_before." // @Param limit query int false "Page limit" -// @Param after_id query string false "Cursor pagination after ID" +// @Param after_id query string false "Cursor pagination after ID (cannot be used with offset)" +// @Param offset query int false "Offset pagination (cannot be used with after_id)" // @Success 200 {object} codersdk.AIBridgeListInterceptionsResponse // @Router /api/experimental/aibridge/interceptions [get] func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Request) { @@ -44,10 +45,10 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques if !ok { return } - if page.Offset != 0 { + if page.AfterID != uuid.Nil && page.Offset != 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Offset pagination is not supported.", - Detail: "Offset pagination is not supported for AIBridge interceptions. Use cursor pagination instead with after_id.", + Message: "Query parameters have invalid values.", + Detail: "Cannot use both after_id and offset pagination in the same request.", }) return } @@ -72,7 +73,10 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques return } - var rows []database.AIBridgeInterception + var ( + count int64 + rows []database.AIBridgeInterception + ) err := api.Database.InTx(func(db database.Store) error { // Ensure the after_id interception exists and is visible to the user. if page.AfterID != uuid.Nil { @@ -83,6 +87,19 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques } var err error + // Get the full count of authorized interceptions matching the filter + // for pagination purposes. + count, err = db.CountAIBridgeInterceptions(ctx, database.CountAIBridgeInterceptionsParams{ + StartedAfter: filter.StartedAfter, + StartedBefore: filter.StartedBefore, + InitiatorID: filter.InitiatorID, + Provider: filter.Provider, + Model: filter.Model, + }) + if err != nil { + return xerrors.Errorf("count authorized aibridge interceptions: %w", err) + } + // This only returns authorized interceptions (when using dbauthz). rows, err = db.ListAIBridgeInterceptions(ctx, filter) if err != nil { @@ -110,6 +127,7 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques } httpapi.Write(ctx, rw, http.StatusOK, codersdk.AIBridgeListInterceptionsResponse{ + Total: count, Results: items, }) } diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 8babf2324deeb..551e9eb748614 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -196,42 +196,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { allInterceptionIDs = append(allInterceptionIDs, interception.ID) } - // Get all interceptions one by one from the API using cursor - // pagination. - getAllInterceptionsOneByOne := func() []uuid.UUID { - interceptionIDs := []uuid.UUID{} - for { - afterID := uuid.Nil - if len(interceptionIDs) > 0 { - afterID = interceptionIDs[len(interceptionIDs)-1] - } - res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{ - Pagination: codersdk.Pagination{ - AfterID: afterID, - Limit: 1, - }, - }) - require.NoError(t, err) - if len(res.Results) == 0 { - break - } - require.Len(t, res.Results, 1) - interceptionIDs = append(interceptionIDs, res.Results[0].ID) - } - return interceptionIDs - } - - // First attempt: get all interceptions one by one. - gotInterceptionIDs1 := getAllInterceptionsOneByOne() - // We should have all of the interceptions returned: - require.ElementsMatch(t, allInterceptionIDs, gotInterceptionIDs1) - - // Second attempt: get all interceptions one by one again. - gotInterceptionIDs2 := getAllInterceptionsOneByOne() - // They should be returned in the exact same order. - require.Equal(t, gotInterceptionIDs1, gotInterceptionIDs2) - - // Try to get an invalid limit. + // Try to fetch with an invalid limit. res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{ Pagination: codersdk.Pagination{ Limit: 1001, @@ -241,6 +206,55 @@ func TestAIBridgeListInterceptions(t *testing.T) { require.ErrorAs(t, err, &sdkErr) require.Contains(t, sdkErr.Message, "Invalid pagination limit value.") require.Empty(t, res.Results) + + // Iterate over all interceptions using both cursor and offset + // pagination modes. + for _, paginationMode := range []string{"after_id", "offset"} { + t.Run(paginationMode, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + + // Get all interceptions one by one using the given pagination + // mode. + getAllInterceptionsOneByOne := func() []uuid.UUID { + interceptionIDs := []uuid.UUID{} + for { + pagination := codersdk.Pagination{ + Limit: 1, + } + if paginationMode == "after_id" { + if len(interceptionIDs) > 0 { + pagination.AfterID = interceptionIDs[len(interceptionIDs)-1] + } + } else { + pagination.Offset = len(interceptionIDs) + } + res, err := experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{ + Pagination: pagination, + }) + require.NoError(t, err) + if len(res.Results) == 0 { + break + } + require.EqualValues(t, len(allInterceptionIDs), res.Total) + require.Len(t, res.Results, 1) + interceptionIDs = append(interceptionIDs, res.Results[0].ID) + } + return interceptionIDs + } + + // First attempt: get all interceptions one by one. + gotInterceptionIDs1 := getAllInterceptionsOneByOne() + // We should have all of the interceptions returned: + require.ElementsMatch(t, allInterceptionIDs, gotInterceptionIDs1) + + // Second attempt: get all interceptions one by one again. + gotInterceptionIDs2 := getAllInterceptionsOneByOne() + // They should be returned in the exact same order. + require.Equal(t, gotInterceptionIDs1, gotInterceptionIDs2) + }) + } }) t.Run("Authorized", func(t *testing.T) { @@ -276,6 +290,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Admin can see all interceptions. res, err := adminExperimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{}) require.NoError(t, err) + require.EqualValues(t, 2, res.Total) require.Len(t, res.Results, 2) require.Equal(t, i1.ID, res.Results[0].ID) require.Equal(t, i2.ID, res.Results[1].ID) @@ -283,6 +298,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Second user can only see their own interceptions. res, err = secondUserExperimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{}) require.NoError(t, err) + require.EqualValues(t, 1, res.Total) require.Len(t, res.Results, 1) require.Equal(t, i2.ID, res.Results[0].ID) }) @@ -436,6 +452,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) res, err := experimentalClient.AIBridgeListInterceptions(ctx, tc.filter) require.NoError(t, err) + require.EqualValues(t, len(tc.want), res.Total) // We just compare UUID strings for the sake of this test. wantIDs := make([]string, len(tc.want)) for i, r := range tc.want { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e15ccd917d548..732fbb5af073e 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -39,6 +39,7 @@ export interface AIBridgeInterception { // From codersdk/aibridge.go export interface AIBridgeListInterceptionsResponse { + readonly total: number; readonly results: readonly AIBridgeInterception[]; } From da53587a79b7f65c941fb942ea2add8bfc7bec9f Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 21 Oct 2025 11:37:30 +0000 Subject: [PATCH 2/2] pr comments --- enterprise/coderd/aibridge_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 551e9eb748614..4d8cac139de0c 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -207,6 +207,17 @@ func TestAIBridgeListInterceptions(t *testing.T) { require.Contains(t, sdkErr.Message, "Invalid pagination limit value.") require.Empty(t, res.Results) + // Try to fetch with both after_id and offset pagination. + res, err = experimentalClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{ + Pagination: codersdk.Pagination{ + AfterID: allInterceptionIDs[0], + Offset: 1, + }, + }) + require.ErrorAs(t, err, &sdkErr) + require.Contains(t, sdkErr.Message, "Query parameters have invalid values") + require.Contains(t, sdkErr.Detail, "Cannot use both after_id and offset pagination in the same request.") + // Iterate over all interceptions using both cursor and offset // pagination modes. for _, paginationMode := range []string{"after_id", "offset"} {