diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index dab8583f0..cd0b3f1dd 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,11 +1,44 @@ -Closes: +## Summary + + +## Why + +Fixes # + +## What changed + +- +- + +## MCP impact + +- [ ] No tool or API changes +- [ ] Tool schema or behavior changed +- [ ] New tool added + +## Prompts tested (tool changes only) + + + +- + +## Security / limits + +- [ ] No security or limits impact +- [ ] Auth / permissions considered +- [ ] Data exposure, filtering, or token/size limits considered + +## Lint & tests + +- [ ] Linted locally with `./script/lint` +- [ ] Tested locally with `./script/test` + +## Docs + +- [ ] Not needed +- [ ] Updated (README / docs / examples) diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml new file mode 100644 index 000000000..92524ea17 --- /dev/null +++ b/.github/workflows/conformance.yml @@ -0,0 +1,69 @@ +name: Conformance Test + +on: + pull_request: + +permissions: + contents: read + +jobs: + conformance: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v6 + with: + # Fetch full history to access merge-base + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "go.mod" + + - name: Download dependencies + run: go mod download + + - name: Run conformance test + id: conformance + run: | + # Run conformance test, capture stdout for summary + script/conformance-test > conformance-summary.txt 2>&1 || true + + # Output the summary + cat conformance-summary.txt + + # Check result + if grep -q "RESULT: ALL TESTS PASSED" conformance-summary.txt; then + echo "status=passed" >> $GITHUB_OUTPUT + else + echo "status=differences" >> $GITHUB_OUTPUT + fi + + - name: Generate Job Summary + run: | + # Add the full markdown report to the job summary + echo "# MCP Server Conformance Report" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Comparing PR branch against merge-base with \`origin/main\`" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Extract and append the report content (skip the header since we added our own) + tail -n +5 conformance-report/CONFORMANCE_REPORT.md >> $GITHUB_STEP_SUMMARY + + echo "" >> $GITHUB_STEP_SUMMARY + echo "---" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + # Add interpretation note + if [ "${{ steps.conformance.outputs.status }}" = "passed" ]; then + echo "✅ **All conformance tests passed** - No behavioral differences detected." >> $GITHUB_STEP_SUMMARY + else + echo "⚠️ **Differences detected** - Review the diffs above to ensure changes are intentional." >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Common expected differences:" >> $GITHUB_STEP_SUMMARY + echo "- New tools/toolsets added" >> $GITHUB_STEP_SUMMARY + echo "- Tool descriptions updated" >> $GITHUB_STEP_SUMMARY + echo "- Capability changes (intentional improvements)" >> $GITHUB_STEP_SUMMARY + fi diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index af5fd5bbf..ee63b9a87 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -87,7 +87,7 @@ jobs: type=raw,value=latest,enable=${{ github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') }} - name: Go Build Cache for Docker - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: go-build-cache key: ${{ runner.os }}-go-build-cache-${{ hashFiles('**/go.sum') }} diff --git a/.gitignore b/.gitignore index b018fafac..5684108b0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ bin/ # binary github-mcp-server -.history \ No newline at end of file +.history +conformance-report/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4ad4ece12..6c16cd27d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -39,6 +39,8 @@ These are one time installations required to be able to test your changes locall - Run linter: `script/lint` - Update snapshots and run tests: `UPDATE_TOOLSNAPS=true go test ./...` - Update readme documentation: `script/generate-docs` + - If renaming a tool, add a deprecation alias (see [Tool Renaming Guide](docs/tool-renaming.md)) + - For toolset and icon configuration, see [Toolsets and Icons Guide](docs/toolsets-and-icons.md) 6. Push to your fork and [submit a pull request][pr] targeting the `main` branch 7. Pat yourself on the back and wait for your pull request to be reviewed and merged. diff --git a/README.md b/README.md index 3364d6f8c..ce6eb81cb 100644 --- a/README.md +++ b/README.md @@ -384,6 +384,7 @@ You can also configure specific tools using the `--tools` flag. Tools can be use - Tools, toolsets, and dynamic toolsets can all be used together - Read-only mode takes priority: write tools are skipped if `--read-only` is set, even if explicitly requested via `--tools` - Tool names must match exactly (e.g., `get_file_contents`, not `getFileContents`). Invalid tool names will cause the server to fail at startup with an error message +- When tools are renamed, old names are preserved as aliases for backward compatibility. See [Deprecated Tool Aliases](docs/deprecated-tool-aliases.md) for details. ### Using Toolsets With Docker @@ -452,27 +453,26 @@ GITHUB_TOOLSETS="default,stargazers" ./github-mcp-server The following sets of tools are available: -| Toolset | Description | -| ----------------------- | ------------------------------------------------------------- | -| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in | -| `actions` | GitHub Actions workflows and CI/CD operations | -| `code_security` | Code security related tools, such as GitHub Code Scanning | -| `dependabot` | Dependabot tools | -| `discussions` | GitHub Discussions related tools | -| `experiments` | Experimental features that are not considered stable yet | -| `gists` | GitHub Gist related tools | -| `git` | GitHub Git API related tools for low-level Git operations | -| `issues` | GitHub Issues related tools | -| `labels` | GitHub Labels related tools | -| `notifications` | GitHub Notifications related tools | -| `orgs` | GitHub Organization related tools | -| `projects` | GitHub Projects related tools | -| `pull_requests` | GitHub Pull Request related tools | -| `repos` | GitHub Repository related tools | -| `secret_protection` | Secret protection related tools, such as GitHub Secret Scanning | -| `security_advisories` | Security advisories related tools | -| `stargazers` | GitHub Stargazers related tools | -| `users` | GitHub User related tools | +| | Toolset | Description | +| --- | ----------------------- | ------------------------------------------------------------- | +| person | `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in | +| workflow | `actions` | GitHub Actions workflows and CI/CD operations | +| codescan | `code_security` | Code security related tools, such as GitHub Code Scanning | +| dependabot | `dependabot` | Dependabot tools | +| comment-discussion | `discussions` | GitHub Discussions related tools | +| logo-gist | `gists` | GitHub Gist related tools | +| git-branch | `git` | GitHub Git API related tools for low-level Git operations | +| issue-opened | `issues` | GitHub Issues related tools | +| tag | `labels` | GitHub Labels related tools | +| bell | `notifications` | GitHub Notifications related tools | +| organization | `orgs` | GitHub Organization related tools | +| project | `projects` | GitHub Projects related tools | +| git-pull-request | `pull_requests` | GitHub Pull Request related tools | +| repo | `repos` | GitHub Repository related tools | +| shield-lock | `secret_protection` | Secret protection related tools, such as GitHub Secret Scanning | +| shield | `security_advisories` | Security advisories related tools | +| star | `stargazers` | GitHub Stargazers related tools | +| people | `users` | GitHub User related tools | ### Additional Toolsets in Remote GitHub MCP Server @@ -488,7 +488,41 @@ The following sets of tools are available:
-Actions +workflow Actions + +- **actions_get** - Get details of GitHub Actions resources (workflows, workflow runs, jobs, and artifacts) + - `method`: The method to execute (string, required) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `resource_id`: The unique identifier of the resource. This will vary based on the "method" provided, so ensure you provide the correct ID: + - Provide a workflow ID or workflow file name (e.g. ci.yaml) for 'get_workflow' method. + - Provide a workflow run ID for 'get_workflow_run', 'get_workflow_run_usage', and 'get_workflow_run_logs_url' methods. + - Provide an artifact ID for 'download_workflow_run_artifact' method. + - Provide a job ID for 'get_workflow_job' method. + (string, required) + +- **actions_list** - List GitHub Actions workflows in a repository + - `method`: The action to perform (string, required) + - `owner`: Repository owner (string, required) + - `page`: Page number for pagination (default: 1) (number, optional) + - `per_page`: Results per page for pagination (default: 30, max: 100) (number, optional) + - `repo`: Repository name (string, required) + - `resource_id`: The unique identifier of the resource. This will vary based on the "method" provided, so ensure you provide the correct ID: + - Do not provide any resource ID for 'list_workflows' method. + - Provide a workflow ID or workflow file name (e.g. ci.yaml) for 'list_workflow_runs' method. + - Provide a workflow run ID for 'list_workflow_jobs' and 'list_workflow_run_artifacts' methods. + (string, optional) + - `workflow_jobs_filter`: Filters for workflow jobs. **ONLY** used when method is 'list_workflow_jobs' (object, optional) + - `workflow_runs_filter`: Filters for workflow runs. **ONLY** used when method is 'list_workflow_runs' (object, optional) + +- **actions_run_trigger** - Trigger GitHub Actions workflow actions + - `inputs`: Inputs the workflow accepts. Only used for 'run_workflow' method. (object, optional) + - `method`: The method to execute (string, required) + - `owner`: Repository owner (string, required) + - `ref`: The git reference for the workflow. The reference can be a branch or tag name. Required for 'run_workflow' method. (string, optional) + - `repo`: Repository name (string, required) + - `run_id`: The ID of the workflow run. Required for all methods except 'run_workflow'. (number, optional) + - `workflow_id`: The workflow ID (numeric) or workflow file name (e.g., main.yml, ci.yaml). Required for 'run_workflow' method. (string, optional) - **cancel_workflow_run** - Cancel workflow run - `owner`: Repository owner (string, required) @@ -514,6 +548,15 @@ The following sets of tools are available: - `run_id`: Workflow run ID (required when using failed_only) (number, optional) - `tail_lines`: Number of lines to return from the end of the log (number, optional) +- **get_job_logs** - Get GitHub Actions workflow job logs + - `failed_only`: When true, gets logs for all failed jobs in the workflow run specified by run_id. Requires run_id to be provided. (boolean, optional) + - `job_id`: The unique identifier of the workflow job. Required when getting logs for a single job. (number, optional) + - `owner`: Repository owner (string, required) + - `repo`: Repository name (string, required) + - `return_content`: Returns actual log content instead of URLs (boolean, optional) + - `run_id`: The unique identifier of the workflow run. Required when failed_only is true to get logs for all failed jobs in the run. (number, optional) + - `tail_lines`: Number of lines to return from the end of the log (number, optional) + - **get_workflow_run** - Get workflow run - `owner`: Repository owner (string, required) - `repo`: Repository name (string, required) @@ -582,7 +625,7 @@ The following sets of tools are available:
-Code Security +codescan Code Security - **get_code_scanning_alert** - Get code scanning alert - `alertNumber`: The number of the alert. (number, required) @@ -601,7 +644,7 @@ The following sets of tools are available:
-Context +person Context - **get_me** - Get my user profile - No parameters required @@ -617,7 +660,7 @@ The following sets of tools are available:
-Dependabot +dependabot Dependabot - **get_dependabot_alert** - Get dependabot alert - `alertNumber`: The number of the alert. (number, required) @@ -634,7 +677,7 @@ The following sets of tools are available:
-Discussions +comment-discussion Discussions - **get_discussion** - Get discussion - `discussionNumber`: Discussion Number (number, required) @@ -665,7 +708,7 @@ The following sets of tools are available:
-Gists +logo-gist Gists - **create_gist** - Create Gist - `content`: Content for simple single-file gist creation (string, required) @@ -692,7 +735,7 @@ The following sets of tools are available:
-Git +git-branch Git - **get_repository_tree** - Get repository tree - `owner`: Repository owner (username or organization) (string, required) @@ -705,7 +748,7 @@ The following sets of tools are available:
-Issues +issue-opened Issues - **add_issue_comment** - Add comment to issue - `body`: Comment content (string, required) @@ -798,7 +841,7 @@ The following sets of tools are available:
-Labels +tag Labels - **get_label** - Get a specific label from a repository. - `name`: Label name. (string, required) @@ -822,7 +865,7 @@ The following sets of tools are available:
-Notifications +bell Notifications - **dismiss_notification** - Dismiss notification - `state`: The new state of the notification (read/done) (string, required) @@ -858,7 +901,7 @@ The following sets of tools are available:
-Organizations +organization Organizations - **search_orgs** - Search organizations - `order`: Sort order (string, optional) @@ -871,7 +914,7 @@ The following sets of tools are available:
-Projects +project Projects - **add_project_item** - Add project item - `item_id`: The numeric ID of the issue or pull request to add to the project. (number, required) @@ -941,7 +984,7 @@ The following sets of tools are available:
-Pull Requests +git-pull-request Pull Requests - **add_comment_to_pending_review** - Add review comment to the requester's latest pending pull request review - `body`: The text of the review comment (string, required) @@ -1046,7 +1089,7 @@ The following sets of tools are available:
-Repositories +repo Repositories - **create_branch** - Create branch - `branch`: Name for new branch (string, required) @@ -1061,7 +1104,7 @@ The following sets of tools are available: - `owner`: Repository owner (username or organization) (string, required) - `path`: Path where to create/update the file (string, required) - `repo`: Repository name (string, required) - - `sha`: Required if updating an existing file. The blob SHA of the file being replaced. (string, optional) + - `sha`: The blob SHA of the file being replaced. (string, optional) - **create_repository** - Create repository - `autoInit`: Initialize with README (boolean, optional) @@ -1163,7 +1206,7 @@ The following sets of tools are available:
-Secret Protection +shield-lock Secret Protection - **get_secret_scanning_alert** - Get secret scanning alert - `alertNumber`: The number of the alert. (number, required) @@ -1181,7 +1224,7 @@ The following sets of tools are available:
-Security Advisories +shield Security Advisories - **get_global_security_advisory** - Get a global security advisory - `ghsaId`: GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx). (string, required) @@ -1216,7 +1259,7 @@ The following sets of tools are available:
-Stargazers +star Stargazers - **list_starred_repositories** - List starred repositories - `direction`: The direction to sort the results by. (string, optional) @@ -1237,7 +1280,7 @@ The following sets of tools are available:
-Users +people Users - **search_users** - Search users - `order`: Sort order (string, optional) diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 61459d7f0..b40e3e2f4 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -1,23 +1,17 @@ package main import ( - "context" "fmt" "net/url" "os" - "regexp" "sort" "strings" "github.com/github/github-mcp-server/pkg/github" - "github.com/github/github-mcp-server/pkg/lockdown" - "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" - gogithub "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" - "github.com/shurcooL/githubv4" "github.com/spf13/cobra" ) @@ -34,30 +28,21 @@ func init() { rootCmd.AddCommand(generateDocsCmd) } -// mockGetClient returns a mock GitHub client for documentation generation -func mockGetClient(_ context.Context) (*gogithub.Client, error) { - return gogithub.NewClient(nil), nil -} - -// mockGetGQLClient returns a mock GraphQL client for documentation generation -func mockGetGQLClient(_ context.Context) (*githubv4.Client, error) { - return githubv4.NewClient(nil), nil -} - -// mockGetRawClient returns a mock raw client for documentation generation -func mockGetRawClient(_ context.Context) (*raw.Client, error) { - return nil, nil -} - func generateAllDocs() error { - if err := generateReadmeDocs("README.md"); err != nil { - return fmt.Errorf("failed to generate README docs: %w", err) - } - - if err := generateRemoteServerDocs("docs/remote-server.md"); err != nil { - return fmt.Errorf("failed to generate remote-server docs: %w", err) + for _, doc := range []struct { + path string + fn func(string) error + }{ + // File to edit, function to generate its docs + {"README.md", generateReadmeDocs}, + {"docs/remote-server.md", generateRemoteServerDocs}, + {"docs/tool-renaming.md", generateDeprecatedAliasesDocs}, + } { + if err := doc.fn(doc.path); err != nil { + return fmt.Errorf("failed to generate docs for %s: %w", doc.path, err) + } + fmt.Printf("Successfully updated %s with automated documentation\n", doc.path) } - return nil } @@ -65,15 +50,14 @@ func generateReadmeDocs(readmePath string) error { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients - repoAccessCache := lockdown.GetInstance(nil) - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) + // Build inventory - stateless, no dependencies needed for doc generation + r := github.NewInventory(t).Build() // Generate toolsets documentation - toolsetsDoc := generateToolsetsDoc(tsg) + toolsetsDoc := generateToolsetsDoc(r) // Generate tools documentation - toolsDoc := generateToolsDoc(tsg) + toolsDoc := generateToolsDoc(r) // Read the current README.md // #nosec G304 - readmePath is controlled by command line flag, not user input @@ -83,10 +67,16 @@ func generateReadmeDocs(readmePath string) error { } // Replace toolsets section - updatedContent := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + updatedContent, err := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + if err != nil { + return err + } // Replace tools section - updatedContent = replaceSection(updatedContent, "START AUTOMATED TOOLS", "END AUTOMATED TOOLS", toolsDoc) + updatedContent, err = replaceSection(updatedContent, "START AUTOMATED TOOLS", "END AUTOMATED TOOLS", toolsDoc) + if err != nil { + return err + } // Write back to file err = os.WriteFile(readmePath, []byte(updatedContent), 0600) @@ -94,7 +84,6 @@ func generateReadmeDocs(readmePath string) error { return fmt.Errorf("failed to write README.md: %w", err) } - fmt.Println("Successfully updated README.md with automated documentation") return nil } @@ -107,93 +96,108 @@ func generateRemoteServerDocs(docsPath string) error { toolsetsDoc := generateRemoteToolsetsDoc() // Replace content between markers - startMarker := "" - endMarker := "" - - contentStr := string(content) - startIndex := strings.Index(contentStr, startMarker) - endIndex := strings.Index(contentStr, endMarker) - - if startIndex == -1 || endIndex == -1 { - return fmt.Errorf("automation markers not found in %s", docsPath) + updatedContent, err := replaceSection(string(content), "START AUTOMATED TOOLSETS", "END AUTOMATED TOOLSETS", toolsetsDoc) + if err != nil { + return err } - newContent := contentStr[:startIndex] + startMarker + "\n" + toolsetsDoc + "\n" + endMarker + contentStr[endIndex+len(endMarker):] + // Also generate remote-only toolsets section + remoteOnlyDoc := generateRemoteOnlyToolsetsDoc() + updatedContent, err = replaceSection(updatedContent, "START AUTOMATED REMOTE TOOLSETS", "END AUTOMATED REMOTE TOOLSETS", remoteOnlyDoc) + if err != nil { + return err + } - return os.WriteFile(docsPath, []byte(newContent), 0600) //#nosec G306 + return os.WriteFile(docsPath, []byte(updatedContent), 0600) //#nosec G306 } -func generateToolsetsDoc(tsg *toolsets.ToolsetGroup) string { - var lines []string - - // Add table header and separator - lines = append(lines, "| Toolset | Description |") - lines = append(lines, "| ----------------------- | ------------------------------------------------------------- |") +// octiconImg returns an img tag for an Octicon that works with GitHub's light/dark theme. +// Uses picture element with prefers-color-scheme for automatic theme switching. +// References icons from the repo's pkg/octicons/icons directory. +// Optional pathPrefix for files in subdirectories (e.g., "../" for docs/). +func octiconImg(name string, pathPrefix ...string) string { + if name == "" { + return "" + } + prefix := "" + if len(pathPrefix) > 0 { + prefix = pathPrefix[0] + } + // Use picture element with media queries for light/dark mode support + // GitHub renders these correctly in markdown + lightIcon := fmt.Sprintf("%spkg/octicons/icons/%s-light.png", prefix, name) + darkIcon := fmt.Sprintf("%spkg/octicons/icons/%s-dark.png", prefix, name) + return fmt.Sprintf(`%s`, darkIcon, lightIcon, lightIcon, name) +} - // Add the context toolset row (handled separately in README) - lines = append(lines, "| `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |") +func generateToolsetsDoc(i *inventory.Inventory) string { + var buf strings.Builder - // Get all toolsets except context (which is handled separately above) - var toolsetNames []string - for name := range tsg.Toolsets { - if name != "context" && name != "dynamic" { // Skip context and dynamic toolsets as they're handled separately - toolsetNames = append(toolsetNames, name) - } - } + // Add table header and separator (with icon column) + buf.WriteString("| | Toolset | Description |\n") + buf.WriteString("| --- | ----------------------- | ------------------------------------------------------------- |\n") - // Sort toolset names for consistent output - sort.Strings(toolsetNames) + // Add the context toolset row with custom description (strongly recommended) + // Get context toolset for its icon + contextIcon := octiconImg("person") + fmt.Fprintf(&buf, "| %s | `context` | **Strongly recommended**: Tools that provide context about the current user and GitHub context you are operating in |\n", contextIcon) - for _, name := range toolsetNames { - toolset := tsg.Toolsets[name] - lines = append(lines, fmt.Sprintf("| `%s` | %s |", name, toolset.Description)) + // AvailableToolsets() returns toolsets that have tools, sorted by ID + // Exclude context (custom description above) and dynamic (internal only) + for _, ts := range i.AvailableToolsets("context", "dynamic") { + icon := octiconImg(ts.Icon) + fmt.Fprintf(&buf, "| %s | `%s` | %s |\n", icon, ts.ID, ts.Description) } - return strings.Join(lines, "\n") + return strings.TrimSuffix(buf.String(), "\n") } -func generateToolsDoc(tsg *toolsets.ToolsetGroup) string { - var sections []string - - // Get all toolset names and sort them alphabetically for deterministic order - var toolsetNames []string - for name := range tsg.Toolsets { - if name != "dynamic" { // Skip dynamic toolset as it's handled separately - toolsetNames = append(toolsetNames, name) - } +func generateToolsDoc(r *inventory.Inventory) string { + // AllTools() returns tools sorted by toolset ID then tool name. + // We iterate once, grouping by toolset as we encounter them. + tools := r.AllTools() + if len(tools) == 0 { + return "" } - sort.Strings(toolsetNames) - - for _, toolsetName := range toolsetNames { - toolset := tsg.Toolsets[toolsetName] - tools := toolset.GetAvailableTools() - if len(tools) == 0 { - continue + var buf strings.Builder + var toolBuf strings.Builder + var currentToolsetID inventory.ToolsetID + var currentToolsetIcon string + firstSection := true + + writeSection := func() { + if toolBuf.Len() == 0 { + return } - - // Sort tools by name for deterministic order - sort.Slice(tools, func(i, j int) bool { - return tools[i].Tool.Name < tools[j].Tool.Name - }) - - // Generate section header - capitalize first letter and replace underscores - sectionName := formatToolsetName(toolsetName) - - var toolDocs []string - for _, serverTool := range tools { - toolDoc := generateToolDoc(serverTool.Tool) - toolDocs = append(toolDocs, toolDoc) + if !firstSection { + buf.WriteString("\n\n") + } + firstSection = false + sectionName := formatToolsetName(string(currentToolsetID)) + icon := octiconImg(currentToolsetIcon) + if icon != "" { + icon += " " } + fmt.Fprintf(&buf, "
\n\n%s%s\n\n%s\n\n
", icon, sectionName, strings.TrimSuffix(toolBuf.String(), "\n\n")) + toolBuf.Reset() + } - if len(toolDocs) > 0 { - section := fmt.Sprintf("
\n\n%s\n\n%s\n\n
", - sectionName, strings.Join(toolDocs, "\n\n")) - sections = append(sections, section) + for _, tool := range tools { + // When toolset changes, emit the previous section + if tool.Toolset.ID != currentToolsetID { + writeSection() + currentToolsetID = tool.Toolset.ID + currentToolsetIcon = tool.Toolset.Icon } + writeToolDoc(&toolBuf, tool.Tool) + toolBuf.WriteString("\n\n") } - return strings.Join(sections, "\n\n") + // Emit the last section + writeSection() + + return buf.String() } func formatToolsetName(name string) string { @@ -220,21 +224,19 @@ func formatToolsetName(name string) string { } } -func generateToolDoc(tool mcp.Tool) string { - var lines []string - - // Tool name only (using annotation name instead of verbose description) - lines = append(lines, fmt.Sprintf("- **%s** - %s", tool.Name, tool.Annotations.Title)) +func writeToolDoc(buf *strings.Builder, tool mcp.Tool) { + // Tool name (no icon - section header already has the toolset icon) + fmt.Fprintf(buf, "- **%s** - %s\n", tool.Name, tool.Annotations.Title) // Parameters if tool.InputSchema == nil { - lines = append(lines, " - No parameters required") - return strings.Join(lines, "\n") + buf.WriteString(" - No parameters required") + return } schema, ok := tool.InputSchema.(*jsonschema.Schema) if !ok || schema == nil { - lines = append(lines, " - No parameters required") - return strings.Join(lines, "\n") + buf.WriteString(" - No parameters required") + return } if len(schema.Properties) > 0 { @@ -245,7 +247,7 @@ func generateToolDoc(tool mcp.Tool) string { } sort.Strings(paramNames) - for _, propName := range paramNames { + for i, propName := range paramNames { prop := schema.Properties[propName] required := contains(schema.Required, propName) requiredStr := "optional" @@ -253,7 +255,7 @@ func generateToolDoc(tool mcp.Tool) string { requiredStr = "required" } - var typeStr, description string + var typeStr string // Get the type and description switch prop.Type { @@ -267,19 +269,17 @@ func generateToolDoc(tool mcp.Tool) string { typeStr = prop.Type } - description = prop.Description - // Indent any continuation lines in the description to maintain markdown formatting - description = indentMultilineDescription(description, " ") + description := indentMultilineDescription(prop.Description, " ") - paramLine := fmt.Sprintf(" - `%s`: %s (%s, %s)", propName, description, typeStr, requiredStr) - lines = append(lines, paramLine) + fmt.Fprintf(buf, " - `%s`: %s (%s, %s)", propName, description, typeStr, requiredStr) + if i < len(paramNames)-1 { + buf.WriteString("\n") + } } } else { - lines = append(lines, " - No parameters required") + buf.WriteString(" - No parameters required") } - - return strings.Join(lines, "\n") } func contains(slice []string, item string) bool { @@ -294,25 +294,38 @@ func contains(slice []string, item string) bool { // indentMultilineDescription adds the specified indent to all lines after the first line. // This ensures that multi-line descriptions maintain proper markdown list formatting. func indentMultilineDescription(description, indent string) string { - lines := strings.Split(description, "\n") - if len(lines) <= 1 { + if !strings.Contains(description, "\n") { return description } + var buf strings.Builder + lines := strings.Split(description, "\n") + buf.WriteString(lines[0]) for i := 1; i < len(lines); i++ { - lines[i] = indent + lines[i] + buf.WriteString("\n") + buf.WriteString(indent) + buf.WriteString(lines[i]) } - return strings.Join(lines, "\n") + return buf.String() } -func replaceSection(content, startMarker, endMarker, newContent string) string { - startPattern := fmt.Sprintf(``, regexp.QuoteMeta(startMarker)) - endPattern := fmt.Sprintf(``, regexp.QuoteMeta(endMarker)) - - re := regexp.MustCompile(fmt.Sprintf(`(?s)%s.*?%s`, startPattern, endPattern)) +func replaceSection(content, startMarker, endMarker, newContent string) (string, error) { + start := fmt.Sprintf("", startMarker) + end := fmt.Sprintf("", endMarker) - replacement := fmt.Sprintf("\n%s\n", startMarker, newContent, endMarker) + startIdx := strings.Index(content, start) + endIdx := strings.Index(content, end) + if startIdx == -1 || endIdx == -1 { + return "", fmt.Errorf("markers not found: %s / %s", start, end) + } - return re.ReplaceAllString(content, replacement) + var buf strings.Builder + buf.WriteString(content[:startIdx]) + buf.WriteString(start) + buf.WriteString("\n") + buf.WriteString(newContent) + buf.WriteString("\n") + buf.WriteString(content[endIdx:]) + return buf.String(), nil } func generateRemoteToolsetsDoc() string { @@ -321,34 +334,66 @@ func generateRemoteToolsetsDoc() string { // Create translation helper t, _ := translations.TranslationHelper() - // Create toolset group with mock clients - repoAccessCache := lockdown.GetInstance(nil) - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache) + // Build inventory - stateless + r := github.NewInventory(t).Build() - // Generate table header - buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") - buf.WriteString("|----------------|--------------------------------------------------|-------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n") + // Generate table header (icon is combined with Name column) + buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") + buf.WriteString("| ---- | ----------- | ------- | ------------------------- | -------------- | ----------------------------------- |\n") - // Get all toolsets - toolsetNames := make([]string, 0, len(tsg.Toolsets)) - for name := range tsg.Toolsets { - if name != "context" && name != "dynamic" { // Skip context and dynamic toolsets as they're handled separately - toolsetNames = append(toolsetNames, name) - } + // Add "all" toolset first (special case) + allIcon := octiconImg("apps", "../") + fmt.Fprintf(&buf, "| %s
all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%%7B%%22type%%22%%3A%%20%%22http%%22%%2C%%22url%%22%%3A%%20%%22https%%3A%%2F%%2Fapi.githubcopilot.com%%2Fmcp%%2F%%22%%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%%7B%%22type%%22%%3A%%20%%22http%%22%%2C%%22url%%22%%3A%%20%%22https%%3A%%2F%%2Fapi.githubcopilot.com%%2Fmcp%%2Freadonly%%22%%7D) |\n", allIcon) + + // AvailableToolsets() returns toolsets that have tools, sorted by ID + // Exclude context (handled separately) and dynamic (internal only) + for _, ts := range r.AvailableToolsets("context", "dynamic") { + idStr := string(ts.ID) + + formattedName := formatToolsetName(idStr) + apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) + readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) + + // Create install config JSON (URL encoded) + installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) + readonlyConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, readonlyURL)) + + // Fix URL encoding to use %20 instead of + for spaces + installConfig = strings.ReplaceAll(installConfig, "+", "%20") + readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") + + installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) + readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) + + icon := octiconImg(ts.Icon, "../") + fmt.Fprintf(&buf, "| %s
%s | %s | %s | %s | [read-only](%s) | %s |\n", + icon, + formattedName, + ts.Description, + apiURL, + installLink, + readonlyURL, + readonlyInstallLink, + ) } - sort.Strings(toolsetNames) - // Add "all" toolset first (special case) - buf.WriteString("| all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) |\n") + return strings.TrimSuffix(buf.String(), "\n") +} + +func generateRemoteOnlyToolsetsDoc() string { + var buf strings.Builder + + // Generate table header (icon is combined with Name column) + buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") + buf.WriteString("| ---- | ----------- | ------- | ------------------------- | -------------- | ----------------------------------- |\n") - // Add individual toolsets - for _, name := range toolsetNames { - toolset := tsg.Toolsets[name] + // Use RemoteOnlyToolsets from github package + for _, ts := range github.RemoteOnlyToolsets() { + idStr := string(ts.ID) - formattedName := formatToolsetName(name) - description := toolset.Description - apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", name) - readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", name) + formattedName := formatToolsetName(idStr) + apiURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s", idStr) + readonlyURL := fmt.Sprintf("https://api.githubcopilot.com/mcp/x/%s/readonly", idStr) // Create install config JSON (URL encoded) installConfig := url.QueryEscape(fmt.Sprintf(`{"type": "http","url": "%s"}`, apiURL)) @@ -358,17 +403,73 @@ func generateRemoteToolsetsDoc() string { installConfig = strings.ReplaceAll(installConfig, "+", "%20") readonlyConfig = strings.ReplaceAll(readonlyConfig, "+", "%20") - installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", name, installConfig) - readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", name, readonlyConfig) + installLink := fmt.Sprintf("[Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, installConfig) + readonlyInstallLink := fmt.Sprintf("[Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-%s&config=%s)", idStr, readonlyConfig) - buf.WriteString(fmt.Sprintf("| %-14s | %-48s | %-53s | %-218s | %-110s | %-288s |\n", + icon := octiconImg(ts.Icon, "../") + fmt.Fprintf(&buf, "| %s
%s | %s | %s | %s | [read-only](%s) | %s |\n", + icon, formattedName, - description, + ts.Description, apiURL, installLink, - fmt.Sprintf("[read-only](%s)", readonlyURL), + readonlyURL, readonlyInstallLink, - )) + ) + } + + return strings.TrimSuffix(buf.String(), "\n") +} +func generateDeprecatedAliasesDocs(docsPath string) error { + // Read the current file + content, err := os.ReadFile(docsPath) //#nosec G304 + if err != nil { + return fmt.Errorf("failed to read docs file: %w", err) + } + + // Generate the table + aliasesDoc := generateDeprecatedAliasesTable() + + // Replace content between markers + updatedContent, err := replaceSection(string(content), "START AUTOMATED ALIASES", "END AUTOMATED ALIASES", aliasesDoc) + if err != nil { + return err + } + + // Write back to file + err = os.WriteFile(docsPath, []byte(updatedContent), 0600) + if err != nil { + return fmt.Errorf("failed to write deprecated aliases docs: %w", err) + } + + return nil +} + +func generateDeprecatedAliasesTable() string { + var buf strings.Builder + + // Add table header + buf.WriteString("| Old Name | New Name |\n") + buf.WriteString("|----------|----------|\n") + + aliases := github.DeprecatedToolAliases + if len(aliases) == 0 { + buf.WriteString("| *(none currently)* | |") + } else { + // Sort keys for deterministic output + var oldNames []string + for oldName := range aliases { + oldNames = append(oldNames, oldName) + } + sort.Strings(oldNames) + + for i, oldName := range oldNames { + newName := aliases[oldName] + fmt.Fprintf(&buf, "| `%s` | `%s` |", oldName, newName) + if i < len(oldNames)-1 { + buf.WriteString("\n") + } + } } return buf.String() diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 87eeedd2e..cfb68be4e 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -41,20 +41,31 @@ var ( // it's because viper doesn't handle comma-separated values correctly for env // vars when using GetStringSlice. // https://github.com/spf13/viper/issues/380 + // + // Additionally, viper.UnmarshalKey returns an empty slice even when the flag + // is not set, but we need nil to indicate "use defaults". So we check IsSet first. var enabledToolsets []string - if err := viper.UnmarshalKey("toolsets", &enabledToolsets); err != nil { - return fmt.Errorf("failed to unmarshal toolsets: %w", err) + if viper.IsSet("toolsets") { + if err := viper.UnmarshalKey("toolsets", &enabledToolsets); err != nil { + return fmt.Errorf("failed to unmarshal toolsets: %w", err) + } } + // else: enabledToolsets stays nil, meaning "use defaults" // Parse tools (similar to toolsets) var enabledTools []string - if err := viper.UnmarshalKey("tools", &enabledTools); err != nil { - return fmt.Errorf("failed to unmarshal tools: %w", err) + if viper.IsSet("tools") { + if err := viper.UnmarshalKey("tools", &enabledTools); err != nil { + return fmt.Errorf("failed to unmarshal tools: %w", err) + } } - // If neither toolset config nor tools config is passed we enable the default toolset - if len(enabledToolsets) == 0 && len(enabledTools) == 0 { - enabledToolsets = []string{github.ToolsetMetadataDefault.ID} + // Parse enabled features (similar to toolsets) + var enabledFeatures []string + if viper.IsSet("features") { + if err := viper.UnmarshalKey("features", &enabledFeatures); err != nil { + return fmt.Errorf("failed to unmarshal features: %w", err) + } } ttl := viper.GetDuration("repo-access-cache-ttl") @@ -64,6 +75,7 @@ var ( Token: token, EnabledToolsets: enabledToolsets, EnabledTools: enabledTools, + EnabledFeatures: enabledFeatures, DynamicToolsets: viper.GetBool("dynamic_toolsets"), ReadOnly: viper.GetBool("read-only"), ExportTranslations: viper.GetBool("export-translations"), @@ -87,6 +99,7 @@ func init() { // Add global flags that will be shared by all commands rootCmd.PersistentFlags().StringSlice("toolsets", nil, github.GenerateToolsetsHelp()) rootCmd.PersistentFlags().StringSlice("tools", nil, "Comma-separated list of specific tools to enable") + rootCmd.PersistentFlags().StringSlice("features", nil, "Comma-separated list of feature flags to enable") rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().String("log-file", "", "Path to log file") @@ -100,6 +113,7 @@ func init() { // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) _ = viper.BindPFlag("tools", rootCmd.PersistentFlags().Lookup("tools")) + _ = viper.BindPFlag("features", rootCmd.PersistentFlags().Lookup("features")) _ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) diff --git a/docs/installation-guides/README.md b/docs/installation-guides/README.md index 097d97b02..be967f81d 100644 --- a/docs/installation-guides/README.md +++ b/docs/installation-guides/README.md @@ -4,6 +4,7 @@ This directory contains detailed installation instructions for the GitHub MCP Se ## Installation Guides by Host Application - **[GitHub Copilot in other IDEs](install-other-copilot-ides.md)** - Installation for JetBrains, Visual Studio, Eclipse, and Xcode with GitHub Copilot +- **[Antigravity](install-antigravity.md)** - Installation for Google Antigravity IDE - **[Claude Applications](install-claude.md)** - Installation guide for Claude Web, Claude Desktop and Claude Code CLI - **[Cursor](install-cursor.md)** - Installation guide for Cursor IDE - **[Google Gemini CLI](install-gemini-cli.md)** - Installation guide for Google Gemini CLI diff --git a/docs/installation-guides/install-antigravity.md b/docs/installation-guides/install-antigravity.md new file mode 100644 index 000000000..c24d8e01d --- /dev/null +++ b/docs/installation-guides/install-antigravity.md @@ -0,0 +1,143 @@ +# Installing GitHub MCP Server in Antigravity + +This guide covers setting up the GitHub MCP Server in Google's Antigravity IDE. + +## Prerequisites + +- Antigravity IDE installed (latest version) +- GitHub Personal Access Token with appropriate scopes + +## Installation Methods + +### Option 1: Remote Server (Recommended) + +Uses GitHub's hosted server at `https://api.githubcopilot.com/mcp/`. + +> [!NOTE] +> We recommend this manual configuration method because the "official" installation via the Antigravity MCP Store currently has known issues (often resulting in Docker errors). This direct remote connection is more reliable. + +#### Step 1: Access MCP Configuration + +1. Open Antigravity +2. Click the "..." (Additional Options) menu in the Agent panel +3. Select "MCP Servers" +4. Click "Manage MCP Servers" +5. Click "View raw config" + +This will open your `mcp_config.json` file at: +- **Windows**: `C:\Users\\.gemini\antigravity\mcp_config.json` +- **macOS/Linux**: `~/.gemini/antigravity/mcp_config.json` + +#### Step 2: Add Configuration + +Add the following to your `mcp_config.json`: + +```json +{ + "mcpServers": { + "github": { + "serverUrl": "https://api.githubcopilot.com/mcp/", + "headers": { + "Authorization": "Bearer YOUR_GITHUB_PAT" + } + } + } +} +``` + +**Important**: Note that Antigravity uses `serverUrl` instead of `url` for HTTP-based MCP servers. + +#### Step 3: Configure Your Token + +Replace `YOUR_GITHUB_PAT` with your actual GitHub Personal Access Token. + +Create a token here: https://github.com/settings/tokens + +Recommended scopes: +- `repo` - Full control of private repositories +- `read:org` - Read org and team membership +- `read:user` - Read user profile data + +#### Step 4: Restart Antigravity + +Close and reopen Antigravity for the changes to take effect. + +#### Step 5: Verify Installation + +1. Open the MCP Servers panel (... menu → MCP Servers) +2. You should see "github" with a list of available tools +3. You can now use GitHub tools in your conversations + +> [!NOTE] +> The status indicator in the MCP Servers panel might not immediately turn green in some versions, but the tools will still function if configured correctly. + +### Option 2: Local Docker Server + +If you prefer running the server locally with Docker: + +```json +{ + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "ghcr.io/github/github-mcp-server" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "YOUR_GITHUB_PAT" + } + } + } +} +``` + +**Requirements**: +- Docker Desktop installed and running +- Docker must be in your system PATH + +## Troubleshooting + +### "Error: serverUrl or command must be specified" + +Make sure you're using `serverUrl` (not `url`) for the remote server configuration. Antigravity requires `serverUrl` for HTTP-based MCP servers. + +### Server not appearing in MCP list + +- Verify JSON syntax in your config file +- Check that your PAT hasn't expired +- Restart Antigravity completely + +### Tools not working + +- Ensure your PAT has the correct scopes +- Check the MCP Servers panel for error messages +- Verify internet connection for remote server + +## Available Tools + +Once installed, you'll have access to tools like: +- `create_repository` - Create new GitHub repositories +- `push_files` - Push files to repositories +- `search_repositories` - Search for repositories +- `create_or_update_file` - Manage file content +- `get_file_contents` - Read file content +- And many more... + +For a complete list of available tools and features, see the [main README](../../README.md). + +## Differences from Other IDEs + +- **Configuration key**: Antigravity uses `serverUrl` instead of `url` for HTTP servers +- **Config location**: `.gemini/antigravity/mcp_config.json` instead of `.cursor/mcp.json` +- **Tool limits**: Antigravity recommends keeping total enabled tools under 50 for optimal performance + +## Next Steps + +- Explore the [Server Configuration Guide](../server-configuration.md) for advanced options +- Check out [toolsets documentation](../../README.md#available-toolsets) to customize available tools +- See the [Remote Server Documentation](../remote-server.md) for more details diff --git a/docs/remote-server.md b/docs/remote-server.md index e06d41a75..d7d0f72b1 100644 --- a/docs/remote-server.md +++ b/docs/remote-server.md @@ -17,39 +17,39 @@ The remote server has [additional tools](#toolsets-only-available-in-the-remote- Below is a table of available toolsets for the remote GitHub MCP Server. Each toolset is provided as a distinct URL so you can mix and match to create the perfect combination of tools for your use-case. Add `/readonly` to the end of any URL to restrict the tools in the toolset to only those that enable read access. We also provide the option to use [headers](#headers) instead. -| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) | -|----------------|--------------------------------------------------|-------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) | -| Actions | GitHub Actions workflows and CI/CD operations | https://api.githubcopilot.com/mcp/x/actions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-actions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Factions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/actions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-actions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Factions%2Freadonly%22%7D) | -| Code Security | Code security related tools, such as GitHub Code Scanning | https://api.githubcopilot.com/mcp/x/code_security | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/code_security/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%2Freadonly%22%7D) | -| Dependabot | Dependabot tools | https://api.githubcopilot.com/mcp/x/dependabot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/dependabot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%2Freadonly%22%7D) | -| Discussions | GitHub Discussions related tools | https://api.githubcopilot.com/mcp/x/discussions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/discussions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%2Freadonly%22%7D) | -| Experiments | Experimental features that are not considered stable yet | https://api.githubcopilot.com/mcp/x/experiments | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/experiments/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-experiments&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fexperiments%2Freadonly%22%7D) | -| Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | -| Git | GitHub Git API related tools for low-level Git operations | https://api.githubcopilot.com/mcp/x/git | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/git/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%2Freadonly%22%7D) | -| Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | -| Labels | GitHub Labels related tools | https://api.githubcopilot.com/mcp/x/labels | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/labels/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%2Freadonly%22%7D) | -| Notifications | GitHub Notifications related tools | https://api.githubcopilot.com/mcp/x/notifications | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/notifications/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%2Freadonly%22%7D) | -| Organizations | GitHub Organization related tools | https://api.githubcopilot.com/mcp/x/orgs | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/orgs/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%2Freadonly%22%7D) | -| Projects | GitHub Projects related tools | https://api.githubcopilot.com/mcp/x/projects | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/projects/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%2Freadonly%22%7D) | -| Pull Requests | GitHub Pull Request related tools | https://api.githubcopilot.com/mcp/x/pull_requests | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-pull_requests&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fpull_requests%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/pull_requests/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-pull_requests&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fpull_requests%2Freadonly%22%7D) | -| Repositories | GitHub Repository related tools | https://api.githubcopilot.com/mcp/x/repos | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-repos&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Frepos%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/repos/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-repos&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Frepos%2Freadonly%22%7D) | -| Secret Protection | Secret protection related tools, such as GitHub Secret Scanning | https://api.githubcopilot.com/mcp/x/secret_protection | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-secret_protection&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecret_protection%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/secret_protection/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-secret_protection&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecret_protection%2Freadonly%22%7D) | -| Security Advisories | Security advisories related tools | https://api.githubcopilot.com/mcp/x/security_advisories | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/security_advisories/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%2Freadonly%22%7D) | -| Stargazers | GitHub Stargazers related tools | https://api.githubcopilot.com/mcp/x/stargazers | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/stargazers/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%2Freadonly%22%7D) | -| Users | GitHub User related tools | https://api.githubcopilot.com/mcp/x/users | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/users/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%2Freadonly%22%7D) | - +| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) | +| ---- | ----------- | ------- | ------------------------- | -------------- | ----------------------------------- | +| apps
all | All available GitHub MCP tools | https://api.githubcopilot.com/mcp/ | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2F%22%7D) | [read-only](https://api.githubcopilot.com/mcp/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=github&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Freadonly%22%7D) | +| workflow
Actions | GitHub Actions workflows and CI/CD operations | https://api.githubcopilot.com/mcp/x/actions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-actions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Factions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/actions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-actions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Factions%2Freadonly%22%7D) | +| codescan
Code Security | Code security related tools, such as GitHub Code Scanning | https://api.githubcopilot.com/mcp/x/code_security | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/code_security/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-code_security&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcode_security%2Freadonly%22%7D) | +| dependabot
Dependabot | Dependabot tools | https://api.githubcopilot.com/mcp/x/dependabot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/dependabot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-dependabot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdependabot%2Freadonly%22%7D) | +| comment-discussion
Discussions | GitHub Discussions related tools | https://api.githubcopilot.com/mcp/x/discussions | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/discussions/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-discussions&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fdiscussions%2Freadonly%22%7D) | +| logo-gist
Gists | GitHub Gist related tools | https://api.githubcopilot.com/mcp/x/gists | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/gists/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-gists&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgists%2Freadonly%22%7D) | +| git-branch
Git | GitHub Git API related tools for low-level Git operations | https://api.githubcopilot.com/mcp/x/git | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/git/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-git&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgit%2Freadonly%22%7D) | +| issue-opened
Issues | GitHub Issues related tools | https://api.githubcopilot.com/mcp/x/issues | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/issues/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-issues&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fissues%2Freadonly%22%7D) | +| tag
Labels | GitHub Labels related tools | https://api.githubcopilot.com/mcp/x/labels | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/labels/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-labels&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Flabels%2Freadonly%22%7D) | +| bell
Notifications | GitHub Notifications related tools | https://api.githubcopilot.com/mcp/x/notifications | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/notifications/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-notifications&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fnotifications%2Freadonly%22%7D) | +| organization
Organizations | GitHub Organization related tools | https://api.githubcopilot.com/mcp/x/orgs | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/orgs/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-orgs&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Forgs%2Freadonly%22%7D) | +| project
Projects | GitHub Projects related tools | https://api.githubcopilot.com/mcp/x/projects | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/projects/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-projects&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fprojects%2Freadonly%22%7D) | +| git-pull-request
Pull Requests | GitHub Pull Request related tools | https://api.githubcopilot.com/mcp/x/pull_requests | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-pull_requests&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fpull_requests%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/pull_requests/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-pull_requests&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fpull_requests%2Freadonly%22%7D) | +| repo
Repositories | GitHub Repository related tools | https://api.githubcopilot.com/mcp/x/repos | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-repos&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Frepos%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/repos/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-repos&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Frepos%2Freadonly%22%7D) | +| shield-lock
Secret Protection | Secret protection related tools, such as GitHub Secret Scanning | https://api.githubcopilot.com/mcp/x/secret_protection | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-secret_protection&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecret_protection%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/secret_protection/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-secret_protection&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecret_protection%2Freadonly%22%7D) | +| shield
Security Advisories | Security advisories related tools | https://api.githubcopilot.com/mcp/x/security_advisories | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/security_advisories/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-security_advisories&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fsecurity_advisories%2Freadonly%22%7D) | +| star
Stargazers | GitHub Stargazers related tools | https://api.githubcopilot.com/mcp/x/stargazers | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/stargazers/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-stargazers&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fstargazers%2Freadonly%22%7D) | +| people
Users | GitHub User related tools | https://api.githubcopilot.com/mcp/x/users | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/users/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-users&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fusers%2Freadonly%22%7D) | ### Additional _Remote_ Server Toolsets These toolsets are only available in the remote GitHub MCP Server and are not included in the local MCP server. -| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) | -| -------------------- | --------------------------------------------- | ------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Copilot | Copilot related tools | https://api.githubcopilot.com/mcp/x/copilot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/copilot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot%2Freadonly%22%7D) | -| Copilot Spaces | Copilot Spaces tools | https://api.githubcopilot.com/mcp/x/copilot_spaces | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot_spaces&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot_spaces%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/copilot_spaces/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot_spaces&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot_spaces%2Freadonly%22%7D) | -| GitHub support docs search | Retrieve documentation to answer GitHub product and support questions. Topics include: GitHub Actions Workflows, Authentication, ... | https://api.githubcopilot.com/mcp/x/github_support_docs_search | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-support&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgithub_support_docs_search%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/github_support_docs_search/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-support&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgithub_support_docs_search%2Freadonly%22%7D) | + +| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) | +| ---- | ----------- | ------- | ------------------------- | -------------- | ----------------------------------- | +| copilot
Copilot | Copilot related tools | https://api.githubcopilot.com/mcp/x/copilot | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/copilot/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot%2Freadonly%22%7D) | +| copilot
Copilot Spaces | Copilot Spaces tools | https://api.githubcopilot.com/mcp/x/copilot_spaces | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot_spaces&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot_spaces%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/copilot_spaces/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-copilot_spaces&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fcopilot_spaces%2Freadonly%22%7D) | +| book
Github Support Docs Search | Retrieve documentation to answer GitHub product and support questions. Topics include: GitHub Actions Workflows, Authentication, ... | https://api.githubcopilot.com/mcp/x/github_support_docs_search | [Install](https://insiders.vscode.dev/redirect/mcp/install?name=gh-github_support_docs_search&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgithub_support_docs_search%22%7D) | [read-only](https://api.githubcopilot.com/mcp/x/github_support_docs_search/readonly) | [Install read-only](https://insiders.vscode.dev/redirect/mcp/install?name=gh-github_support_docs_search&config=%7B%22type%22%3A%20%22http%22%2C%22url%22%3A%20%22https%3A%2F%2Fapi.githubcopilot.com%2Fmcp%2Fx%2Fgithub_support_docs_search%2Freadonly%22%7D) | + ### Optional Headers diff --git a/docs/tool-renaming.md b/docs/tool-renaming.md new file mode 100644 index 000000000..66d3ff410 --- /dev/null +++ b/docs/tool-renaming.md @@ -0,0 +1,50 @@ +# Tool Renaming Guide + +How to safely rename MCP tools without breaking existing user configurations. + +## Overview + +When tools are renamed, users who have the old tool name in their MCP configuration (for example, in `X-MCP-Tools` headers for the remote MCP server or `--tools` flags for the local MCP server) would normally get errors. +The deprecation alias system allows us to maintain backward compatibility by silently resolving old tool names to their new canonical names. + +This allows us to rename tools safely, without introducing breaking changes for users that have a hard reference to those tools in their server configuration. + +## Quick Steps + +1. **Rename the tool** in your code (as usual, this will imply a range of changes like updating the tool registration, the tests and the toolsnaps). +2. **Add a deprecation alias** in [pkg/github/deprecated_tool_aliases.go](../pkg/github/deprecated_tool_aliases.go): + ```go + var DeprecatedToolAliases = map[string]string{ + "old_tool_name": "new_tool_name", + } + ``` +3. **Update documentation** (README, etc.) to reference the new canonical name + +That's it. The server will silently resolve old names to new ones. This will work across both local and remote MCP servers. + +## Example + +If renaming `get_issue` to `issue_read`: + +```go +var DeprecatedToolAliases = map[string]string{ + "get_issue": "issue_read", +} +``` + +A user with this configuration: +```json +{ + "--tools": "get_issue,get_file_contents" +} +``` + +Will get `issue_read` and `get_file_contents` tools registered, with no errors. + +## Current Deprecations + + +| Old Name | New Name | +|----------|----------| +| *(none currently)* | | + diff --git a/docs/toolsets-and-icons.md b/docs/toolsets-and-icons.md new file mode 100644 index 000000000..9c26b4aa1 --- /dev/null +++ b/docs/toolsets-and-icons.md @@ -0,0 +1,201 @@ +# Toolsets and Icons + +This document explains how to work with toolsets and icons in the GitHub MCP Server. + +## Toolset Overview + +Toolsets are logical groupings of related tools. Each toolset has metadata defined in `pkg/github/tools.go`: + +```go +ToolsetMetadataRepos = inventory.ToolsetMetadata{ + ID: "repos", + Description: "GitHub Repository related tools", + Default: true, + Icon: "repo", +} +``` + +### Toolset Fields + +| Field | Type | Description | +|-------|------|-------------| +| `ID` | `ToolsetID` | Unique identifier used in URLs and CLI flags (e.g., `repos`, `issues`) | +| `Description` | `string` | Human-readable description shown in documentation | +| `Default` | `bool` | Whether this toolset is enabled by default | +| `Icon` | `string` | Octicon name for visual representation in MCP clients | + +## Adding Icons to Toolsets + +Icons help users quickly identify toolsets in MCP-compatible clients. We use [Primer Octicons](https://primer.style/foundations/icons) for all icons. + +### Step 1: Choose an Octicon + +Browse the [Octicon gallery](https://primer.style/foundations/icons) and select an appropriate icon. Use the base name without size suffix (e.g., `repo` not `repo-16`). + +### Step 2: Add Icon to Required Icons List + +Icons are defined in `pkg/octicons/required_icons.txt`, which is the single source of truth for which icons should be embedded: + +``` +# Required icons for the GitHub MCP Server +# Add new icons below (one per line) +repo +issue-opened +git-pull-request +your-new-icon # Add your icon here +``` + +### Step 3: Fetch the Icon Files + +Run the fetch-icons script to download and convert the icon: + +```bash +# Fetch a specific icon +script/fetch-icons your-new-icon + +# Or fetch all required icons +script/fetch-icons +``` + +This script: +- Downloads the 24px SVG from [Primer Octicons](https://github.com/primer/octicons) +- Converts to PNG with light theme (dark icons for light backgrounds) +- Converts to PNG with dark theme (white icons for dark backgrounds) +- Saves both variants to `pkg/octicons/icons/` + +**Requirements:** The script requires `rsvg-convert`: +- Ubuntu/Debian: `sudo apt-get install librsvg2-bin` +- macOS: `brew install librsvg` + +### Step 4: Update the Toolset Metadata + +Add or update the `Icon` field in the toolset definition: + +```go +// In pkg/github/tools.go +ToolsetMetadataRepos = inventory.ToolsetMetadata{ + ID: "repos", + Description: "GitHub Repository related tools", + Default: true, + Icon: "repo", // Add this line +} +``` + +### Step 5: Regenerate Documentation + +Run the documentation generator to update all markdown files: + +```bash +go run ./cmd/github-mcp-server generate-docs +``` + +This updates icons in: +- `README.md` - Toolsets table and tool section headers +- `docs/remote-server.md` - Remote toolsets table + +## Remote-Only Toolsets + +Some toolsets are only available in the remote GitHub MCP Server (hosted at `api.githubcopilot.com`). These are defined in `pkg/github/tools.go` with their icons, but are not registered with the local server: + +```go +// Remote-only toolsets +ToolsetMetadataCopilot = inventory.ToolsetMetadata{ + ID: "copilot", + Description: "Copilot related tools", + Icon: "copilot", +} +``` + +The `RemoteOnlyToolsets()` function returns the list of these toolsets for documentation generation. + +To add a new remote-only toolset: + +1. Add the metadata definition in `pkg/github/tools.go` +2. Add it to the slice returned by `RemoteOnlyToolsets()` +3. Regenerate documentation + +## Tool Icon Inheritance + +Individual tools inherit icons from their parent toolset. When a tool is registered with a toolset, its icons are automatically set: + +```go +// In pkg/inventory/server_tool.go +toolCopy.Icons = tool.Toolset.Icons() +``` + +This means you only need to set the icon once on the toolset, and all tools in that toolset will display the same icon. + +## How Icons Work in MCP + +The MCP protocol supports tool icons via the `icons` field. We provide icons in two formats: + +1. **Data URIs** - Base64-encoded PNG images embedded in the tool definition +2. **Light/Dark variants** - Both theme variants are provided for proper display + +The `octicons.Icons()` function generates the MCP-compatible icon objects: + +```go +// Returns []mcp.Icon with both light and dark variants +icons := octicons.Icons("repo") +``` + +## Existing Toolset Icons + +| Toolset | Octicon Name | +|---------|--------------| +| Context | `person` | +| Repositories | `repo` | +| Issues | `issue-opened` | +| Pull Requests | `git-pull-request` | +| Git | `git-branch` | +| Users | `people` | +| Organizations | `organization` | +| Actions | `workflow` | +| Code Security | `codescan` | +| Secret Protection | `shield-lock` | +| Dependabot | `dependabot` | +| Discussions | `comment-discussion` | +| Gists | `logo-gist` | +| Security Advisories | `shield` | +| Projects | `project` | +| Labels | `tag` | +| Stargazers | `star` | +| Notifications | `bell` | +| Dynamic | `tools` | +| Copilot | `copilot` | +| Support Search | `book` | + +## Troubleshooting + +### Icons not appearing in documentation + +1. Ensure PNG files exist in `pkg/octicons/icons/` with `-light.png` and `-dark.png` suffixes +2. Run `go run ./cmd/github-mcp-server generate-docs` to regenerate +3. Check that the `Icon` field is set on the toolset metadata + +### Icons not appearing in MCP clients + +1. Verify the client supports MCP tool icons +2. Check that the octicons package is properly generating base64 data URIs +3. Ensure the icon name matches a file in `pkg/octicons/icons/` + +## CI Validation + +The following tests run in CI to catch icon issues early: + +### `pkg/octicons.TestEmbeddedIconsExist` + +Verifies that all icons listed in `pkg/octicons/required_icons.txt` have corresponding PNG files embedded. + +### `pkg/github.TestAllToolsetIconsExist` + +Verifies that all toolset `Icon` fields reference icons that are properly embedded. + +### `pkg/github.TestToolsetMetadataHasIcons` + +Ensures all toolsets have an `Icon` field set. + +If any of these tests fail: +1. Add the missing icon to `pkg/octicons/required_icons.txt` +2. Run `script/fetch-icons` to download the icon +3. Commit the new icon files diff --git a/e2e.test b/e2e.test new file mode 100755 index 000000000..58505b3a2 Binary files /dev/null and b/e2e.test differ diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 5f67fb84c..86ff45b29 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -33,8 +33,15 @@ var ( buildOnce sync.Once buildError error + + // Rate limit management + rateLimitMu sync.Mutex ) +// minRateLimitRemaining is the minimum number of API requests we want to have +// remaining before we start waiting for the rate limit to reset. +const minRateLimitRemaining = 50 + // getE2EToken ensures the environment variable is checked only once and returns the token func getE2EToken(t *testing.T) string { getTokenOnce.Do(func() { @@ -72,6 +79,36 @@ func getRESTClient(t *testing.T) *gogithub.Client { return ghClient } +// waitForRateLimit checks the current rate limit and waits if necessary. +// It ensures we have at least minRateLimitRemaining requests available before proceeding. +func waitForRateLimit(t *testing.T) { + rateLimitMu.Lock() + defer rateLimitMu.Unlock() + + ghClient := getRESTClient(t) + ctx := context.Background() + + rateLimits, _, err := ghClient.RateLimit.Get(ctx) + if err != nil { + t.Logf("Warning: failed to check rate limit: %v", err) + return + } + + core := rateLimits.Core + if core.Remaining < minRateLimitRemaining { + waitDuration := time.Until(core.Reset.Time) + time.Second // Add 1 second buffer + if waitDuration > 0 { + t.Logf("Rate limit low (%d/%d remaining). Waiting %v until reset...", + core.Remaining, core.Limit, waitDuration.Round(time.Second)) + time.Sleep(waitDuration) + t.Log("Rate limit reset, continuing...") + } + } else { + t.Logf("Rate limit OK: %d/%d remaining (reset in %v)", + core.Remaining, core.Limit, time.Until(core.Reset.Time).Round(time.Second)) + } +} + // ensureDockerImageBuilt makes sure the Docker image is built only once across all tests func ensureDockerImageBuilt(t *testing.T) { buildOnce.Do(func() { @@ -107,6 +144,9 @@ func withToolsets(toolsets []string) clientOption { } func setupMCPClient(t *testing.T, options ...clientOption) *mcp.ClientSession { + // Check rate limit before setting up the client + waitForRateLimit(t) + // Get token and ensure Docker image is built token := getE2EToken(t) @@ -219,7 +259,7 @@ func TestGetMe(t *testing.T) { response, err := mcpClient.CallTool(ctx, &mcp.CallToolParams{Name: "get_me"}) require.NoError(t, err, "expected to call 'get_me' tool successfully") - require.False(t, response.IsError, "expected result not to be an error") + require.False(t, response.IsError, fmt.Sprintf("expected result not to be an error: %+v", response)) require.Len(t, response.Content, 1, "expected content to have one item") textContent, ok := response.Content[0].(*mcp.TextContent) @@ -926,7 +966,16 @@ func TestRequestCopilotReview(t *testing.T) { }, }) require.NoError(t, err, "expected to call 'request_copilot_review' tool successfully") - require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Check if Copilot is available - skip if not + if resp.IsError { + if tc, ok := resp.Content[0].(*mcp.TextContent); ok { + if strings.Contains(tc.Text, "copilot") || strings.Contains(tc.Text, "Copilot") { + t.Skip("skipping because copilot isn't available as a reviewer on this repository") + } + } + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + } textContent, ok = resp.Content[0].(*mcp.TextContent) require.True(t, ok, "expected content to be of type TextContent") @@ -939,6 +988,11 @@ func TestRequestCopilotReview(t *testing.T) { reviewRequests, _, err := ghClient.PullRequests.ListReviewers(context.Background(), currentOwner, repoName, 1, nil) require.NoError(t, err, "expected to get review requests successfully") + // Check if Copilot was added as a reviewer - skip if not available + if len(reviewRequests.Users) == 0 { + t.Skip("skipping because copilot wasn't added as a reviewer (likely not enabled for this account)") + } + // Check that there is one review request from copilot require.Len(t, reviewRequests.Users, 1, "expected to find one review request") require.Equal(t, "Copilot", *reviewRequests.Users[0].Login, "expected review request to be for Copilot") @@ -1278,16 +1332,17 @@ func TestPullRequestReviewCommentSubmit(t *testing.T) { require.NoError(t, err, "expected to call 'create_branch' tool successfully") require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) - // Create a commit with a new file + // Create a commit with a new file (multi-line content to support multi-line review comments) t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName) + multiLineContent := fmt.Sprintf("Line 1: Created by e2e test %s\nLine 2: Additional content for multi-line comments\nLine 3: More content", t.Name()) resp, err = mcpClient.CallTool(ctx, &mcp.CallToolParams{ Name: "create_or_update_file", Arguments: map[string]any{ "owner": currentOwner, "repo": repoName, "path": "test-file.txt", - "content": fmt.Sprintf("Created by e2e test %s", t.Name()), + "content": multiLineContent, "message": "Add test file", "branch": "test-branch", }, diff --git a/go.mod b/go.mod index 661778fc3..9423ce557 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/gorilla/mux v1.8.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect @@ -36,7 +37,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/go-querystring v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/modelcontextprotocol/go-sdk v1.1.0 + github.com/modelcontextprotocol/go-sdk v1.2.0-pre.1 github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect diff --git a/go.sum b/go.sum index e422a548c..fc0980ab1 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/go-openapi/swag v0.21.1 h1:wm0rhTb5z7qpJRHBdPOMuY4QjVUMbF6/kwoYeRAOrK github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= @@ -55,8 +57,8 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= -github.com/modelcontextprotocol/go-sdk v1.1.0 h1:Qjayg53dnKC4UZ+792W21e4BpwEZBzwgRW6LrjLWSwA= -github.com/modelcontextprotocol/go-sdk v1.1.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= +github.com/modelcontextprotocol/go-sdk v1.2.0-pre.1 h1:14+JrlEIFvUmbu5+iJzWPLk8CkpvegfKr42oXyjp3O4= +github.com/modelcontextprotocol/go-sdk v1.2.0-pre.1/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 h1:31Y+Yu373ymebRdJN1cWLLooHH8xAr0MhKTEJGV/87g= github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021/go.mod h1:WERUkUryfUWlrHnFSO/BEUZ+7Ns8aZy7iVOGewxKzcc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -88,6 +90,8 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 41f9016a2..4842b2c64 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -15,6 +15,7 @@ import ( "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/raw" @@ -42,6 +43,10 @@ type MCPServerConfig struct { // When specified, these tools are registered in addition to any specified toolset tools EnabledTools []string + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + // Whether to enable dynamic toolsets // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery DynamicToolsets bool @@ -64,138 +69,195 @@ type MCPServerConfig struct { RepoAccessTTL *time.Duration } -func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { - apiHost, err := parseAPIHost(cfg.Host) - if err != nil { - return nil, fmt.Errorf("failed to parse API host: %w", err) - } +// githubClients holds all the GitHub API clients created for a server instance. +type githubClients struct { + rest *gogithub.Client + gql *githubv4.Client + gqlHTTP *http.Client // retained for middleware to modify transport + raw *raw.Client + repoAccess *lockdown.RepoAccessCache +} - // Construct our REST client +// createGitHubClients creates all the GitHub API clients needed by the server. +func createGitHubClients(cfg MCPServerConfig, apiHost apiHost) (*githubClients, error) { + // Construct REST client restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) restClient.BaseURL = apiHost.baseRESTURL restClient.UploadURL = apiHost.uploadURL - // Construct our GraphQL client - // We're using NewEnterpriseClient here unconditionally as opposed to NewClient because we already - // did the necessary API host parsing so that github.com will return the correct URL anyway. + // Construct GraphQL client + // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ Transport: &bearerAuthTransport{ transport: http.DefaultTransport, token: cfg.Token, }, - } // We're going to wrap the Transport later in beforeInit - gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) - repoAccessOpts := []lockdown.RepoAccessOption{} - if cfg.RepoAccessTTL != nil { - repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL)) } + gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient) + + // Create raw content client (shares REST client's HTTP transport) + rawClient := raw.NewClient(restClient, apiHost.rawURL) - repoAccessLogger := cfg.Logger.With("component", "lockdown") - repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger)) + // Set up repo access cache for lockdown mode var repoAccessCache *lockdown.RepoAccessCache if cfg.LockdownMode { - repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) + opts := []lockdown.RepoAccessOption{ + lockdown.WithLogger(cfg.Logger.With("component", "lockdown")), + } + if cfg.RepoAccessTTL != nil { + opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL)) + } + repoAccessCache = lockdown.GetInstance(gqlClient, opts...) } + return &githubClients{ + rest: restClient, + gql: gqlClient, + gqlHTTP: gqlHTTPClient, + raw: rawClient, + repoAccess: repoAccessCache, + }, nil +} + +// resolveEnabledToolsets determines which toolsets should be enabled based on config. +// Returns nil for "use defaults", empty slice for "none", or explicit list. +func resolveEnabledToolsets(cfg MCPServerConfig) []string { enabledToolsets := cfg.EnabledToolsets - // If dynamic toolsets are enabled, remove "all" and "default" from the enabled toolsets - if cfg.DynamicToolsets { - enabledToolsets = github.RemoveToolset(enabledToolsets, github.ToolsetMetadataAll.ID) - enabledToolsets = github.RemoveToolset(enabledToolsets, github.ToolsetMetadataDefault.ID) + // In dynamic mode, remove "all" and "default" since users enable toolsets on demand + if cfg.DynamicToolsets && enabledToolsets != nil { + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataAll.ID)) + enabledToolsets = github.RemoveToolset(enabledToolsets, string(github.ToolsetMetadataDefault.ID)) } - // Clean up the passed toolsets - enabledToolsets, invalidToolsets := github.CleanToolsets(enabledToolsets) - - // If "all" is present, override all other toolsets - if github.ContainsToolset(enabledToolsets, github.ToolsetMetadataAll.ID) { - enabledToolsets = []string{github.ToolsetMetadataAll.ID} + if enabledToolsets != nil { + return enabledToolsets + } + if cfg.DynamicToolsets { + // Dynamic mode with no toolsets specified: start empty so users enable on demand + return []string{} } - // If "default" is present, expand to real toolset IDs - if github.ContainsToolset(enabledToolsets, github.ToolsetMetadataDefault.ID) { - enabledToolsets = github.AddDefaultToolset(enabledToolsets) + if len(cfg.EnabledTools) > 0 { + // When specific tools are requested but no toolsets, don't use default toolsets + // This matches the original behavior: --tools=X alone registers only X + return []string{} + } + // nil means "use defaults" in WithToolsets + return nil +} + +func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) { + apiHost, err := parseAPIHost(cfg.Host) + if err != nil { + return nil, fmt.Errorf("failed to parse API host: %w", err) } - if len(invalidToolsets) > 0 { - fmt.Fprintf(os.Stderr, "Invalid toolsets ignored: %s\n", strings.Join(invalidToolsets, ", ")) + clients, err := createGitHubClients(cfg, apiHost) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } - // Generate instructions based on enabled toolsets - instructions := github.GenerateInstructions(enabledToolsets) + enabledToolsets := resolveEnabledToolsets(cfg) - getClient := func(_ context.Context) (*gogithub.Client, error) { - return restClient, nil // closing over client + // For instruction generation, we need actual toolset names (not nil). + // nil means "use defaults" in inventory, so expand it for instructions. + instructionToolsets := enabledToolsets + if instructionToolsets == nil { + instructionToolsets = github.GetDefaultToolsetIDs() } - getGQLClient := func(_ context.Context) (*githubv4.Client, error) { - return gqlClient, nil // closing over client + // Create the MCP server + serverOpts := &mcp.ServerOptions{ + Instructions: github.GenerateInstructions(instructionToolsets), + Logger: cfg.Logger, + CompletionHandler: github.CompletionsHandler(func(_ context.Context) (*gogithub.Client, error) { + return clients.rest, nil + }), } - getRawClient := func(ctx context.Context) (*raw.Client, error) { - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + // In dynamic mode, explicitly advertise capabilities since tools/resources/prompts + // may be enabled at runtime even if none are registered initially. + if cfg.DynamicToolsets { + serverOpts.Capabilities = &mcp.ServerCapabilities{ + Tools: &mcp.ToolCapabilities{}, + Resources: &mcp.ResourceCapabilities{}, + Prompts: &mcp.PromptCapabilities{}, } - return raw.NewClient(client, apiHost.rawURL), nil // closing over client } - ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{ - Instructions: instructions, - Logger: cfg.Logger, - CompletionHandler: github.CompletionsHandler(getClient), - }) + ghServer := github.NewServer(cfg.Version, serverOpts) // Add middlewares ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext) - ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient)) - - // Create default toolsets - tsg := github.DefaultToolsetGroup( - cfg.ReadOnly, - getClient, - getGQLClient, - getRawClient, + ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, clients.rest, clients.gqlHTTP)) + + // Create dependencies for tool handlers + deps := github.NewBaseDeps( + clients.rest, + clients.gql, + clients.raw, + clients.repoAccess, cfg.Translator, - cfg.ContentWindowSize, github.FeatureFlags{LockdownMode: cfg.LockdownMode}, - repoAccessCache, + cfg.ContentWindowSize, ) - // Enable and register toolsets if configured - // This always happens if toolsets are specified, regardless of whether tools are also specified - if len(enabledToolsets) > 0 { - err = tsg.EnableToolsets(enabledToolsets, nil) - if err != nil { - return nil, fmt.Errorf("failed to enable toolsets: %w", err) - } + // Build and register the tool/resource/prompt inventory + inventory := github.NewInventory(cfg.Translator). + WithDeprecatedAliases(github.DeprecatedToolAliases). + WithReadOnly(cfg.ReadOnly). + WithToolsets(enabledToolsets). + WithTools(github.CleanTools(cfg.EnabledTools)). + WithFeatureChecker(createFeatureChecker(cfg.EnabledFeatures)). + Build() - // Register all mcp functionality with the server - tsg.RegisterAll(ghServer) + if unrecognized := inventory.UnrecognizedToolsets(); len(unrecognized) > 0 { + fmt.Fprintf(os.Stderr, "Warning: unrecognized toolsets ignored: %s\n", strings.Join(unrecognized, ", ")) } - // Register specific tools if configured - if len(cfg.EnabledTools) > 0 { - enabledTools := github.CleanTools(cfg.EnabledTools) - enabledTools, _ = tsg.ResolveToolAliases(enabledTools) + // Register GitHub tools/resources/prompts from the inventory. + // In dynamic mode with no explicit toolsets, this is a no-op since enabledToolsets + // is empty - users enable toolsets at runtime via the dynamic tools below (but can + // enable toolsets or tools explicitly that do need registration). + inventory.RegisterAll(context.Background(), ghServer, deps) - // Register the specified tools (additive to any toolsets already enabled) - err = tsg.RegisterSpecificTools(ghServer, enabledTools, cfg.ReadOnly) - if err != nil { - return nil, fmt.Errorf("failed to register tools: %w", err) - } - } - - // Register dynamic toolsets if configured (additive to toolsets and tools) + // Register dynamic toolset management tools (enable/disable) - these are separate + // meta-tools that control the inventory, not part of the inventory itself if cfg.DynamicToolsets { - dynamic := github.InitDynamicToolset(ghServer, tsg, cfg.Translator) - dynamic.RegisterTools(ghServer) + registerDynamicTools(ghServer, inventory, deps, cfg.Translator) } return ghServer, nil } +// registerDynamicTools adds the dynamic toolset enable/disable tools to the server. +func registerDynamicTools(server *mcp.Server, inventory *inventory.Inventory, deps *github.BaseDeps, t translations.TranslationHelperFunc) { + dynamicDeps := github.DynamicToolDependencies{ + Server: server, + Inventory: inventory, + ToolDeps: deps, + T: t, + } + for _, tool := range github.DynamicTools(inventory) { + tool.RegisterFunc(server, dynamicDeps) + } +} + +// createFeatureChecker returns a FeatureFlagChecker that checks if a flag name +// is present in the provided list of enabled features. For the local server, +// this is populated from the --features CLI flag. +func createFeatureChecker(enabledFeatures []string) inventory.FeatureFlagChecker { + // Build a set for O(1) lookup + featureSet := make(map[string]bool, len(enabledFeatures)) + for _, f := range enabledFeatures { + featureSet[f] = true + } + return func(_ context.Context, flagName string) (bool, error) { + return featureSet[flagName], nil + } +} + type StdioServerConfig struct { // Version of the server Version string @@ -214,6 +276,10 @@ type StdioServerConfig struct { // When specified, these tools are registered in addition to any specified toolset tools EnabledTools []string + // EnabledFeatures is a list of feature flags that are enabled + // Items with FeatureFlagEnable matching an entry in this list will be available + EnabledFeatures []string + // Whether to enable dynamic toolsets // See: https://github.com/github/github-mcp-server?tab=readme-ov-file#dynamic-tool-discovery DynamicToolsets bool @@ -271,6 +337,7 @@ func RunStdioServer(cfg StdioServerConfig) error { Token: cfg.Token, EnabledToolsets: cfg.EnabledToolsets, EnabledTools: cfg.EnabledTools, + EnabledFeatures: cfg.EnabledFeatures, DynamicToolsets: cfg.DynamicToolsets, ReadOnly: cfg.ReadOnly, Translator: t, diff --git a/pkg/errors/error.go b/pkg/errors/error.go index be2cf58f9..095f8d5b7 100644 --- a/pkg/errors/error.go +++ b/pkg/errors/error.go @@ -124,3 +124,11 @@ func NewGitHubGraphQLErrorResponse(ctx context.Context, message string, err erro } return utils.NewToolResultErrorFromErr(message, err) } + +// NewGitHubAPIStatusErrorResponse handles cases where the API call succeeds (err == nil) +// but returns an unexpected HTTP status code. It creates a synthetic error from the +// status code and response body, then records it in context for observability tracking. +func NewGitHubAPIStatusErrorResponse(ctx context.Context, message string, resp *github.Response, body []byte) *mcp.CallToolResult { + err := fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + return NewGitHubAPIErrorResponse(ctx, message, resp, err) +} diff --git a/pkg/errors/error_test.go b/pkg/errors/error_test.go index 0d7aa6afa..b5ef40596 100644 --- a/pkg/errors/error_test.go +++ b/pkg/errors/error_test.go @@ -231,6 +231,33 @@ func TestGitHubErrorContext(t *testing.T) { assert.Equal(t, originalErr, gqlError.Err) }) + t.Run("NewGitHubAPIStatusErrorResponse creates MCP error result from status code", func(t *testing.T) { + // Given a context with GitHub error tracking enabled + ctx := ContextWithGitHubErrors(context.Background()) + + resp := &github.Response{Response: &http.Response{StatusCode: 422}} + body := []byte(`{"message": "Validation Failed"}`) + + // When we create a status error response + result := NewGitHubAPIStatusErrorResponse(ctx, "failed to create issue", resp, body) + + // Then it should return an MCP error result + require.NotNil(t, result) + assert.True(t, result.IsError) + + // And the error should be stored in the context + apiErrors, err := GetGitHubAPIErrors(ctx) + require.NoError(t, err) + require.Len(t, apiErrors, 1) + + apiError := apiErrors[0] + assert.Equal(t, "failed to create issue", apiError.Message) + assert.Equal(t, resp, apiError.Response) + // The synthetic error should contain the status code and body + assert.Contains(t, apiError.Err.Error(), "unexpected status 422") + assert.Contains(t, apiError.Err.Error(), "Validation Failed") + }) + t.Run("NewGitHubAPIErrorToCtx with uninitialized context does not error", func(t *testing.T) { // Given a regular context without GitHub error tracking initialized ctx := context.Background() diff --git a/pkg/github/__toolsnaps__/actions_get.snap b/pkg/github/__toolsnaps__/actions_get.snap new file mode 100644 index 000000000..b5f3b85bd --- /dev/null +++ b/pkg/github/__toolsnaps__/actions_get.snap @@ -0,0 +1,43 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "Get details of GitHub Actions resources (workflows, workflow runs, jobs, and artifacts)" + }, + "description": "Get details about specific GitHub Actions resources.\nUse this tool to get details about individual workflows, workflow runs, jobs, and artifacts by their unique IDs.\n", + "inputSchema": { + "type": "object", + "required": [ + "method", + "owner", + "repo", + "resource_id" + ], + "properties": { + "method": { + "type": "string", + "description": "The method to execute", + "enum": [ + "get_workflow", + "get_workflow_run", + "get_workflow_job", + "download_workflow_run_artifact", + "get_workflow_run_usage", + "get_workflow_run_logs_url" + ] + }, + "owner": { + "type": "string", + "description": "Repository owner" + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "resource_id": { + "type": "string", + "description": "The unique identifier of the resource. This will vary based on the \"method\" provided, so ensure you provide the correct ID:\n- Provide a workflow ID or workflow file name (e.g. ci.yaml) for 'get_workflow' method.\n- Provide a workflow run ID for 'get_workflow_run', 'get_workflow_run_usage', and 'get_workflow_run_logs_url' methods.\n- Provide an artifact ID for 'download_workflow_run_artifact' method.\n- Provide a job ID for 'get_workflow_job' method.\n" + } + } + }, + "name": "actions_get" +} \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/actions_list.snap b/pkg/github/__toolsnaps__/actions_list.snap new file mode 100644 index 000000000..3968a6eae --- /dev/null +++ b/pkg/github/__toolsnaps__/actions_list.snap @@ -0,0 +1,128 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "List GitHub Actions workflows in a repository" + }, + "description": "Tools for listing GitHub Actions resources.\nUse this tool to list workflows in a repository, or list workflow runs, jobs, and artifacts for a specific workflow or workflow run.\n", + "inputSchema": { + "type": "object", + "required": [ + "method", + "owner", + "repo" + ], + "properties": { + "method": { + "type": "string", + "description": "The action to perform", + "enum": [ + "list_workflows", + "list_workflow_runs", + "list_workflow_jobs", + "list_workflow_run_artifacts" + ] + }, + "owner": { + "type": "string", + "description": "Repository owner" + }, + "page": { + "type": "number", + "description": "Page number for pagination (default: 1)", + "minimum": 1 + }, + "per_page": { + "type": "number", + "description": "Results per page for pagination (default: 30, max: 100)", + "minimum": 1, + "maximum": 100 + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "resource_id": { + "type": "string", + "description": "The unique identifier of the resource. This will vary based on the \"method\" provided, so ensure you provide the correct ID:\n- Do not provide any resource ID for 'list_workflows' method.\n- Provide a workflow ID or workflow file name (e.g. ci.yaml) for 'list_workflow_runs' method.\n- Provide a workflow run ID for 'list_workflow_jobs' and 'list_workflow_run_artifacts' methods.\n" + }, + "workflow_jobs_filter": { + "type": "object", + "description": "Filters for workflow jobs. **ONLY** used when method is 'list_workflow_jobs'", + "properties": { + "filter": { + "type": "string", + "description": "Filters jobs by their completed_at timestamp", + "enum": [ + "latest", + "all" + ] + } + } + }, + "workflow_runs_filter": { + "type": "object", + "description": "Filters for workflow runs. **ONLY** used when method is 'list_workflow_runs'", + "properties": { + "actor": { + "type": "string", + "description": "Filter to a specific GitHub user's workflow runs." + }, + "branch": { + "type": "string", + "description": "Filter workflow runs to a specific Git branch. Use the name of the branch." + }, + "event": { + "type": "string", + "description": "Filter workflow runs to a specific event type", + "enum": [ + "branch_protection_rule", + "check_run", + "check_suite", + "create", + "delete", + "deployment", + "deployment_status", + "discussion", + "discussion_comment", + "fork", + "gollum", + "issue_comment", + "issues", + "label", + "merge_group", + "milestone", + "page_build", + "public", + "pull_request", + "pull_request_review", + "pull_request_review_comment", + "pull_request_target", + "push", + "registry_package", + "release", + "repository_dispatch", + "schedule", + "status", + "watch", + "workflow_call", + "workflow_dispatch", + "workflow_run" + ] + }, + "status": { + "type": "string", + "description": "Filter workflow runs to only runs with a specific status", + "enum": [ + "queued", + "in_progress", + "completed", + "requested", + "waiting" + ] + } + } + } + } + }, + "name": "actions_list" +} \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/actions_run_trigger.snap b/pkg/github/__toolsnaps__/actions_run_trigger.snap new file mode 100644 index 000000000..4e16f8958 --- /dev/null +++ b/pkg/github/__toolsnaps__/actions_run_trigger.snap @@ -0,0 +1,53 @@ +{ + "annotations": { + "destructiveHint": true, + "title": "Trigger GitHub Actions workflow actions" + }, + "description": "Trigger GitHub Actions workflow operations, including running, re-running, cancelling workflow runs, and deleting workflow run logs.", + "inputSchema": { + "type": "object", + "required": [ + "method", + "owner", + "repo" + ], + "properties": { + "inputs": { + "type": "object", + "description": "Inputs the workflow accepts. Only used for 'run_workflow' method." + }, + "method": { + "type": "string", + "description": "The method to execute", + "enum": [ + "run_workflow", + "rerun_workflow_run", + "rerun_failed_jobs", + "cancel_workflow_run", + "delete_workflow_run_logs" + ] + }, + "owner": { + "type": "string", + "description": "Repository owner" + }, + "ref": { + "type": "string", + "description": "The git reference for the workflow. The reference can be a branch or tag name. Required for 'run_workflow' method." + }, + "repo": { + "type": "string", + "description": "Repository name" + }, + "run_id": { + "type": "number", + "description": "The ID of the workflow run. Required for all methods except 'run_workflow'." + }, + "workflow_id": { + "type": "string", + "description": "The workflow ID (numeric) or workflow file name (e.g., main.yml, ci.yaml). Required for 'run_workflow' method." + } + } + }, + "name": "actions_run_trigger" +} \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/assign_copilot_to_issue.snap b/pkg/github/__toolsnaps__/assign_copilot_to_issue.snap index e250ca9c1..aff4aa597 100644 --- a/pkg/github/__toolsnaps__/assign_copilot_to_issue.snap +++ b/pkg/github/__toolsnaps__/assign_copilot_to_issue.snap @@ -26,5 +26,23 @@ } } }, - "name": "assign_copilot_to_issue" + "name": "assign_copilot_to_issue", + "icons": [ + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "light" + }, + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "dark" + } + ] } \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/create_or_update_file.snap b/pkg/github/__toolsnaps__/create_or_update_file.snap index 4ec2ae914..2d9ae1144 100644 --- a/pkg/github/__toolsnaps__/create_or_update_file.snap +++ b/pkg/github/__toolsnaps__/create_or_update_file.snap @@ -2,7 +2,7 @@ "annotations": { "title": "Create or update file" }, - "description": "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations.", + "description": "Create or update a single file in a GitHub repository. \nIf updating, you should provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations.\n\nIn order to obtain the SHA of original file version before updating, use the following git command:\ngit ls-tree HEAD \u003cpath to file\u003e\n\nIf the SHA is not provided, the tool will attempt to acquire it by fetching the current file contents from the repository, which may lead to rewriting latest committed changes if the file has changed since last retrieval.\n", "inputSchema": { "type": "object", "required": [ @@ -40,7 +40,7 @@ }, "sha": { "type": "string", - "description": "Required if updating an existing file. The blob SHA of the file being replaced." + "description": "The blob SHA of the file being replaced." } } }, diff --git a/pkg/github/__toolsnaps__/fork_repository.snap b/pkg/github/__toolsnaps__/fork_repository.snap index c195bd7d2..ad48688b1 100644 --- a/pkg/github/__toolsnaps__/fork_repository.snap +++ b/pkg/github/__toolsnaps__/fork_repository.snap @@ -24,5 +24,23 @@ } } }, - "name": "fork_repository" + "name": "fork_repository", + "icons": [ + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "light" + }, + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "dark" + } + ] } \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/merge_pull_request.snap b/pkg/github/__toolsnaps__/merge_pull_request.snap index 50d040f2a..2801b6a53 100644 --- a/pkg/github/__toolsnaps__/merge_pull_request.snap +++ b/pkg/github/__toolsnaps__/merge_pull_request.snap @@ -42,5 +42,23 @@ } } }, - "name": "merge_pull_request" + "name": "merge_pull_request", + "icons": [ + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "light" + }, + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "dark" + } + ] } \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/request_copilot_review.snap b/pkg/github/__toolsnaps__/request_copilot_review.snap index b967b51cc..06828776c 100644 --- a/pkg/github/__toolsnaps__/request_copilot_review.snap +++ b/pkg/github/__toolsnaps__/request_copilot_review.snap @@ -25,5 +25,23 @@ } } }, - "name": "request_copilot_review" + "name": "request_copilot_review", + "icons": [ + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "light" + }, + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "dark" + } + ] } \ No newline at end of file diff --git a/pkg/github/__toolsnaps__/star_repository.snap b/pkg/github/__toolsnaps__/star_repository.snap index 382d40395..165a011bd 100644 --- a/pkg/github/__toolsnaps__/star_repository.snap +++ b/pkg/github/__toolsnaps__/star_repository.snap @@ -20,5 +20,23 @@ } } }, - "name": "star_repository" + "name": "star_repository", + "icons": [ + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "light" + }, + { + "src": "", + "mimeType": "image/png", + "sizes": [ + "24x24" + ], + "theme": "dark" + } + ] } \ No newline at end of file diff --git a/pkg/github/actions.go b/pkg/github/actions.go index 81ed55296..6c7cdc367 100644 --- a/pkg/github/actions.go +++ b/pkg/github/actions.go @@ -3,6 +3,7 @@ package github import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strconv" @@ -11,6 +12,7 @@ import ( "github.com/github/github-mcp-server/internal/profiler" buffer "github.com/github/github-mcp-server/pkg/buffer" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -23,9 +25,34 @@ const ( DescriptionRepositoryName = "Repository name" ) +// FeatureFlagConsolidatedActions is the feature flag that disables individual actions tools +// in favor of the consolidated actions tools. +const FeatureFlagConsolidatedActions = "remote_mcp_consolidated_actions" + +// Method constants for consolidated actions tools +const ( + actionsMethodListWorkflows = "list_workflows" + actionsMethodListWorkflowRuns = "list_workflow_runs" + actionsMethodListWorkflowJobs = "list_workflow_jobs" + actionsMethodListWorkflowArtifacts = "list_workflow_run_artifacts" + actionsMethodGetWorkflow = "get_workflow" + actionsMethodGetWorkflowRun = "get_workflow_run" + actionsMethodGetWorkflowJob = "get_workflow_job" + actionsMethodGetWorkflowRunUsage = "get_workflow_run_usage" + actionsMethodGetWorkflowRunLogsURL = "get_workflow_run_logs_url" + actionsMethodDownloadWorkflowArtifact = "download_workflow_run_artifact" + actionsMethodRunWorkflow = "run_workflow" + actionsMethodRerunWorkflowRun = "rerun_workflow_run" + actionsMethodRerunFailedJobs = "rerun_failed_jobs" + actionsMethodCancelWorkflowRun = "cancel_workflow_run" + actionsMethodDeleteWorkflowRunLogs = "delete_workflow_run_logs" +) + // ListWorkflows creates a tool to list workflows in a repository -func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflows(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "list_workflows", Description: t("TOOL_LIST_WORKFLOWS_DESCRIPTION", "List workflows in a repository"), Annotations: &mcp.ToolAnnotations{ @@ -47,7 +74,12 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) Required: []string{"owner", "repo"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -63,11 +95,6 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Set up list options opts := &github.ListOptions{ PerPage: pagination.PerPage, @@ -86,12 +113,17 @@ func ListWorkflows(getClient GetClientFn, t translations.TranslationHelperFunc) } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // ListWorkflowRuns creates a tool to list workflow runs for a specific workflow -func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflowRuns(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "list_workflow_runs", Description: t("TOOL_LIST_WORKFLOW_RUNS_DESCRIPTION", "List workflow runs for a specific workflow"), Annotations: &mcp.ToolAnnotations{ @@ -168,7 +200,12 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "workflow_id"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -206,11 +243,6 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Set up list options opts := &github.ListWorkflowRunsOptions{ Actor: actor, @@ -235,12 +267,17 @@ func ListWorkflowRuns(getClient GetClientFn, t translations.TranslationHelperFun } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // RunWorkflow creates a tool to run an Actions workflow -func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func RunWorkflow(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "run_workflow", Description: t("TOOL_RUN_WORKFLOW_DESCRIPTION", "Run an Actions workflow by workflow ID or filename"), Annotations: &mcp.ToolAnnotations{ @@ -274,7 +311,12 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (m Required: []string{"owner", "repo", "workflow_id", "ref"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -300,11 +342,6 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (m } } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - event := github.CreateWorkflowDispatchEventRequest{ Ref: ref, Inputs: inputs, @@ -342,12 +379,17 @@ func RunWorkflow(getClient GetClientFn, t translations.TranslationHelperFunc) (m } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // GetWorkflowRun creates a tool to get details of a specific workflow run -func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetWorkflowRun(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "get_workflow_run", Description: t("TOOL_GET_WORKFLOW_RUN_DESCRIPTION", "Get details of a specific workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -373,7 +415,12 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -388,11 +435,6 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, runID) if err != nil { return nil, nil, fmt.Errorf("failed to get workflow run: %w", err) @@ -405,12 +447,17 @@ func GetWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // GetWorkflowRunLogs creates a tool to download logs for a specific workflow run -func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetWorkflowRunLogs(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "get_workflow_run_logs", Description: t("TOOL_GET_WORKFLOW_RUN_LOGS_DESCRIPTION", "Download logs for a specific workflow run (EXPENSIVE: downloads ALL logs as ZIP. Consider using get_job_logs with failed_only=true for debugging failed jobs)"), Annotations: &mcp.ToolAnnotations{ @@ -436,7 +483,12 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -451,11 +503,6 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Get the download URL for the logs url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, owner, repo, runID, 1) if err != nil { @@ -478,12 +525,17 @@ func GetWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperF } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // ListWorkflowJobs creates a tool to list jobs for a specific workflow run -func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflowJobs(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "list_workflow_jobs", Description: t("TOOL_LIST_WORKFLOW_JOBS_DESCRIPTION", "List jobs for a specific workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -514,7 +566,12 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "run_id"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -541,11 +598,6 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Set up list options opts := &github.ListWorkflowJobsOptions{ Filter: filter, @@ -573,12 +625,17 @@ func ListWorkflowJobs(getClient GetClientFn, t translations.TranslationHelperFun } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // GetJobLogs creates a tool to download logs for a specific workflow job or efficiently get all failed job logs for a workflow run -func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc, contentWindowSize int) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetJobLogs(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "get_job_logs", Description: t("TOOL_GET_JOB_LOGS_DESCRIPTION", "Download logs for a specific workflow job or efficiently get all failed job logs for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -621,7 +678,12 @@ func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc, con Required: []string{"owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -657,11 +719,6 @@ func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc, con tailLines = 500 } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Validate parameters if failedOnly && runID == 0 { return utils.NewToolResultError("run_id is required when failed_only is true"), nil, nil @@ -672,14 +729,17 @@ func GetJobLogs(getClient GetClientFn, t translations.TranslationHelperFunc, con if failedOnly && runID > 0 { // Handle failed-only mode: get logs for all failed jobs in the workflow run - return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, contentWindowSize) + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) } else if jobID > 0 { // Handle single job mode - return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, contentWindowSize) + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) } return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // handleFailedJobLogs gets logs for all failed jobs in a workflow run @@ -837,8 +897,10 @@ func downloadLogContent(ctx context.Context, logURL string, tailLines int, maxLi } // RerunWorkflowRun creates a tool to re-run an entire workflow run -func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func RerunWorkflowRun(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "rerun_workflow_run", Description: t("TOOL_RERUN_WORKFLOW_RUN_DESCRIPTION", "Re-run an entire workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -864,7 +926,12 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -879,11 +946,6 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil, nil @@ -903,12 +965,17 @@ func RerunWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFun } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // RerunFailedJobs creates a tool to re-run only the failed jobs in a workflow run -func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func RerunFailedJobs(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "rerun_failed_jobs", Description: t("TOOL_RERUN_FAILED_JOBS_DESCRIPTION", "Re-run only the failed jobs in a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -934,7 +1001,12 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -949,11 +1021,6 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil, nil @@ -973,12 +1040,17 @@ func RerunFailedJobs(getClient GetClientFn, t translations.TranslationHelperFunc } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // CancelWorkflowRun creates a tool to cancel a workflow run -func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func CancelWorkflowRun(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "cancel_workflow_run", Description: t("TOOL_CANCEL_WORKFLOW_RUN_DESCRIPTION", "Cancel a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1004,7 +1076,12 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1019,11 +1096,6 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) if err != nil { if _, ok := err.(*github.AcceptedError); !ok { @@ -1045,12 +1117,17 @@ func CancelWorkflowRun(getClient GetClientFn, t translations.TranslationHelperFu } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // ListWorkflowRunArtifacts creates a tool to list artifacts for a workflow run -func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListWorkflowRunArtifacts(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "list_workflow_run_artifacts", Description: t("TOOL_LIST_WORKFLOW_RUN_ARTIFACTS_DESCRIPTION", "List artifacts for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1076,7 +1153,12 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH Required: []string{"owner", "repo", "run_id"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1097,11 +1179,6 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Set up list options opts := &github.ListOptions{ PerPage: pagination.PerPage, @@ -1120,12 +1197,17 @@ func ListWorkflowRunArtifacts(getClient GetClientFn, t translations.TranslationH } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // DownloadWorkflowRunArtifact creates a tool to download a workflow run artifact -func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DownloadWorkflowRunArtifact(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "download_workflow_run_artifact", Description: t("TOOL_DOWNLOAD_WORKFLOW_RUN_ARTIFACT_DESCRIPTION", "Get download URL for a workflow run artifact"), Annotations: &mcp.ToolAnnotations{ @@ -1151,7 +1233,12 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati Required: []string{"owner", "repo", "artifact_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1166,11 +1253,6 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati } artifactID := int64(artifactIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - // Get the download URL for the artifact url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, artifactID, 1) if err != nil { @@ -1192,12 +1274,17 @@ func DownloadWorkflowRunArtifact(getClient GetClientFn, t translations.Translati } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // DeleteWorkflowRunLogs creates a tool to delete logs for a workflow run -func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DeleteWorkflowRunLogs(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "delete_workflow_run_logs", Description: t("TOOL_DELETE_WORKFLOW_RUN_LOGS_DESCRIPTION", "Delete logs for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1224,7 +1311,12 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1239,11 +1331,6 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil, nil @@ -1263,12 +1350,17 @@ func DeleteWorkflowRunLogs(getClient GetClientFn, t translations.TranslationHelp } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool } // GetWorkflowRunUsage creates a tool to get usage metrics for a workflow run -func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetWorkflowRunUsage(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ Name: "get_workflow_run_usage", Description: t("TOOL_GET_WORKFLOW_RUN_USAGE_DESCRIPTION", "Get usage metrics for a workflow run"), Annotations: &mcp.ToolAnnotations{ @@ -1294,7 +1386,12 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper Required: []string{"owner", "repo", "run_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1309,11 +1406,6 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper } runID := int64(runIDInt) - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, runID) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil, nil @@ -1326,5 +1418,921 @@ func GetWorkflowRunUsage(getClient GetClientFn, t translations.TranslationHelper } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) + tool.FeatureFlagDisable = FeatureFlagConsolidatedActions + return tool +} + +// ActionsList returns the tool and handler for listing GitHub Actions resources. +func ActionsList(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ + Name: "actions_list", + Description: t("TOOL_ACTIONS_LIST_DESCRIPTION", + `Tools for listing GitHub Actions resources. +Use this tool to list workflows in a repository, or list workflow runs, jobs, and artifacts for a specific workflow or workflow run. +`), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_ACTIONS_LIST_USER_TITLE", "List GitHub Actions workflows in a repository"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: "The action to perform", + Enum: []any{ + actionsMethodListWorkflows, + actionsMethodListWorkflowRuns, + actionsMethodListWorkflowJobs, + actionsMethodListWorkflowArtifacts, + }, + }, + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "resource_id": { + Type: "string", + Description: `The unique identifier of the resource. This will vary based on the "method" provided, so ensure you provide the correct ID: +- Do not provide any resource ID for 'list_workflows' method. +- Provide a workflow ID or workflow file name (e.g. ci.yaml) for 'list_workflow_runs' method. +- Provide a workflow run ID for 'list_workflow_jobs' and 'list_workflow_run_artifacts' methods. +`, + }, + "workflow_runs_filter": { + Type: "object", + Description: "Filters for workflow runs. **ONLY** used when method is 'list_workflow_runs'", + Properties: map[string]*jsonschema.Schema{ + "actor": { + Type: "string", + Description: "Filter to a specific GitHub user's workflow runs.", + }, + "branch": { + Type: "string", + Description: "Filter workflow runs to a specific Git branch. Use the name of the branch.", + }, + "event": { + Type: "string", + Description: "Filter workflow runs to a specific event type", + Enum: []any{ + "branch_protection_rule", + "check_run", + "check_suite", + "create", + "delete", + "deployment", + "deployment_status", + "discussion", + "discussion_comment", + "fork", + "gollum", + "issue_comment", + "issues", + "label", + "merge_group", + "milestone", + "page_build", + "public", + "pull_request", + "pull_request_review", + "pull_request_review_comment", + "pull_request_target", + "push", + "registry_package", + "release", + "repository_dispatch", + "schedule", + "status", + "watch", + "workflow_call", + "workflow_dispatch", + "workflow_run", + }, + }, + "status": { + Type: "string", + Description: "Filter workflow runs to only runs with a specific status", + Enum: []any{"queued", "in_progress", "completed", "requested", "waiting"}, + }, + }, + }, + "workflow_jobs_filter": { + Type: "object", + Description: "Filters for workflow jobs. **ONLY** used when method is 'list_workflow_jobs'", + Properties: map[string]*jsonschema.Schema{ + "filter": { + Type: "string", + Description: "Filters jobs by their completed_at timestamp", + Enum: []any{"latest", "all"}, + }, + }, + }, + "page": { + Type: "number", + Description: "Page number for pagination (default: 1)", + Minimum: jsonschema.Ptr(1.0), + }, + "per_page": { + Type: "number", + Description: "Results per page for pagination (default: 30, max: 100)", + Minimum: jsonschema.Ptr(1.0), + Maximum: jsonschema.Ptr(100.0), + }, + }, + Required: []string{"method", "owner", "repo"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + resourceID, err := OptionalParam[string](args, "resource_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + var resourceIDInt int64 + var parseErr error + switch method { + case actionsMethodListWorkflows: + // Do nothing, no resource ID needed + default: + if resourceID == "" { + return utils.NewToolResultError(fmt.Sprintf("missing required parameter for method %s: resource_id", method)), nil, nil + } + + // For list_workflow_runs, resource_id could be a filename or numeric ID + // For other actions, resource ID must be an integer + if method != actionsMethodListWorkflowRuns { + resourceIDInt, parseErr = strconv.ParseInt(resourceID, 10, 64) + if parseErr != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid resource_id, must be an integer for method %s: %v", method, parseErr)), nil, nil + } + } + } + + switch method { + case actionsMethodListWorkflows: + return listWorkflows(ctx, client, owner, repo, pagination) + case actionsMethodListWorkflowRuns: + return listWorkflowRuns(ctx, client, args, owner, repo, resourceID, pagination) + case actionsMethodListWorkflowJobs: + return listWorkflowJobs(ctx, client, args, owner, repo, resourceIDInt, pagination) + case actionsMethodListWorkflowArtifacts: + return listWorkflowArtifacts(ctx, client, owner, repo, resourceIDInt, pagination) + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + } + }, + ) + tool.FeatureFlagEnable = FeatureFlagConsolidatedActions + return tool +} + +// ActionsGet returns the tool and handler for getting GitHub Actions resources. +func ActionsGet(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ + Name: "actions_get", + Description: t("TOOL_ACTIONS_GET_DESCRIPTION", `Get details about specific GitHub Actions resources. +Use this tool to get details about individual workflows, workflow runs, jobs, and artifacts by their unique IDs. +`), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_ACTIONS_GET_USER_TITLE", "Get details of GitHub Actions resources (workflows, workflow runs, jobs, and artifacts)"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: "The method to execute", + Enum: []any{ + actionsMethodGetWorkflow, + actionsMethodGetWorkflowRun, + actionsMethodGetWorkflowJob, + actionsMethodDownloadWorkflowArtifact, + actionsMethodGetWorkflowRunUsage, + actionsMethodGetWorkflowRunLogsURL, + }, + }, + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "resource_id": { + Type: "string", + Description: `The unique identifier of the resource. This will vary based on the "method" provided, so ensure you provide the correct ID: +- Provide a workflow ID or workflow file name (e.g. ci.yaml) for 'get_workflow' method. +- Provide a workflow run ID for 'get_workflow_run', 'get_workflow_run_usage', and 'get_workflow_run_logs_url' methods. +- Provide an artifact ID for 'download_workflow_run_artifact' method. +- Provide a job ID for 'get_workflow_job' method. +`, + }, + }, + Required: []string{"method", "owner", "repo", "resource_id"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + resourceID, err := RequiredParam[string](args, "resource_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + var resourceIDInt int64 + var parseErr error + switch method { + case actionsMethodGetWorkflow: + // Do nothing, we accept both a string workflow ID or filename + default: + // For other methods, resource ID must be an integer + resourceIDInt, parseErr = strconv.ParseInt(resourceID, 10, 64) + if parseErr != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid resource_id, must be an integer for method %s: %v", method, parseErr)), nil, nil + } + } + + switch method { + case actionsMethodGetWorkflow: + return getWorkflow(ctx, client, owner, repo, resourceID) + case actionsMethodGetWorkflowRun: + return getWorkflowRun(ctx, client, owner, repo, resourceIDInt) + case actionsMethodGetWorkflowJob: + return getWorkflowJob(ctx, client, owner, repo, resourceIDInt) + case actionsMethodDownloadWorkflowArtifact: + return downloadWorkflowArtifact(ctx, client, owner, repo, resourceIDInt) + case actionsMethodGetWorkflowRunUsage: + return getWorkflowRunUsage(ctx, client, owner, repo, resourceIDInt) + case actionsMethodGetWorkflowRunLogsURL: + return getWorkflowRunLogsURL(ctx, client, owner, repo, resourceIDInt) + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + } + }, + ) + tool.FeatureFlagEnable = FeatureFlagConsolidatedActions + return tool +} + +// ActionsRunTrigger returns the tool and handler for triggering GitHub Actions workflows. +func ActionsRunTrigger(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ + Name: "actions_run_trigger", + Description: t("TOOL_ACTIONS_RUN_TRIGGER_DESCRIPTION", "Trigger GitHub Actions workflow operations, including running, re-running, cancelling workflow runs, and deleting workflow run logs."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_ACTIONS_RUN_TRIGGER_USER_TITLE", "Trigger GitHub Actions workflow actions"), + ReadOnlyHint: false, + DestructiveHint: jsonschema.Ptr(true), + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: "The method to execute", + Enum: []any{ + actionsMethodRunWorkflow, + actionsMethodRerunWorkflowRun, + actionsMethodRerunFailedJobs, + actionsMethodCancelWorkflowRun, + actionsMethodDeleteWorkflowRunLogs, + }, + }, + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "workflow_id": { + Type: "string", + Description: "The workflow ID (numeric) or workflow file name (e.g., main.yml, ci.yaml). Required for 'run_workflow' method.", + }, + "ref": { + Type: "string", + Description: "The git reference for the workflow. The reference can be a branch or tag name. Required for 'run_workflow' method.", + }, + "inputs": { + Type: "object", + Description: "Inputs the workflow accepts. Only used for 'run_workflow' method.", + }, + "run_id": { + Type: "number", + Description: "The ID of the workflow run. Required for all methods except 'run_workflow'.", + }, + }, + Required: []string{"method", "owner", "repo"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + method, err := RequiredParam[string](args, "method") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + // Get optional parameters + workflowID, _ := OptionalParam[string](args, "workflow_id") + ref, _ := OptionalParam[string](args, "ref") + runID, _ := OptionalIntParam(args, "run_id") + + // Get optional inputs parameter + var inputs map[string]interface{} + if requestInputs, ok := args["inputs"]; ok { + if inputsMap, ok := requestInputs.(map[string]interface{}); ok { + inputs = inputsMap + } + } + + // Validate required parameters based on action type + if method == actionsMethodRunWorkflow { + if workflowID == "" { + return utils.NewToolResultError("workflow_id is required for run_workflow action"), nil, nil + } + if ref == "" { + return utils.NewToolResultError("ref is required for run_workflow action"), nil, nil + } + } else if runID == 0 { + return utils.NewToolResultError("missing required parameter: run_id"), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + switch method { + case actionsMethodRunWorkflow: + return runWorkflow(ctx, client, owner, repo, workflowID, ref, inputs) + case actionsMethodRerunWorkflowRun: + return rerunWorkflowRun(ctx, client, owner, repo, int64(runID)) + case actionsMethodRerunFailedJobs: + return rerunFailedJobs(ctx, client, owner, repo, int64(runID)) + case actionsMethodCancelWorkflowRun: + return cancelWorkflowRun(ctx, client, owner, repo, int64(runID)) + case actionsMethodDeleteWorkflowRunLogs: + return deleteWorkflowRunLogs(ctx, client, owner, repo, int64(runID)) + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil + } + }, + ) + tool.FeatureFlagEnable = FeatureFlagConsolidatedActions + return tool +} + +// ActionsGetJobLogs returns the tool and handler for getting workflow job logs. +func ActionsGetJobLogs(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := NewTool( + ToolsetMetadataActions, + mcp.Tool{ + Name: "get_job_logs", + Description: t("TOOL_GET_JOB_LOGS_CONSOLIDATED_DESCRIPTION", `Get logs for GitHub Actions workflow jobs. +Use this tool to retrieve logs for a specific job or all failed jobs in a workflow run. +For single job logs, provide job_id. For all failed jobs in a run, provide run_id with failed_only=true. +`), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_JOB_LOGS_CONSOLIDATED_USER_TITLE", "Get GitHub Actions workflow job logs"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "job_id": { + Type: "number", + Description: "The unique identifier of the workflow job. Required when getting logs for a single job.", + }, + "run_id": { + Type: "number", + Description: "The unique identifier of the workflow run. Required when failed_only is true to get logs for all failed jobs in the run.", + }, + "failed_only": { + Type: "boolean", + Description: "When true, gets logs for all failed jobs in the workflow run specified by run_id. Requires run_id to be provided.", + }, + "return_content": { + Type: "boolean", + Description: "Returns actual log content instead of URLs", + }, + "tail_lines": { + Type: "number", + Description: "Number of lines to return from the end of the log", + Default: json.RawMessage(`500`), + }, + }, + Required: []string{"owner", "repo"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + jobID, err := OptionalIntParam(args, "job_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + runID, err := OptionalIntParam(args, "run_id") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + failedOnly, err := OptionalParam[bool](args, "failed_only") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + returnContent, err := OptionalParam[bool](args, "return_content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + tailLines, err := OptionalIntParam(args, "tail_lines") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // Default to 500 lines if not specified + if tailLines == 0 { + tailLines = 500 + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + // Validate parameters + if failedOnly && runID == 0 { + return utils.NewToolResultError("run_id is required when failed_only is true"), nil, nil + } + if !failedOnly && jobID == 0 { + return utils.NewToolResultError("job_id is required when failed_only is false"), nil, nil + } + + if failedOnly && runID > 0 { + // Handle failed-only mode: get logs for all failed jobs in the workflow run + return handleFailedJobLogs(ctx, client, owner, repo, int64(runID), returnContent, tailLines, deps.GetContentWindowSize()) + } else if jobID > 0 { + // Handle single job mode + return handleSingleJobLogs(ctx, client, owner, repo, int64(jobID), returnContent, tailLines, deps.GetContentWindowSize()) + } + + return utils.NewToolResultError("Either job_id must be provided for single job logs, or run_id with failed_only=true for failed job logs"), nil, nil + }, + ) + tool.FeatureFlagEnable = FeatureFlagConsolidatedActions + return tool +} + +// Helper functions for consolidated actions tools + +func getWorkflow(ctx context.Context, client *github.Client, owner, repo, resourceID string) (*mcp.CallToolResult, any, error) { + var workflow *github.Workflow + var resp *github.Response + var err error + + if workflowIDInt, parseErr := strconv.ParseInt(resourceID, 10, 64); parseErr == nil { + workflow, resp, err = client.Actions.GetWorkflowByID(ctx, owner, repo, workflowIDInt) + } else { + workflow, resp, err = client.Actions.GetWorkflowByFileName(ctx, owner, repo, resourceID) + } + + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow", resp, err), nil, nil + } + + defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(workflow) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal workflow: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func getWorkflowRun(ctx context.Context, client *github.Client, owner, repo string, resourceID int64) (*mcp.CallToolResult, any, error) { + workflowRun, resp, err := client.Actions.GetWorkflowRunByID(ctx, owner, repo, resourceID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(workflowRun) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal workflow run: %w", err) + } + return utils.NewToolResultText(string(r)), nil, nil +} + +func getWorkflowJob(ctx context.Context, client *github.Client, owner, repo string, resourceID int64) (*mcp.CallToolResult, any, error) { + workflowJob, resp, err := client.Actions.GetWorkflowJobByID(ctx, owner, repo, resourceID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow job", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(workflowJob) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal workflow job: %w", err) + } + return utils.NewToolResultText(string(r)), nil, nil +} + +func listWorkflows(ctx context.Context, client *github.Client, owner, repo string, pagination PaginationParams) (*mcp.CallToolResult, any, error) { + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } + + workflows, resp, err := client.Actions.ListWorkflows(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflows", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + r, err := json.Marshal(workflows) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal workflows: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func listWorkflowRuns(ctx context.Context, client *github.Client, args map[string]any, owner, repo, resourceID string, pagination PaginationParams) (*mcp.CallToolResult, any, error) { + filterArgs, err := OptionalParam[map[string]any](args, "workflow_runs_filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filterArgsTyped := make(map[string]string) + for k, v := range filterArgs { + if strVal, ok := v.(string); ok { + filterArgsTyped[k] = strVal + } else { + filterArgsTyped[k] = "" + } + } + + listWorkflowRunsOptions := &github.ListWorkflowRunsOptions{ + Actor: filterArgsTyped["actor"], + Branch: filterArgsTyped["branch"], + Event: filterArgsTyped["event"], + Status: filterArgsTyped["status"], + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + + var workflowRuns *github.WorkflowRuns + var resp *github.Response + + if workflowIDInt, parseErr := strconv.ParseInt(resourceID, 10, 64); parseErr == nil { + workflowRuns, resp, err = client.Actions.ListWorkflowRunsByID(ctx, owner, repo, workflowIDInt, listWorkflowRunsOptions) + } else { + workflowRuns, resp, err = client.Actions.ListWorkflowRunsByFileName(ctx, owner, repo, resourceID, listWorkflowRunsOptions) + } + + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow runs", resp, err), nil, nil + } + + defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(workflowRuns) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal workflow runs: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func listWorkflowJobs(ctx context.Context, client *github.Client, args map[string]any, owner, repo string, resourceID int64, pagination PaginationParams) (*mcp.CallToolResult, any, error) { + filterArgs, err := OptionalParam[map[string]any](args, "workflow_jobs_filter") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + filterArgsTyped := make(map[string]string) + for k, v := range filterArgs { + if strVal, ok := v.(string); ok { + filterArgsTyped[k] = strVal + } else { + filterArgsTyped[k] = "" + } + } + + workflowJobs, resp, err := client.Actions.ListWorkflowJobs(ctx, owner, repo, resourceID, &github.ListWorkflowJobsOptions{ + Filter: filterArgsTyped["filter"], + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + }) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow jobs", resp, err), nil, nil + } + + response := map[string]any{ + "jobs": workflowJobs, + } + + defer func() { _ = resp.Body.Close() }() + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal workflow jobs: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func listWorkflowArtifacts(ctx context.Context, client *github.Client, owner, repo string, resourceID int64, pagination PaginationParams) (*mcp.CallToolResult, any, error) { + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } + + artifacts, resp, err := client.Actions.ListWorkflowRunArtifacts(ctx, owner, repo, resourceID, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list workflow run artifacts", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + r, err := json.Marshal(artifacts) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func downloadWorkflowArtifact(ctx context.Context, client *github.Client, owner, repo string, resourceID int64) (*mcp.CallToolResult, any, error) { + // Get the download URL for the artifact + url, resp, err := client.Actions.DownloadArtifact(ctx, owner, repo, resourceID, 1) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get artifact download URL", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Create response with the download URL and information + result := map[string]any{ + "download_url": url.String(), + "message": "Artifact is available for download", + "note": "The download_url provides a download link for the artifact as a ZIP archive. The link is temporary and expires after a short time.", + "artifact_id": resourceID, + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func getWorkflowRunLogsURL(ctx context.Context, client *github.Client, owner, repo string, runID int64) (*mcp.CallToolResult, any, error) { + // Get the download URL for the logs + url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, owner, repo, runID, 1) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run logs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + // Create response with the logs URL and information + result := map[string]any{ + "logs_url": url.String(), + "message": "Workflow run logs are available for download", + "note": "The logs_url provides a download link for the complete workflow run logs as a ZIP archive. You can download this archive to extract and examine individual job logs.", + "warning": "This downloads ALL logs as a ZIP file which can be large and expensive. For debugging failed jobs, consider using get_job_logs with failed_only=true and run_id instead.", + "optimization_tip": "Use: get_job_logs with parameters {run_id: " + fmt.Sprintf("%d", runID) + ", failed_only: true} for more efficient failed job debugging", + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func getWorkflowRunUsage(ctx context.Context, client *github.Client, owner, repo string, resourceID int64) (*mcp.CallToolResult, any, error) { + usage, resp, err := client.Actions.GetWorkflowRunUsageByID(ctx, owner, repo, resourceID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get workflow run usage", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + r, err := json.Marshal(usage) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func runWorkflow(ctx context.Context, client *github.Client, owner, repo, workflowID, ref string, inputs map[string]interface{}) (*mcp.CallToolResult, any, error) { + event := github.CreateWorkflowDispatchEventRequest{ + Ref: ref, + Inputs: inputs, + } + + var resp *github.Response + var err error + var workflowType string + + if workflowIDInt, parseErr := strconv.ParseInt(workflowID, 10, 64); parseErr == nil { + resp, err = client.Actions.CreateWorkflowDispatchEventByID(ctx, owner, repo, workflowIDInt, event) + workflowType = "workflow_id" + } else { + resp, err = client.Actions.CreateWorkflowDispatchEventByFileName(ctx, owner, repo, workflowID, event) + workflowType = "workflow_file" + } + + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to run workflow", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Workflow run has been queued", + "workflow_type": workflowType, + "workflow_id": workflowID, + "ref": ref, + "inputs": inputs, + "status": resp.Status, + "status_code": resp.StatusCode, + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func rerunWorkflowRun(ctx context.Context, client *github.Client, owner, repo string, runID int64) (*mcp.CallToolResult, any, error) { + resp, err := client.Actions.RerunWorkflowByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun workflow run", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Workflow run has been queued for re-run", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func rerunFailedJobs(ctx context.Context, client *github.Client, owner, repo string, runID int64) (*mcp.CallToolResult, any, error) { + resp, err := client.Actions.RerunFailedJobsByID(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to rerun failed jobs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Failed jobs have been queued for re-run", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func cancelWorkflowRun(ctx context.Context, client *github.Client, owner, repo string, runID int64) (*mcp.CallToolResult, any, error) { + resp, err := client.Actions.CancelWorkflowRunByID(ctx, owner, repo, runID) + if err != nil { + var acceptedErr *github.AcceptedError + if !errors.As(err, &acceptedErr) { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to cancel workflow run", resp, err), nil, nil + } + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Workflow run has been cancelled", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil +} + +func deleteWorkflowRunLogs(ctx context.Context, client *github.Client, owner, repo string, runID int64) (*mcp.CallToolResult, any, error) { + resp, err := client.Actions.DeleteWorkflowRunLogs(ctx, owner, repo, runID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to delete workflow run logs", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + result := map[string]any{ + "message": "Workflow run logs have been deleted", + "run_id": runID, + "status": resp.Status, + "status_code": resp.StatusCode, + } + + r, err := json.Marshal(result) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil } diff --git a/pkg/github/actions_test.go b/pkg/github/actions_test.go index 6d9921f2e..f2d336e21 100644 --- a/pkg/github/actions_test.go +++ b/pkg/github/actions_test.go @@ -25,13 +25,12 @@ import ( func Test_ListWorkflows(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflows(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListWorkflows(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_workflows", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_workflows", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "perPage") @@ -47,44 +46,41 @@ func Test_ListWorkflows(t *testing.T) { }{ { name: "successful workflow listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsWorkflowsByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - workflows := &github.Workflows{ - TotalCount: github.Ptr(2), - Workflows: []*github.Workflow{ - { - ID: github.Ptr(int64(123)), - Name: github.Ptr("CI"), - Path: github.Ptr(".github/workflows/ci.yml"), - State: github.Ptr("active"), - CreatedAt: &github.Timestamp{}, - UpdatedAt: &github.Timestamp{}, - URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/workflows/123"), - HTMLURL: github.Ptr("https://github.com/owner/repo/actions/workflows/ci.yml"), - BadgeURL: github.Ptr("https://github.com/owner/repo/workflows/CI/badge.svg"), - NodeID: github.Ptr("W_123"), - }, - { - ID: github.Ptr(int64(456)), - Name: github.Ptr("Deploy"), - Path: github.Ptr(".github/workflows/deploy.yml"), - State: github.Ptr("active"), - CreatedAt: &github.Timestamp{}, - UpdatedAt: &github.Timestamp{}, - URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/workflows/456"), - HTMLURL: github.Ptr("https://github.com/owner/repo/actions/workflows/deploy.yml"), - BadgeURL: github.Ptr("https://github.com/owner/repo/workflows/Deploy/badge.svg"), - NodeID: github.Ptr("W_456"), - }, + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsWorkflowsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + workflows := &github.Workflows{ + TotalCount: github.Ptr(2), + Workflows: []*github.Workflow{ + { + ID: github.Ptr(int64(123)), + Name: github.Ptr("CI"), + Path: github.Ptr(".github/workflows/ci.yml"), + State: github.Ptr("active"), + CreatedAt: &github.Timestamp{}, + UpdatedAt: &github.Timestamp{}, + URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/workflows/123"), + HTMLURL: github.Ptr("https://github.com/owner/repo/actions/workflows/ci.yml"), + BadgeURL: github.Ptr("https://github.com/owner/repo/workflows/CI/badge.svg"), + NodeID: github.Ptr("W_123"), }, - } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(workflows) - }), - ), - ), + { + ID: github.Ptr(int64(456)), + Name: github.Ptr("Deploy"), + Path: github.Ptr(".github/workflows/deploy.yml"), + State: github.Ptr("active"), + CreatedAt: &github.Timestamp{}, + UpdatedAt: &github.Timestamp{}, + URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/workflows/456"), + HTMLURL: github.Ptr("https://github.com/owner/repo/actions/workflows/deploy.yml"), + BadgeURL: github.Ptr("https://github.com/owner/repo/workflows/Deploy/badge.svg"), + NodeID: github.Ptr("W_456"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(workflows) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -93,7 +89,7 @@ func Test_ListWorkflows(t *testing.T) { }, { name: "missing required parameter owner", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "repo": "repo", }, @@ -106,13 +102,16 @@ func Test_ListWorkflows(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListWorkflows(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -138,18 +137,17 @@ func Test_ListWorkflows(t *testing.T) { func Test_RunWorkflow(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := RunWorkflow(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - assert.Equal(t, "run_workflow", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "workflow_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "ref") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "inputs") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "workflow_id", "ref"}) + toolDef := RunWorkflow(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "run_workflow", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "workflow_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "ref") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "inputs") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "workflow_id", "ref"}) tests := []struct { name string @@ -160,14 +158,11 @@ func Test_RunWorkflow(t *testing.T) { }{ { name: "successful workflow run", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNoContent) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -178,7 +173,7 @@ func Test_RunWorkflow(t *testing.T) { }, { name: "missing required parameter workflow_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -193,13 +188,16 @@ func Test_RunWorkflow(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := RunWorkflow(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -224,6 +222,8 @@ func Test_RunWorkflow(t *testing.T) { func Test_RunWorkflow_WithFilename(t *testing.T) { // Test the unified RunWorkflow function with filenames + toolDef := RunWorkflow(translations.NullTranslationHelper) + tests := []struct { name string mockedClient *http.Client @@ -233,14 +233,11 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { }{ { name: "successful workflow run by filename", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNoContent) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -251,14 +248,11 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { }, { name: "successful workflow run by numeric ID as string", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNoContent) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -269,7 +263,7 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { }, { name: "missing required parameter workflow_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -284,13 +278,16 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := RunWorkflow(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -315,16 +312,15 @@ func Test_RunWorkflow_WithFilename(t *testing.T) { func Test_CancelWorkflowRun(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CancelWorkflowRun(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := CancelWorkflowRun(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "cancel_workflow_run", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + assert.Equal(t, "cancel_workflow_run", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -335,17 +331,11 @@ func Test_CancelWorkflowRun(t *testing.T) { }{ { name: "successful workflow run cancellation", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.EndpointPattern{ - Pattern: "/repos/owner/repo/actions/runs/12345/cancel", - Method: "POST", - }, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusAccepted) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + "POST /repos/owner/repo/actions/runs/12345/cancel": http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -355,17 +345,11 @@ func Test_CancelWorkflowRun(t *testing.T) { }, { name: "conflict when cancelling a workflow run", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.EndpointPattern{ - Pattern: "/repos/owner/repo/actions/runs/12345/cancel", - Method: "POST", - }, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusConflict) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + "POST /repos/owner/repo/actions/runs/12345/cancel": http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusConflict) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -376,7 +360,7 @@ func Test_CancelWorkflowRun(t *testing.T) { }, { name: "missing required parameter run_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -390,13 +374,16 @@ func Test_CancelWorkflowRun(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CancelWorkflowRun(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -421,18 +408,17 @@ func Test_CancelWorkflowRun(t *testing.T) { func Test_ListWorkflowRunArtifacts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflowRunArtifacts(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - assert.Equal(t, "list_workflow_run_artifacts", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "perPage") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "page") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + toolDef := ListWorkflowRunArtifacts(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "list_workflow_run_artifacts", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "perPage") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "page") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -443,58 +429,55 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { }{ { name: "successful artifacts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsRunsArtifactsByOwnerByRepoByRunId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - artifacts := &github.ArtifactList{ - TotalCount: github.Ptr(int64(2)), - Artifacts: []*github.Artifact{ - { - ID: github.Ptr(int64(1)), - NodeID: github.Ptr("A_1"), - Name: github.Ptr("build-artifacts"), - SizeInBytes: github.Ptr(int64(1024)), - URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/1"), - ArchiveDownloadURL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/1/zip"), - Expired: github.Ptr(false), - CreatedAt: &github.Timestamp{}, - UpdatedAt: &github.Timestamp{}, - ExpiresAt: &github.Timestamp{}, - WorkflowRun: &github.ArtifactWorkflowRun{ - ID: github.Ptr(int64(12345)), - RepositoryID: github.Ptr(int64(1)), - HeadRepositoryID: github.Ptr(int64(1)), - HeadBranch: github.Ptr("main"), - HeadSHA: github.Ptr("abc123"), - }, + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsRunsArtifactsByOwnerByRepoByRunID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + artifacts := &github.ArtifactList{ + TotalCount: github.Ptr(int64(2)), + Artifacts: []*github.Artifact{ + { + ID: github.Ptr(int64(1)), + NodeID: github.Ptr("A_1"), + Name: github.Ptr("build-artifacts"), + SizeInBytes: github.Ptr(int64(1024)), + URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/1"), + ArchiveDownloadURL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/1/zip"), + Expired: github.Ptr(false), + CreatedAt: &github.Timestamp{}, + UpdatedAt: &github.Timestamp{}, + ExpiresAt: &github.Timestamp{}, + WorkflowRun: &github.ArtifactWorkflowRun{ + ID: github.Ptr(int64(12345)), + RepositoryID: github.Ptr(int64(1)), + HeadRepositoryID: github.Ptr(int64(1)), + HeadBranch: github.Ptr("main"), + HeadSHA: github.Ptr("abc123"), }, - { - ID: github.Ptr(int64(2)), - NodeID: github.Ptr("A_2"), - Name: github.Ptr("test-results"), - SizeInBytes: github.Ptr(int64(512)), - URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/2"), - ArchiveDownloadURL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/2/zip"), - Expired: github.Ptr(false), - CreatedAt: &github.Timestamp{}, - UpdatedAt: &github.Timestamp{}, - ExpiresAt: &github.Timestamp{}, - WorkflowRun: &github.ArtifactWorkflowRun{ - ID: github.Ptr(int64(12345)), - RepositoryID: github.Ptr(int64(1)), - HeadRepositoryID: github.Ptr(int64(1)), - HeadBranch: github.Ptr("main"), - HeadSHA: github.Ptr("abc123"), - }, + }, + { + ID: github.Ptr(int64(2)), + NodeID: github.Ptr("A_2"), + Name: github.Ptr("test-results"), + SizeInBytes: github.Ptr(int64(512)), + URL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/2"), + ArchiveDownloadURL: github.Ptr("https://api.github.com/repos/owner/repo/actions/artifacts/2/zip"), + Expired: github.Ptr(false), + CreatedAt: &github.Timestamp{}, + UpdatedAt: &github.Timestamp{}, + ExpiresAt: &github.Timestamp{}, + WorkflowRun: &github.ArtifactWorkflowRun{ + ID: github.Ptr(int64(12345)), + RepositoryID: github.Ptr(int64(1)), + HeadRepositoryID: github.Ptr(int64(1)), + HeadBranch: github.Ptr("main"), + HeadSHA: github.Ptr("abc123"), }, }, - } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(artifacts) - }), - ), - ), + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(artifacts) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -504,7 +487,7 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { }, { name: "missing required parameter run_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -518,13 +501,16 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListWorkflowRunArtifacts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -550,16 +536,15 @@ func Test_ListWorkflowRunArtifacts(t *testing.T) { func Test_DownloadWorkflowRunArtifact(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := DownloadWorkflowRunArtifact(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := DownloadWorkflowRunArtifact(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "download_workflow_run_artifact", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "artifact_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "artifact_id"}) + assert.Equal(t, "download_workflow_run_artifact", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "artifact_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "artifact_id"}) tests := []struct { name string @@ -570,19 +555,13 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { }{ { name: "successful artifact download URL", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.EndpointPattern{ - Pattern: "/repos/owner/repo/actions/artifacts/123/zip", - Method: "GET", - }, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - // GitHub returns a 302 redirect to the download URL - w.Header().Set("Location", "https://api.github.com/repos/owner/repo/actions/artifacts/123/download") - w.WriteHeader(http.StatusFound) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + "GET /repos/owner/repo/actions/artifacts/123/zip": http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // GitHub returns a 302 redirect to the download URL + w.Header().Set("Location", "https://api.github.com/repos/owner/repo/actions/artifacts/123/download") + w.WriteHeader(http.StatusFound) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -592,7 +571,7 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { }, { name: "missing required parameter artifact_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -606,13 +585,16 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := DownloadWorkflowRunArtifact(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -639,16 +621,15 @@ func Test_DownloadWorkflowRunArtifact(t *testing.T) { func Test_DeleteWorkflowRunLogs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := DeleteWorkflowRunLogs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := DeleteWorkflowRunLogs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "delete_workflow_run_logs", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + assert.Equal(t, "delete_workflow_run_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -659,14 +640,11 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { }{ { name: "successful logs deletion", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.DeleteReposActionsRunsLogsByOwnerByRepoByRunId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNoContent) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + DeleteReposActionsRunsLogsByOwnerByRepoByRunID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -676,7 +654,7 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { }, { name: "missing required parameter run_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -690,13 +668,16 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := DeleteWorkflowRunLogs(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -721,16 +702,15 @@ func Test_DeleteWorkflowRunLogs(t *testing.T) { func Test_GetWorkflowRunUsage(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetWorkflowRunUsage(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetWorkflowRunUsage(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_workflow_run_usage", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) + assert.Equal(t, "get_workflow_run_usage", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo", "run_id"}) tests := []struct { name string @@ -741,34 +721,31 @@ func Test_GetWorkflowRunUsage(t *testing.T) { }{ { name: "successful workflow run usage", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsRunsTimingByOwnerByRepoByRunId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - usage := &github.WorkflowRunUsage{ - Billable: &github.WorkflowRunBillMap{ - "UBUNTU": &github.WorkflowRunBill{ - TotalMS: github.Ptr(int64(120000)), - Jobs: github.Ptr(2), - JobRuns: []*github.WorkflowRunJobRun{ - { - JobID: github.Ptr(1), - DurationMS: github.Ptr(int64(60000)), - }, - { - JobID: github.Ptr(2), - DurationMS: github.Ptr(int64(60000)), - }, + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsRunsTimingByOwnerByRepoByRunID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + usage := &github.WorkflowRunUsage{ + Billable: &github.WorkflowRunBillMap{ + "UBUNTU": &github.WorkflowRunBill{ + TotalMS: github.Ptr(int64(120000)), + Jobs: github.Ptr(2), + JobRuns: []*github.WorkflowRunJobRun{ + { + JobID: github.Ptr(1), + DurationMS: github.Ptr(int64(60000)), + }, + { + JobID: github.Ptr(2), + DurationMS: github.Ptr(int64(60000)), }, }, }, - RunDurationMS: github.Ptr(int64(120000)), - } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(usage) - }), - ), - ), + }, + RunDurationMS: github.Ptr(int64(120000)), + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(usage) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -778,7 +755,7 @@ func Test_GetWorkflowRunUsage(t *testing.T) { }, { name: "missing required parameter run_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -792,13 +769,16 @@ func Test_GetWorkflowRunUsage(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetWorkflowRunUsage(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -823,19 +803,18 @@ func Test_GetWorkflowRunUsage(t *testing.T) { func Test_GetJobLogs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetJobLogs(stubGetClientFn(mockClient), translations.NullTranslationHelper, 5000) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - assert.Equal(t, "get_job_logs", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "owner") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "repo") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "job_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "failed_only") - assert.Contains(t, tool.InputSchema.(*jsonschema.Schema).Properties, "return_content") - assert.ElementsMatch(t, tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo"}) + toolDef := GetJobLogs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "get_job_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "owner") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "repo") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "job_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "run_id") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "failed_only") + assert.Contains(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Properties, "return_content") + assert.ElementsMatch(t, toolDef.Tool.InputSchema.(*jsonschema.Schema).Required, []string{"owner", "repo"}) tests := []struct { name string @@ -847,15 +826,12 @@ func Test_GetJobLogs(t *testing.T) { }{ { name: "successful single job logs with URL", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Location", "https://github.com/logs/job/123") - w.WriteHeader(http.StatusFound) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsJobsLogsByOwnerByRepoByJobID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "https://github.com/logs/job/123") + w.WriteHeader(http.StatusFound) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -871,42 +847,36 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "successful failed jobs logs", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsRunsJobsByOwnerByRepoByRunId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - jobs := &github.Jobs{ - TotalCount: github.Ptr(3), - Jobs: []*github.WorkflowJob{ - { - ID: github.Ptr(int64(1)), - Name: github.Ptr("test-job-1"), - Conclusion: github.Ptr("success"), - }, - { - ID: github.Ptr(int64(2)), - Name: github.Ptr("test-job-2"), - Conclusion: github.Ptr("failure"), - }, - { - ID: github.Ptr(int64(3)), - Name: github.Ptr("test-job-3"), - Conclusion: github.Ptr("failure"), - }, + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsRunsJobsByOwnerByRepoByRunID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + jobs := &github.Jobs{ + TotalCount: github.Ptr(3), + Jobs: []*github.WorkflowJob{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("test-job-1"), + Conclusion: github.Ptr("success"), }, - } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(jobs) - }), - ), - mock.WithRequestMatchHandler( - mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Location", "https://github.com/logs/job/"+r.URL.Path[len(r.URL.Path)-1:]) - w.WriteHeader(http.StatusFound) - }), - ), - ), + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("test-job-2"), + Conclusion: github.Ptr("failure"), + }, + { + ID: github.Ptr(int64(3)), + Name: github.Ptr("test-job-3"), + Conclusion: github.Ptr("failure"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(jobs) + }), + GetReposActionsJobsLogsByOwnerByRepoByJobID: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", "https://github.com/logs/job/"+r.URL.Path[len(r.URL.Path)-1:]) + w.WriteHeader(http.StatusFound) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -928,30 +898,27 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "no failed jobs found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsRunsJobsByOwnerByRepoByRunId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - jobs := &github.Jobs{ - TotalCount: github.Ptr(2), - Jobs: []*github.WorkflowJob{ - { - ID: github.Ptr(int64(1)), - Name: github.Ptr("test-job-1"), - Conclusion: github.Ptr("success"), - }, - { - ID: github.Ptr(int64(2)), - Name: github.Ptr("test-job-2"), - Conclusion: github.Ptr("success"), - }, + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsRunsJobsByOwnerByRepoByRunID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + jobs := &github.Jobs{ + TotalCount: github.Ptr(2), + Jobs: []*github.WorkflowJob{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("test-job-1"), + Conclusion: github.Ptr("success"), }, - } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(jobs) - }), - ), - ), + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("test-job-2"), + Conclusion: github.Ptr("success"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(jobs) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -968,7 +935,7 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "missing job_id when not using failed_only", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -978,7 +945,7 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "missing run_id when using failed_only", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -989,7 +956,7 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "missing required parameter owner", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "repo": "repo", "job_id": float64(123), @@ -999,7 +966,7 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "missing required parameter repo", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]any{ "owner": "owner", "job_id": float64(123), @@ -1009,17 +976,14 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "API error when getting single job logs", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _ = json.NewEncoder(w).Encode(map[string]string{ - "message": "Not Found", - }) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsJobsLogsByOwnerByRepoByJobID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{ + "message": "Not Found", + }) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -1029,17 +993,14 @@ func Test_GetJobLogs(t *testing.T) { }, { name: "API error when listing workflow jobs for failed_only", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsRunsJobsByOwnerByRepoByRunId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _ = json.NewEncoder(w).Encode(map[string]string{ - "message": "Not Found", - }) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsRunsJobsByOwnerByRepoByRunID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{ + "message": "Not Found", + }) + }), + }), requestArgs: map[string]any{ "owner": "owner", "repo": "repo", @@ -1054,13 +1015,17 @@ func Test_GetJobLogs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.Equal(t, tc.expectError, result.IsError) @@ -1102,18 +1067,20 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { })) defer testServer.Close() - mockedClient := mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Location", testServer.URL) - w.WriteHeader(http.StatusFound) - }), - ), - ) + mockedClient := MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsJobsLogsByOwnerByRepoByJobID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", testServer.URL) + w.WriteHeader(http.StatusFound) + }), + }) client := github.NewClient(mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + toolDef := GetJobLogs(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) request := createMCPRequest(map[string]any{ "owner": "owner", @@ -1121,14 +1088,8 @@ func Test_GetJobLogs_WithContentReturn(t *testing.T) { "job_id": float64(123), "return_content": true, }) - args := map[string]any{ - "owner": "owner", - "repo": "repo", - "job_id": float64(123), - "return_content": true, - } - result, _, err := handler(context.Background(), &request, args) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1155,18 +1116,20 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { })) defer testServer.Close() - mockedClient := mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Location", testServer.URL) - w.WriteHeader(http.StatusFound) - }), - ), - ) + mockedClient := MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsJobsLogsByOwnerByRepoByJobID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", testServer.URL) + w.WriteHeader(http.StatusFound) + }), + }) client := github.NewClient(mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + toolDef := GetJobLogs(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) request := createMCPRequest(map[string]any{ "owner": "owner", @@ -1175,15 +1138,8 @@ func Test_GetJobLogs_WithContentReturnAndTailLines(t *testing.T) { "return_content": true, "tail_lines": float64(1), // Requesting last 1 line }) - args := map[string]any{ - "owner": "owner", - "repo": "repo", - "job_id": float64(123), - "return_content": true, - "tail_lines": float64(1), - } - result, _, err := handler(context.Background(), &request, args) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1209,18 +1165,20 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { })) defer testServer.Close() - mockedClient := mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Location", testServer.URL) - w.WriteHeader(http.StatusFound) - }), - ), - ) + mockedClient := MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposActionsJobsLogsByOwnerByRepoByJobID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", testServer.URL) + w.WriteHeader(http.StatusFound) + }), + }) client := github.NewClient(mockedClient) - _, handler := GetJobLogs(stubGetClientFn(client), translations.NullTranslationHelper, 5000) + toolDef := GetJobLogs(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) request := createMCPRequest(map[string]any{ "owner": "owner", @@ -1229,15 +1187,8 @@ func Test_GetJobLogs_WithContentReturnAndLargeTailLines(t *testing.T) { "return_content": true, "tail_lines": float64(100), }) - args := map[string]any{ - "owner": "owner", - "repo": "repo", - "job_id": float64(123), - "return_content": true, - "tail_lines": float64(100), - } - result, _, err := handler(context.Background(), &request, args) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -1353,13 +1304,12 @@ func Test_MemoryUsage_SlidingWindow_vs_NoWindow(t *testing.T) { func Test_ListWorkflowRuns(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflowRuns(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListWorkflowRuns(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_workflow_runs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_workflow_runs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "workflow_id") @@ -1368,13 +1318,12 @@ func Test_ListWorkflowRuns(t *testing.T) { func Test_GetWorkflowRun(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetWorkflowRun(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetWorkflowRun(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_workflow_run", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_workflow_run", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1383,13 +1332,12 @@ func Test_GetWorkflowRun(t *testing.T) { func Test_GetWorkflowRunLogs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetWorkflowRunLogs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetWorkflowRunLogs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_workflow_run_logs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_workflow_run_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1398,13 +1346,12 @@ func Test_GetWorkflowRunLogs(t *testing.T) { func Test_ListWorkflowJobs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListWorkflowJobs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListWorkflowJobs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_workflow_jobs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_workflow_jobs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1413,13 +1360,12 @@ func Test_ListWorkflowJobs(t *testing.T) { func Test_RerunWorkflowRun(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := RerunWorkflowRun(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := RerunWorkflowRun(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "rerun_workflow_run", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "rerun_workflow_run", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") @@ -1428,15 +1374,1136 @@ func Test_RerunWorkflowRun(t *testing.T) { func Test_RerunFailedJobs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := RerunFailedJobs(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := RerunFailedJobs(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "rerun_failed_jobs", tool.Name) - assert.NotEmpty(t, tool.Description) - inputSchema := tool.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "rerun_failed_jobs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") assert.Contains(t, inputSchema.Properties, "run_id") assert.ElementsMatch(t, inputSchema.Required, []string{"owner", "repo", "run_id"}) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful rerun of failed jobs", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/actions/runs/12345/rerun-failed-jobs", + Method: "POST", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + }), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }, + expectError: false, + }, + { + name: "missing required parameter run_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: run_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "Failed jobs have been queued for re-run", response["message"]) + assert.Equal(t, float64(12345), response["run_id"]) + }) + } +} + +func Test_RerunWorkflowRun_Behavioral(t *testing.T) { + toolDef := RerunWorkflowRun(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful rerun of workflow run", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/actions/runs/12345/rerun", + Method: "POST", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + }), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }, + expectError: false, + }, + { + name: "missing required parameter run_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: run_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "Workflow run has been queued for re-run", response["message"]) + assert.Equal(t, float64(12345), response["run_id"]) + }) + } +} + +func Test_ListWorkflowRuns_Behavioral(t *testing.T) { + toolDef := ListWorkflowRuns(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful workflow runs listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsWorkflowsRunsByOwnerByRepoByWorkflowId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + runs := &github.WorkflowRuns{ + TotalCount: github.Ptr(2), + WorkflowRuns: []*github.WorkflowRun{ + { + ID: github.Ptr(int64(123)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("success"), + }, + { + ID: github.Ptr(int64(456)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("failure"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(runs) + }), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "workflow_id": "ci.yml", + }, + expectError: false, + }, + { + name: "missing required parameter workflow_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: workflow_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response github.WorkflowRuns + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.NotNil(t, response.TotalCount) + assert.Greater(t, *response.TotalCount, 0) + }) + } +} + +func Test_GetWorkflowRun_Behavioral(t *testing.T) { + toolDef := GetWorkflowRun(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful get workflow run", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsRunsByOwnerByRepoByRunId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + run := &github.WorkflowRun{ + ID: github.Ptr(int64(12345)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("success"), + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(run) + }), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }, + expectError: false, + }, + { + name: "missing required parameter run_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: run_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response github.WorkflowRun + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.NotNil(t, response.ID) + assert.Equal(t, int64(12345), *response.ID) + }) + } +} + +func Test_GetWorkflowRunLogs_Behavioral(t *testing.T) { + toolDef := GetWorkflowRunLogs(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful get workflow run logs", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsRunsLogsByOwnerByRepoByRunId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "https://github.com/logs/run/12345") + w.WriteHeader(http.StatusFound) + }), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }, + expectError: false, + }, + { + name: "missing required parameter run_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: run_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Contains(t, response, "logs_url") + assert.Equal(t, "Workflow run logs are available for download", response["message"]) + }) + } +} + +func Test_ListWorkflowJobs_Behavioral(t *testing.T) { + toolDef := ListWorkflowJobs(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful list workflow jobs", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsRunsJobsByOwnerByRepoByRunId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + jobs := &github.Jobs{ + TotalCount: github.Ptr(2), + Jobs: []*github.WorkflowJob{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("build"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("success"), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("test"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("failure"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(jobs) + }), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }, + expectError: false, + }, + { + name: "missing required parameter run_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: run_id", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Contains(t, response, "jobs") + }) + } +} + +// Tests for consolidated actions tools + +func Test_ActionsList(t *testing.T) { + // Verify tool definition once + toolDef := ActionsList(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "actions_list", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, inputSchema.Properties, "method") + assert.Contains(t, inputSchema.Properties, "owner") + assert.Contains(t, inputSchema.Properties, "repo") + assert.ElementsMatch(t, inputSchema.Required, []string{"method", "owner", "repo"}) +} + +func Test_ActionsList_ListWorkflows(t *testing.T) { + toolDef := ActionsList(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful workflow list", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsWorkflowsByOwnerByRepo, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + workflows := &github.Workflows{ + TotalCount: github.Ptr(2), + Workflows: []*github.Workflow{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("CI"), + Path: github.Ptr(".github/workflows/ci.yml"), + State: github.Ptr("active"), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("Deploy"), + Path: github.Ptr(".github/workflows/deploy.yml"), + State: github.Ptr("active"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(workflows) + }), + ), + ), + requestArgs: map[string]any{ + "method": "list_workflows", + "owner": "owner", + "repo": "repo", + }, + expectError: false, + }, + { + name: "missing required parameter method", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + }, + expectError: true, + expectedErrMsg: "missing required parameter: method", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response github.Workflows + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.NotNil(t, response.TotalCount) + assert.Greater(t, *response.TotalCount, 0) + }) + } +} + +func Test_ActionsList_ListWorkflowRuns(t *testing.T) { + toolDef := ActionsList(translations.NullTranslationHelper) + + t.Run("successful workflow runs list", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsWorkflowsRunsByOwnerByRepoByWorkflowId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + runs := &github.WorkflowRuns{ + TotalCount: github.Ptr(1), + WorkflowRuns: []*github.WorkflowRun{ + { + ID: github.Ptr(int64(123)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("success"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(runs) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "list_workflow_runs", + "owner": "owner", + "repo": "repo", + "resource_id": "ci.yml", + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response github.WorkflowRuns + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.NotNil(t, response.TotalCount) + }) + + t.Run("missing resource_id for list_workflow_runs", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient() + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "list_workflow_runs", + "owner": "owner", + "repo": "repo", + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.True(t, result.IsError) + + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, "missing required parameter") + }) +} + +func Test_ActionsGet(t *testing.T) { + // Verify tool definition once + toolDef := ActionsGet(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "actions_get", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, inputSchema.Properties, "method") + assert.Contains(t, inputSchema.Properties, "owner") + assert.Contains(t, inputSchema.Properties, "repo") + assert.Contains(t, inputSchema.Properties, "resource_id") + assert.ElementsMatch(t, inputSchema.Required, []string{"method", "owner", "repo", "resource_id"}) +} + +func Test_ActionsGet_GetWorkflow(t *testing.T) { + toolDef := ActionsGet(translations.NullTranslationHelper) + + t.Run("successful workflow get", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsWorkflowsByOwnerByRepoByWorkflowId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + workflow := &github.Workflow{ + ID: github.Ptr(int64(1)), + Name: github.Ptr("CI"), + Path: github.Ptr(".github/workflows/ci.yml"), + State: github.Ptr("active"), + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(workflow) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "get_workflow", + "owner": "owner", + "repo": "repo", + "resource_id": "ci.yml", + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response github.Workflow + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.NotNil(t, response.ID) + assert.Equal(t, "CI", *response.Name) + }) +} + +func Test_ActionsGet_GetWorkflowRun(t *testing.T) { + toolDef := ActionsGet(translations.NullTranslationHelper) + + t.Run("successful workflow run get", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsRunsByOwnerByRepoByRunId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + run := &github.WorkflowRun{ + ID: github.Ptr(int64(12345)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("success"), + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(run) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "get_workflow_run", + "owner": "owner", + "repo": "repo", + "resource_id": "12345", + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response github.WorkflowRun + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.NotNil(t, response.ID) + assert.Equal(t, int64(12345), *response.ID) + }) +} + +func Test_ActionsRunTrigger(t *testing.T) { + // Verify tool definition once + toolDef := ActionsRunTrigger(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) + + assert.Equal(t, "actions_run_trigger", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, inputSchema.Properties, "method") + assert.Contains(t, inputSchema.Properties, "owner") + assert.Contains(t, inputSchema.Properties, "repo") + assert.Contains(t, inputSchema.Properties, "workflow_id") + assert.Contains(t, inputSchema.Properties, "ref") + assert.Contains(t, inputSchema.Properties, "run_id") + assert.ElementsMatch(t, inputSchema.Required, []string{"method", "owner", "repo"}) +} + +func Test_ActionsRunTrigger_RunWorkflow(t *testing.T) { + toolDef := ActionsRunTrigger(translations.NullTranslationHelper) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + }{ + { + name: "successful workflow run", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }), + ), + ), + requestArgs: map[string]any{ + "method": "run_workflow", + "owner": "owner", + "repo": "repo", + "workflow_id": "12345", + "ref": "main", + }, + expectError: false, + }, + { + name: "missing required parameter workflow_id", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "method": "run_workflow", + "owner": "owner", + "repo": "repo", + "ref": "main", + }, + expectError: true, + expectedErrMsg: "workflow_id is required for run_workflow action", + }, + { + name: "missing required parameter ref", + mockedClient: mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "method": "run_workflow", + "owner": "owner", + "repo": "repo", + "workflow_id": "12345", + }, + expectError: true, + expectedErrMsg: "ref is required for run_workflow action", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.Equal(t, tc.expectError, result.IsError) + + textContent := getTextResult(t, result) + + if tc.expectedErrMsg != "" { + assert.Equal(t, tc.expectedErrMsg, textContent.Text) + return + } + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "Workflow run has been queued", response["message"]) + }) + } +} + +func Test_ActionsRunTrigger_CancelWorkflowRun(t *testing.T) { + toolDef := ActionsRunTrigger(translations.NullTranslationHelper) + + t.Run("successful workflow run cancellation", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/actions/runs/12345/cancel", + Method: "POST", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "cancel_workflow_run", + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "Workflow run has been cancelled", response["message"]) + }) + + t.Run("conflict when cancelling a workflow run", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/actions/runs/12345/cancel", + Method: "POST", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusConflict) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "cancel_workflow_run", + "owner": "owner", + "repo": "repo", + "run_id": float64(12345), + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.True(t, result.IsError) + + textContent := getTextResult(t, result) + assert.Contains(t, textContent.Text, "failed to cancel workflow run") + }) + + t.Run("missing run_id for non-run_workflow methods", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient() + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "cancel_workflow_run", + "owner": "owner", + "repo": "repo", + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.True(t, result.IsError) + + textContent := getTextResult(t, result) + assert.Equal(t, "missing required parameter: run_id", textContent.Text) + }) +} + +func Test_ActionsGetJobLogs(t *testing.T) { + // Verify tool definition once + toolDef := ActionsGetJobLogs(translations.NullTranslationHelper) + + // Note: consolidated ActionsGetJobLogs has same tool name "get_job_logs" as the individual tool + // but with different descriptions. We skip toolsnap validation here since the individual + // tool's toolsnap already exists and is tested in Test_GetJobLogs. + // The consolidated tool has FeatureFlagEnable set, so only one will be active at a time. + assert.Equal(t, "get_job_logs", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) + inputSchema := toolDef.Tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, inputSchema.Properties, "owner") + assert.Contains(t, inputSchema.Properties, "repo") + assert.Contains(t, inputSchema.Properties, "job_id") + assert.Contains(t, inputSchema.Properties, "run_id") + assert.Contains(t, inputSchema.Properties, "failed_only") + assert.Contains(t, inputSchema.Properties, "return_content") + assert.ElementsMatch(t, inputSchema.Required, []string{"owner", "repo"}) +} + +func Test_ActionsGetJobLogs_SingleJob(t *testing.T) { + toolDef := ActionsGetJobLogs(translations.NullTranslationHelper) + + t.Run("successful single job logs with URL", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Location", "https://github.com/logs/job/123") + w.WriteHeader(http.StatusFound) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "job_id": float64(123), + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, float64(123), response["job_id"]) + assert.Contains(t, response, "logs_url") + assert.Equal(t, "Job logs are available for download", response["message"]) + }) +} + +func Test_ActionsGetJobLogs_FailedJobs(t *testing.T) { + toolDef := ActionsGetJobLogs(translations.NullTranslationHelper) + + t.Run("successful failed jobs logs", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsRunsJobsByOwnerByRepoByRunId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + jobs := &github.Jobs{ + TotalCount: github.Ptr(3), + Jobs: []*github.WorkflowJob{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("test-job-1"), + Conclusion: github.Ptr("success"), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("test-job-2"), + Conclusion: github.Ptr("failure"), + }, + { + ID: github.Ptr(int64(3)), + Name: github.Ptr("test-job-3"), + Conclusion: github.Ptr("failure"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(jobs) + }), + ), + mock.WithRequestMatchHandler( + mock.GetReposActionsJobsLogsByOwnerByRepoByJobId, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", "https://github.com/logs/job/"+r.URL.Path[len(r.URL.Path)-1:]) + w.WriteHeader(http.StatusFound) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(456), + "failed_only": true, + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, float64(456), response["run_id"]) + assert.Contains(t, response, "logs") + assert.Contains(t, response["message"], "Retrieved logs for") + }) + + t.Run("no failed jobs found", func(t *testing.T) { + mockedClient := mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposActionsRunsJobsByOwnerByRepoByRunId, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + jobs := &github.Jobs{ + TotalCount: github.Ptr(2), + Jobs: []*github.WorkflowJob{ + { + ID: github.Ptr(int64(1)), + Name: github.Ptr("test-job-1"), + Conclusion: github.Ptr("success"), + }, + { + ID: github.Ptr(int64(2)), + Name: github.Ptr("test-job-2"), + Conclusion: github.Ptr("success"), + }, + }, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(jobs) + }), + ), + ) + + client := github.NewClient(mockedClient) + deps := BaseDeps{ + Client: client, + ContentWindowSize: 5000, + } + handler := toolDef.Handler(deps) + + request := createMCPRequest(map[string]any{ + "owner": "owner", + "repo": "repo", + "run_id": float64(456), + "failed_only": true, + }) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "No failed jobs found in this workflow run", response["message"]) + }) } diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 0f8e2780b..5e25d0501 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -3,11 +3,11 @@ package github import ( "context" "encoding/json" - "fmt" "io" "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,8 +15,10 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetCodeScanningAlert(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataCodeSecurity, + mcp.Tool{ Name: "get_code_scanning_alert", Description: t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -42,7 +44,7 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -56,7 +58,7 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -76,7 +78,7 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil } r, err := json.Marshal(alert) @@ -85,11 +87,14 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListCodeScanningAlerts(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataCodeSecurity, + mcp.Tool{ Name: "list_code_scanning_alerts", Description: t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -130,7 +135,7 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel Required: []string{"owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -156,7 +161,7 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -175,7 +180,7 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil } r, err := json.Marshal(alerts) @@ -184,5 +189,6 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index 13e89fc30..59972fe52 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -10,22 +10,20 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetCodeScanningAlert(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetCodeScanningAlert(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_code_scanning_alert", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "get_code_scanning_alert", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // InputSchema is of type any, need to cast to *jsonschema.Schema - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -50,12 +48,9 @@ func Test_GetCodeScanningAlert(t *testing.T) { }{ { name: "successful alert fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, - mockAlert, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber: mockResponse(t, http.StatusOK, mockAlert), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -66,15 +61,12 @@ func Test_GetCodeScanningAlert(t *testing.T) { }, { name: "alert fetch fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -89,13 +81,16 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler with new signature - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -127,15 +122,14 @@ func Test_GetCodeScanningAlert(t *testing.T) { func Test_ListCodeScanningAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := ListCodeScanningAlerts(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_code_scanning_alerts", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "list_code_scanning_alerts", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // InputSchema is of type any, need to cast to *jsonschema.Schema - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -171,19 +165,16 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }{ { name: "successful alerts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposCodeScanningAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{ - "ref": "main", - "state": "open", - "severity": "high", - "tool_name": "codeql", - }).andThen( - mockResponse(t, http.StatusOK, mockAlerts), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{ + "ref": "main", + "state": "open", + "severity": "high", + "tool_name": "codeql", + }).andThen( + mockResponse(t, http.StatusOK, mockAlerts), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -197,15 +188,12 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }, { name: "alerts listing fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposCodeScanningAlertsByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposCodeScanningAlertsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -219,13 +207,16 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler with new signature - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 3fe622379..e0df82c88 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -6,6 +6,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -36,8 +37,10 @@ type UserDetails struct { } // GetMe creates a tool to get details of the authenticated user. -func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetMe(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataContext, + mcp.Tool{ Name: "get_me", Description: t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request is about the user's own profile for GitHub. Or when information is missing to build other tool calls."), Annotations: &mcp.ToolAnnotations{ @@ -48,8 +51,8 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too // OpenAI strict mode requires the properties field to be present. InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -60,7 +63,7 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too "failed to get user", res, err, - ), nil, err + ), nil, nil } // Create minimal user representation instead of returning full user object @@ -91,7 +94,8 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Too } return MarshalledTextResult(minimalUser), nil, nil - }) + }, + ) } type TeamInfo struct { @@ -105,8 +109,10 @@ type OrganizationTeams struct { Teams []TeamInfo `json:"teams"` } -func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeams(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataContext, + mcp.Tool{ Name: "get_teams", Description: t("TOOL_GET_TEAMS_DESCRIPTION", "Get details of the teams the user is a member of. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -123,7 +129,7 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations }, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { user, err := OptionalParam[string](args, "user") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -133,7 +139,7 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations if user != "" { username = user } else { - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -149,7 +155,7 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations username = userResp.GetLogin() } - gqlClient, err := getGQLClient(ctx) + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } @@ -196,11 +202,14 @@ func GetTeams(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations } return MarshalledTextResult(organizations), nil, nil - } + }, + ) } -func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetTeamMembers(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataContext, + mcp.Tool{ Name: "get_team_members", Description: t("TOOL_GET_TEAM_MEMBERS_DESCRIPTION", "Get member usernames of a specific team in an organization. Limited to organizations accessible with current credentials"), Annotations: &mcp.ToolAnnotations{ @@ -222,7 +231,7 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe Required: []string{"org", "team_slug"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { org, err := RequiredParam[string](args, "org") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -233,7 +242,7 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe return utils.NewToolResultError(err.Error()), nil, nil } - gqlClient, err := getGQLClient(ctx) + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } @@ -263,5 +272,6 @@ func GetTeamMembers(getGQLClient GetGQLClientFn, t translations.TranslationHelpe } return MarshalledTextResult(members), nil, nil - } + }, + ) } diff --git a/pkg/github/context_tools_test.go b/pkg/github/context_tools_test.go index 96e21c233..3f4261e71 100644 --- a/pkg/github/context_tools_test.go +++ b/pkg/github/context_tools_test.go @@ -3,7 +3,7 @@ package github import ( "context" "encoding/json" - "fmt" + "net/http" "testing" "time" @@ -11,7 +11,6 @@ import ( "github.com/github/github-mcp-server/internal/toolsnaps" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,7 +19,8 @@ import ( func Test_GetMe(t *testing.T) { t.Parallel() - tool, _ := GetMe(nil, translations.NullTranslationHelper) + serverTool := GetMe(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) // Verify some basic very important properties @@ -47,7 +47,8 @@ func Test_GetMe(t *testing.T) { tests := []struct { name string - stubbedGetClientFn GetClientFn + mockedClient *http.Client + clientErr string // if set, GetClient returns this error requestArgs map[string]any expectToolError bool expectedUser *github.User @@ -55,28 +56,18 @@ func Test_GetMe(t *testing.T) { }{ { name: "successful get user", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetUser: mockResponse(t, http.StatusOK, mockUser), + }), requestArgs: map[string]any{}, expectToolError: false, expectedUser: mockUser, }, { name: "successful get user with reason", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetUser: mockResponse(t, http.StatusOK, mockUser), + }), requestArgs: map[string]any{ "reason": "Testing API", }, @@ -85,21 +76,16 @@ func Test_GetMe(t *testing.T) { }, { name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), + clientErr: "expected test error", requestArgs: map[string]any{}, expectToolError: true, expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetUser: badRequestHandler("expected test failure"), + }), requestArgs: map[string]any{}, expectToolError: true, expectedToolErrMsg: "expected test failure", @@ -108,21 +94,31 @@ func Test_GetMe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetMe(tc.stubbedGetClientFn, translations.NullTranslationHelper) + var deps ToolDependencies + if tc.clientErr != "" { + deps = stubDeps{clientFn: stubClientFnErr(tc.clientErr)} + } else { + deps = BaseDeps{Client: github.NewClient(tc.mockedClient)} + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, _ := handler(context.Background(), &request, tc.requestArgs) - textContent := getTextResult(t, result) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + // Unmarshal and verify the result var returnedUser MinimalUser - err := json.Unmarshal([]byte(textContent.Text), &returnedUser) + err = json.Unmarshal([]byte(textContent.Text), &returnedUser) require.NoError(t, err) // Verify minimal user details @@ -145,7 +141,8 @@ func Test_GetMe(t *testing.T) { func Test_GetTeams(t *testing.T) { t.Parallel() - tool, _ := GetTeams(nil, nil, translations.NullTranslationHelper) + serverTool := GetTeams(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_teams", tool.Name) @@ -214,49 +211,77 @@ func Test_GetTeams(t *testing.T) { }, }) + // Create GQL clients for different test scenarios - these are factory functions + // to ensure each test gets a fresh client + gqlClientForTestuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientForSpecificuser := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "specificuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoTeams := func() *githubv4.Client { + queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" + vars := map[string]interface{}{ + "login": "testuser", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + // Factory function for mock HTTP clients with user response + httpClientWithUser := func() *http.Client { + return MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetUser: mockResponse(t, http.StatusOK, mockUser), + }) + } + + httpClientUserFails := func() *http.Client { + return MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetUser: badRequestHandler("expected test failure"), + }) + } + tests := []struct { - name string - stubbedGetClientFn GetClientFn - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedTeamsCount int + name string + makeDeps func() ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedTeamsCount int }{ { name: "successful get teams", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientForTestuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 2, }, { - name: "successful get teams for specific user", - stubbedGetClientFn: nil, - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "specificuser", + name: "successful get teams for specific user", + makeDeps: func() ToolDependencies { + return BaseDeps{ + GQLClient: gqlClientForSpecificuser(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{ "user": "specificuser", @@ -266,62 +291,43 @@ func Test_GetTeams(t *testing.T) { }, { name: "no teams found", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($login:String!){user(login: $login){organizations(first: 100){nodes{login,teams(first: 100, userLogins: [$login]){nodes{name,slug,description}}}}}}" - vars := map[string]interface{}{ - "login": "testuser", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientWithUser()), + GQLClient: gqlClientNoTeams(), } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoTeamsResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil }, requestArgs: map[string]any{}, expectToolError: false, expectedTeamsCount: 0, }, { - name: "getting client fails", - stubbedGetClientFn: stubGetClientFnErr("expected test error"), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "failed to get GitHub client: expected test error", + name: "getting client fails", + makeDeps: func() ToolDependencies { + return stubDeps{clientFn: stubClientFnErr("expected test error")} + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "failed to get GitHub client: expected test error", }, { name: "get user fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUser, - badRequestHandler("expected test failure"), - ), - ), - ), - stubbedGetGQLClientFn: nil, - requestArgs: map[string]any{}, - expectToolError: true, - expectedToolErrMsg: "expected test failure", + makeDeps: func() ToolDependencies { + return BaseDeps{ + Client: github.NewClient(httpClientUserFails()), + } + }, + requestArgs: map[string]any{}, + expectToolError: true, + expectedToolErrMsg: "expected test failure", }, { name: "getting GraphQL client fails", - stubbedGetClientFn: stubGetClientFromHTTPFn( - mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetUser, - mockUser, - ), - ), - ), - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") + makeDeps: func() ToolDependencies { + return stubDeps{ + clientFn: stubClientFnFromHTTP(httpClientWithUser()), + gqlClientFn: stubGQLClientFnErr("GraphQL client error"), + } }, requestArgs: map[string]any{}, expectToolError: true, @@ -331,19 +337,23 @@ func Test_GetTeams(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeams(tc.stubbedGetClientFn, tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + deps := tc.makeDeps() + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var organizations []OrganizationTeams err = json.Unmarshal([]byte(textContent.Text), &organizations) require.NoError(t, err) @@ -372,7 +382,8 @@ func Test_GetTeams(t *testing.T) { func Test_GetTeamMembers(t *testing.T) { t.Parallel() - tool, _ := GetTeamMembers(nil, translations.NullTranslationHelper) + serverTool := GetTeamMembers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_team_members", tool.Name) @@ -405,26 +416,40 @@ func Test_GetTeamMembers(t *testing.T) { }, }) + // Create GQL clients for different test scenarios + gqlClientWithMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "testteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + + gqlClientNoMembers := func() *githubv4.Client { + queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" + vars := map[string]interface{}{ + "org": "testorg", + "teamSlug": "emptyteam", + } + matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) + httpClient := githubv4mock.NewMockedHTTPClient(matcher) + return githubv4.NewClient(httpClient) + } + tests := []struct { - name string - stubbedGetGQLClientFn GetGQLClientFn - requestArgs map[string]any - expectToolError bool - expectedToolErrMsg string - expectedMembersCount int + name string + deps ToolDependencies + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedMembersCount int }{ { name: "successful get team members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "testteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockTeamMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientWithMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -434,16 +459,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "team with no members", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - queryStr := "query($org:String!$teamSlug:String!){organization(login: $org){team(slug: $teamSlug){members(first: 100){nodes{login}}}}}" - vars := map[string]interface{}{ - "org": "testorg", - "teamSlug": "emptyteam", - } - matcher := githubv4mock.NewQueryMatcher(queryStr, vars, mockNoMembersResponse) - httpClient := githubv4mock.NewMockedHTTPClient(matcher) - return githubv4.NewClient(httpClient), nil - }, + deps: BaseDeps{GQLClient: gqlClientNoMembers()}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "emptyteam", @@ -453,9 +469,7 @@ func Test_GetTeamMembers(t *testing.T) { }, { name: "getting GraphQL client fails", - stubbedGetGQLClientFn: func(_ context.Context) (*githubv4.Client, error) { - return nil, fmt.Errorf("GraphQL client error") - }, + deps: stubDeps{gqlClientFn: stubGQLClientFnErr("GraphQL client error")}, requestArgs: map[string]any{ "org": "testorg", "team_slug": "testteam", @@ -467,19 +481,22 @@ func Test_GetTeamMembers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetTeamMembers(tc.stubbedGetGQLClientFn, translations.NullTranslationHelper) + handler := serverTool.Handler(tc.deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), tc.deps), &request) require.NoError(t, err) - textContent := getTextResult(t, result) if tc.expectToolError { - assert.True(t, result.IsError, "expected tool call result to be an error") - assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + require.True(t, result.IsError, "expected tool call result to be an error") + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedToolErrMsg) return } + require.False(t, result.IsError) + textContent := getTextResult(t, result) + var members []string err = json.Unmarshal([]byte(textContent.Text), &members) require.NoError(t, err) diff --git a/pkg/github/dependabot.go b/pkg/github/dependabot.go index 351cbdb37..db6352dab 100644 --- a/pkg/github/dependabot.go +++ b/pkg/github/dependabot.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,168 +16,168 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetDependabotAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_dependabot_alert", - Description: t("TOOL_GET_DEPENDABOT_ALERT_DESCRIPTION", "Get details of a specific dependabot alert in a GitHub repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_DEPENDABOT_ALERT_USER_TITLE", "Get dependabot alert"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "The owner of the repository.", - }, - "repo": { - Type: "string", - Description: "The name of the repository.", - }, - "alertNumber": { - Type: "number", - Description: "The number of the alert.", +func GetDependabotAlert(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDependabot, + mcp.Tool{ + Name: "get_dependabot_alert", + Description: t("TOOL_GET_DEPENDABOT_ALERT_DESCRIPTION", "Get details of a specific dependabot alert in a GitHub repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_DEPENDABOT_ALERT_USER_TITLE", "Get dependabot alert"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "The owner of the repository.", + }, + "repo": { + Type: "string", + Description: "The name of the repository.", + }, + "alertNumber": { + Type: "number", + Description: "The number of the alert.", + }, }, + Required: []string{"owner", "repo", "alertNumber"}, }, - Required: []string{"owner", "repo", "alertNumber"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - alertNumber, err := RequiredInt(args, "alertNumber") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } - - alert, resp, err := client.Dependabot.GetRepoAlert(ctx, owner, repo, alertNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get alert with number '%d'", alertNumber), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + alertNumber, err := RequiredInt(args, "alertNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err } - return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil - } - r, err := json.Marshal(alert) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, err - } + alert, resp, err := client.Dependabot.GetRepoAlert(ctx, owner, repo, alertNumber) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get alert with number '%d'", alertNumber), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(alert) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, err + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } -func ListDependabotAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_dependabot_alerts", - Description: t("TOOL_LIST_DEPENDABOT_ALERTS_DESCRIPTION", "List dependabot alerts in a GitHub repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_DEPENDABOT_ALERTS_USER_TITLE", "List dependabot alerts"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "The owner of the repository.", - }, - "repo": { - Type: "string", - Description: "The name of the repository.", - }, - "state": { - Type: "string", - Description: "Filter dependabot alerts by state. Defaults to open", - Enum: []any{"open", "fixed", "dismissed", "auto_dismissed"}, - Default: json.RawMessage(`"open"`), - }, - "severity": { - Type: "string", - Description: "Filter dependabot alerts by severity", - Enum: []any{"low", "medium", "high", "critical"}, +func ListDependabotAlerts(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDependabot, + mcp.Tool{ + Name: "list_dependabot_alerts", + Description: t("TOOL_LIST_DEPENDABOT_ALERTS_DESCRIPTION", "List dependabot alerts in a GitHub repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_DEPENDABOT_ALERTS_USER_TITLE", "List dependabot alerts"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "The owner of the repository.", + }, + "repo": { + Type: "string", + Description: "The name of the repository.", + }, + "state": { + Type: "string", + Description: "Filter dependabot alerts by state. Defaults to open", + Enum: []any{"open", "fixed", "dismissed", "auto_dismissed"}, + Default: json.RawMessage(`"open"`), + }, + "severity": { + Type: "string", + Description: "Filter dependabot alerts by severity", + Enum: []any{"low", "medium", "high", "critical"}, + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - severity, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err - } - - alerts, resp, err := client.Dependabot.ListRepoAlerts(ctx, owner, repo, &github.ListAlertsOptions{ - State: ToStringPtr(state), - Severity: ToStringPtr(severity), - }) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + severity, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err } - return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil - } - r, err := json.Marshal(alerts) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, err - } + alerts, resp, err := client.Dependabot.ListRepoAlerts(ctx, owner, repo, &github.ListAlertsOptions{ + State: ToStringPtr(state), + Severity: ToStringPtr(severity), + }) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list alerts for repository '%s/%s'", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(alerts) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, err + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } diff --git a/pkg/github/dependabot_test.go b/pkg/github/dependabot_test.go index 24e5130e9..e57405a8c 100644 --- a/pkg/github/dependabot_test.go +++ b/pkg/github/dependabot_test.go @@ -9,15 +9,14 @@ import ( "github.com/github/github-mcp-server/internal/toolsnaps" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetDependabotAlert(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetDependabotAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := GetDependabotAlert(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) // Validate tool schema @@ -42,12 +41,9 @@ func Test_GetDependabotAlert(t *testing.T) { }{ { name: "successful alert fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposDependabotAlertsByOwnerByRepoByAlertNumber, - mockAlert, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposDependabotAlertsByOwnerByRepoByAlertNumber: mockResponse(t, http.StatusOK, mockAlert), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -58,15 +54,12 @@ func Test_GetDependabotAlert(t *testing.T) { }, { name: "alert fetch fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposDependabotAlertsByOwnerByRepoByAlertNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposDependabotAlertsByOwnerByRepoByAlertNumber: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -81,13 +74,14 @@ func Test_GetDependabotAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetDependabotAlert(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{Client: client} + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -117,8 +111,8 @@ func Test_GetDependabotAlert(t *testing.T) { func Test_ListDependabotAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListDependabotAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListDependabotAlerts(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_dependabot_alerts", tool.Name) @@ -153,16 +147,13 @@ func Test_ListDependabotAlerts(t *testing.T) { }{ { name: "successful open alerts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposDependabotAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{ - "state": "open", - }).andThen( - mockResponse(t, http.StatusOK, []*github.DependabotAlert{&criticalAlert}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposDependabotAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{ + "state": "open", + }).andThen( + mockResponse(t, http.StatusOK, []*github.DependabotAlert{&criticalAlert}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -173,16 +164,13 @@ func Test_ListDependabotAlerts(t *testing.T) { }, { name: "successful severity filtered listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposDependabotAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{ - "severity": "high", - }).andThen( - mockResponse(t, http.StatusOK, []*github.DependabotAlert{&highSeverityAlert}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposDependabotAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{ + "severity": "high", + }).andThen( + mockResponse(t, http.StatusOK, []*github.DependabotAlert{&highSeverityAlert}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -193,14 +181,11 @@ func Test_ListDependabotAlerts(t *testing.T) { }, { name: "successful all alerts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposDependabotAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{}).andThen( - mockResponse(t, http.StatusOK, []*github.DependabotAlert{&criticalAlert, &highSeverityAlert}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposDependabotAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{}).andThen( + mockResponse(t, http.StatusOK, []*github.DependabotAlert{&criticalAlert, &highSeverityAlert}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -210,15 +195,12 @@ func Test_ListDependabotAlerts(t *testing.T) { }, { name: "alerts listing fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposDependabotAlertsByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposDependabotAlertsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -231,11 +213,12 @@ func Test_ListDependabotAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListDependabotAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{Client: client} + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go new file mode 100644 index 000000000..d23e993c3 --- /dev/null +++ b/pkg/github/dependencies.go @@ -0,0 +1,168 @@ +package github + +import ( + "context" + "errors" + + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" + gogithub "github.com/google/go-github/v79/github" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/shurcooL/githubv4" +) + +// depsContextKey is the context key for ToolDependencies. +// Using a private type prevents collisions with other packages. +type depsContextKey struct{} + +// ErrDepsNotInContext is returned when ToolDependencies is not found in context. +var ErrDepsNotInContext = errors.New("ToolDependencies not found in context; use ContextWithDeps to inject") + +// ContextWithDeps returns a new context with the ToolDependencies stored in it. +// This is used to inject dependencies at request time rather than at registration time, +// avoiding expensive closure creation during server initialization. +// +// For the local server, this is called once at startup since deps don't change. +// For the remote server, this is called per-request with request-specific deps. +func ContextWithDeps(ctx context.Context, deps ToolDependencies) context.Context { + return context.WithValue(ctx, depsContextKey{}, deps) +} + +// DepsFromContext retrieves ToolDependencies from the context. +// Returns the deps and true if found, or nil and false if not present. +// Use MustDepsFromContext if you want to panic on missing deps (for handlers +// that require deps to function). +func DepsFromContext(ctx context.Context) (ToolDependencies, bool) { + deps, ok := ctx.Value(depsContextKey{}).(ToolDependencies) + return deps, ok +} + +// MustDepsFromContext retrieves ToolDependencies from the context. +// Panics if deps are not found - use this in handlers where deps are required. +func MustDepsFromContext(ctx context.Context) ToolDependencies { + deps, ok := DepsFromContext(ctx) + if !ok { + panic(ErrDepsNotInContext) + } + return deps +} + +// ToolDependencies defines the interface for dependencies that tool handlers need. +// This is an interface to allow different implementations: +// - Local server: stores closures that create clients on demand +// - Remote server: can store pre-created clients per-request for efficiency +// +// The toolsets package uses `any` for deps and tool handlers type-assert to this interface. +type ToolDependencies interface { + // GetClient returns a GitHub REST API client + GetClient(ctx context.Context) (*gogithub.Client, error) + + // GetGQLClient returns a GitHub GraphQL client + GetGQLClient(ctx context.Context) (*githubv4.Client, error) + + // GetRawClient returns a raw content client for GitHub + GetRawClient(ctx context.Context) (*raw.Client, error) + + // GetRepoAccessCache returns the lockdown mode repo access cache + GetRepoAccessCache() *lockdown.RepoAccessCache + + // GetT returns the translation helper function + GetT() translations.TranslationHelperFunc + + // GetFlags returns feature flags + GetFlags() FeatureFlags + + // GetContentWindowSize returns the content window size for log truncation + GetContentWindowSize() int +} + +// BaseDeps is the standard implementation of ToolDependencies for the local server. +// It stores pre-created clients. The remote server can create its own struct +// implementing ToolDependencies with different client creation strategies. +type BaseDeps struct { + // Pre-created clients + Client *gogithub.Client + GQLClient *githubv4.Client + RawClient *raw.Client + + // Static dependencies + RepoAccessCache *lockdown.RepoAccessCache + T translations.TranslationHelperFunc + Flags FeatureFlags + ContentWindowSize int +} + +// NewBaseDeps creates a BaseDeps with the provided clients and configuration. +func NewBaseDeps( + client *gogithub.Client, + gqlClient *githubv4.Client, + rawClient *raw.Client, + repoAccessCache *lockdown.RepoAccessCache, + t translations.TranslationHelperFunc, + flags FeatureFlags, + contentWindowSize int, +) *BaseDeps { + return &BaseDeps{ + Client: client, + GQLClient: gqlClient, + RawClient: rawClient, + RepoAccessCache: repoAccessCache, + T: t, + Flags: flags, + ContentWindowSize: contentWindowSize, + } +} + +// GetClient implements ToolDependencies. +func (d BaseDeps) GetClient(_ context.Context) (*gogithub.Client, error) { + return d.Client, nil +} + +// GetGQLClient implements ToolDependencies. +func (d BaseDeps) GetGQLClient(_ context.Context) (*githubv4.Client, error) { + return d.GQLClient, nil +} + +// GetRawClient implements ToolDependencies. +func (d BaseDeps) GetRawClient(_ context.Context) (*raw.Client, error) { + return d.RawClient, nil +} + +// GetRepoAccessCache implements ToolDependencies. +func (d BaseDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return d.RepoAccessCache } + +// GetT implements ToolDependencies. +func (d BaseDeps) GetT() translations.TranslationHelperFunc { return d.T } + +// GetFlags implements ToolDependencies. +func (d BaseDeps) GetFlags() FeatureFlags { return d.Flags } + +// GetContentWindowSize implements ToolDependencies. +func (d BaseDeps) GetContentWindowSize() int { return d.ContentWindowSize } + +// NewTool creates a ServerTool that retrieves ToolDependencies from context at call time. +// This avoids creating closures at registration time, which is important for performance +// in servers that create a new server instance per request (like the remote server). +// +// The handler function receives deps extracted from context via MustDepsFromContext. +// Ensure ContextWithDeps is called to inject deps before any tool handlers are invoked. +func NewTool[In, Out any](toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(ctx context.Context, deps ToolDependencies, req *mcp.CallToolRequest, args In) (*mcp.CallToolResult, Out, error)) inventory.ServerTool { + return inventory.NewServerToolWithContextHandler(tool, toolset, func(ctx context.Context, req *mcp.CallToolRequest, args In) (*mcp.CallToolResult, Out, error) { + deps := MustDepsFromContext(ctx) + return handler(ctx, deps, req, args) + }) +} + +// NewToolFromHandler creates a ServerTool that retrieves ToolDependencies from context at call time. +// Use this when you have a handler that conforms to mcp.ToolHandler directly. +// +// The handler function receives deps extracted from context via MustDepsFromContext. +// Ensure ContextWithDeps is called to inject deps before any tool handlers are invoked. +func NewToolFromHandler(toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(ctx context.Context, deps ToolDependencies, req *mcp.CallToolRequest) (*mcp.CallToolResult, error)) inventory.ServerTool { + return inventory.NewServerToolWithRawContextHandler(tool, toolset, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + deps := MustDepsFromContext(ctx) + return handler(ctx, deps, req) + }) +} diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 8a5019701..c891ba294 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/go-viper/mapstructure/v2" @@ -121,8 +122,10 @@ func getQueryType(useOrdering bool, categoryID *githubv4.ID) any { return &BasicNoOrder{} } -func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListDiscussions(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDiscussions, + mcp.Tool{ Name: "list_discussions", Description: t("TOOL_LIST_DISCUSSIONS_DESCRIPTION", "List discussions for a repository or organisation."), Annotations: &mcp.ToolAnnotations{ @@ -158,7 +161,7 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp Required: []string{"owner"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -198,7 +201,7 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp return nil, nil, err } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } @@ -267,11 +270,14 @@ func ListDiscussions(getGQLClient GetGQLClientFn, t translations.TranslationHelp return nil, nil, fmt.Errorf("failed to marshal discussions: %w", err) } return utils.NewToolResultText(string(out)), nil, nil - } + }, + ) } -func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDiscussions, + mcp.Tool{ Name: "get_discussion", Description: t("TOOL_GET_DISCUSSION_DESCRIPTION", "Get a specific discussion by ID"), Annotations: &mcp.ToolAnnotations{ @@ -297,7 +303,7 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper Required: []string{"owner", "repo", "discussionNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { // Decode params var params struct { Owner string @@ -307,7 +313,7 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper if err := mapstructure.Decode(args, ¶ms); err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } @@ -367,11 +373,14 @@ func GetDiscussion(getGQLClient GetGQLClientFn, t translations.TranslationHelper } return utils.NewToolResultText(string(out)), nil, nil - } + }, + ) } -func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDiscussions, + mcp.Tool{ Name: "get_discussion_comments", Description: t("TOOL_GET_DISCUSSION_COMMENTS_DESCRIPTION", "Get comments from a discussion"), Annotations: &mcp.ToolAnnotations{ @@ -397,7 +406,7 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati Required: []string{"owner", "repo", "discussionNumber"}, }), }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { // Decode params var params struct { Owner string @@ -429,7 +438,7 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati paginationParams.First = &defaultFirst } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } @@ -490,11 +499,14 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati } return utils.NewToolResultText(string(out)), nil, nil - } + }, + ) } -func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListDiscussionCategories(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataDiscussions, + mcp.Tool{ Name: "list_discussion_categories", Description: t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation."), Annotations: &mcp.ToolAnnotations{ @@ -516,7 +528,7 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl Required: []string{"owner"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -531,7 +543,7 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl repo = ".github" } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } @@ -587,5 +599,6 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl return nil, nil, fmt.Errorf("failed to marshal discussion categories: %w", err) } return utils.NewToolResultText(string(out)), nil, nil - } + }, + ) } diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 1a73d523e..0ec998280 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -213,13 +213,13 @@ var ( ) func Test_ListDiscussions(t *testing.T) { - mockClient := githubv4.NewClient(nil) - toolDef, _ := ListDiscussions(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) + toolDef := ListDiscussions(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) - assert.Equal(t, "list_discussions", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "list_discussions", tool.Name) + assert.NotEmpty(t, tool.Description) + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -447,10 +447,11 @@ func Test_ListDiscussions(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - _, handler := ListDiscussions(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{GQLClient: gqlClient} + handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -494,12 +495,13 @@ func Test_ListDiscussions(t *testing.T) { func Test_GetDiscussion(t *testing.T) { // Verify tool definition and schema - toolDef, _ := GetDiscussion(nil, translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) + toolDef := GetDiscussion(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) - assert.Equal(t, "get_discussion", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_discussion", tool.Name) + assert.NotEmpty(t, tool.Description) + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -557,11 +559,12 @@ func Test_GetDiscussion(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetDiscussion, vars, tc.response) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - _, handler := GetDiscussion(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{GQLClient: gqlClient} + handler := toolDef.Handler(deps) reqParams := map[string]interface{}{"owner": "owner", "repo": "repo", "discussionNumber": int32(1)} req := createMCPRequest(reqParams) - res, _, err := handler(context.Background(), &req, reqParams) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -589,12 +592,13 @@ func Test_GetDiscussion(t *testing.T) { func Test_GetDiscussionComments(t *testing.T) { // Verify tool definition and schema - toolDef, _ := GetDiscussionComments(nil, translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) + toolDef := GetDiscussionComments(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) - assert.Equal(t, "get_discussion_comments", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + assert.Equal(t, "get_discussion_comments", tool.Name) + assert.NotEmpty(t, tool.Description) + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -635,7 +639,8 @@ func Test_GetDiscussionComments(t *testing.T) { matcher := githubv4mock.NewQueryMatcher(qGetComments, vars, mockResponse) httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - _, handler := GetDiscussionComments(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{GQLClient: gqlClient} + handler := toolDef.Handler(deps) reqParams := map[string]interface{}{ "owner": "owner", @@ -644,7 +649,7 @@ func Test_GetDiscussionComments(t *testing.T) { } request := createMCPRequest(reqParams) - result, _, err := handler(context.Background(), &request, reqParams) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -671,14 +676,14 @@ func Test_GetDiscussionComments(t *testing.T) { } func Test_ListDiscussionCategories(t *testing.T) { - mockClient := githubv4.NewClient(nil) - toolDef, _ := ListDiscussionCategories(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(toolDef.Name, toolDef)) - - assert.Equal(t, "list_discussion_categories", toolDef.Name) - assert.NotEmpty(t, toolDef.Description) - assert.Contains(t, toolDef.Description, "or organisation") - schema, ok := toolDef.InputSchema.(*jsonschema.Schema) + toolDef := ListDiscussionCategories(translations.NullTranslationHelper) + tool := toolDef.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "list_discussion_categories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.Description, "or organisation") + schema, ok := tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -786,10 +791,11 @@ func Test_ListDiscussionCategories(t *testing.T) { httpClient := githubv4mock.NewMockedHTTPClient(matcher) gqlClient := githubv4.NewClient(httpClient) - _, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{GQLClient: gqlClient} + handler := toolDef.Handler(deps) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index c65510246..5c7d31d4e 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -5,28 +5,67 @@ import ( "encoding/json" "fmt" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ToolsetEnum(toolsetGroup *toolsets.ToolsetGroup) []any { - toolsetNames := make([]any, 0, len(toolsetGroup.Toolsets)) - for name := range toolsetGroup.Toolsets { - toolsetNames = append(toolsetNames, name) +// DynamicToolDependencies contains dependencies for dynamic toolset management tools. +// It includes the managed Inventory, the server for registration, and the deps +// that will be passed to tools when they are dynamically enabled. +type DynamicToolDependencies struct { + // Server is the MCP server to register tools with + Server *mcp.Server + // Inventory contains all available tools, resources and prompts that can be enabled dynamically + Inventory *inventory.Inventory + // ToolDeps are the dependencies passed to tools when they are registered + ToolDeps any + // T is the translation helper function + T translations.TranslationHelperFunc +} + +// NewDynamicTool creates a ServerTool with fully-typed DynamicToolDependencies. +// Dynamic tools use a different dependency structure (DynamicToolDependencies) than regular +// tools (ToolDependencies), so they intentionally use the closure pattern. +func NewDynamicTool(toolset inventory.ToolsetMetadata, tool mcp.Tool, handler func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any]) inventory.ServerTool { + //nolint:staticcheck // SA1019: Dynamic tools use a different deps structure, closure pattern is intentional + return inventory.NewServerTool(tool, toolset, func(d any) mcp.ToolHandlerFor[map[string]any, any] { + return handler(d.(DynamicToolDependencies)) + }) +} + +// toolsetIDsEnum returns the list of toolset IDs as an enum for JSON Schema. +func toolsetIDsEnum(r *inventory.Inventory) []any { + toolsetIDs := r.ToolsetIDs() + result := make([]any, len(toolsetIDs)) + for i, id := range toolsetIDs { + result[i] = id + } + return result +} + +// DynamicTools returns the tools for dynamic toolset management. +// These tools allow runtime discovery and enablement of inventory. +// The r parameter provides the available toolset IDs for JSON Schema enums. +func DynamicTools(r *inventory.Inventory) []inventory.ServerTool { + return []inventory.ServerTool{ + ListAvailableToolsets(), + GetToolsetsTools(r), + EnableToolset(r), } - return toolsetNames } -func EnableToolset(s *mcp.Server, toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +// EnableToolset creates a tool that enables a toolset at runtime. +func EnableToolset(r *inventory.Inventory) inventory.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ Name: "enable_toolset", - Description: t("TOOL_ENABLE_TOOLSET_DESCRIPTION", "Enable one of the sets of tools the GitHub MCP server provides, use get_toolset_tools and list_available_toolsets first to see what this will enable"), + Description: "Enable one of the sets of tools the GitHub MCP server provides, use get_toolset_tools and list_available_toolsets first to see what this will enable", Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_ENABLE_TOOLSET_USER_TITLE", "Enable a toolset"), - // Not modifying GitHub data so no need to show a warning + Title: "Enable a toolset", ReadOnlyHint: true, }, InputSchema: &jsonschema.Schema{ @@ -35,46 +74,53 @@ func EnableToolset(s *mcp.Server, toolsetGroup *toolsets.ToolsetGroup, t transla "toolset": { Type: "string", Description: "The name of the toolset to enable", - Enum: ToolsetEnum(toolsetGroup), + Enum: toolsetIDsEnum(r), }, }, Required: []string{"toolset"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsets back to a map for JSON serialization - toolsetName, err := RequiredParam[string](args, "toolset") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolset := toolsetGroup.Toolsets[toolsetName] - if toolset == nil { - return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil - } - if toolset.Enabled { - return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil - } + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + toolsetName, err := RequiredParam[string](args, "toolset") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - toolset.Enabled = true + toolsetID := inventory.ToolsetID(toolsetName) - // caution: this currently affects the global tools and notifies all clients: - // - // Send notification to all initialized sessions - // s.sendNotificationToAllClients("notifications/tools/list_changed", nil) - for _, serverTool := range toolset.GetActiveTools() { - serverTool.RegisterFunc(s) - } + if !deps.Inventory.HasToolset(toolsetID) { + return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil + } + + if deps.Inventory.IsToolsetEnabled(toolsetID) { + return utils.NewToolResultText(fmt.Sprintf("Toolset %s is already enabled", toolsetName)), nil, nil + } - return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled", toolsetName)), nil, nil - }) + // Mark the toolset as enabled so IsToolsetEnabled returns true + deps.Inventory.EnableToolset(toolsetID) + + // Get tools for this toolset and register them with the managed deps + toolsForToolset := deps.Inventory.ToolsForToolset(toolsetID) + for _, st := range toolsForToolset { + st.RegisterFunc(deps.Server, deps.ToolDeps) + } + + return utils.NewToolResultText(fmt.Sprintf("Toolset %s enabled with %d tools", toolsetName, len(toolsForToolset))), nil, nil + } + }, + ) } -func ListAvailableToolsets(toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +// ListAvailableToolsets creates a tool that lists all available inventory. +func ListAvailableToolsets() inventory.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ Name: "list_available_toolsets", - Description: t("TOOL_LIST_AVAILABLE_TOOLSETS_DESCRIPTION", "List all available toolsets this GitHub MCP server can offer, providing the enabled status of each. Use this when a task could be achieved with a GitHub tool and the currently available tools aren't enough. Call get_toolset_tools with these toolset names to discover specific tools you can call"), + Description: "List all available toolsets this GitHub MCP server can offer, providing the enabled status of each. Use this when a task could be achieved with a GitHub tool and the currently available tools aren't enough. Call get_toolset_tools with these toolset names to discover specific tools you can call", Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_AVAILABLE_TOOLSETS_USER_TITLE", "List available toolsets"), + Title: "List available toolsets", ReadOnlyHint: true, }, InputSchema: &jsonschema.Schema{ @@ -82,38 +128,42 @@ func ListAvailableToolsets(toolsetGroup *toolsets.ToolsetGroup, t translations.T Properties: map[string]*jsonschema.Schema{}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsetGroup back to a map for JSON serialization + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, _ map[string]any) (*mcp.CallToolResult, any, error) { + toolsetIDs := deps.Inventory.ToolsetIDs() + descriptions := deps.Inventory.ToolsetDescriptions() - payload := []map[string]string{} - - for name, ts := range toolsetGroup.Toolsets { - { + payload := make([]map[string]string, 0, len(toolsetIDs)) + for _, id := range toolsetIDs { t := map[string]string{ - "name": name, - "description": ts.Description, + "name": string(id), + "description": descriptions[id], "can_enable": "true", - "currently_enabled": fmt.Sprintf("%t", ts.Enabled), + "currently_enabled": fmt.Sprintf("%t", deps.Inventory.IsToolsetEnabled(id)), } payload = append(payload, t) } - } - r, err := json.Marshal(payload) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal features: %w", err) - } + r, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal features: %w", err) + } - return utils.NewToolResultText(string(r)), nil, nil - }) + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } -func GetToolsetsTools(toolsetGroup *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +// GetToolsetsTools creates a tool that lists all tools in a specific toolset. +func GetToolsetsTools(r *inventory.Inventory) inventory.ServerTool { + return NewDynamicTool( + ToolsetMetadataDynamic, + mcp.Tool{ Name: "get_toolset_tools", - Description: t("TOOL_GET_TOOLSET_TOOLS_DESCRIPTION", "Lists all the capabilities that are enabled with the specified toolset, use this to get clarity on whether enabling a toolset would help you to complete a task"), + Description: "Lists all the capabilities that are enabled with the specified toolset, use this to get clarity on whether enabling a toolset would help you to complete a task", Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_TOOLSET_TOOLS_USER_TITLE", "List all tools in a toolset"), + Title: "List all tools in a toolset", ReadOnlyHint: true, }, InputSchema: &jsonschema.Schema{ @@ -122,39 +172,46 @@ func GetToolsetsTools(toolsetGroup *toolsets.ToolsetGroup, t translations.Transl "toolset": { Type: "string", Description: "The name of the toolset you want to get the tools for", - Enum: ToolsetEnum(toolsetGroup), + Enum: toolsetIDsEnum(r), }, }, Required: []string{"toolset"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // We need to convert the toolsetGroup back to a map for JSON serialization - toolsetName, err := RequiredParam[string](args, "toolset") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - toolset := toolsetGroup.Toolsets[toolsetName] - if toolset == nil { - return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil - } - payload := []map[string]string{} - - for _, st := range toolset.GetAvailableTools() { - tool := map[string]string{ - "name": st.Tool.Name, - "description": st.Tool.Description, - "can_enable": "true", - "toolset": toolsetName, + func(deps DynamicToolDependencies) mcp.ToolHandlerFor[map[string]any, any] { + return func(_ context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + toolsetName, err := RequiredParam[string](args, "toolset") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - payload = append(payload, tool) - } - r, err := json.Marshal(payload) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal features: %w", err) - } + toolsetID := inventory.ToolsetID(toolsetName) - return utils.NewToolResultText(string(r)), nil, nil - }) + if !deps.Inventory.HasToolset(toolsetID) { + return utils.NewToolResultError(fmt.Sprintf("Toolset %s not found", toolsetName)), nil, nil + } + + // Get all tools for this toolset (ignoring current filters for discovery) + toolsInToolset := deps.Inventory.ToolsForToolset(toolsetID) + payload := make([]map[string]string, 0, len(toolsInToolset)) + + for _, st := range toolsInToolset { + tool := map[string]string{ + "name": st.Tool.Name, + "description": st.Tool.Description, + "can_enable": "true", + "toolset": toolsetName, + } + payload = append(payload, tool) + } + + r, err := json.Marshal(payload) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal features: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + } + }, + ) } diff --git a/pkg/github/dynamic_tools_test.go b/pkg/github/dynamic_tools_test.go new file mode 100644 index 000000000..8d12b78c2 --- /dev/null +++ b/pkg/github/dynamic_tools_test.go @@ -0,0 +1,231 @@ +package github + +import ( + "context" + "encoding/json" + "testing" + + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createDynamicRequest creates an MCP request with the given arguments for dynamic tools. +func createDynamicRequest(args map[string]any) *mcp.CallToolRequest { + argsJSON, _ := json.Marshal(args) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Arguments: json.RawMessage(argsJSON), + }, + } +} + +func TestDynamicTools_ListAvailableToolsets(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewInventory(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Inventory: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the list_available_toolsets tool + tool := ListAvailableToolsets() + handler := tool.Handler(deps) + + // Call the handler + result, err := handler(context.Background(), createDynamicRequest(map[string]any{})) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var toolsets []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &toolsets) + require.NoError(t, err) + + // Verify we got toolsets + assert.NotEmpty(t, toolsets, "should have available toolsets") + + // Find the repos toolset and verify it's not enabled + var reposToolset map[string]string + for _, ts := range toolsets { + if ts["name"] == "repos" { + reposToolset = ts + break + } + } + require.NotNil(t, reposToolset, "repos toolset should exist") + assert.Equal(t, "false", reposToolset["currently_enabled"], "repos should not be enabled initially") +} + +func TestDynamicTools_GetToolsetTools(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewInventory(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Inventory: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the get_toolset_tools tool + tool := GetToolsetsTools(reg) + handler := tool.Handler(deps) + + // Call the handler for repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Parse the result + var tools []map[string]string + textContent := result.Content[0].(*mcp.TextContent) + err = json.Unmarshal([]byte(textContent.Text), &tools) + require.NoError(t, err) + + // Verify we got tools for the repos toolset + assert.NotEmpty(t, tools, "repos toolset should have tools") + + // Verify at least get_commit is there (a repos toolset tool) + var foundGetCommit bool + for _, tool := range tools { + if tool["name"] == "get_commit" { + foundGetCommit = true + break + } + } + assert.True(t, foundGetCommit, "get_commit should be in repos toolset") +} + +func TestDynamicTools_EnableToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewInventory(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Inventory: reg, + ToolDeps: NewBaseDeps(nil, nil, nil, nil, translations.NullTranslationHelper, FeatureFlags{}, 0), + T: translations.NullTranslationHelper, + } + + // Verify repos is not enabled initially + assert.False(t, reg.IsToolsetEnabled(inventory.ToolsetID("repos"))) + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Enable the repos toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Content, 1) + + // Verify the toolset is now enabled + assert.True(t, reg.IsToolsetEnabled(inventory.ToolsetID("repos")), "repos should be enabled after enable_toolset") + + // Verify the success message + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "enabled") + + // Try enabling again - should say already enabled + result2, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "repos", + })) + require.NoError(t, err) + textContent2 := result2.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent2.Text, "already enabled") +} + +func TestDynamicTools_EnableToolset_InvalidToolset(t *testing.T) { + // Build a registry with no toolsets enabled (dynamic mode) + reg := NewInventory(translations.NullTranslationHelper). + WithToolsets([]string{}). + Build() + + // Create a mock server + server := mcp.NewServer(&mcp.Implementation{Name: "test"}, nil) + + // Create dynamic tool dependencies + deps := DynamicToolDependencies{ + Server: server, + Inventory: reg, + ToolDeps: nil, + T: translations.NullTranslationHelper, + } + + // Get the enable_toolset tool + tool := EnableToolset(reg) + handler := tool.Handler(deps) + + // Try to enable a non-existent toolset + result, err := handler(context.Background(), createDynamicRequest(map[string]any{ + "toolset": "nonexistent", + })) + require.NoError(t, err) + require.NotNil(t, result) + + // Should be an error result + textContent := result.Content[0].(*mcp.TextContent) + assert.Contains(t, textContent.Text, "not found") +} + +func TestDynamicTools_ToolsetsEnum(t *testing.T) { + // Build a registry + reg := NewInventory(translations.NullTranslationHelper).Build() + + // Get tools to verify they have proper enum values + tools := DynamicTools(reg) + + // Find enable_toolset and get_toolset_tools + for _, tool := range tools { + if tool.Tool.Name == "enable_toolset" || tool.Tool.Name == "get_toolset_tools" { + // Verify the toolset property has an enum + schema := tool.Tool.InputSchema.(*jsonschema.Schema) + toolsetProp := schema.Properties["toolset"] + require.NotNil(t, toolsetProp, "toolset property should exist") + assert.NotEmpty(t, toolsetProp.Enum, "toolset property should have enum values") + + // Verify repos is in the enum + var foundRepos bool + for _, v := range toolsetProp.Enum { + if v == inventory.ToolsetID("repos") { + foundRepos = true + break + } + } + assert.True(t, foundRepos, "repos should be in toolset enum for %s", tool.Tool.Name) + } + } +} diff --git a/pkg/github/gists.go b/pkg/github/gists.go index b54553aac..4d741b88d 100644 --- a/pkg/github/gists.go +++ b/pkg/github/gists.go @@ -7,6 +7,8 @@ import ( "io" "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,346 +17,346 @@ import ( ) // ListGists creates a tool to list gists for a user -func ListGists(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_gists", - Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_GISTS", "List Gists"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "username": { - Type: "string", - Description: "GitHub username (omit for authenticated user's gists)", - }, - "since": { - Type: "string", - Description: "Only gists updated after this time (ISO 8601 timestamp)", - }, +func ListGists(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataGists, + mcp.Tool{ + Name: "list_gists", + Description: t("TOOL_LIST_GISTS_DESCRIPTION", "List gists for a user"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_GISTS", "List Gists"), + ReadOnlyHint: true, }, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - since, err := OptionalParam[string](args, "since") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.GistListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "username": { + Type: "string", + Description: "GitHub username (omit for authenticated user's gists)", + }, + "since": { + Type: "string", + Description: "Only gists updated after this time (ISO 8601 timestamp)", + }, + }, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + since, err := OptionalParam[string](args, "since") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Parse since timestamp if provided - if since != "" { - sinceTime, err := parseISOTimestamp(since) + pagination, err := OptionalPaginationParams(args) if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil + return utils.NewToolResultError(err.Error()), nil, nil } - opts.Since = sinceTime - } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.GistListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - gists, resp, err := client.Gists.List(ctx, username, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list gists: %w", err) - } - defer func() { _ = resp.Body.Close() }() + // Parse since timestamp if provided + if since != "" { + sinceTime, err := parseISOTimestamp(since) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid since timestamp: %v", err)), nil, nil + } + opts.Since = sinceTime + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list gists: %s", string(body))), nil, nil - } - r, err := json.Marshal(gists) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + gists, resp, err := client.Gists.List(ctx, username, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to list gists", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list gists", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(gists) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // GetGist creates a tool to get the content of a gist -func GetGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_gist", - Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_GIST", "Get Gist Content"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "The ID of the gist", +func GetGist(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataGists, + mcp.Tool{ + Name: "get_gist", + Description: t("TOOL_GET_GIST_DESCRIPTION", "Get gist content of a particular gist, by gist ID"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_GIST", "Get Gist Content"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "The ID of the gist", + }, }, + Required: []string{"gist_id"}, }, - Required: []string{"gist_id"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - gist, resp, err := client.Gists.Get(ctx, gistID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return utils.NewToolResultError(err.Error()), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get gist: %s", string(body))), nil, nil - } - r, err := json.Marshal(gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + gist, resp, err := client.Gists.Get(ctx, gistID) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get gist", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get gist", resp, body), nil, nil + } - return tool, handler + r, err := json.Marshal(gist) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // CreateGist creates a tool to create a new gist -func CreateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_gist", - Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_GIST", "Create Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "description": { - Type: "string", - Description: "Description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename for simple single-file gist creation", - }, - "content": { - Type: "string", - Description: "Content for simple single-file gist creation", - }, - "public": { - Type: "boolean", - Description: "Whether the gist is public", - Default: json.RawMessage(`false`), +func CreateGist(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataGists, + mcp.Tool{ + Name: "create_gist", + Description: t("TOOL_CREATE_GIST_DESCRIPTION", "Create a new gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_GIST", "Create Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "description": { + Type: "string", + Description: "Description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename for simple single-file gist creation", + }, + "content": { + Type: "string", + Description: "Content for simple single-file gist creation", + }, + "public": { + Type: "boolean", + Description: "Whether the gist is public", + Default: json.RawMessage(`false`), + }, }, + Required: []string{"filename", "content"}, }, - Required: []string{"filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - public, err := OptionalParam[bool](args, "public") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Public: github.Ptr(public), - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - createdGist, resp, err := client.Gists.Create(ctx, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to create gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + description, err := OptionalParam[string](args, "description") if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return utils.NewToolResultError(err.Error()), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create gist: %s", string(body))), nil, nil - } - minimalResponse := MinimalResponse{ - ID: createdGist.GetID(), - URL: createdGist.GetHTMLURL(), - } + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + public, err := OptionalParam[bool](args, "public") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } - return tool, handler + gist := &github.Gist{ + Files: files, + Public: github.Ptr(public), + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + createdGist, resp, err := client.Gists.Create(ctx, gist) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to create gist", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create gist", resp, body), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: createdGist.GetID(), + URL: createdGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // UpdateGist creates a tool to edit an existing gist -func UpdateGist(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "update_gist", - Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_UPDATE_GIST", "Update Gist"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "gist_id": { - Type: "string", - Description: "ID of the gist to update", - }, - "description": { - Type: "string", - Description: "Updated description of the gist", - }, - "filename": { - Type: "string", - Description: "Filename to update or create", - }, - "content": { - Type: "string", - Description: "Content for the file", +func UpdateGist(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataGists, + mcp.Tool{ + Name: "update_gist", + Description: t("TOOL_UPDATE_GIST_DESCRIPTION", "Update an existing gist"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_UPDATE_GIST", "Update Gist"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "gist_id": { + Type: "string", + Description: "ID of the gist to update", + }, + "description": { + Type: "string", + Description: "Updated description of the gist", + }, + "filename": { + Type: "string", + Description: "Filename to update or create", + }, + "content": { + Type: "string", + Description: "Content for the file", + }, }, + Required: []string{"gist_id", "filename", "content"}, }, - Required: []string{"gist_id", "filename", "content"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - gistID, err := RequiredParam[string](args, "gist_id") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - filename, err := RequiredParam[string](args, "filename") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - files := make(map[github.GistFilename]github.GistFile) - files[github.GistFilename(filename)] = github.GistFile{ - Filename: github.Ptr(filename), - Content: github.Ptr(content), - } - - gist := &github.Gist{ - Files: files, - Description: github.Ptr(description), - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) - if err != nil { - return nil, nil, fmt.Errorf("failed to update gist: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + gistID, err := RequiredParam[string](args, "gist_id") if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return utils.NewToolResultError(err.Error()), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update gist: %s", string(body))), nil, nil - } - minimalResponse := MinimalResponse{ - ID: updatedGist.GetID(), - URL: updatedGist.GetHTMLURL(), - } + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + filename, err := RequiredParam[string](args, "filename") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + files := make(map[github.GistFilename]github.GistFile) + files[github.GistFilename(filename)] = github.GistFile{ + Filename: github.Ptr(filename), + Content: github.Ptr(content), + } + + gist := &github.Gist{ + Files: files, + Description: github.Ptr(description), + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - return tool, handler + updatedGist, resp, err := client.Gists.Edit(ctx, gistID, gist) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to update gist", resp, err), nil, nil + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update gist", resp, body), nil, nil + } + + minimalResponse := MinimalResponse{ + ID: updatedGist.GetID(), + URL: updatedGist.GetHTMLURL(), + } + + r, err := json.Marshal(minimalResponse) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil + } + + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } diff --git a/pkg/github/gists_test.go b/pkg/github/gists_test.go index f0f62f420..0dd112afb 100644 --- a/pkg/github/gists_test.go +++ b/pkg/github/gists_test.go @@ -11,15 +11,14 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_ListGists(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := ListGists(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListGists(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -77,24 +76,18 @@ func Test_ListGists(t *testing.T) { }{ { name: "list authenticated user's gists", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetGists, - mockGists, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetGists: mockResponse(t, http.StatusOK, mockGists), + }), requestArgs: map[string]interface{}{}, expectError: false, expectedGists: mockGists, }, { name: "list specific user's gists", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetUsersGistsByUsername, - mockResponse(t, http.StatusOK, mockGists), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetUsersGistsByUsername: mockResponse(t, http.StatusOK, mockGists), + }), requestArgs: map[string]interface{}{ "username": "testuser", }, @@ -103,18 +96,15 @@ func Test_ListGists(t *testing.T) { }, { name: "list gists with pagination and since parameter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetGists, - expectQueryParams(t, map[string]string{ - "since": "2023-01-01T00:00:00Z", - "page": "2", - "per_page": "5", - }).andThen( - mockResponse(t, http.StatusOK, mockGists), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetGists: expectQueryParams(t, map[string]string{ + "since": "2023-01-01T00:00:00Z", + "page": "2", + "per_page": "5", + }).andThen( + mockResponse(t, http.StatusOK, mockGists), ), - ), + }), requestArgs: map[string]interface{}{ "since": "2023-01-01T00:00:00Z", "page": float64(2), @@ -125,12 +115,9 @@ func Test_ListGists(t *testing.T) { }, { name: "invalid since parameter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetGists, - mockGists, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetGists: mockResponse(t, http.StatusOK, mockGists), + }), requestArgs: map[string]interface{}{ "since": "invalid-date", }, @@ -139,15 +126,12 @@ func Test_ListGists(t *testing.T) { }, { name: "list gists fails with error", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetGists, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetGists: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) + }), + }), requestArgs: map[string]interface{}{}, expectError: true, expectedErrMsg: "failed to list gists", @@ -158,28 +142,27 @@ func Test_ListGists(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListGists(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -202,8 +185,8 @@ func Test_ListGists(t *testing.T) { func Test_GetGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := GetGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -243,12 +226,9 @@ func Test_GetGist(t *testing.T) { }{ { name: "Successful fetching different gist", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetGistsByGistId, - mockResponse(t, http.StatusOK, mockGist), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetGistsByGistID: mockResponse(t, http.StatusOK, mockGist), + }), requestArgs: map[string]interface{}{ "gist_id": "gist1", }, @@ -257,15 +237,12 @@ func Test_GetGist(t *testing.T) { }, { name: "gist_id parameter missing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetGistsByGistId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnprocessableEntity) - _, _ = w.Write([]byte(`{"message": "Invalid Request"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetGistsByGistID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"message": "Invalid Request"}`)) + }), + }), requestArgs: map[string]interface{}{}, expectError: true, expectedErrMsg: "missing required parameter: gist_id", @@ -276,28 +253,27 @@ func Test_GetGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) // Parse the result and get the text content if no error textContent := getTextResult(t, result) @@ -317,8 +293,8 @@ func Test_GetGist(t *testing.T) { func Test_CreateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := CreateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -363,12 +339,9 @@ func Test_CreateGist(t *testing.T) { }{ { name: "create gist successfully", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostGists, - mockResponse(t, http.StatusCreated, createdGist), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PostGists: mockResponse(t, http.StatusCreated, createdGist), + }), requestArgs: map[string]interface{}{ "filename": "test.go", "content": "package main\n\nfunc main() {\n\tfmt.Println(\"Hello, Gist!\")\n}", @@ -380,7 +353,7 @@ func Test_CreateGist(t *testing.T) { }, { name: "missing required filename", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "content": "test content", "description": "Test Gist", @@ -390,7 +363,7 @@ func Test_CreateGist(t *testing.T) { }, { name: "missing required content", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "filename": "test.go", "description": "Test Gist", @@ -400,15 +373,12 @@ func Test_CreateGist(t *testing.T) { }, { name: "api returns error", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostGists, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PostGists: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Requires authentication"}`)) + }), + }), requestArgs: map[string]interface{}{ "filename": "test.go", "content": "package main", @@ -423,28 +393,27 @@ func Test_CreateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content @@ -462,8 +431,8 @@ func Test_CreateGist(t *testing.T) { func Test_UpdateGist(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - tool, _ := UpdateGist(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdateGist(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -509,12 +478,9 @@ func Test_UpdateGist(t *testing.T) { }{ { name: "update gist successfully", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PatchGistsByGistId, - mockResponse(t, http.StatusOK, updatedGist), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PatchGistsByGistID: mockResponse(t, http.StatusOK, updatedGist), + }), requestArgs: map[string]interface{}{ "gist_id": "existing-gist-id", "filename": "updated.go", @@ -526,7 +492,7 @@ func Test_UpdateGist(t *testing.T) { }, { name: "missing required gist_id", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "filename": "updated.go", "content": "updated content", @@ -537,7 +503,7 @@ func Test_UpdateGist(t *testing.T) { }, { name: "missing required filename", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "gist_id": "existing-gist-id", "content": "updated content", @@ -548,7 +514,7 @@ func Test_UpdateGist(t *testing.T) { }, { name: "missing required content", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "gist_id": "existing-gist-id", "filename": "updated.go", @@ -559,15 +525,12 @@ func Test_UpdateGist(t *testing.T) { }, { name: "api returns error", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PatchGistsByGistId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PatchGistsByGistID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "gist_id": "nonexistent-gist-id", "filename": "updated.go", @@ -583,28 +546,27 @@ func Test_UpdateGist(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdateGist(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) // Verify results if tc.expectError { - if err != nil { - assert.Contains(t, err.Error(), tc.expectedErrMsg) - } else { - // For errors returned as part of the result, not as an error - assert.NotNil(t, result) - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - } + require.True(t, result.IsError) + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) return } - require.NoError(t, err) + require.False(t, result.IsError) assert.NotNil(t, result) // Parse the result and get the text content diff --git a/pkg/github/git.go b/pkg/github/git.go index c2a839132..7b93c3675 100644 --- a/pkg/github/git.go +++ b/pkg/github/git.go @@ -7,6 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -37,45 +38,45 @@ type TreeResponse struct { } // GetRepositoryTree creates a tool to get the tree structure of a GitHub repository. -func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_repository_tree", - Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "tree_sha": { - Type: "string", - Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", - }, - "recursive": { - Type: "boolean", - Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", - Default: json.RawMessage(`false`), - }, - "path_filter": { - Type: "string", - Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", +func GetRepositoryTree(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataGit, + mcp.Tool{ + Name: "get_repository_tree", + Description: t("TOOL_GET_REPOSITORY_TREE_DESCRIPTION", "Get the tree structure (files and directories) of a GitHub repository at a specific ref or SHA"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_REPOSITORY_TREE_USER_TITLE", "Get repository tree"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "tree_sha": { + Type: "string", + Description: "The SHA1 value or ref (branch or tag) name of the tree. Defaults to the repository's default branch", + }, + "recursive": { + Type: "boolean", + Description: "Setting this parameter to true returns the objects or subtrees referenced by the tree. Default is false", + Default: json.RawMessage(`false`), + }, + "path_filter": { + Type: "string", + Description: "Optional path prefix to filter the tree results (e.g., 'src/' to only show files in the src directory)", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any]( - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -97,7 +98,7 @@ func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFu return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError("failed to get GitHub client"), nil, nil } @@ -171,6 +172,4 @@ func GetRepositoryTree(getClient GetClientFn, t translations.TranslationHelperFu return utils.NewToolResultText(string(r)), nil, nil }, ) - - return tool, handler } diff --git a/pkg/github/git_test.go b/pkg/github/git_test.go index 66cbccd6e..d60aed092 100644 --- a/pkg/github/git_test.go +++ b/pkg/github/git_test.go @@ -11,22 +11,20 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetRepositoryTree(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + toolDef := GetRepositoryTree(translations.NullTranslationHelper) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_repository_tree", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "get_repository_tree", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // Type assert the InputSchema to access its properties - inputSchema, ok := tool.InputSchema.(*jsonschema.Schema) + inputSchema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "expected InputSchema to be *jsonschema.Schema") assert.Contains(t, inputSchema.Properties, "owner") assert.Contains(t, inputSchema.Properties, "repo") @@ -71,16 +69,10 @@ func Test_GetRepositoryTree(t *testing.T) { }{ { name: "successfully get repository tree", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), + GetReposGitTreesByOwnerByRepoByTree: mockResponse(t, http.StatusOK, mockTree), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -88,16 +80,10 @@ func Test_GetRepositoryTree(t *testing.T) { }, { name: "successfully get repository tree with path filter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), + GetReposGitTreesByOwnerByRepoByTree: mockResponse(t, http.StatusOK, mockTree), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -106,15 +92,12 @@ func Test_GetRepositoryTree(t *testing.T) { }, { name: "repository not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "nonexistent", @@ -124,19 +107,13 @@ func Test_GetRepositoryTree(t *testing.T) { }, { name: "tree not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposByOwnerByRepo: mockResponse(t, http.StatusOK, mockRepo), + GetReposGitTreesByOwnerByRepoByTree: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -148,12 +125,16 @@ func Test_GetRepositoryTree(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper) + client := github.NewClient(tc.mockedClient) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create the tool request request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 9c55ba841..56a236660 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -1,15 +1,148 @@ package github import ( + "bytes" "encoding/json" + "io" "net/http" + "net/url" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +// GitHub API endpoint patterns for testing +// These constants define the URL patterns used in HTTP mocking for tests +const ( + // User endpoints + GetUser = "GET /user" + GetUserStarred = "GET /user/starred" + GetUsersGistsByUsername = "GET /users/{username}/gists" + GetUsersStarredByUsername = "GET /users/{username}/starred" + PutUserStarredByOwnerByRepo = "PUT /user/starred/{owner}/{repo}" + DeleteUserStarredByOwnerByRepo = "DELETE /user/starred/{owner}/{repo}" + + // Repository endpoints + GetReposByOwnerByRepo = "GET /repos/{owner}/{repo}" + GetReposBranchesByOwnerByRepo = "GET /repos/{owner}/{repo}/branches" + GetReposTagsByOwnerByRepo = "GET /repos/{owner}/{repo}/tags" + GetReposCommitsByOwnerByRepo = "GET /repos/{owner}/{repo}/commits" + GetReposCommitsByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}" + GetReposContentsByOwnerByRepoByPath = "GET /repos/{owner}/{repo}/contents/{path}" + PutReposContentsByOwnerByRepoByPath = "PUT /repos/{owner}/{repo}/contents/{path}" + PostReposForksByOwnerByRepo = "POST /repos/{owner}/{repo}/forks" + GetReposSubscriptionByOwnerByRepo = "GET /repos/{owner}/{repo}/subscription" + PutReposSubscriptionByOwnerByRepo = "PUT /repos/{owner}/{repo}/subscription" + DeleteReposSubscriptionByOwnerByRepo = "DELETE /repos/{owner}/{repo}/subscription" + + // Git endpoints + GetReposGitTreesByOwnerByRepoByTree = "GET /repos/{owner}/{repo}/git/trees/{tree}" + GetReposGitRefByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/git/ref/{ref}" + PostReposGitRefsByOwnerByRepo = "POST /repos/{owner}/{repo}/git/refs" + PatchReposGitRefsByOwnerByRepoByRef = "PATCH /repos/{owner}/{repo}/git/refs/{ref}" + GetReposGitCommitsByOwnerByRepoByCommitSHA = "GET /repos/{owner}/{repo}/git/commits/{commit_sha}" + PostReposGitCommitsByOwnerByRepo = "POST /repos/{owner}/{repo}/git/commits" + GetReposGitTagsByOwnerByRepoByTagSHA = "GET /repos/{owner}/{repo}/git/tags/{tag_sha}" + PostReposGitTreesByOwnerByRepo = "POST /repos/{owner}/{repo}/git/trees" + GetReposCommitsStatusByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}/status" + GetReposCommitsStatusesByOwnerByRepoByRef = "GET /repos/{owner}/{repo}/commits/{ref}/statuses" + + // Issues endpoints + GetReposIssuesByOwnerByRepoByIssueNumber = "GET /repos/{owner}/{repo}/issues/{issue_number}" + GetReposIssuesCommentsByOwnerByRepoByIssueNumber = "GET /repos/{owner}/{repo}/issues/{issue_number}/comments" + PostReposIssuesByOwnerByRepo = "POST /repos/{owner}/{repo}/issues" + PostReposIssuesCommentsByOwnerByRepoByIssueNumber = "POST /repos/{owner}/{repo}/issues/{issue_number}/comments" + PatchReposIssuesByOwnerByRepoByIssueNumber = "PATCH /repos/{owner}/{repo}/issues/{issue_number}" + GetReposIssuesSubIssuesByOwnerByRepoByIssueNumber = "GET /repos/{owner}/{repo}/issues/{issue_number}/sub_issues" + PostReposIssuesSubIssuesByOwnerByRepoByIssueNumber = "POST /repos/{owner}/{repo}/issues/{issue_number}/sub_issues" + DeleteReposIssuesSubIssueByOwnerByRepoByIssueNumber = "DELETE /repos/{owner}/{repo}/issues/{issue_number}/sub_issues" + PatchReposIssuesSubIssuesPriorityByOwnerByRepoByIssueNumber = "PATCH /repos/{owner}/{repo}/issues/{issue_number}/sub_issues/priority" + + // Pull request endpoints + GetReposPullsByOwnerByRepo = "GET /repos/{owner}/{repo}/pulls" + GetReposPullsByOwnerByRepoByPullNumber = "GET /repos/{owner}/{repo}/pulls/{pull_number}" + GetReposPullsFilesByOwnerByRepoByPullNumber = "GET /repos/{owner}/{repo}/pulls/{pull_number}/files" + GetReposPullsReviewsByOwnerByRepoByPullNumber = "GET /repos/{owner}/{repo}/pulls/{pull_number}/reviews" + PostReposPullsByOwnerByRepo = "POST /repos/{owner}/{repo}/pulls" + PatchReposPullsByOwnerByRepoByPullNumber = "PATCH /repos/{owner}/{repo}/pulls/{pull_number}" + PutReposPullsMergeByOwnerByRepoByPullNumber = "PUT /repos/{owner}/{repo}/pulls/{pull_number}/merge" + PutReposPullsUpdateBranchByOwnerByRepoByPullNumber = "PUT /repos/{owner}/{repo}/pulls/{pull_number}/update-branch" + PostReposPullsRequestedReviewersByOwnerByRepoByPullNumber = "POST /repos/{owner}/{repo}/pulls/{pull_number}/requested_reviewers" + + // Notifications endpoints + GetNotifications = "GET /notifications" + PutNotifications = "PUT /notifications" + GetReposNotificationsByOwnerByRepo = "GET /repos/{owner}/{repo}/notifications" + PutReposNotificationsByOwnerByRepo = "PUT /repos/{owner}/{repo}/notifications" + GetNotificationsThreadsByThreadID = "GET /notifications/threads/{thread_id}" + PatchNotificationsThreadsByThreadID = "PATCH /notifications/threads/{thread_id}" + DeleteNotificationsThreadsByThreadID = "DELETE /notifications/threads/{thread_id}" + PutNotificationsThreadsSubscriptionByThreadID = "PUT /notifications/threads/{thread_id}/subscription" + DeleteNotificationsThreadsSubscriptionByThreadID = "DELETE /notifications/threads/{thread_id}/subscription" + + // Gists endpoints + GetGists = "GET /gists" + GetGistsByGistID = "GET /gists/{gist_id}" + PostGists = "POST /gists" + PatchGistsByGistID = "PATCH /gists/{gist_id}" + + // Releases endpoints + GetReposReleasesByOwnerByRepo = "GET /repos/{owner}/{repo}/releases" + GetReposReleasesLatestByOwnerByRepo = "GET /repos/{owner}/{repo}/releases/latest" + GetReposReleasesTagsByOwnerByRepoByTag = "GET /repos/{owner}/{repo}/releases/tags/{tag}" + + // Code scanning endpoints + GetReposCodeScanningAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/code-scanning/alerts" + GetReposCodeScanningAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/code-scanning/alerts/{alert_number}" + + // Secret scanning endpoints + GetReposSecretScanningAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/secret-scanning/alerts" //nolint:gosec // False positive - this is an API endpoint pattern, not a credential + GetReposSecretScanningAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/secret-scanning/alerts/{alert_number}" //nolint:gosec // False positive - this is an API endpoint pattern, not a credential + + // Dependabot endpoints + GetReposDependabotAlertsByOwnerByRepo = "GET /repos/{owner}/{repo}/dependabot/alerts" + GetReposDependabotAlertsByOwnerByRepoByAlertNumber = "GET /repos/{owner}/{repo}/dependabot/alerts/{alert_number}" + + // Security advisories endpoints + GetAdvisories = "GET /advisories" + GetAdvisoriesByGhsaID = "GET /advisories/{ghsa_id}" + GetReposSecurityAdvisoriesByOwnerByRepo = "GET /repos/{owner}/{repo}/security-advisories" + GetOrgsSecurityAdvisoriesByOrg = "GET /orgs/{org}/security-advisories" + + // Actions endpoints + GetReposActionsWorkflowsByOwnerByRepo = "GET /repos/{owner}/{repo}/actions/workflows" + GetReposActionsWorkflowsByOwnerByRepoByWorkflowID = "GET /repos/{owner}/{repo}/actions/workflows/{workflow_id}" + PostReposActionsWorkflowsDispatchesByOwnerByRepoByWorkflowID = "POST /repos/{owner}/{repo}/actions/workflows/{workflow_id}/dispatches" + GetReposActionsWorkflowsRunsByOwnerByRepoByWorkflowID = "GET /repos/{owner}/{repo}/actions/workflows/{workflow_id}/runs" + GetReposActionsRunsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}" + GetReposActionsRunsLogsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/logs" + GetReposActionsRunsJobsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/jobs" + GetReposActionsRunsArtifactsByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/artifacts" + GetReposActionsRunsTimingByOwnerByRepoByRunID = "GET /repos/{owner}/{repo}/actions/runs/{run_id}/timing" + PostReposActionsRunsRerunByOwnerByRepoByRunID = "POST /repos/{owner}/{repo}/actions/runs/{run_id}/rerun" + PostReposActionsRunsRerunFailedJobsByOwnerByRepoByRunID = "POST /repos/{owner}/{repo}/actions/runs/{run_id}/rerun-failed-jobs" + PostReposActionsRunsCancelByOwnerByRepoByRunID = "POST /repos/{owner}/{repo}/actions/runs/{run_id}/cancel" + GetReposActionsJobsLogsByOwnerByRepoByJobID = "GET /repos/{owner}/{repo}/actions/jobs/{job_id}/logs" + DeleteReposActionsRunsLogsByOwnerByRepoByRunID = "DELETE /repos/{owner}/{repo}/actions/runs/{run_id}/logs" + + // Search endpoints + GetSearchCode = "GET /search/code" + GetSearchIssues = "GET /search/issues" + GetSearchRepositories = "GET /search/repositories" + GetSearchUsers = "GET /search/users" + + // Raw content endpoints (used for GitHub raw content API, not standard API) + // These are used with the raw content client that interacts with raw.githubusercontent.com + GetRawReposContentsByOwnerByRepoByPath = "GET /{owner}/{repo}/HEAD/{path:.*}" + GetRawReposContentsByOwnerByRepoByBranchByPath = "GET /{owner}/{repo}/refs/heads/{branch}/{path:.*}" + GetRawReposContentsByOwnerByRepoByTagByPath = "GET /{owner}/{repo}/refs/tags/{tag}/{path:.*}" + GetRawReposContentsByOwnerByRepoBySHAByPath = "GET /{owner}/{repo}/{sha}/{path:.*}" +) + type expectations struct { path string queryParams map[string]string @@ -216,7 +349,7 @@ func TestOptionalParamOK(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Test with string type assertion if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { - val, ok, err := OptionalParamOK[string, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[string](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -231,7 +364,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with bool type assertion if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { - val, ok, err := OptionalParamOK[bool, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[bool](tc.args, tc.paramName) if tc.expectError { require.Error(t, err) assert.Contains(t, err.Error(), tc.errorMsg) @@ -246,7 +379,7 @@ func TestOptionalParamOK(t *testing.T) { // Test with float64 type assertion (for number case) if _, isFloat := tc.expectedVal.(float64); isFloat { - val, ok, err := OptionalParamOK[float64, map[string]any](tc.args, tc.paramName) + val, ok, err := OptionalParamOK[float64](tc.args, tc.paramName) if tc.expectError { // This case shouldn't happen for float64 in the defined tests require.Fail(t, "Unexpected error case for float64") @@ -272,3 +405,257 @@ func getResourceResult(t *testing.T, result *mcp.CallToolResult) *mcp.ResourceCo require.IsType(t, &mcp.ResourceContents{}, resource.Resource) return resource.Resource } + +// MockRoundTripper is a mock HTTP transport using testify/mock +type MockRoundTripper struct { + mock.Mock + handlers map[string]http.HandlerFunc +} + +// NewMockRoundTripper creates a new mock round tripper +func NewMockRoundTripper() *MockRoundTripper { + return &MockRoundTripper{ + handlers: make(map[string]http.HandlerFunc), + } +} + +// RoundTrip implements the http.RoundTripper interface +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Normalize the request path and method for matching + key := req.Method + " " + req.URL.Path + + // Check if we have a specific handler for this request + if handler, ok := m.handlers[key]; ok { + // Use httptest.ResponseRecorder to capture the handler's response + recorder := &responseRecorder{ + header: make(http.Header), + body: &bytes.Buffer{}, + } + handler(recorder, req) + + return &http.Response{ + StatusCode: recorder.statusCode, + Header: recorder.header, + Body: io.NopCloser(bytes.NewReader(recorder.body.Bytes())), + Request: req, + }, nil + } + + // Fall back to mock.Mock assertions if defined + args := m.Called(req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*http.Response), args.Error(1) +} + +// On registers an expectation using testify/mock +func (m *MockRoundTripper) OnRequest(method, path string, handler http.HandlerFunc) *MockRoundTripper { + key := method + " " + path + m.handlers[key] = handler + return m +} + +// NewMockHTTPClient creates an HTTP client with a mock transport +func NewMockHTTPClient() (*http.Client, *MockRoundTripper) { + transport := NewMockRoundTripper() + client := &http.Client{Transport: transport} + return client, transport +} + +// responseRecorder is a simple response recorder for the mock transport +type responseRecorder struct { + statusCode int + header http.Header + body *bytes.Buffer +} + +func (r *responseRecorder) Header() http.Header { + return r.header +} + +func (r *responseRecorder) Write(data []byte) (int, error) { + if r.statusCode == 0 { + r.statusCode = http.StatusOK + } + return r.body.Write(data) +} + +func (r *responseRecorder) WriteHeader(statusCode int) { + r.statusCode = statusCode +} + +// matchPath checks if a request path matches a pattern (supports simple wildcards) +func matchPath(pattern, path string) bool { + // Simple exact match for now + if pattern == path { + return true + } + + // Support for path parameters like /repos/{owner}/{repo}/issues/{issue_number} + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + // Handle patterns with wildcard path like {path:.*} + if len(patternParts) > 0 { + lastPart := patternParts[len(patternParts)-1] + if strings.HasPrefix(lastPart, "{") && strings.Contains(lastPart, ":") && strings.HasSuffix(lastPart, "}") { + // This is a wildcard pattern like {path:.*} + // Check if all parts before the wildcard match + if len(pathParts) < len(patternParts)-1 { + return false + } + for i := 0; i < len(patternParts)-1; i++ { + if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { + continue // Path parameter matches anything + } + if patternParts[i] != pathParts[i] { + return false + } + } + return true + } + } + + if len(patternParts) != len(pathParts) { + return false + } + + for i := range patternParts { + // Check if this is a path parameter (enclosed in {}) + if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { + continue // Path parameters match anything + } + if patternParts[i] != pathParts[i] { + return false + } + } + + return true +} + +// executeHandler executes an HTTP handler and returns the response +func executeHandler(handler http.HandlerFunc, req *http.Request) *http.Response { + recorder := &responseRecorder{ + header: make(http.Header), + body: &bytes.Buffer{}, + } + handler(recorder, req) + + return &http.Response{ + StatusCode: recorder.statusCode, + Header: recorder.header, + Body: io.NopCloser(bytes.NewReader(recorder.body.Bytes())), + Request: req, + } +} + +// MockHTTPClientWithHandler creates an HTTP client with a single handler function +func MockHTTPClientWithHandler(handler http.HandlerFunc) *http.Client { + handlers := map[string]http.HandlerFunc{ + "": handler, // Empty key acts as catch-all + } + return MockHTTPClientWithHandlers(handlers) +} + +// MockHTTPClientWithHandlers creates an HTTP client with multiple handlers for different paths +func MockHTTPClientWithHandlers(handlers map[string]http.HandlerFunc) *http.Client { + transport := &multiHandlerTransport{handlers: handlers} + return &http.Client{Transport: transport} +} + +type multiHandlerTransport struct { + handlers map[string]http.HandlerFunc +} + +func (m *multiHandlerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Check for catch-all handler + if handler, ok := m.handlers[""]; ok { + return executeHandler(handler, req), nil + } + + // Try to find a handler for this request + key := req.Method + " " + req.URL.Path + + // First try exact match + if handler, ok := m.handlers[key]; ok { + return executeHandler(handler, req), nil + } + + // Then try pattern matching, prioritizing patterns without wildcards + // This is important because wildcard patterns like /{owner}/{repo}/{sha}/{path:.*} + // can incorrectly match API paths like /repos/owner/repo/pulls/42 + var wildcardPattern string + var wildcardHandler http.HandlerFunc + + for pattern, handler := range m.handlers { + if pattern == "" { + continue // Skip catch-all + } + parts := strings.SplitN(pattern, " ", 2) + if len(parts) != 2 { + continue + } + method, pathPattern := parts[0], parts[1] + if req.Method != method { + continue + } + + // Check if this pattern contains a wildcard like {path:.*} + isWildcard := strings.Contains(pathPattern, ":.*}") + + if matchPath(pathPattern, req.URL.Path) { + if isWildcard { + // Save wildcard match for later, prefer non-wildcard patterns + wildcardPattern = pattern + wildcardHandler = handler + } else { + // Non-wildcard pattern takes priority + return executeHandler(handler, req), nil + } + } + } + + // If we found a wildcard match but no specific match, use it + if wildcardPattern != "" && wildcardHandler != nil { + return executeHandler(wildcardHandler, req), nil + } + + // No handler found + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(bytes.NewReader([]byte("not found"))), + Request: req, + }, nil +} + +// extractPathParams extracts path parameters from a URL path given a pattern +func extractPathParams(pattern, path string) map[string]string { + params := make(map[string]string) + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + if len(patternParts) != len(pathParts) { + return params + } + + for i := range patternParts { + if strings.HasPrefix(patternParts[i], "{") && strings.HasSuffix(patternParts[i], "}") { + paramName := strings.Trim(patternParts[i], "{}") + params[paramName] = pathParts[i] + } + } + + return params +} + +// ParseRequestPath is a helper to extract path parameters +func ParseRequestPath(t *testing.T, req *http.Request, pattern string) url.Values { + t.Helper() + params := extractPathParams(pattern, req.URL.Path) + values := url.Values{} + for k, v := range params { + values.Set(k, v) + } + return values +} diff --git a/pkg/github/inventory.go b/pkg/github/inventory.go new file mode 100644 index 000000000..38c936d86 --- /dev/null +++ b/pkg/github/inventory.go @@ -0,0 +1,18 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/translations" +) + +// NewInventory creates an Inventory with all available tools, resources, and prompts. +// Tools, resources, and prompts are self-describing with their toolset metadata embedded. +// This function is stateless - no dependencies are captured. +// Handlers are generated on-demand during registration via RegisterAll(ctx, server, deps). +// The "default" keyword in WithToolsets will expand to toolsets marked with Default: true. +func NewInventory(t translations.TranslationHelperFunc) *inventory.Builder { + return inventory.NewBuilder(). + SetTools(AllTools(t)). + SetResources(AllResources(t)). + SetPrompts(AllPrompts(t)) +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index ec83e4efa..f06dc2d9d 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -10,7 +10,9 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" @@ -229,7 +231,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // IssueRead creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func IssueRead(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -261,7 +263,9 @@ Options are: } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "issue_read", Description: t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -270,7 +274,7 @@ Options are: }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { method, err := RequiredParam[string](args, "method") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -294,25 +298,25 @@ Options are: return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - gqlClient, err := getGQLClient(ctx) + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil } switch method { case "get": - result, err := GetIssue(ctx, client, cache, owner, repo, issueNumber, flags) + result, err := GetIssue(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, cache, owner, repo, issueNumber, pagination, flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_sub_issues": - result, err := GetSubIssues(ctx, client, cache, owner, repo, issueNumber, pagination, flags) + result, err := GetSubIssues(ctx, client, deps.GetRepoAccessCache(), owner, repo, issueNumber, pagination, deps.GetFlags()) return result, nil, err case "get_labels": result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) @@ -320,7 +324,7 @@ Options are: default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil } - } + }) } func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { @@ -335,7 +339,7 @@ func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAc if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get issue: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get issue", resp, body), nil } if flags.LockdownMode { @@ -391,7 +395,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdow if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get issue comments", resp, body), nil } if flags.LockdownMode { if cache == nil { @@ -450,7 +454,7 @@ func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.Re if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to list sub-issues: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list sub-issues", resp, body), nil } if featureFlags.LockdownMode { @@ -540,8 +544,10 @@ func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, } // ListIssueTypes creates a tool to list defined issue types for an organization. This can be used to understand supported issue type values for creating or updating issues. -func ListIssueTypes(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListIssueTypes(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "list_issue_types", Description: t("TOOL_LIST_ISSUE_TYPES_FOR_ORG", "List supported issue types for repository owner (organization)."), Annotations: &mcp.ToolAnnotations{ @@ -559,13 +565,13 @@ func ListIssueTypes(getClient GetClientFn, t translations.TranslationHelperFunc) Required: []string{"owner"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -580,7 +586,7 @@ func ListIssueTypes(getClient GetClientFn, t translations.TranslationHelperFunc) if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list issue types: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list issue types", resp, body), nil, nil } r, err := json.Marshal(issueTypes) @@ -589,12 +595,14 @@ func ListIssueTypes(getClient GetClientFn, t translations.TranslationHelperFunc) } return utils.NewToolResultText(string(r)), nil, nil - } + }) } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func AddIssueComment(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "add_issue_comment", Description: t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to a specific issue in a GitHub repository. Use this tool to add comments to pull requests as well (in this case pass pull request number as issue_number), but only if user is not asking specifically to add review comments."), Annotations: &mcp.ToolAnnotations{ @@ -624,7 +632,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc Required: []string{"owner", "repo", "issue_number", "body"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -646,7 +654,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc Body: github.Ptr(body), } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -661,7 +669,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create comment: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create comment", resp, body), nil, nil } r, err := json.Marshal(createdComment) @@ -670,12 +678,14 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc } return utils.NewToolResultText(string(r)), nil, nil - } + }) } // SubIssueWrite creates a tool to add a sub-issue to a parent issue. -func SubIssueWrite(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func SubIssueWrite(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "sub_issue_write", Description: t("TOOL_SUB_ISSUE_WRITE_DESCRIPTION", "Add a sub-issue to a parent issue in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -726,7 +736,7 @@ Options are: Required: []string{"method", "owner", "repo", "issue_number", "sub_issue_id"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { method, err := RequiredParam[string](args, "method") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -761,7 +771,7 @@ Options are: return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -781,7 +791,7 @@ Options are: default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil } - } + }) } func AddSubIssue(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, subIssueID int, replaceParent bool) (*mcp.CallToolResult, error) { @@ -806,7 +816,7 @@ func AddSubIssue(ctx context.Context, client *github.Client, owner string, repo if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to add sub-issue: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to add sub-issue", resp, body), nil } r, err := json.Marshal(subIssue) @@ -838,7 +848,7 @@ func RemoveSubIssue(ctx context.Context, client *github.Client, owner string, re if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to remove sub-issue: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to remove sub-issue", resp, body), nil } r, err := json.Marshal(subIssue) @@ -887,7 +897,7 @@ func ReprioritizeSubIssue(ctx context.Context, client *github.Client, owner stri if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to reprioritize sub-issue: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to reprioritize sub-issue", resp, body), nil } r, err := json.Marshal(subIssue) @@ -899,7 +909,7 @@ func ReprioritizeSubIssue(ctx context.Context, client *github.Client, owner stri } // SearchIssues creates a tool to search for issues. -func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchIssues(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -942,7 +952,9 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "search_issues", Description: t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues in GitHub repositories using issues search syntax already scoped to is:issue"), Annotations: &mcp.ToolAnnotations{ @@ -951,15 +963,17 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - result, err := searchHandler(ctx, getClient, args, "issue", "failed to search issues") + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + result, err := searchHandler(ctx, deps.GetClient, args, "issue", "failed to search issues") return result, nil, err - } + }) } // IssueWrite creates a tool to create a new or update an existing issue in a GitHub repository. -func IssueWrite(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func IssueWrite(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "issue_write", Description: t("TOOL_ISSUE_WRITE_DESCRIPTION", "Create a new or update an existing issue in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -1038,7 +1052,7 @@ Options are: Required: []string{"method", "owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { method, err := RequiredParam[string](args, "method") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1111,12 +1125,12 @@ Options are: return utils.NewToolResultError("duplicate_of can only be used when state_reason is 'duplicate'"), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } - gqlClient, err := getGQLClient(ctx) + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GraphQL client", err), nil, nil } @@ -1135,7 +1149,7 @@ Options are: default: return utils.NewToolResultError("invalid method, must be either 'create' or 'update'"), nil, nil } - } + }) } func CreateIssue(ctx context.Context, client *github.Client, owner string, repo string, title string, body string, assignees []string, labels []string, milestoneNum int, issueType string) (*mcp.CallToolResult, error) { @@ -1170,7 +1184,7 @@ func CreateIssue(ctx context.Context, client *github.Client, owner string, repo if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create issue: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create issue", resp, body), nil } // Return minimal response with just essential information @@ -1231,7 +1245,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to update issue: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update issue", resp, body), nil } // Use GraphQL API for state updates @@ -1313,7 +1327,7 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4 } // ListIssues creates a tool to list and filter repository issues -func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func ListIssues(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1356,7 +1370,9 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun } WithCursorPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "list_issues", Description: t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter."), Annotations: &mcp.ToolAnnotations{ @@ -1365,7 +1381,7 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1469,7 +1485,7 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun paginationParams.First = &defaultFirst } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } @@ -1544,7 +1560,7 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun return nil, nil, fmt.Errorf("failed to marshal issues: %w", err) } return utils.NewToolResultText(string(out)), nil, nil - } + }) } // mvpDescription is an MVP idea for generating tool descriptions from structured data in a shared format. @@ -1577,7 +1593,7 @@ func (d *mvpDescription) String() string { return sb.String() } -func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func AssignCopilotToIssue(t translations.TranslationHelperFunc) inventory.ServerTool { description := mvpDescription{ summary: "Assign Copilot to a specific issue in a GitHub repository.", outcomes: []string{ @@ -1588,9 +1604,12 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio }, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ Name: "assign_copilot_to_issue", Description: t("TOOL_ASSIGN_COPILOT_TO_ISSUE_DESCRIPTION", description.String()), + Icons: octicons.Icons("copilot"), Annotations: &mcp.ToolAnnotations{ Title: t("TOOL_ASSIGN_COPILOT_TO_ISSUE_USER_TITLE", "Assign Copilot to issue"), ReadOnlyHint: false, @@ -1615,7 +1634,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio Required: []string{"owner", "repo", "issueNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { var params struct { Owner string Repo string @@ -1625,7 +1644,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -1740,7 +1759,7 @@ func AssignCopilotToIssue(getGQLClient GetGQLClientFn, t translations.Translatio } return utils.NewToolResultText("successfully assigned copilot to issue"), nil, nil - } + }) } type ReplaceActorsForAssignableInput struct { @@ -1772,8 +1791,10 @@ func parseISOTimestamp(timestamp string) (time.Time, error) { return time.Time{}, fmt.Errorf("invalid ISO 8601 timestamp: %s (supported formats: YYYY-MM-DDThh:mm:ssZ or YYYY-MM-DD)", timestamp) } -func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, mcp.PromptHandler) { - return mcp.Prompt{ +func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) inventory.ServerPrompt { + return inventory.NewServerPrompt( + ToolsetMetadataIssues, + mcp.Prompt{ Name: "AssignCodingAgent", Description: t("PROMPT_ASSIGN_CODING_AGENT_DESCRIPTION", "Assign GitHub Coding Agent to multiple tasks in a GitHub repository."), Arguments: []*mcp.PromptArgument{ @@ -1783,7 +1804,8 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, Required: true, }, }, - }, func(_ context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + }, + func(_ context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { repo := request.Params.Arguments["repo"] messages := []*mcp.PromptMessage{ @@ -1827,5 +1849,6 @@ func AssignCodingAgentPrompt(t translations.TranslationHelperFunc) (mcp.Prompt, return &mcp.GetPromptResult{ Messages: messages, }, nil - } + }, + ) } diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index c4454624b..b810cede3 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -122,9 +122,8 @@ func toString(v any) string { func Test_GetIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - defaultGQLClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), repoAccessCache, translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -327,14 +326,20 @@ func Test_GetIssue(t *testing.T) { gqlClient = githubv4.NewClient(tc.gqlHTTPClient) cache = stubRepoAccessCache(gqlClient, 15*time.Minute) } else { - gqlClient = defaultGQLClient + gqlClient = githubv4.NewClient(nil) } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectHandlerError { require.Error(t, err) @@ -368,8 +373,8 @@ func Test_GetIssue(t *testing.T) { func Test_AddIssueComment(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := AddIssueComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AddIssueComment(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "add_issue_comment", tool.Name) @@ -442,13 +447,16 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := AddIssueComment(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -483,8 +491,8 @@ func Test_AddIssueComment(t *testing.T) { func Test_SearchIssues(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchIssues(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_issues", tool.Name) @@ -773,13 +781,16 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchIssues(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -817,9 +828,8 @@ func Test_SearchIssues(t *testing.T) { func Test_CreateIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - mockGQLClient := githubv4.NewClient(nil) - tool, _ := IssueWrite(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper) + serverTool := IssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_write", tool.Name) @@ -942,13 +952,17 @@ func Test_CreateIssue(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueWrite(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -979,8 +993,8 @@ func Test_CreateIssue(t *testing.T) { func Test_ListIssues(t *testing.T) { // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := ListIssues(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListIssues(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_issues", tool.Name) @@ -1254,10 +1268,13 @@ func Test_ListIssues(t *testing.T) { } gqlClient := githubv4.NewClient(httpClient) - _, handler := ListIssues(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: gqlClient, + } + handler := serverTool.Handler(deps) req := createMCPRequest(tc.reqParams) - res, _, err := handler(context.Background(), &req, tc.reqParams) + res, err := handler(ContextWithDeps(context.Background(), deps), &req) text := getTextResult(t, res).Text if tc.expectError { @@ -1300,9 +1317,8 @@ func Test_ListIssues(t *testing.T) { func Test_UpdateIssue(t *testing.T) { // Verify tool definition - mockClient := github.NewClient(nil) - mockGQLClient := githubv4.NewClient(nil) - tool, _ := IssueWrite(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper) + serverTool := IssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_write", tool.Name) @@ -1753,13 +1769,17 @@ func Test_UpdateIssue(t *testing.T) { // Setup clients with mocks restClient := github.NewClient(tc.mockedRESTClient) gqlClient := githubv4.NewClient(tc.mockedGQLClient) - _, handler := IssueWrite(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError || tc.expectedErrMsg != "" { @@ -1844,9 +1864,8 @@ func Test_ParseISOTimestamp(t *testing.T) { func Test_GetIssueComments(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1997,13 +2016,19 @@ func Test_GetIssueComments(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 15*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2035,9 +2060,8 @@ func Test_GetIssueLabels(t *testing.T) { t.Parallel() // Verify tool definition - mockGQClient := githubv4.NewClient(nil) - mockClient := github.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -2112,10 +2136,16 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -2137,8 +2167,8 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := AssignCopilotToIssue(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AssignCopilotToIssue(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "assign_copilot_to_issue", tool.Name) @@ -2530,13 +2560,16 @@ func TestAssignCopilotToIssue(t *testing.T) { t.Parallel() // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := AssignCopilotToIssue(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2555,8 +2588,8 @@ func TestAssignCopilotToIssue(t *testing.T) { func Test_AddSubIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SubIssueWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SubIssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "sub_issue_write", tool.Name) @@ -2758,13 +2791,16 @@ func Test_AddSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SubIssueWrite(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2801,9 +2837,8 @@ func Test_AddSubIssue(t *testing.T) { func Test_GetSubIssues(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := IssueRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -3000,13 +3035,19 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + RepoAccessCache: stubRepoAccessCache(gqlClient, 15*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3052,8 +3093,8 @@ func Test_GetSubIssues(t *testing.T) { func Test_RemoveSubIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SubIssueWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SubIssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "sub_issue_write", tool.Name) @@ -3234,13 +3275,16 @@ func Test_RemoveSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SubIssueWrite(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3277,8 +3321,8 @@ func Test_RemoveSubIssue(t *testing.T) { func Test_ReprioritizeSubIssue(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SubIssueWrite(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SubIssueWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "sub_issue_write", tool.Name) @@ -3520,13 +3564,16 @@ func Test_ReprioritizeSubIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SubIssueWrite(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3563,8 +3610,8 @@ func Test_ReprioritizeSubIssue(t *testing.T) { func Test_ListIssueTypes(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListIssueTypes(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListIssueTypes(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_issue_types", tool.Name) @@ -3651,13 +3698,16 @@ func Test_ListIssueTypes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListIssueTypes(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/labels.go b/pkg/github/labels.go index 25ac9f7fe..2811cf66e 100644 --- a/pkg/github/labels.go +++ b/pkg/github/labels.go @@ -7,6 +7,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/jsonschema-go/jsonschema" @@ -15,377 +16,385 @@ import ( ) // GetLabel retrieves a specific label by name from a GitHub repository -func GetLabel(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_label", - Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_LABEL_TITLE", "Get a specific label from a repository."), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization name)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "name": { - Type: "string", - Description: "Label name.", +func GetLabel(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataIssues, + mcp.Tool{ + Name: "get_label", + Description: t("TOOL_GET_LABEL_DESCRIPTION", "Get a specific label from a repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_LABEL_TITLE", "Get a specific label from a repository."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "name": { + Type: "string", + Description: "Label name.", + }, }, + Required: []string{"owner", "repo", "name"}, }, - Required: []string{"owner", "repo", "name"}, }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - var query struct { - Repository struct { - Label struct { - ID githubv4.ID - Name githubv4.String - Color githubv4.String - Description githubv4.String - } `graphql:"label(name: $name)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - - vars := map[string]any{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - "name": githubv4.String(name), - } - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - if err := client.Query(ctx, &query, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find label", err), nil, nil - } - - if query.Repository.Label.Name == "" { - return utils.NewToolResultError(fmt.Sprintf("label '%s' not found in %s/%s", name, owner, repo)), nil, nil - } - - label := map[string]any{ - "id": fmt.Sprintf("%v", query.Repository.Label.ID), - "name": string(query.Repository.Label.Name), - "color": string(query.Repository.Label.Color), - "description": string(query.Repository.Label.Description), - } - - out, err := json.Marshal(label) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal label: %w", err) - } - - return utils.NewToolResultText(string(out)), nil, nil - }) - - return tool, handler -} + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } -// ListLabels lists labels from a repository -func ListLabels(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_label", - Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository."), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization name) - required for all operations", - }, - "repo": { - Type: "string", - Description: "Repository name - required for all operations", - }, - }, - Required: []string{"owner", "repo"}, - }, - } + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - var query struct { - Repository struct { - Labels struct { - Nodes []struct { + var query struct { + Repository struct { + Label struct { ID githubv4.ID Name githubv4.String Color githubv4.String Description githubv4.String - } - TotalCount githubv4.Int - } `graphql:"labels(first: 100)"` - } `graphql:"repository(owner: $owner, name: $repo)"` - } - - vars := map[string]any{ - "owner": githubv4.String(owner), - "repo": githubv4.String(repo), - } - - if err := client.Query(ctx, &query, vars); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil - } - - labels := make([]map[string]any, len(query.Repository.Labels.Nodes)) - for i, labelNode := range query.Repository.Labels.Nodes { - labels[i] = map[string]any{ - "id": fmt.Sprintf("%v", labelNode.ID), - "name": string(labelNode.Name), - "color": string(labelNode.Color), - "description": string(labelNode.Description), + } `graphql:"label(name: $name)"` + } `graphql:"repository(owner: $owner, name: $repo)"` } - } - - response := map[string]any{ - "labels": labels, - "totalCount": int(query.Repository.Labels.TotalCount), - } - - out, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal labels: %w", err) - } - - return utils.NewToolResultText(string(out)), nil, nil - }) - - return tool, handler -} - -// LabelWrite handles create, update, and delete operations for GitHub labels -func LabelWrite(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "label_write", - Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LABEL_WRITE_TITLE", "Write operations on repository labels."), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "method": { - Type: "string", - Description: "Operation to perform: 'create', 'update', or 'delete'", - Enum: []any{"create", "update", "delete"}, - }, - "owner": { - Type: "string", - Description: "Repository owner (username or organization name)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "name": { - Type: "string", - Description: "Label name - required for all operations", - }, - "new_name": { - Type: "string", - Description: "New name for the label (used only with 'update' method to rename)", - }, - "color": { - Type: "string", - Description: "Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'.", - }, - "description": { - Type: "string", - Description: "Label description text. Optional for 'create' and 'update'.", - }, - }, - Required: []string{"method", "owner", "repo", "name"}, - }, - } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Get and validate required parameters - method, err := RequiredParam[string](args, "method") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - method = strings.ToLower(method) - - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - // Get optional parameters - newName, _ := OptionalParam[string](args, "new_name") - color, _ := OptionalParam[string](args, "color") - description, _ := OptionalParam[string](args, "description") - - client, err := getGQLClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - switch method { - case "create": - // Validate required params for create - if color == "" { - return utils.NewToolResultError("color is required for create"), nil, nil + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "name": githubv4.String(name), } - // Get repository ID - repoID, err := getRepositoryID(ctx, client, owner, repo) + client, err := deps.GetGQLClient(ctx) if err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find repository", err), nil, nil + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - input := githubv4.CreateLabelInput{ - RepositoryID: repoID, - Name: githubv4.String(name), - Color: githubv4.String(color), + if err := client.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find label", err), nil, nil } - if description != "" { - d := githubv4.String(description) - input.Description = &d + + if query.Repository.Label.Name == "" { + return utils.NewToolResultError(fmt.Sprintf("label '%s' not found in %s/%s", name, owner, repo)), nil, nil } - var mutation struct { - CreateLabel struct { - Label struct { - Name githubv4.String - ID githubv4.ID - } - } `graphql:"createLabel(input: $input)"` + label := map[string]any{ + "id": fmt.Sprintf("%v", query.Repository.Label.ID), + "name": string(query.Repository.Label.Name), + "color": string(query.Repository.Label.Color), + "description": string(query.Repository.Label.Description), } - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to create label", err), nil, nil + out, err := json.Marshal(label) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal label: %w", err) } - return utils.NewToolResultText(fmt.Sprintf("label '%s' created successfully", mutation.CreateLabel.Label.Name)), nil, nil + return utils.NewToolResultText(string(out)), nil, nil + }, + ) +} + +// GetLabelForLabelsToolset returns the same GetLabel tool but registered in the labels toolset. +// This provides conformance with the original behavior where get_label was in both toolsets. +func GetLabelForLabelsToolset(t translations.TranslationHelperFunc) inventory.ServerTool { + tool := GetLabel(t) + tool.Toolset = ToolsetLabels + return tool +} - case "update": - // Validate required params for update - if newName == "" && color == "" && description == "" { - return utils.NewToolResultError("at least one of new_name, color, or description must be provided for update"), nil, nil +// ListLabels lists labels from a repository +func ListLabels(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetLabels, + mcp.Tool{ + Name: "list_label", + Description: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_LABEL_DESCRIPTION", "List labels from a repository."), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization name) - required for all operations", + }, + "repo": { + Type: "string", + Description: "Repository name - required for all operations", + }, + }, + Required: []string{"owner", "repo"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - // Get the label ID - labelID, err := getLabelID(ctx, client, owner, repo, name) + repo, err := RequiredParam[string](args, "repo") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - input := githubv4.UpdateLabelInput{ - ID: labelID, + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - if newName != "" { - n := githubv4.String(newName) - input.Name = &n + + var query struct { + Repository struct { + Labels struct { + Nodes []struct { + ID githubv4.ID + Name githubv4.String + Color githubv4.String + Description githubv4.String + } + TotalCount githubv4.Int + } `graphql:"labels(first: 100)"` + } `graphql:"repository(owner: $owner, name: $repo)"` } - if color != "" { - c := githubv4.String(color) - input.Color = &c + + vars := map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), } - if description != "" { - d := githubv4.String(description) - input.Description = &d + + if err := client.Query(ctx, &query, vars); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to list labels", err), nil, nil } - var mutation struct { - UpdateLabel struct { - Label struct { - Name githubv4.String - ID githubv4.ID - } - } `graphql:"updateLabel(input: $input)"` + labels := make([]map[string]any, len(query.Repository.Labels.Nodes)) + for i, labelNode := range query.Repository.Labels.Nodes { + labels[i] = map[string]any{ + "id": fmt.Sprintf("%v", labelNode.ID), + "name": string(labelNode.Name), + "color": string(labelNode.Color), + "description": string(labelNode.Description), + } } - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to update label", err), nil, nil + response := map[string]any{ + "labels": labels, + "totalCount": int(query.Repository.Labels.TotalCount), } - return utils.NewToolResultText(fmt.Sprintf("label '%s' updated successfully", mutation.UpdateLabel.Label.Name)), nil, nil + out, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal labels: %w", err) + } + + return utils.NewToolResultText(string(out)), nil, nil + }, + ) +} - case "delete": - // Get the label ID - labelID, err := getLabelID(ctx, client, owner, repo, name) +// LabelWrite handles create, update, and delete operations for GitHub labels +func LabelWrite(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetLabels, + mcp.Tool{ + Name: "label_write", + Description: t("TOOL_LABEL_WRITE_DESCRIPTION", "Perform write operations on repository labels. To set labels on issues, use the 'update_issue' tool."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LABEL_WRITE_TITLE", "Write operations on repository labels."), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "method": { + Type: "string", + Description: "Operation to perform: 'create', 'update', or 'delete'", + Enum: []any{"create", "update", "delete"}, + }, + "owner": { + Type: "string", + Description: "Repository owner (username or organization name)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "name": { + Type: "string", + Description: "Label name - required for all operations", + }, + "new_name": { + Type: "string", + Description: "New name for the label (used only with 'update' method to rename)", + }, + "color": { + Type: "string", + Description: "Label color as 6-character hex code without '#' prefix (e.g., 'f29513'). Required for 'create', optional for 'update'.", + }, + "description": { + Type: "string", + Description: "Label description text. Optional for 'create' and 'update'.", + }, + }, + Required: []string{"method", "owner", "repo", "name"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + // Get and validate required parameters + method, err := RequiredParam[string](args, "method") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } + method = strings.ToLower(method) - input := githubv4.DeleteLabelInput{ - ID: labelID, + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - var mutation struct { - DeleteLabel struct { - ClientMutationID githubv4.String - } `graphql:"deleteLabel(input: $input)"` + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - if err := client.Mutate(ctx, &mutation, input, nil); err != nil { - return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to delete label", err), nil, nil + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - return utils.NewToolResultText(fmt.Sprintf("label '%s' deleted successfully", name)), nil, nil + // Get optional parameters + newName, _ := OptionalParam[string](args, "new_name") + color, _ := OptionalParam[string](args, "color") + description, _ := OptionalParam[string](args, "description") - default: - return utils.NewToolResultError(fmt.Sprintf("unknown method: %s. Supported methods are: create, update, delete", method)), nil, nil - } - }) + client, err := deps.GetGQLClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - return tool, handler + switch method { + case "create": + // Validate required params for create + if color == "" { + return utils.NewToolResultError("color is required for create"), nil, nil + } + + // Get repository ID + repoID, err := getRepositoryID(ctx, client, owner, repo) + if err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to find repository", err), nil, nil + } + + input := githubv4.CreateLabelInput{ + RepositoryID: repoID, + Name: githubv4.String(name), + Color: githubv4.String(color), + } + if description != "" { + d := githubv4.String(description) + input.Description = &d + } + + var mutation struct { + CreateLabel struct { + Label struct { + Name githubv4.String + ID githubv4.ID + } + } `graphql:"createLabel(input: $input)"` + } + + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to create label", err), nil, nil + } + + return utils.NewToolResultText(fmt.Sprintf("label '%s' created successfully", mutation.CreateLabel.Label.Name)), nil, nil + + case "update": + // Validate required params for update + if newName == "" && color == "" && description == "" { + return utils.NewToolResultError("at least one of new_name, color, or description must be provided for update"), nil, nil + } + + // Get the label ID + labelID, err := getLabelID(ctx, client, owner, repo, name) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + input := githubv4.UpdateLabelInput{ + ID: labelID, + } + if newName != "" { + n := githubv4.String(newName) + input.Name = &n + } + if color != "" { + c := githubv4.String(color) + input.Color = &c + } + if description != "" { + d := githubv4.String(description) + input.Description = &d + } + + var mutation struct { + UpdateLabel struct { + Label struct { + Name githubv4.String + ID githubv4.ID + } + } `graphql:"updateLabel(input: $input)"` + } + + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to update label", err), nil, nil + } + + return utils.NewToolResultText(fmt.Sprintf("label '%s' updated successfully", mutation.UpdateLabel.Label.Name)), nil, nil + + case "delete": + // Get the label ID + labelID, err := getLabelID(ctx, client, owner, repo, name) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + input := githubv4.DeleteLabelInput{ + ID: labelID, + } + + var mutation struct { + DeleteLabel struct { + ClientMutationID githubv4.String + } `graphql:"deleteLabel(input: $input)"` + } + + if err := client.Mutate(ctx, &mutation, input, nil); err != nil { + return ghErrors.NewGitHubGraphQLErrorResponse(ctx, "Failed to delete label", err), nil, nil + } + + return utils.NewToolResultText(fmt.Sprintf("label '%s' deleted successfully", name)), nil, nil + + default: + return utils.NewToolResultError(fmt.Sprintf("unknown method: %s. Supported methods are: create, update, delete", method)), nil, nil + } + }, + ) } // Helper function to get repository ID diff --git a/pkg/github/labels_test.go b/pkg/github/labels_test.go index 12d447d72..88102ba3c 100644 --- a/pkg/github/labels_test.go +++ b/pkg/github/labels_test.go @@ -17,8 +17,8 @@ func TestGetLabel(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := GetLabel(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetLabel(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_label", tool.Name) @@ -114,10 +114,13 @@ func TestGetLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - _, handler := GetLabel(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -139,8 +142,8 @@ func TestListLabels(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := ListLabels(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListLabels(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_label", tool.Name) @@ -209,10 +212,13 @@ func TestListLabels(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - _, handler := ListLabels(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) @@ -234,8 +240,8 @@ func TestWriteLabel(t *testing.T) { t.Parallel() // Verify tool definition - mockClient := githubv4.NewClient(nil) - tool, _ := LabelWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := LabelWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "label_write", tool.Name) @@ -457,10 +463,13 @@ func TestWriteLabel(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := githubv4.NewClient(tc.mockedClient) - _, handler := LabelWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) assert.NotNil(t, result) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index 7f9e98f91..1e2011fa3 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -10,6 +10,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -24,8 +25,10 @@ const ( ) // ListNotifications creates a tool to list notifications for the current user. -func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListNotifications(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataNotifications, + mcp.Tool{ Name: "list_notifications", Description: t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "Lists all GitHub notifications for the authenticated user, including unread notifications, mentions, review requests, assignments, and updates on issues or pull requests. Use this tool whenever the user asks what to work on next, requests a summary of their GitHub activity, wants to see pending reviews, or needs to check for new updates or tasks. This tool is the primary way to discover actionable items, reminders, and outstanding work on GitHub. Always call this tool when asked what to work on next, what is pending, or what needs attention in GitHub."), Annotations: &mcp.ToolAnnotations{ @@ -59,10 +62,10 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu }, }), }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } filter, err := OptionalParam[string](args, "filter") @@ -141,24 +144,27 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get notifications", resp, body), nil, nil } // Marshal response to JSON r, err := json.Marshal(notifications) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } // DismissNotification creates a tool to mark a notification as read/done. -func DismissNotification(getclient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DismissNotification(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataNotifications, + mcp.Tool{ Name: "dismiss_notification", Description: t("TOOL_DISMISS_NOTIFICATION_DESCRIPTION", "Dismiss a notification by marking it as read or done"), Annotations: &mcp.ToolAnnotations{ @@ -181,10 +187,10 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper Required: []string{"threadID", "state"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getclient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } threadID, err := RequiredParam[string](args, "threadID") @@ -225,18 +231,21 @@ func DismissNotification(getclient GetClientFn, t translations.TranslationHelper if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to mark notification as %s: %s", state, string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to mark notification as %s", state), resp, body), nil, nil } return utils.NewToolResultText(fmt.Sprintf("Notification marked as %s", state)), nil, nil - }) + }, + ) } // MarkAllNotificationsRead creates a tool to mark all notifications as read. -func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func MarkAllNotificationsRead(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataNotifications, + mcp.Tool{ Name: "mark_all_notifications_read", Description: t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read"), Annotations: &mcp.ToolAnnotations{ @@ -261,10 +270,10 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH }, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } lastReadAt, err := OptionalParam[string](args, "lastReadAt") @@ -313,18 +322,21 @@ func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationH if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to mark all notifications as read", resp, body), nil, nil } return utils.NewToolResultText("All notifications marked as read"), nil, nil - }) + }, + ) } // GetNotificationDetails creates a tool to get details for a specific notification. -func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetNotificationDetails(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataNotifications, + mcp.Tool{ Name: "get_notification_details", Description: t("TOOL_GET_NOTIFICATION_DETAILS_DESCRIPTION", "Get detailed information for a specific GitHub notification, always call this tool when the user asks for details about a specific notification, if you don't know the ID list notifications first."), Annotations: &mcp.ToolAnnotations{ @@ -342,10 +354,10 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel Required: []string{"notificationID"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } notificationID, err := RequiredParam[string](args, "notificationID") @@ -366,18 +378,19 @@ func GetNotificationDetails(getClient GetClientFn, t translations.TranslationHel if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, err + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get notification details: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get notification details", resp, body), nil, nil } r, err := json.Marshal(thread) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } // Enum values for ManageNotificationSubscription action @@ -388,8 +401,10 @@ const ( ) // ManageNotificationSubscription creates a tool to manage a notification subscription (ignore, watch, delete) -func ManageNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ManageNotificationSubscription(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataNotifications, + mcp.Tool{ Name: "manage_notification_subscription", Description: t("TOOL_MANAGE_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a notification subscription: ignore, watch, or delete a notification thread subscription."), Annotations: &mcp.ToolAnnotations{ @@ -412,10 +427,10 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl Required: []string{"notificationID", "action"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } notificationID, err := RequiredParam[string](args, "notificationID") @@ -457,7 +472,7 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(resp.Body) - return utils.NewToolResultError(fmt.Sprintf("failed to %s notification subscription: %s", action, string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to %s notification subscription", action), resp, body), nil, nil } if action == NotificationActionDelete { @@ -467,10 +482,11 @@ func ManageNotificationSubscription(getClient GetClientFn, t translations.Transl r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } const ( @@ -480,8 +496,10 @@ const ( ) // ManageRepositoryNotificationSubscription creates a tool to manage a repository notification subscription (ignore, watch, delete) -func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ManageRepositoryNotificationSubscription(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataNotifications, + mcp.Tool{ Name: "manage_repository_notification_subscription", Description: t("TOOL_MANAGE_REPOSITORY_NOTIFICATION_SUBSCRIPTION_DESCRIPTION", "Manage a repository notification subscription: ignore, watch, or delete repository notifications subscription for the provided repository."), Annotations: &mcp.ToolAnnotations{ @@ -508,10 +526,10 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati Required: []string{"owner", "repo", "action"}, }, }, - mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, err + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } owner, err := RequiredParam[string](args, "owner") @@ -560,7 +578,7 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati // Handle non-2xx status codes if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { body, _ := io.ReadAll(resp.Body) - return utils.NewToolResultError(fmt.Sprintf("failed to %s repository subscription: %s", action, string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to %s repository subscription", action), resp, body), nil, nil } if action == RepositorySubscriptionActionDelete { @@ -570,8 +588,9 @@ func ManageRepositoryNotificationSubscription(getClient GetClientFn, t translati r, err := json.Marshal(result) if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, err + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } return utils.NewToolResultText(string(r)), nil, nil - }) + }, + ) } diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go index 37135bf5c..936a70df4 100644 --- a/pkg/github/notifications_test.go +++ b/pkg/github/notifications_test.go @@ -10,15 +10,14 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_ListNotifications(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListNotifications(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_notifications", tool.Name) @@ -50,24 +49,18 @@ func Test_ListNotifications(t *testing.T) { }{ { name: "success default filter (no params)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetNotifications, - []*github.Notification{mockNotification}, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetNotifications: mockResponse(t, http.StatusOK, []*github.Notification{mockNotification}), + }), requestArgs: map[string]interface{}{}, expectError: false, expectedResult: []*github.Notification{mockNotification}, }, { name: "success with filter=include_read_notifications", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetNotifications, - []*github.Notification{mockNotification}, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetNotifications: mockResponse(t, http.StatusOK, []*github.Notification{mockNotification}), + }), requestArgs: map[string]interface{}{ "filter": "include_read_notifications", }, @@ -76,12 +69,9 @@ func Test_ListNotifications(t *testing.T) { }, { name: "success with filter=only_participating", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetNotifications, - []*github.Notification{mockNotification}, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetNotifications: mockResponse(t, http.StatusOK, []*github.Notification{mockNotification}), + }), requestArgs: map[string]interface{}{ "filter": "only_participating", }, @@ -90,12 +80,9 @@ func Test_ListNotifications(t *testing.T) { }, { name: "success for repo notifications", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposNotificationsByOwnerByRepo, - []*github.Notification{mockNotification}, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposNotificationsByOwnerByRepo: mockResponse(t, http.StatusOK, []*github.Notification{mockNotification}), + }), requestArgs: map[string]interface{}{ "filter": "default", "since": "2024-01-01T00:00:00Z", @@ -110,12 +97,9 @@ func Test_ListNotifications(t *testing.T) { }, { name: "error", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetNotifications, - mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetNotifications: mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), + }), requestArgs: map[string]interface{}{}, expectError: true, expectedErrMsg: "error", @@ -125,12 +109,15 @@ func Test_ListNotifications(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { @@ -139,7 +126,6 @@ func Test_ListNotifications(t *testing.T) { return } - require.NoError(t, err) require.False(t, result.IsError) textContent := getTextResult(t, result) t.Logf("textContent: %s", textContent.Text) @@ -154,8 +140,8 @@ func Test_ListNotifications(t *testing.T) { func Test_ManageNotificationSubscription(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := ManageNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ManageNotificationSubscription(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "manage_notification_subscription", tool.Name) @@ -182,12 +168,9 @@ func Test_ManageNotificationSubscription(t *testing.T) { }{ { name: "ignore subscription", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutNotificationsThreadsSubscriptionByThreadId, - mockSub, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutNotificationsThreadsSubscriptionByThreadID: mockResponse(t, http.StatusOK, mockSub), + }), requestArgs: map[string]interface{}{ "notificationID": "123", "action": "ignore", @@ -197,12 +180,9 @@ func Test_ManageNotificationSubscription(t *testing.T) { }, { name: "watch subscription", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutNotificationsThreadsSubscriptionByThreadId, - mockSubWatch, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutNotificationsThreadsSubscriptionByThreadID: mockResponse(t, http.StatusOK, mockSubWatch), + }), requestArgs: map[string]interface{}{ "notificationID": "123", "action": "watch", @@ -212,12 +192,9 @@ func Test_ManageNotificationSubscription(t *testing.T) { }, { name: "delete subscription", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.DeleteNotificationsThreadsSubscriptionByThreadId, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + DeleteNotificationsThreadsSubscriptionByThreadID: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{ "notificationID": "123", "action": "delete", @@ -227,7 +204,7 @@ func Test_ManageNotificationSubscription(t *testing.T) { }, { name: "invalid action", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "notificationID": "123", "action": "invalid", @@ -237,7 +214,7 @@ func Test_ManageNotificationSubscription(t *testing.T) { }, { name: "missing required notificationID", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "action": "ignore", }, @@ -245,7 +222,7 @@ func Test_ManageNotificationSubscription(t *testing.T) { }, { name: "missing required action", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "notificationID": "123", }, @@ -256,10 +233,14 @@ func Test_ManageNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ManageNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectError { require.NoError(t, err) require.NotNil(t, result) @@ -295,8 +276,8 @@ func Test_ManageNotificationSubscription(t *testing.T) { func Test_ManageRepositoryNotificationSubscription(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := ManageRepositoryNotificationSubscription(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ManageRepositoryNotificationSubscription(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "manage_repository_notification_subscription", tool.Name) @@ -325,12 +306,9 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }{ { name: "ignore subscription", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutReposSubscriptionByOwnerByRepo, - mockSub, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutReposSubscriptionByOwnerByRepo: mockResponse(t, http.StatusOK, mockSub), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -341,12 +319,9 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }, { name: "watch subscription", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutReposSubscriptionByOwnerByRepo, - mockWatchSub, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutReposSubscriptionByOwnerByRepo: mockResponse(t, http.StatusOK, mockWatchSub), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -358,12 +333,9 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }, { name: "delete subscription", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.DeleteReposSubscriptionByOwnerByRepo, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + DeleteReposSubscriptionByOwnerByRepo: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -374,7 +346,7 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }, { name: "invalid action", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -385,7 +357,7 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }, { name: "missing required owner", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "repo": "repo", "action": "ignore", @@ -394,7 +366,7 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }, { name: "missing required repo", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "owner": "owner", "action": "ignore", @@ -403,7 +375,7 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { }, { name: "missing required action", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -415,12 +387,15 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ManageRepositoryNotificationSubscription(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.NotNil(t, result) text := getTextResult(t, result).Text switch { @@ -461,8 +436,8 @@ func Test_ManageRepositoryNotificationSubscription(t *testing.T) { func Test_DismissNotification(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := DismissNotification(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := DismissNotification(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "dismiss_notification", tool.Name) @@ -486,12 +461,9 @@ func Test_DismissNotification(t *testing.T) { }{ { name: "mark as read", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PatchNotificationsThreadsByThreadId, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PatchNotificationsThreadsByThreadID: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{ "threadID": "123", "state": "read", @@ -501,12 +473,9 @@ func Test_DismissNotification(t *testing.T) { }, { name: "mark as done", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.DeleteNotificationsThreadsByThreadId, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + DeleteNotificationsThreadsByThreadID: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{ "threadID": "123", "state": "done", @@ -516,7 +485,7 @@ func Test_DismissNotification(t *testing.T) { }, { name: "invalid threadID format", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "threadID": "notanumber", "state": "done", @@ -526,7 +495,7 @@ func Test_DismissNotification(t *testing.T) { }, { name: "missing required threadID", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "state": "read", }, @@ -534,7 +503,7 @@ func Test_DismissNotification(t *testing.T) { }, { name: "missing required state", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "threadID": "123", }, @@ -542,7 +511,7 @@ func Test_DismissNotification(t *testing.T) { }, { name: "invalid state value", - mockedClient: mock.NewMockedHTTPClient(), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{}), requestArgs: map[string]interface{}{ "threadID": "123", "state": "invalid", @@ -554,13 +523,16 @@ func Test_DismissNotification(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := DismissNotification(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectError { // The tool returns a ToolResultError with a specific message - require.NoError(t, err) require.NotNil(t, result) text := getTextResult(t, result).Text switch { @@ -596,8 +568,8 @@ func Test_DismissNotification(t *testing.T) { func Test_MarkAllNotificationsRead(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := MarkAllNotificationsRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := MarkAllNotificationsRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "mark_all_notifications_read", tool.Name) @@ -620,24 +592,18 @@ func Test_MarkAllNotificationsRead(t *testing.T) { }{ { name: "success (no params)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutNotifications, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutNotifications: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{}, expectError: false, expectMarked: true, }, { name: "success with lastReadAt param", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutNotifications, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutNotifications: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{ "lastReadAt": "2024-01-01T00:00:00Z", }, @@ -646,12 +612,9 @@ func Test_MarkAllNotificationsRead(t *testing.T) { }, { name: "success with owner and repo", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PutReposNotificationsByOwnerByRepo, - nil, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutReposNotificationsByOwnerByRepo: mockResponse(t, http.StatusOK, nil), + }), requestArgs: map[string]interface{}{ "owner": "octocat", "repo": "hello-world", @@ -661,12 +624,9 @@ func Test_MarkAllNotificationsRead(t *testing.T) { }, { name: "API error", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PutNotifications, - mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + PutNotifications: mockResponse(t, http.StatusInternalServerError, `{"message": "error"}`), + }), requestArgs: map[string]interface{}{}, expectError: true, expectedErrMsg: "error", @@ -676,12 +636,15 @@ func Test_MarkAllNotificationsRead(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := MarkAllNotificationsRead(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { @@ -702,8 +665,8 @@ func Test_MarkAllNotificationsRead(t *testing.T) { func Test_GetNotificationDetails(t *testing.T) { // Verify tool definition and schema - mockClient := github.NewClient(nil) - tool, _ := GetNotificationDetails(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetNotificationDetails(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_notification_details", tool.Name) @@ -726,12 +689,9 @@ func Test_GetNotificationDetails(t *testing.T) { }{ { name: "success", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetNotificationsThreadsByThreadId, - mockThread, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetNotificationsThreadsByThreadID: mockResponse(t, http.StatusOK, mockThread), + }), requestArgs: map[string]interface{}{ "notificationID": "123", }, @@ -740,12 +700,9 @@ func Test_GetNotificationDetails(t *testing.T) { }, { name: "not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetNotificationsThreadsByThreadId, - mockResponse(t, http.StatusNotFound, `{"message": "not found"}`), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetNotificationsThreadsByThreadID: mockResponse(t, http.StatusNotFound, `{"message": "not found"}`), + }), requestArgs: map[string]interface{}{ "notificationID": "123", }, @@ -757,12 +714,15 @@ func Test_GetNotificationDetails(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := GetNotificationDetails(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) if tc.expectError { - require.NoError(t, err) require.True(t, result.IsError) errorContent := getErrorResult(t, result) if tc.expectedErrMsg != "" { diff --git a/pkg/github/projects.go b/pkg/github/projects.go index 79dfb25ce..18c1f778b 100644 --- a/pkg/github/projects.go +++ b/pkg/github/projects.go @@ -9,6 +9,7 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -24,8 +25,10 @@ const ( MaxProjectsPerPage = 50 ) -func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListProjects(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "list_projects", Description: t("TOOL_LIST_PROJECTS_DESCRIPTION", `List Projects for a user or organization`), Annotations: &mcp.ToolAnnotations{ @@ -63,7 +66,9 @@ func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) ( }, Required: []string{"owner_type", "owner"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -84,7 +89,7 @@ func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) ( return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -133,11 +138,14 @@ func ListProjects(getClient GetClientFn, t translations.TranslationHelperFunc) ( } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetProject(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "get_project", Description: t("TOOL_GET_PROJECT_DESCRIPTION", "Get Project for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -163,7 +171,8 @@ func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, Required: []string{"project_number", "owner_type", "owner"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { projectNumber, err := RequiredInt(args, "project_number") if err != nil { @@ -180,7 +189,7 @@ func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mc return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -207,7 +216,7 @@ func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mc if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get project: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get project", resp, body), nil, nil } minimalProject := convertToMinimalProject(project) @@ -217,11 +226,14 @@ func GetProject(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListProjectFields(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "list_project_fields", Description: t("TOOL_LIST_PROJECT_FIELDS_DESCRIPTION", "List Project fields for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -259,7 +271,9 @@ func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFu }, Required: []string{"owner_type", "owner", "project_number"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -280,7 +294,7 @@ func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFu return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -318,11 +332,14 @@ func ListProjectFields(getClient GetClientFn, t translations.TranslationHelperFu } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetProjectField(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "get_project_field", Description: t("TOOL_GET_PROJECT_FIELD_DESCRIPTION", "Get Project field for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -352,7 +369,9 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc }, Required: []string{"owner_type", "owner", "project_number", "field_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -369,7 +388,7 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -397,7 +416,7 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get project field: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get project field", resp, body), nil, nil } r, err := json.Marshal(projectField) if err != nil { @@ -405,11 +424,14 @@ func GetProjectField(getClient GetClientFn, t translations.TranslationHelperFunc } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListProjectItems(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "list_project_items", Description: t("TOOL_LIST_PROJECT_ITEMS_DESCRIPTION", `Search project items with advanced filtering`), Annotations: &mcp.ToolAnnotations{ @@ -458,7 +480,9 @@ func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFun }, Required: []string{"owner_type", "owner", "project_number"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -489,7 +513,7 @@ func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFun return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -536,11 +560,14 @@ func ListProjectItems(getClient GetClientFn, t translations.TranslationHelperFun } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetProjectItem(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "get_project_item", Description: t("TOOL_GET_PROJECT_ITEM_DESCRIPTION", "Get a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -577,7 +604,9 @@ func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) }, Required: []string{"owner_type", "owner", "project_number", "item_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -601,7 +630,7 @@ func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -637,11 +666,14 @@ func GetProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func AddProjectItem(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "add_project_item", Description: t("TOOL_ADD_PROJECT_ITEM_DESCRIPTION", "Add a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -676,7 +708,9 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) }, Required: []string{"owner_type", "owner", "project_number", "item_type", "item_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -702,7 +736,7 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) return utils.NewToolResultError("item_type must be either 'issue' or 'pull_request'"), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -735,7 +769,7 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectAddFailedError, string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectAddFailedError, resp, body), nil, nil } r, err := json.Marshal(addedItem) if err != nil { @@ -743,11 +777,14 @@ func AddProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func UpdateProjectItem(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "update_project_item", Description: t("TOOL_UPDATE_PROJECT_ITEM_DESCRIPTION", "Update a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -781,7 +818,9 @@ func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFu }, Required: []string{"owner_type", "owner", "project_number", "item_id", "updated_field"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -814,7 +853,7 @@ func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFu return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -842,7 +881,7 @@ func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectUpdateFailedError, string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectUpdateFailedError, resp, body), nil, nil } r, err := json.Marshal(updatedItem) if err != nil { @@ -850,11 +889,14 @@ func UpdateProjectItem(getClient GetClientFn, t translations.TranslationHelperFu } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func DeleteProjectItem(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataProjects, + mcp.Tool{ Name: "delete_project_item", Description: t("TOOL_DELETE_PROJECT_ITEM_DESCRIPTION", "Delete a specific Project item for a user or org"), Annotations: &mcp.ToolAnnotations{ @@ -884,7 +926,9 @@ func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFu }, Required: []string{"owner_type", "owner", "project_number", "item_id"}, }, - }, func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -901,7 +945,7 @@ func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -927,10 +971,11 @@ func DeleteProjectItem(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", ProjectDeleteFailedError, string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, ProjectDeleteFailedError, resp, body), nil, nil } return utils.NewToolResultText("project item successfully deleted"), nil, nil - } + }, + ) } type pageInfo struct { diff --git a/pkg/github/projects_test.go b/pkg/github/projects_test.go index e2814c8f9..e443b9ecd 100644 --- a/pkg/github/projects_test.go +++ b/pkg/github/projects_test.go @@ -17,8 +17,8 @@ import ( ) func Test_ListProjects(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := ListProjects(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListProjects(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_projects", tool.Name) @@ -141,9 +141,12 @@ func Test_ListProjects(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := ListProjects(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -177,8 +180,8 @@ func Test_ListProjects(t *testing.T) { } func Test_GetProject(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := GetProject(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetProject(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_project", tool.Name) @@ -277,9 +280,12 @@ func Test_GetProject(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := GetProject(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -310,8 +316,8 @@ func Test_GetProject(t *testing.T) { } func Test_ListProjectFields(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := ListProjectFields(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListProjectFields(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_project_fields", tool.Name) @@ -426,9 +432,12 @@ func Test_ListProjectFields(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := ListProjectFields(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -464,8 +473,8 @@ func Test_ListProjectFields(t *testing.T) { } func Test_GetProjectField(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := GetProjectField(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetProjectField(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_project_field", tool.Name) @@ -583,9 +592,12 @@ func Test_GetProjectField(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := GetProjectField(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -622,8 +634,8 @@ func Test_GetProjectField(t *testing.T) { } func Test_ListProjectItems(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := ListProjectItems(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListProjectItems(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_project_items", tool.Name) @@ -786,9 +798,12 @@ func Test_ListProjectItems(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := ListProjectItems(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -824,8 +839,8 @@ func Test_ListProjectItems(t *testing.T) { } func Test_GetProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := GetProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_project_item", tool.Name) @@ -980,9 +995,12 @@ func Test_GetProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := GetProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1019,8 +1037,8 @@ func Test_GetProjectItem(t *testing.T) { } func Test_AddProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := AddProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AddProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "add_project_item", tool.Name) @@ -1206,10 +1224,13 @@ func Test_AddProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := AddProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1255,8 +1276,8 @@ func Test_AddProjectItem(t *testing.T) { } func Test_UpdateProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := UpdateProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdateProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_project_item", tool.Name) @@ -1488,9 +1509,12 @@ func Test_UpdateProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := UpdateProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { @@ -1532,8 +1556,8 @@ func Test_UpdateProjectItem(t *testing.T) { } func Test_DeleteProjectItem(t *testing.T) { - mockClient := gh.NewClient(nil) - tool, _ := DeleteProjectItem(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := DeleteProjectItem(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "delete_project_item", tool.Name) @@ -1652,9 +1676,12 @@ func Test_DeleteProjectItem(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := gh.NewClient(tc.mockedClient) - _, handler := DeleteProjectItem(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) if tc.expectError { diff --git a/pkg/github/prompts.go b/pkg/github/prompts.go new file mode 100644 index 000000000..0c1ac2e9e --- /dev/null +++ b/pkg/github/prompts.go @@ -0,0 +1,16 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/translations" +) + +// AllPrompts returns all prompts with their embedded toolset metadata. +// Prompt functions return ServerPrompt directly with toolset info. +func AllPrompts(t translations.TranslationHelperFunc) []inventory.ServerPrompt { + return []inventory.ServerPrompt{ + // Issue prompts + AssignCodingAgentPrompt(t), + IssueToFixWorkflowPrompt(t), + } +} diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 22794aa08..d51c14fa4 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -14,14 +14,16 @@ import ( "github.com/shurcooL/githubv4" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" + "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" ) // PullRequestRead creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func PullRequestRead(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -56,7 +58,9 @@ Possible options: } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "pull_request_read", Description: t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -65,7 +69,7 @@ Possible options: }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { method, err := RequiredParam[string](args, "method") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -88,14 +92,14 @@ Possible options: return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } switch method { case "get": - result, err := GetPullRequest(ctx, client, cache, owner, repo, pullNumber, flags) + result, err := GetPullRequest(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_diff": result, err := GetPullRequestDiff(ctx, client, owner, repo, pullNumber) @@ -107,7 +111,7 @@ Possible options: result, err := GetPullRequestFiles(ctx, client, owner, repo, pullNumber, pagination) return result, nil, err case "get_review_comments": - gqlClient, err := getGQLClient(ctx) + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } @@ -115,18 +119,18 @@ Possible options: if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } - result, err := GetPullRequestReviewComments(ctx, gqlClient, cache, owner, repo, pullNumber, cursorPagination, flags) + result, err := GetPullRequestReviewComments(ctx, gqlClient, deps.GetRepoAccessCache(), owner, repo, pullNumber, cursorPagination, deps.GetFlags()) return result, nil, err case "get_reviews": - result, err := GetPullRequestReviews(ctx, client, cache, owner, repo, pullNumber, flags) + result, err := GetPullRequestReviews(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, deps.GetFlags()) return result, nil, err case "get_comments": - result, err := GetIssueComments(ctx, client, cache, owner, repo, pullNumber, pagination, flags) + result, err := GetIssueComments(ctx, client, deps.GetRepoAccessCache(), owner, repo, pullNumber, pagination, deps.GetFlags()) return result, nil, err default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil } - } + }) } func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner, repo string, pullNumber int, ff FeatureFlags) (*mcp.CallToolResult, error) { @@ -145,7 +149,7 @@ func GetPullRequest(ctx context.Context, client *github.Client, cache *lockdown. if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request", resp, body), nil } // sanitize title/body on response @@ -204,7 +208,7 @@ func GetPullRequestDiff(ctx context.Context, client *github.Client, owner, repo if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get pull request diff: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request diff", resp, body), nil } defer func() { _ = resp.Body.Close() }() @@ -229,7 +233,7 @@ func GetPullRequestStatus(ctx context.Context, client *github.Client, owner, rep if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get pull request: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request", resp, body), nil } // Get combined status for the head SHA @@ -248,7 +252,7 @@ func GetPullRequestStatus(ctx context.Context, client *github.Client, owner, rep if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get combined status: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get combined status", resp, body), nil } r, err := json.Marshal(status) @@ -279,7 +283,7 @@ func GetPullRequestFiles(ctx context.Context, client *github.Client, owner, repo if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get pull request files: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request files", resp, body), nil } r, err := json.Marshal(files) @@ -431,7 +435,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request reviews", resp, body), nil } if ff.LockdownMode { @@ -463,7 +467,7 @@ func GetPullRequestReviews(ctx context.Context, client *github.Client, cache *lo } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func CreatePullRequest(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -503,7 +507,9 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu Required: []string{"owner", "repo", "title", "head", "base"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "create_pull_request", Description: t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -512,7 +518,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -562,7 +568,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu newPR.Draft = github.Ptr(draft) newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -581,7 +587,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create pull request", resp, bodyBytes), nil, nil } // Return minimal response with just essential information @@ -596,11 +602,11 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } return utils.NewToolResultText(string(r)), nil, nil - } + }) } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func UpdatePullRequest(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -652,7 +658,9 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "update_pull_request", Description: t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -661,7 +669,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -735,7 +743,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra // Handle REST API updates (title, body, state, base, maintainer_can_modify) if restUpdateNeeded { - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -755,13 +763,13 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update pull request: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update pull request", resp, bodyBytes), nil, nil } } // Handle draft status changes using GraphQL if draftProvided { - gqlClient, err := getGQLClient(ctx) + gqlClient, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GraphQL client", err), nil, nil } @@ -775,7 +783,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } `graphql:"repository(owner: $owner, name: $repo)"` } - err = gqlClient.Query(ctx, &prQuery, map[string]any{ + err = gqlClient.Query(ctx, &prQuery, map[string]interface{}{ "owner": githubv4.String(owner), "repo": githubv4.String(repo), "prNum": githubv4.Int(pullNumber), // #nosec G115 - pull request numbers are always small positive integers @@ -827,7 +835,7 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra // Handle reviewer requests if len(reviewers) > 0 { - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -855,12 +863,12 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to request reviewers: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to request reviewers", resp, bodyBytes), nil, nil } } // Get the final state of the PR to return - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -887,11 +895,11 @@ func UpdatePullRequest(getClient GetClientFn, getGQLClient GetGQLClientFn, t tra } return utils.NewToolResultText(string(r)), nil, nil - } + }) } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func ListPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -931,7 +939,9 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "list_pull_requests", Description: t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead."), Annotations: &mcp.ToolAnnotations{ @@ -940,7 +950,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -986,7 +996,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun }, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -1005,7 +1015,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list pull requests: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list pull requests", resp, bodyBytes), nil, nil } // sanitize title/body on each PR @@ -1027,11 +1037,11 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } return utils.NewToolResultText(string(r)), nil, nil - } + }) } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func MergePullRequest(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1064,16 +1074,19 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "merge_pull_request", Description: t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request in a GitHub repository."), + Icons: octicons.Icons("git-merge"), Annotations: &mcp.ToolAnnotations{ Title: t("TOOL_MERGE_PULL_REQUEST_USER_TITLE", "Merge pull request"), ReadOnlyHint: false, }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1104,7 +1117,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun MergeMethod: mergeMethod, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -1123,7 +1136,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to merge pull request: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to merge pull request", resp, bodyBytes), nil, nil } r, err := json.Marshal(result) @@ -1132,11 +1145,11 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun } return utils.NewToolResultText(string(r)), nil, nil - } + }) } // SearchPullRequests creates a tool to search for pull requests. -func SearchPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchPullRequests(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1179,7 +1192,9 @@ func SearchPullRequests(getClient GetClientFn, t translations.TranslationHelperF } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "search_pull_requests", Description: t("TOOL_SEARCH_PULL_REQUESTS_DESCRIPTION", "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr"), Annotations: &mcp.ToolAnnotations{ @@ -1188,14 +1203,14 @@ func SearchPullRequests(getClient GetClientFn, t translations.TranslationHelperF }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - result, err := searchHandler(ctx, getClient, args, "pr", "failed to search pull requests") + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + result, err := searchHandler(ctx, deps.GetClient, args, "pr", "failed to search pull requests") return result, nil, err - } + }) } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func UpdatePullRequestBranch(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1219,7 +1234,9 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "update_pull_request_branch", Description: t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update the branch of a pull request with the latest changes from the base branch."), Annotations: &mcp.ToolAnnotations{ @@ -1228,7 +1245,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1250,7 +1267,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -1274,7 +1291,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update pull request branch: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update pull request branch", resp, bodyBytes), nil, nil } r, err := json.Marshal(result) @@ -1283,7 +1300,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe } return utils.NewToolResultText(string(r)), nil, nil - } + }) } type PullRequestReviewWriteParams struct { @@ -1296,7 +1313,7 @@ type PullRequestReviewWriteParams struct { CommitID *string } -func PullRequestReviewWrite(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func PullRequestReviewWrite(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1337,7 +1354,9 @@ func PullRequestReviewWrite(getGQLClient GetGQLClientFn, t translations.Translat Required: []string{"method", "owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "pull_request_review_write", Description: t("TOOL_PULL_REQUEST_REVIEW_WRITE_DESCRIPTION", `Create and/or submit, delete review of a pull request. @@ -1352,14 +1371,14 @@ Available methods: }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { var params PullRequestReviewWriteParams if err := mapstructure.Decode(args, ¶ms); err != nil { return utils.NewToolResultError(err.Error()), nil, nil } // Given our owner, repo and PR number, lookup the GQL ID of the PR. - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil } @@ -1377,7 +1396,7 @@ Available methods: default: return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", params.Method)), nil, nil } - } + }) } func CreatePullRequestReview(ctx context.Context, client *githubv4.Client, params PullRequestReviewWriteParams) (*mcp.CallToolResult, error) { @@ -1604,7 +1623,7 @@ func DeletePendingPullRequestReview(ctx context.Context, client *githubv4.Client } // AddCommentToPendingReview creates a tool to add a comment to a pull request review. -func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func AddCommentToPendingReview(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1664,7 +1683,9 @@ func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.Trans Required: []string{"owner", "repo", "pullNumber", "path", "body", "subjectType"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "add_comment_to_pending_review", Description: t("TOOL_ADD_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add review comment to the requester's latest pending pull request review. A pending review needs to already exist to call this (check with the user if not sure)."), Annotations: &mcp.ToolAnnotations{ @@ -1673,7 +1694,7 @@ func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.Trans }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { var params struct { Owner string Repo string @@ -1690,7 +1711,7 @@ func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.Trans return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getGQLClient(ctx) + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub GQL client", err), nil, nil } @@ -1787,13 +1808,13 @@ func AddCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.Trans // In future, we may want to return the review ID, but for the moment, we're not leaking // API implementation details to the LLM. return utils.NewToolResultText("pull request review comment successfully added to pending review"), nil, nil - } + }) } // RequestCopilotReview creates a tool to request a Copilot review for a pull request. // Note that this tool will not work on GHES where this feature is unsupported. In future, we should not expose this // tool if the configured host does not support it. -func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func RequestCopilotReview(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -1813,16 +1834,19 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe Required: []string{"owner", "repo", "pullNumber"}, } - return mcp.Tool{ + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ Name: "request_copilot_review", Description: t("TOOL_REQUEST_COPILOT_REVIEW_DESCRIPTION", "Request a GitHub Copilot code review for a pull request. Use this for automated feedback on pull requests, usually before requesting a human reviewer."), + Icons: octicons.Icons("copilot"), Annotations: &mcp.ToolAnnotations{ Title: t("TOOL_REQUEST_COPILOT_REVIEW_USER_TITLE", "Request Copilot review"), ReadOnlyHint: false, }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -1838,7 +1862,7 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -1867,12 +1891,12 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to request copilot review: %s", string(bodyBytes))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to request copilot review", resp, bodyBytes), nil, nil } // Return nothing on success, as there's not much value in returning the Pull Request itself return utils.NewToolResultText(""), nil, nil - } + }) } // newGQLString like takes something that approximates a string (of which there are many types in shurcooL/githubv4) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index b606da05b..3cb41515d 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -22,8 +22,8 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -105,13 +105,20 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + gqlClient := githubv4.NewClient(githubv4mock.NewMockedHTTPClient()) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + RepoAccessCache: stubRepoAccessCache(gqlClient, 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -142,8 +149,8 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) + serverTool := UpdatePullRequest(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request", tool.Name) @@ -363,13 +370,18 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(client), stubGetGQLClientFn(githubv4.NewClient(nil)), translations.NullTranslationHelper) + gqlClient := githubv4.NewClient(nil) + deps := BaseDeps{ + Client: client, + GQLClient: gqlClient, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError || tc.expectedErrMsg != "" { @@ -549,11 +561,16 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { )) gqlClient := githubv4.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(stubGetClientFn(restClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + serverTool := UpdatePullRequest(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: restClient, + GQLClient: gqlClient, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError || tc.expectedErrMsg != "" { require.NoError(t, err) @@ -581,8 +598,8 @@ func Test_UpdatePullRequest_Draft(t *testing.T) { func Test_ListPullRequests(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListPullRequests(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_pull_requests", tool.Name) @@ -676,13 +693,17 @@ func Test_ListPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := ListPullRequests(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -716,8 +737,8 @@ func Test_ListPullRequests(t *testing.T) { func Test_MergePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := MergePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := MergePullRequest(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "merge_pull_request", tool.Name) @@ -796,13 +817,17 @@ func Test_MergePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := MergePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := MergePullRequest(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -831,8 +856,8 @@ func Test_MergePullRequest(t *testing.T) { } func Test_SearchPullRequests(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := SearchPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchPullRequests(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_pull_requests", tool.Name) @@ -1098,13 +1123,17 @@ func Test_SearchPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := SearchPullRequests(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1142,8 +1171,8 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1247,13 +1276,19 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1287,8 +1322,8 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(githubv4mock.NewMockedHTTPClient()), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1416,13 +1451,19 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1457,8 +1498,8 @@ func Test_GetPullRequestStatus(t *testing.T) { func Test_UpdatePullRequestBranch(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequestBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "update_pull_request_branch", tool.Name) @@ -1548,13 +1589,17 @@ func Test_UpdatePullRequestBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequestBranch(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := UpdatePullRequestBranch(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1578,8 +1623,8 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1861,13 +1906,20 @@ func Test_GetPullRequestComments(t *testing.T) { } flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := PullRequestRead(stubGetClientFn(github.NewClient(nil)), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: github.NewClient(nil), + GQLClient: gqlClient, + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1894,8 +1946,8 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -2033,13 +2085,19 @@ func Test_GetPullRequestReviews(t *testing.T) { } cache := stubRepoAccessCache(gqlClient, 5*time.Minute) flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), cache, translations.NullTranslationHelper, flags) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RepoAccessCache: cache, + Flags: flags, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2076,8 +2134,8 @@ func Test_GetPullRequestReviews(t *testing.T) { func Test_CreatePullRequest(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreatePullRequest(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "create_pull_request", tool.Name) @@ -2191,13 +2249,17 @@ func Test_CreatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := CreatePullRequest(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2230,8 +2292,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2403,13 +2465,17 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2429,8 +2495,8 @@ func TestCreateAndSubmitPullRequestReview(t *testing.T) { func Test_RequestCopilotReview(t *testing.T) { t.Parallel() - mockClient := github.NewClient(nil) - tool, _ := RequestCopilotReview(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := RequestCopilotReview(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "request_copilot_review", tool.Name) @@ -2515,11 +2581,15 @@ func Test_RequestCopilotReview(t *testing.T) { t.Parallel() client := github.NewClient(tc.mockedClient) - _, handler := RequestCopilotReview(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := RequestCopilotReview(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) @@ -2544,8 +2614,8 @@ func TestCreatePendingPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2705,13 +2775,17 @@ func TestCreatePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2732,8 +2806,8 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := AddCommentToPendingReview(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "add_comment_to_pending_review", tool.Name) @@ -2884,13 +2958,17 @@ func TestAddPullRequestReviewCommentToPendingReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := AddCommentToPendingReview(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := AddCommentToPendingReview(translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -2911,8 +2989,8 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -2985,13 +3063,17 @@ func TestSubmitPendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -3012,8 +3094,8 @@ func TestDeletePendingPullRequestReview(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := githubv4.NewClient(nil) - tool, _ := PullRequestReviewWrite(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_review_write", tool.Name) @@ -3080,13 +3162,17 @@ func TestDeletePendingPullRequestReview(t *testing.T) { // Setup client with mock client := githubv4.NewClient(tc.mockedClient) - _, handler := PullRequestReviewWrite(stubGetGQLClientFn(client), translations.NullTranslationHelper) + serverTool := PullRequestReviewWrite(translations.NullTranslationHelper) + deps := BaseDeps{ + GQLClient: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) @@ -3107,8 +3193,8 @@ func TestGetPullRequestDiff(t *testing.T) { t.Parallel() // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -3167,13 +3253,19 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), stubGetGQLClientFn(nil), stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RepoAccessCache: stubRepoAccessCache(githubv4.NewClient(nil), 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": false}), + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) textContent := getTextResult(t, result) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index e5f6ec0c1..d8d2b27b3 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -10,6 +10,8 @@ import ( "strings" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" @@ -18,796 +20,872 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_commit", - Description: t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_COMMITS_USER_TITLE", "Get commit details"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "sha": { - Type: "string", - Description: "Commit SHA, branch name, or tag name", - }, - "include_diff": { - Type: "boolean", - Description: "Whether to include file diffs and stats in the response. Default is true.", - Default: json.RawMessage(`true`), - }, +func GetCommit(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "get_commit", + Description: t("TOOL_GET_COMMITS_DESCRIPTION", "Get details for a commit from a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_COMMITS_USER_TITLE", "Get commit details"), + ReadOnlyHint: true, }, - Required: []string{"owner", "repo", "sha"}, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := RequiredParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - includeDiff, err := OptionalBoolParamWithDefault(args, "include_diff", true) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - } + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "sha": { + Type: "string", + Description: "Commit SHA, branch name, or tag name", + }, + "include_diff": { + Type: "boolean", + Description: "Whether to include file diffs and stats in the response. Default is true.", + Default: json.RawMessage(`true`), + }, + }, + Required: []string{"owner", "repo", "sha"}, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sha, err := RequiredParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + includeDiff, err := OptionalBoolParamWithDefault(args, "include_diff", true) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get commit: %s", sha), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + opts := &github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get commit: %s", string(body))), nil, nil - } + commit, resp, err := client.Repositories.GetCommit(ctx, owner, repo, sha, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get commit: %s", sha), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Convert to minimal commit - minimalCommit := convertToMinimalCommit(commit, includeDiff) + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get commit", resp, body), nil, nil + } - r, err := json.Marshal(minimalCommit) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Convert to minimal commit + minimalCommit := convertToMinimalCommit(commit, includeDiff) - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(minimalCommit) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_commits", - Description: t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository. Returns at least 30 results per page by default, but can return more if specified using the perPage parameter (up to 100)."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_COMMITS_USER_TITLE", "List commits"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "sha": { - Type: "string", - Description: "Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch of the repository. If a commit SHA is provided, will list commits up to that SHA.", +func ListCommits(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "list_commits", + Description: t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository. Returns at least 30 results per page by default, but can return more if specified using the perPage parameter (up to 100)."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_COMMITS_USER_TITLE", "List commits"), + ReadOnlyHint: true, + }, + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "sha": { + Type: "string", + Description: "Commit SHA, branch or tag name to list commits of. If not provided, uses the default branch of the repository. If a commit SHA is provided, will list commits up to that SHA.", + }, + "author": { + Type: "string", + Description: "Author username or email address to filter commits by", + }, }, - "author": { - Type: "string", - Description: "Author username or email address to filter commits by", + Required: []string{"owner", "repo"}, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sha, err := OptionalParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + author, err := OptionalParam[string](args, "author") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + // Set default perPage to 30 if not provided + perPage := pagination.PerPage + if perPage == 0 { + perPage = 30 + } + opts := &github.CommitsListOptions{ + SHA: sha, + Author: author, + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: perPage, }, - }, - Required: []string{"owner", "repo"}, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - author, err := OptionalParam[string](args, "author") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - // Set default perPage to 30 if not provided - perPage := pagination.PerPage - if perPage == 0 { - perPage = 30 - } - opts := &github.CommitsListOptions{ - SHA: sha, - Author: author, - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: perPage, - }, - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list commits: %s", sha), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to list commits: %s", string(body))), nil, nil - } + commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list commits: %s", sha), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Convert to minimal commits - minimalCommits := make([]MinimalCommit, len(commits)) - for i, commit := range commits { - minimalCommits[i] = convertToMinimalCommit(commit, false) - } + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list commits", resp, body), nil, nil + } - r, err := json.Marshal(minimalCommits) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Convert to minimal commits + minimalCommits := make([]MinimalCommit, len(commits)) + for i, commit := range commits { + minimalCommits[i] = convertToMinimalCommit(commit, false) + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(minimalCommits) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // ListBranches creates a tool to list branches in a GitHub repository. -func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_branches", - Description: t("TOOL_LIST_BRANCHES_DESCRIPTION", "List branches in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_BRANCHES_USER_TITLE", "List branches"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, +func ListBranches(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "list_branches", + Description: t("TOOL_LIST_BRANCHES_DESCRIPTION", "List branches in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_BRANCHES_USER_TITLE", "List branches"), + ReadOnlyHint: true, }, - Required: []string{"owner", "repo"}, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.BranchListOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - }, - } + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + }, + Required: []string{"owner", "repo"}, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.BranchListOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } - branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list branches", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + branches, resp, err := client.Repositories.ListBranches(ctx, owner, repo, opts) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list branches", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list branches: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Convert to minimal branches - minimalBranches := make([]MinimalBranch, 0, len(branches)) - for _, branch := range branches { - minimalBranches = append(minimalBranches, convertToMinimalBranch(branch)) - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list branches", resp, body), nil, nil + } - r, err := json.Marshal(minimalBranches) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Convert to minimal branches + minimalBranches := make([]MinimalBranch, 0, len(branches)) + for _, branch := range branches { + minimalBranches = append(minimalBranches, convertToMinimalBranch(branch)) + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(minimalBranches) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_or_update_file", - Description: t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", "Create or update a single file in a GitHub repository. If updating, you must provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_OR_UPDATE_FILE_USER_TITLE", "Create or update file"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "path": { - Type: "string", - Description: "Path where to create/update the file", - }, - "content": { - Type: "string", - Description: "Content of the file", - }, - "message": { - Type: "string", - Description: "Commit message", - }, - "branch": { - Type: "string", - Description: "Branch to create/update the file in", - }, - "sha": { - Type: "string", - Description: "Required if updating an existing file. The blob SHA of the file being replaced.", +func CreateOrUpdateFile(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "create_or_update_file", + Description: t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", `Create or update a single file in a GitHub repository. +If updating, you should provide the SHA of the file you want to update. Use this tool to create or update a file in a GitHub repository remotely; do not use it for local file operations. + +In order to obtain the SHA of original file version before updating, use the following git command: +git ls-tree HEAD + +If the SHA is not provided, the tool will attempt to acquire it by fetching the current file contents from the repository, which may lead to rewriting latest committed changes if the file has changed since last retrieval. +`), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_OR_UPDATE_FILE_USER_TITLE", "Create or update file"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "path": { + Type: "string", + Description: "Path where to create/update the file", + }, + "content": { + Type: "string", + Description: "Content of the file", + }, + "message": { + Type: "string", + Description: "Commit message", + }, + "branch": { + Type: "string", + Description: "Branch to create/update the file in", + }, + "sha": { + Type: "string", + Description: "The blob SHA of the file being replaced.", + }, }, + Required: []string{"owner", "repo", "path", "content", "message", "branch"}, }, - Required: []string{"owner", "repo", "path", "content", "message", "branch"}, }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + path, err := RequiredParam[string](args, "path") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + content, err := RequiredParam[string](args, "content") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + message, err := RequiredParam[string](args, "message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path, err := RequiredParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - content, err := RequiredParam[string](args, "content") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - message, err := RequiredParam[string](args, "message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // json.Marshal encodes byte arrays with base64, which is required for the API. + contentBytes := []byte(content) - // json.Marshal encodes byte arrays with base64, which is required for the API. - contentBytes := []byte(content) + // Create the file options + opts := &github.RepositoryContentFileOptions{ + Message: github.Ptr(message), + Content: contentBytes, + Branch: github.Ptr(branch), + } - // Create the file options - opts := &github.RepositoryContentFileOptions{ - Message: github.Ptr(message), - Content: contentBytes, - Branch: github.Ptr(branch), - } + // If SHA is provided, set it (for updates) + sha, err := OptionalParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + if sha != "" { + opts.SHA = github.Ptr(sha) + } - // If SHA is provided, set it (for updates) - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - if sha != "" { - opts.SHA = github.Ptr(sha) - } + // Create or update the file + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - // Create or update the file - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + path = strings.TrimPrefix(path, "/") - path = strings.TrimPrefix(path, "/") - fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create/update file", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // SHA validation using conditional HEAD request (efficient - no body transfer) + var previousSHA string + contentURL := fmt.Sprintf("repos/%s/%s/contents/%s", owner, repo, url.PathEscape(path)) + if branch != "" { + contentURL += "?ref=" + url.QueryEscape(branch) + } + + if sha != "" { + // User provided SHA - validate it's still current + req, err := client.NewRequest("HEAD", contentURL, nil) + if err == nil { + req.Header.Set("If-None-Match", fmt.Sprintf(`"%s"`, sha)) + resp, _ := client.Do(ctx, req, nil) + if resp != nil { + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusNotModified: + // SHA matches current - proceed + opts.SHA = github.Ptr(sha) + case http.StatusOK: + // SHA is stale - reject with current SHA so user can check diff + currentSHA := strings.Trim(resp.Header.Get("ETag"), `"`) + return utils.NewToolResultError(fmt.Sprintf( + "SHA mismatch: provided SHA %s is stale. Current file SHA is %s. "+ + "Use get_file_contents or compare commits to review changes before updating.", + sha, currentSHA)), nil, nil + case http.StatusNotFound: + // File doesn't exist - this is a create, ignore provided SHA + } + } + } + } else { + // No SHA provided - check if file exists to warn about blind update + req, err := client.NewRequest("HEAD", contentURL, nil) + if err == nil { + resp, _ := client.Do(ctx, req, nil) + if resp != nil { + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + previousSHA = strings.Trim(resp.Header.Get("ETag"), `"`) + } + // 404 = new file, no previous SHA needed + } + } + } + + if previousSHA != "" { + opts.SHA = github.Ptr(previousSHA) + } - if resp.StatusCode != 200 && resp.StatusCode != 201 { - body, err := io.ReadAll(resp.Body) + fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create/update file", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create/update file: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(fileContent) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if resp.StatusCode != 200 && resp.StatusCode != 201 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create/update file", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(fileContent) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } + + // Warn if file was updated without SHA validation (blind update) + if sha == "" && previousSHA != "" { + return utils.NewToolResultText(fmt.Sprintf( + "Warning: File updated without SHA validation. Previous file SHA was %s. "+ + `Verify no unintended changes were overwritten: +1. Extract the SHA of the local version using git ls-tree HEAD %s. +2. Compare with the previous SHA above. +3. Revert changes if shas do not match. + +%s`, + previousSHA, path, string(r))), nil, nil + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // CreateRepository creates a tool to create a new GitHub repository. -func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_repository", - Description: t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account or specified organization"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_REPOSITORY_USER_TITLE", "Create repository"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "name": { - Type: "string", - Description: "Repository name", - }, - "description": { - Type: "string", - Description: "Repository description", - }, - "organization": { - Type: "string", - Description: "Organization to create the repository in (omit to create in your personal account)", - }, - "private": { - Type: "boolean", - Description: "Whether repo should be private", - }, - "autoInit": { - Type: "boolean", - Description: "Initialize with README", +func CreateRepository(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "create_repository", + Description: t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account or specified organization"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_REPOSITORY_USER_TITLE", "Create repository"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": { + Type: "string", + Description: "Repository name", + }, + "description": { + Type: "string", + Description: "Repository description", + }, + "organization": { + Type: "string", + Description: "Organization to create the repository in (omit to create in your personal account)", + }, + "private": { + Type: "boolean", + Description: "Whether repo should be private", + }, + "autoInit": { + Type: "boolean", + Description: "Initialize with README", + }, }, + Required: []string{"name"}, }, - Required: []string{"name"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - name, err := RequiredParam[string](args, "name") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - description, err := OptionalParam[string](args, "description") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - organization, err := OptionalParam[string](args, "organization") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - private, err := OptionalParam[bool](args, "private") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - autoInit, err := OptionalParam[bool](args, "autoInit") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - repo := &github.Repository{ - Name: github.Ptr(name), - Description: github.Ptr(description), - Private: github.Ptr(private), - AutoInit: github.Ptr(autoInit), - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + name, err := RequiredParam[string](args, "name") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + description, err := OptionalParam[string](args, "description") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + organization, err := OptionalParam[string](args, "organization") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + private, err := OptionalParam[bool](args, "private") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + autoInit, err := OptionalParam[bool](args, "autoInit") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - createdRepo, resp, err := client.Repositories.Create(ctx, organization, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create repository", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + repo := &github.Repository{ + Name: github.Ptr(name), + Description: github.Ptr(description), + Private: github.Ptr(private), + AutoInit: github.Ptr(autoInit), + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to create repository: %s", string(body))), nil, nil - } + createdRepo, resp, err := client.Repositories.Create(ctx, organization, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create repository", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", createdRepo.GetID()), - URL: createdRepo.GetHTMLURL(), - } + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create repository", resp, body), nil, nil + } - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", createdRepo.GetID()), + URL: createdRepo.GetHTMLURL(), + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(minimalResponse) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_file_contents", - Description: t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_FILE_CONTENTS_USER_TITLE", "Get file or directory contents"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "path": { - Type: "string", - Description: "Path to file/directory", - Default: json.RawMessage(`"/"`), - }, - "ref": { - Type: "string", - Description: "Accepts optional git refs such as `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head`", - }, - "sha": { - Type: "string", - Description: "Accepts optional commit SHA. If specified, it will be used instead of ref", +func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "get_file_contents", + Description: t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_FILE_CONTENTS_USER_TITLE", "Get file or directory contents"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "path": { + Type: "string", + Description: "Path to file/directory", + Default: json.RawMessage(`"/"`), + }, + "ref": { + Type: "string", + Description: "Accepts optional git refs such as `refs/tags/{tag}`, `refs/heads/{branch}` or `refs/pull/{pr_number}/head`", + }, + "sha": { + Type: "string", + Description: "Accepts optional commit SHA. If specified, it will be used instead of ref", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path, err := RequiredParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - ref, err := OptionalParam[string](args, "ref") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sha, err := OptionalParam[string](args, "sha") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultError("failed to get GitHub client"), nil, nil - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil - } + path, err := OptionalParam[string](args, "path") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + path = strings.TrimPrefix(path, "/") - if rawOpts.SHA != "" { - ref = rawOpts.SHA - } + ref, err := OptionalParam[string](args, "ref") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sha, err := OptionalParam[string](args, "sha") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - var fileSHA string - opts := &github.RepositoryContentGetOptions{Ref: ref} + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultError("failed to get GitHub client"), nil, nil + } - // Always call GitHub Contents API first to get metadata including SHA and determine if it's a file or directory - fileContent, dirContent, respContents, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) - if respContents != nil { - defer func() { _ = respContents.Body.Close() }() - } + rawOpts, err := resolveGitReference(ctx, client, owner, repo, ref, sha) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to resolve git reference: %s", err)), nil, nil + } - // The path does not point to a file or directory. - // Instead let's try to find it in the Git Tree by matching the end of the path. - if err != nil || (fileContent == nil && dirContent == nil) { - return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, 0) - } + if rawOpts.SHA != "" { + ref = rawOpts.SHA + } - if fileContent != nil && fileContent.SHA != nil { - fileSHA = *fileContent.SHA + var fileSHA string + opts := &github.RepositoryContentGetOptions{Ref: ref} - rawClient, err := getRawClient(ctx) - if err != nil { - return utils.NewToolResultError("failed to get GitHub raw content client"), nil, nil + // Always call GitHub Contents API first to get metadata including SHA and determine if it's a file or directory + fileContent, dirContent, respContents, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) + if respContents != nil { + defer func() { _ = respContents.Body.Close() }() } - resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) - if err != nil { - return utils.NewToolResultError("failed to get raw repository content"), nil, nil + + // The path does not point to a file or directory. + // Instead let's try to find it in the Git Tree by matching the end of the path. + if err != nil || (fileContent == nil && dirContent == nil) { + return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, 0) } - defer func() { - _ = resp.Body.Close() - }() - if resp.StatusCode == http.StatusOK { - // If the raw content is found, return it directly - body, err := io.ReadAll(resp.Body) + if fileContent != nil && fileContent.SHA != nil { + fileSHA = *fileContent.SHA + + rawClient, err := deps.GetRawClient(ctx) + if err != nil { + return utils.NewToolResultError("failed to get GitHub raw content client"), nil, nil + } + resp, err := rawClient.GetRawContent(ctx, owner, repo, path, rawOpts) if err != nil { - return utils.NewToolResultError("failed to read response body"), nil, nil + return utils.NewToolResultError("failed to get raw repository content"), nil, nil } - contentType := resp.Header.Get("Content-Type") + defer func() { + _ = resp.Body.Close() + }() - var resourceURI string - switch { - case sha != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, "sha", sha, "contents", path) - if err != nil { - return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) - } - case ref != "": - resourceURI, err = url.JoinPath("repo://", owner, repo, ref, "contents", path) + if resp.StatusCode == http.StatusOK { + // If the raw content is found, return it directly + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) + return utils.NewToolResultError("failed to read response body"), nil, nil } - default: - resourceURI, err = url.JoinPath("repo://", owner, repo, "contents", path) - if err != nil { - return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) + contentType := resp.Header.Get("Content-Type") + + var resourceURI string + switch { + case sha != "": + resourceURI, err = url.JoinPath("repo://", owner, repo, "sha", sha, "contents", path) + if err != nil { + return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) + } + case ref != "": + resourceURI, err = url.JoinPath("repo://", owner, repo, ref, "contents", path) + if err != nil { + return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) + } + default: + resourceURI, err = url.JoinPath("repo://", owner, repo, "contents", path) + if err != nil { + return nil, nil, fmt.Errorf("failed to create resource URI: %w", err) + } } - } - // Determine if content is text or binary - isTextContent := strings.HasPrefix(contentType, "text/") || - contentType == "application/json" || - contentType == "application/xml" || - strings.HasSuffix(contentType, "+json") || - strings.HasSuffix(contentType, "+xml") + // Determine if content is text or binary + isTextContent := strings.HasPrefix(contentType, "text/") || + contentType == "application/json" || + contentType == "application/xml" || + strings.HasSuffix(contentType, "+json") || + strings.HasSuffix(contentType, "+xml") + + if isTextContent { + result := &mcp.ResourceContents{ + URI: resourceURI, + Text: string(body), + MIMEType: contentType, + } + // Include SHA in the result metadata + if fileSHA != "" { + return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil + } + return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil + } - if isTextContent { result := &mcp.ResourceContents{ URI: resourceURI, - Text: string(body), + Blob: body, MIMEType: contentType, } // Include SHA in the result metadata if fileSHA != "" { - return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded text file (SHA: %s)", fileSHA), result), nil, nil + return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil } - return utils.NewToolResultResource("successfully downloaded text file", result), nil, nil + return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil } - result := &mcp.ResourceContents{ - URI: resourceURI, - Blob: body, - MIMEType: contentType, - } - // Include SHA in the result metadata - if fileSHA != "" { - return utils.NewToolResultResource(fmt.Sprintf("successfully downloaded binary file (SHA: %s)", fileSHA), result), nil, nil + // Raw API call failed + return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, resp.StatusCode) + } else if dirContent != nil { + // file content or file SHA is nil which means it's a directory + r, err := json.Marshal(dirContent) + if err != nil { + return utils.NewToolResultError("failed to marshal response"), nil, nil } - return utils.NewToolResultResource("successfully downloaded binary file", result), nil, nil - } - - // Raw API call failed - return matchFiles(ctx, client, owner, repo, ref, path, rawOpts, resp.StatusCode) - } else if dirContent != nil { - // file content or file SHA is nil which means it's a directory - r, err := json.Marshal(dirContent) - if err != nil { - return utils.NewToolResultError("failed to marshal response"), nil, nil + return utils.NewToolResultText(string(r)), nil, nil } - return utils.NewToolResultText(string(r)), nil, nil - } - - return utils.NewToolResultError("failed to get file contents"), nil, nil - }) - return tool, handler + return utils.NewToolResultError("failed to get file contents"), nil, nil + }, + ) } // ForkRepository creates a tool to fork a repository. -func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "fork_repository", - Description: t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_FORK_REPOSITORY_USER_TITLE", "Fork repository"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "organization": { - Type: "string", - Description: "Organization to fork to", +func ForkRepository(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "fork_repository", + Description: t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization"), + Icons: octicons.Icons("repo-forked"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_FORK_REPOSITORY_USER_TITLE", "Fork repository"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "organization": { + Type: "string", + Description: "Organization to fork to", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - org, err := OptionalParam[string](args, "organization") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.RepositoryCreateForkOptions{} - if org != "" { - opts.Organization = org - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + org, err := OptionalParam[string](args, "organization") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) - if err != nil { - // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, - // and it's not a real error. - if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { - return utils.NewToolResultText("Fork is in progress"), nil, nil - } - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to fork repository", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + opts := &github.RepositoryCreateForkOptions{} + if org != "" { + opts.Organization = org + } - if resp.StatusCode != http.StatusAccepted { - body, err := io.ReadAll(resp.Body) + client, err := deps.GetClient(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to fork repository: %s", string(body))), nil, nil - } + forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) + if err != nil { + // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, + // and it's not a real error. + if resp != nil && resp.StatusCode == http.StatusAccepted && isAcceptedError(err) { + return utils.NewToolResultText("Fork is in progress"), nil, nil + } + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to fork repository", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Return minimal response with just essential information - minimalResponse := MinimalResponse{ - ID: fmt.Sprintf("%d", forkedRepo.GetID()), - URL: forkedRepo.GetHTMLURL(), - } + if resp.StatusCode != http.StatusAccepted { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to fork repository", resp, body), nil, nil + } - r, err := json.Marshal(minimalResponse) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Return minimal response with just essential information + minimalResponse := MinimalResponse{ + ID: fmt.Sprintf("%d", forkedRepo.GetID()), + URL: forkedRepo.GetHTMLURL(), + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(minimalResponse) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // DeleteFile creates a tool to delete a file in a GitHub repository. @@ -816,872 +894,906 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) // unlike how the endpoint backing the create_or_update_files tool does. This appears to be a quirk of the API. // The approach implemented here gets automatic commit signing when used with either the github-actions user or as an app, // both of which suit an LLM well. -func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "delete_file", - Description: t("TOOL_DELETE_FILE_DESCRIPTION", "Delete a file from a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_DELETE_FILE_USER_TITLE", "Delete file"), - ReadOnlyHint: false, - DestructiveHint: github.Ptr(true), - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner (username or organization)", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "path": { - Type: "string", - Description: "Path to the file to delete", - }, - "message": { - Type: "string", - Description: "Commit message", - }, - "branch": { - Type: "string", - Description: "Branch to delete the file from", +func DeleteFile(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "delete_file", + Description: t("TOOL_DELETE_FILE_DESCRIPTION", "Delete a file from a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_DELETE_FILE_USER_TITLE", "Delete file"), + ReadOnlyHint: false, + DestructiveHint: github.Ptr(true), + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner (username or organization)", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "path": { + Type: "string", + Description: "Path to the file to delete", + }, + "message": { + Type: "string", + Description: "Commit message", + }, + "branch": { + Type: "string", + Description: "Branch to delete the file from", + }, }, + Required: []string{"owner", "repo", "path", "message", "branch"}, }, - Required: []string{"owner", "repo", "path", "message", "branch"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - path, err := RequiredParam[string](args, "path") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - message, err := RequiredParam[string](args, "message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + path, err := RequiredParam[string](args, "path") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + message, err := RequiredParam[string](args, "message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // Get the reference for the branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) - if err != nil { - return nil, nil, fmt.Errorf("failed to get branch reference: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - // Get the commit object that the branch points to - baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get base commit", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Get the reference for the branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + if err != nil { + return nil, nil, fmt.Errorf("failed to get branch reference: %w", err) + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Get the commit object that the branch points to + baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get base commit", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get commit: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Create a tree entry for the file deletion by setting SHA to nil - treeEntries := []*github.TreeEntry{ - { - Path: github.Ptr(path), - Mode: github.Ptr("100644"), // Regular file mode - Type: github.Ptr("blob"), - SHA: nil, // Setting SHA to nil deletes the file - }, - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get commit", resp, body), nil, nil + } - // Create a new tree with the deletion - newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create tree", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Create a tree entry for the file deletion by setting SHA to nil + treeEntries := []*github.TreeEntry{ + { + Path: github.Ptr(path), + Mode: github.Ptr("100644"), // Regular file mode + Type: github.Ptr("blob"), + SHA: nil, // Setting SHA to nil deletes the file + }, + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + // Create a new tree with the deletion + newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, treeEntries) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create tree", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create tree: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Create a new commit with the new tree - commit := github.Commit{ - Message: github.Ptr(message), - Tree: newTree, - Parents: []*github.Commit{{SHA: baseCommit.SHA}}, - } - newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create commit", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create tree", resp, body), nil, nil + } - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) + // Create a new commit with the new tree + commit := github.Commit{ + Message: github.Ptr(message), + Tree: newTree, + Parents: []*github.Commit{{SHA: baseCommit.SHA}}, + } + newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create commit", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to create commit: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Update the branch reference to point to the new commit - ref.Object.SHA = newCommit.SHA - _, resp, err = client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ - SHA: *newCommit.SHA, - Force: github.Ptr(false), - }) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to create commit", resp, body), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Update the branch reference to point to the new commit + ref.Object.SHA = newCommit.SHA + _, resp, err = client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ + SHA: *newCommit.SHA, + Force: github.Ptr(false), + }) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update reference", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to update reference: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Create a response similar to what the DeleteFile API would return - response := map[string]interface{}{ - "commit": newCommit, - "content": nil, - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to update reference", resp, body), nil, nil + } - r, err := json.Marshal(response) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + // Create a response similar to what the DeleteFile API would return + response := map[string]interface{}{ + "commit": newCommit, + "content": nil, + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(response) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // CreateBranch creates a tool to create a new branch. -func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "create_branch", - Description: t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_CREATE_BRANCH_USER_TITLE", "Create branch"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "branch": { - Type: "string", - Description: "Name for new branch", - }, - "from_branch": { - Type: "string", - Description: "Source branch (defaults to repo default)", +func CreateBranch(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "create_branch", + Description: t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_CREATE_BRANCH_USER_TITLE", "Create branch"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "branch": { + Type: "string", + Description: "Name for new branch", + }, + "from_branch": { + Type: "string", + Description: "Source branch (defaults to repo default)", + }, }, + Required: []string{"owner", "repo", "branch"}, }, - Required: []string{"owner", "repo", "branch"}, }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + fromBranch, err := OptionalParam[string](args, "from_branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - fromBranch, err := OptionalParam[string](args, "from_branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get the source branch SHA + var ref *github.Reference + + if fromBranch == "" { + // Get default branch if from_branch not specified + repository, resp, err := client.Repositories.Get(ctx, owner, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get repository", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Get the source branch SHA - var ref *github.Reference + fromBranch = *repository.DefaultBranch + } - if fromBranch == "" { - // Get default branch if from_branch not specified - repository, resp, err := client.Repositories.Get(ctx, owner, repo) + // Get SHA of source branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) if err != nil { return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get repository", + "failed to get reference", resp, err, ), nil, nil } defer func() { _ = resp.Body.Close() }() - fromBranch = *repository.DefaultBranch - } - - // Get SHA of source branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+fromBranch) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Create new branch - newRef := github.CreateRef{ - Ref: "refs/heads/" + branch, - SHA: *ref.Object.SHA, - } - - createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create branch", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Create new branch + newRef := github.CreateRef{ + Ref: "refs/heads/" + branch, + SHA: *ref.Object.SHA, + } - r, err := json.Marshal(createdRef) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, newRef) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create branch", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(createdRef) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "push_files", - Description: t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_PUSH_FILES_USER_TITLE", "Push files to repository"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "branch": { - Type: "string", - Description: "Branch to push to", - }, - "files": { - Type: "array", - Description: "Array of file objects to push, each object with path (string) and content (string)", - Items: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "path": { - Type: "string", - Description: "path to the file", - }, - "content": { - Type: "string", - Description: "file content", +func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "push_files", + Description: t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_PUSH_FILES_USER_TITLE", "Push files to repository"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "branch": { + Type: "string", + Description: "Branch to push to", + }, + "files": { + Type: "array", + Description: "Array of file objects to push, each object with path (string) and content (string)", + Items: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "path": { + Type: "string", + Description: "path to the file", + }, + "content": { + Type: "string", + Description: "file content", + }, }, + Required: []string{"path", "content"}, }, - Required: []string{"path", "content"}, + }, + "message": { + Type: "string", + Description: "Commit message", }, }, - "message": { - Type: "string", - Description: "Commit message", - }, + Required: []string{"owner", "repo", "branch", "files", "message"}, }, - Required: []string{"owner", "repo", "branch", "files", "message"}, }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + branch, err := RequiredParam[string](args, "branch") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + message, err := RequiredParam[string](args, "message") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - branch, err := RequiredParam[string](args, "branch") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - message, err := RequiredParam[string](args, "message") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + // Parse files parameter - this should be an array of objects with path and content + filesObj, ok := args["files"].([]interface{}) + if !ok { + return utils.NewToolResultError("files parameter must be an array of objects with path and content"), nil, nil + } - // Parse files parameter - this should be an array of objects with path and content - filesObj, ok := args["files"].([]interface{}) - if !ok { - return utils.NewToolResultError("files parameter must be an array of objects with path and content"), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + // Get the reference for the branch + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get branch reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Get the reference for the branch - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get branch reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Get the commit object that the branch points to + baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get base commit", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - // Get the commit object that the branch points to - baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get base commit", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + // Create tree entries for all files + var entries []*github.TreeEntry - // Create tree entries for all files - var entries []*github.TreeEntry + for _, file := range filesObj { + fileMap, ok := file.(map[string]interface{}) + if !ok { + return utils.NewToolResultError("each file must be an object with path and content"), nil, nil + } - for _, file := range filesObj { - fileMap, ok := file.(map[string]interface{}) - if !ok { - return utils.NewToolResultError("each file must be an object with path and content"), nil, nil + path, ok := fileMap["path"].(string) + if !ok || path == "" { + return utils.NewToolResultError("each file must have a path"), nil, nil + } + + content, ok := fileMap["content"].(string) + if !ok { + return utils.NewToolResultError("each file must have content"), nil, nil + } + + // Create a tree entry for the file + entries = append(entries, &github.TreeEntry{ + Path: github.Ptr(path), + Mode: github.Ptr("100644"), // Regular file mode + Type: github.Ptr("blob"), + Content: github.Ptr(content), + }) } - path, ok := fileMap["path"].(string) - if !ok || path == "" { - return utils.NewToolResultError("each file must have a path"), nil, nil + // Create a new tree with the file entries + newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create tree", + resp, + err, + ), nil, nil } + defer func() { _ = resp.Body.Close() }() - content, ok := fileMap["content"].(string) - if !ok { - return utils.NewToolResultError("each file must have content"), nil, nil + // Create a new commit + commit := github.Commit{ + Message: github.Ptr(message), + Tree: newTree, + Parents: []*github.Commit{{SHA: baseCommit.SHA}}, + } + newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to create commit", + resp, + err, + ), nil, nil } + defer func() { _ = resp.Body.Close() }() - // Create a tree entry for the file - entries = append(entries, &github.TreeEntry{ - Path: github.Ptr(path), - Mode: github.Ptr("100644"), // Regular file mode - Type: github.Ptr("blob"), - Content: github.Ptr(content), + // Update the reference to point to the new commit + ref.Object.SHA = newCommit.SHA + updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ + SHA: *newCommit.SHA, + Force: github.Ptr(false), }) - } - - // Create a new tree with the file entries - newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create tree", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Create a new commit - commit := github.Commit{ - Message: github.Ptr(message), - Tree: newTree, - Parents: []*github.Commit{{SHA: baseCommit.SHA}}, - } - newCommit, resp, err := client.Git.CreateCommit(ctx, owner, repo, commit, nil) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to create commit", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - // Update the reference to point to the new commit - ref.Object.SHA = newCommit.SHA - updatedRef, resp, err := client.Git.UpdateRef(ctx, owner, repo, *ref.Ref, github.UpdateRef{ - SHA: *newCommit.SHA, - Force: github.Ptr(false), - }) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to update reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - r, err := json.Marshal(updatedRef) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to update reference", + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(updatedRef) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // ListTags creates a tool to list tags in a GitHub repository. -func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_tags", - Description: t("TOOL_LIST_TAGS_DESCRIPTION", "List git tags in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_TAGS_USER_TITLE", "List tags"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, +func ListTags(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "list_tags", + Description: t("TOOL_LIST_TAGS_DESCRIPTION", "List git tags in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_TAGS_USER_TITLE", "List tags"), + ReadOnlyHint: true, }, - Required: []string{"owner", "repo"}, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - } + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + }, + Required: []string{"owner", "repo"}, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + } - tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to list tags", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + tags, resp, err := client.Repositories.ListTags(ctx, owner, repo, opts) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to list tags", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list tags: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(tags) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list tags", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(tags) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // GetTag creates a tool to get details about a specific tag in a GitHub repository. -func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_tag", - Description: t("TOOL_GET_TAG_DESCRIPTION", "Get details about a specific git tag in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_TAG_USER_TITLE", "Get tag details"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "tag": { - Type: "string", - Description: "Tag name", +func GetTag(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "get_tag", + Description: t("TOOL_GET_TAG_DESCRIPTION", "Get details about a specific git tag in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_TAG_USER_TITLE", "Get tag details"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "tag": { + Type: "string", + Description: "Tag name", + }, }, + Required: []string{"owner", "repo", "tag"}, }, - Required: []string{"owner", "repo", "tag"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - tag, err := RequiredParam[string](args, "tag") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + tag, err := RequiredParam[string](args, "tag") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - // First get the tag reference - ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get tag reference", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // First get the tag reference + ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/tags/"+tag) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get tag reference", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get tag reference: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Then get the tag object - tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get tag object", - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get tag reference", resp, body), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + // Then get the tag object + tagObj, resp, err := client.Git.GetTag(ctx, owner, repo, *ref.Object.SHA) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get tag object", + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get tag object: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(tagObj) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get tag object", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(tagObj) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // ListReleases creates a tool to list releases in a GitHub repository. -func ListReleases(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_releases", - Description: t("TOOL_LIST_RELEASES_DESCRIPTION", "List releases in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_RELEASES_USER_TITLE", "List releases"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, +func ListReleases(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "list_releases", + Description: t("TOOL_LIST_RELEASES_DESCRIPTION", "List releases in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_RELEASES_USER_TITLE", "List releases"), + ReadOnlyHint: true, }, - Required: []string{"owner", "repo"}, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, - } + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + }, + Required: []string{"owner", "repo"}, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + } - releases, resp, err := client.Repositories.ListReleases(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list releases: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + releases, resp, err := client.Repositories.ListReleases(ctx, owner, repo, opts) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to list releases: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to list releases: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(releases) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list releases", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(releases) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // GetLatestRelease creates a tool to get the latest release in a GitHub repository. -func GetLatestRelease(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_latest_release", - Description: t("TOOL_GET_LATEST_RELEASE_DESCRIPTION", "Get the latest release in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_LATEST_RELEASE_USER_TITLE", "Get latest release"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", +func GetLatestRelease(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "get_latest_release", + Description: t("TOOL_GET_LATEST_RELEASE_DESCRIPTION", "Get the latest release in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_LATEST_RELEASE_USER_TITLE", "Get latest release"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - release, resp, err := client.Repositories.GetLatestRelease(ctx, owner, repo) - if err != nil { - return nil, nil, fmt.Errorf("failed to get latest release: %w", err) - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + release, resp, err := client.Repositories.GetLatestRelease(ctx, owner, repo) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to get latest release: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get latest release: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - r, err := json.Marshal(release) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get latest release", resp, body), nil, nil + } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(release) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } -func GetReleaseByTag(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_release_by_tag", - Description: t("TOOL_GET_RELEASE_BY_TAG_DESCRIPTION", "Get a specific release by its tag name in a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_RELEASE_BY_TAG_USER_TITLE", "Get a release by tag name"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", - }, - "tag": { - Type: "string", - Description: "Tag name (e.g., 'v1.0.0')", +func GetReleaseByTag(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ + Name: "get_release_by_tag", + Description: t("TOOL_GET_RELEASE_BY_TAG_DESCRIPTION", "Get a specific release by its tag name in a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_RELEASE_BY_TAG_USER_TITLE", "Get a release by tag name"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "tag": { + Type: "string", + Description: "Tag name (e.g., 'v1.0.0')", + }, }, + Required: []string{"owner", "repo", "tag"}, }, - Required: []string{"owner", "repo", "tag"}, }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + tag, err := RequiredParam[string](args, "tag") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - tag, err := RequiredParam[string](args, "tag") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + release, resp, err := client.Repositories.GetReleaseByTag(ctx, owner, repo, tag) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to get release by tag: %s", tag), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - release, resp, err := client.Repositories.GetReleaseByTag(ctx, owner, repo, tag) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to get release by tag: %s", tag), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get release by tag", resp, body), nil, nil + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + r, err := json.Marshal(release) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to marshal response: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get release by tag: %s", string(body))), nil, nil - } - r, err := json.Marshal(release) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal response: %w", err) - } + return utils.NewToolResultText(string(r)), nil, nil + }, + ) +} - return utils.NewToolResultText(string(r)), nil, nil - }) +// matchFiles searches for files in the Git tree that match the given path. +// It's used when GetContents fails or returns unexpected results. +func matchFiles(ctx context.Context, client *github.Client, owner, repo, ref, path string, rawOpts *raw.ContentOpts, rawAPIResponseCode int) (*mcp.CallToolResult, any, error) { + // Step 1: Get Git Tree recursively + tree, response, err := client.Git.GetTree(ctx, owner, repo, ref, true) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + "failed to get git tree", + response, + err, + ), nil, nil + } + defer func() { _ = response.Body.Close() }() - return tool, handler + // Step 2: Filter tree for matching paths + const maxMatchingFiles = 3 + matchingFiles := filterPaths(tree.Entries, path, maxMatchingFiles) + if len(matchingFiles) > 0 { + matchingFilesJSON, err := json.Marshal(matchingFiles) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to marshal matching files: %s", err)), nil, nil + } + resolvedRefs, err := json.Marshal(rawOpts) + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("failed to marshal resolved refs: %s", err)), nil, nil + } + if rawAPIResponseCode > 0 { + return utils.NewToolResultText(fmt.Sprintf("Resolved potential matches in the repository tree (resolved refs: %s, matching files: %s), but the content API returned an unexpected status code %d.", string(resolvedRefs), string(matchingFilesJSON), rawAPIResponseCode)), nil, nil + } + return utils.NewToolResultText(fmt.Sprintf("Resolved potential matches in the repository tree (resolved refs: %s, matching files: %s).", string(resolvedRefs), string(matchingFilesJSON))), nil, nil + } + return utils.NewToolResultError("Failed to get file contents. The path does not point to a file or directory, or the file does not exist in the repository."), nil, nil } // filterPaths filters the entries in a GitHub tree to find paths that @@ -1820,292 +1932,261 @@ func resolveGitReference(ctx context.Context, githubClient *github.Client, owner } // ListStarredRepositories creates a tool to list starred repositories for the authenticated user or a specified user. -func ListStarredRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_starred_repositories", - Description: t("TOOL_LIST_STARRED_REPOSITORIES_DESCRIPTION", "List starred repositories"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_STARRED_REPOSITORIES_USER_TITLE", "List starred repositories"), - ReadOnlyHint: true, - }, - InputSchema: WithPagination(&jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "username": { - Type: "string", - Description: "Username to list starred repositories for. Defaults to the authenticated user.", - }, - "sort": { - Type: "string", - Description: "How to sort the results. Can be either 'created' (when the repository was starred) or 'updated' (when the repository was last pushed to).", - Enum: []any{"created", "updated"}, - }, - "direction": { - Type: "string", - Description: "The direction to sort the results by.", - Enum: []any{"asc", "desc"}, - }, - }, - }), - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - username, err := OptionalParam[string](args, "username") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - opts := &github.ActivityListStarredOptions{ - ListOptions: github.ListOptions{ - Page: pagination.Page, - PerPage: pagination.PerPage, +func ListStarredRepositories(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataStargazers, + mcp.Tool{ + Name: "list_starred_repositories", + Description: t("TOOL_LIST_STARRED_REPOSITORIES_DESCRIPTION", "List starred repositories"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_STARRED_REPOSITORIES_USER_TITLE", "List starred repositories"), + ReadOnlyHint: true, }, - } - if sort != "" { - opts.Sort = sort - } - if direction != "" { - opts.Direction = direction - } + InputSchema: WithPagination(&jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "username": { + Type: "string", + Description: "Username to list starred repositories for. Defaults to the authenticated user.", + }, + "sort": { + Type: "string", + Description: "How to sort the results. Can be either 'created' (when the repository was starred) or 'updated' (when the repository was last pushed to).", + Enum: []any{"created", "updated"}, + }, + "direction": { + Type: "string", + Description: "The direction to sort the results by.", + Enum: []any{"asc", "desc"}, + }, + }, + }), + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + username, err := OptionalParam[string](args, "username") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + opts := &github.ActivityListStarredOptions{ + ListOptions: github.ListOptions{ + Page: pagination.Page, + PerPage: pagination.PerPage, + }, + } + if sort != "" { + opts.Sort = sort + } + if direction != "" { + opts.Direction = direction + } - var repos []*github.StarredRepository - var resp *github.Response - if username == "" { - // List starred repositories for the authenticated user - repos, resp, err = client.Activity.ListStarred(ctx, "", opts) - } else { - // List starred repositories for a specific user - repos, resp, err = client.Activity.ListStarred(ctx, username, opts) - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to list starred repositories for user '%s'", username), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + var repos []*github.StarredRepository + var resp *github.Response + if username == "" { + // List starred repositories for the authenticated user + repos, resp, err = client.Activity.ListStarred(ctx, "", opts) + } else { + // List starred repositories for a specific user + repos, resp, err = client.Activity.ListStarred(ctx, username, opts) + } - if resp.StatusCode != 200 { - body, err := io.ReadAll(resp.Body) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to list starred repositories for user '%s'", username), + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to list starred repositories: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - // Convert to minimal format - minimalRepos := make([]MinimalRepository, 0, len(repos)) - for _, starredRepo := range repos { - repo := starredRepo.Repository - minimalRepo := MinimalRepository{ - ID: repo.GetID(), - Name: repo.GetName(), - FullName: repo.GetFullName(), - Description: repo.GetDescription(), - HTMLURL: repo.GetHTMLURL(), - Language: repo.GetLanguage(), - Stars: repo.GetStargazersCount(), - Forks: repo.GetForksCount(), - OpenIssues: repo.GetOpenIssuesCount(), - Private: repo.GetPrivate(), - Fork: repo.GetFork(), - Archived: repo.GetArchived(), - DefaultBranch: repo.GetDefaultBranch(), - } - - if repo.UpdatedAt != nil { - minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") - } - - minimalRepos = append(minimalRepos, minimalRepo) - } + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list starred repositories", resp, body), nil, nil + } - r, err := json.Marshal(minimalRepos) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal starred repositories: %w", err) - } + // Convert to minimal format + minimalRepos := make([]MinimalRepository, 0, len(repos)) + for _, starredRepo := range repos { + repo := starredRepo.Repository + minimalRepo := MinimalRepository{ + ID: repo.GetID(), + Name: repo.GetName(), + FullName: repo.GetFullName(), + Description: repo.GetDescription(), + HTMLURL: repo.GetHTMLURL(), + Language: repo.GetLanguage(), + Stars: repo.GetStargazersCount(), + Forks: repo.GetForksCount(), + OpenIssues: repo.GetOpenIssuesCount(), + Private: repo.GetPrivate(), + Fork: repo.GetFork(), + Archived: repo.GetArchived(), + DefaultBranch: repo.GetDefaultBranch(), + } + + if repo.UpdatedAt != nil { + minimalRepo.UpdatedAt = repo.UpdatedAt.Format("2006-01-02T15:04:05Z") + } - return utils.NewToolResultText(string(r)), nil, nil - }) + minimalRepos = append(minimalRepos, minimalRepo) + } + + r, err := json.Marshal(minimalRepos) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal starred repositories: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } // StarRepository creates a tool to star a repository. -func StarRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "star_repository", - Description: t("TOOL_STAR_REPOSITORY_DESCRIPTION", "Star a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_STAR_REPOSITORY_USER_TITLE", "Star repository"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", +func StarRepository(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataStargazers, + mcp.Tool{ + Name: "star_repository", + Description: t("TOOL_STAR_REPOSITORY_DESCRIPTION", "Star a GitHub repository"), + Icons: octicons.Icons("star-fill"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_STAR_REPOSITORY_USER_TITLE", "Star repository"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - resp, err := client.Activity.Star(ctx, owner, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to star repository %s/%s", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - if resp.StatusCode != 204 { - body, err := io.ReadAll(resp.Body) + resp, err := client.Activity.Star(ctx, owner, repo) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to star repository %s/%s", owner, repo), + resp, + err, + ), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to star repository: %s", string(body))), nil, nil - } + defer func() { _ = resp.Body.Close() }() - return utils.NewToolResultText(fmt.Sprintf("Successfully starred repository %s/%s", owner, repo)), nil, nil - }) + if resp.StatusCode != 204 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to star repository", resp, body), nil, nil + } - return tool, handler + return utils.NewToolResultText(fmt.Sprintf("Successfully starred repository %s/%s", owner, repo)), nil, nil + }, + ) } // UnstarRepository creates a tool to unstar a repository. -func UnstarRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "unstar_repository", - Description: t("TOOL_UNSTAR_REPOSITORY_DESCRIPTION", "Unstar a GitHub repository"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_UNSTAR_REPOSITORY_USER_TITLE", "Unstar repository"), - ReadOnlyHint: false, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "Repository owner", - }, - "repo": { - Type: "string", - Description: "Repository name", +func UnstarRepository(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataStargazers, + mcp.Tool{ + Name: "unstar_repository", + Description: t("TOOL_UNSTAR_REPOSITORY_DESCRIPTION", "Unstar a GitHub repository"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_UNSTAR_REPOSITORY_USER_TITLE", "Unstar repository"), + ReadOnlyHint: false, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"owner", "repo"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - resp, err := client.Activity.Unstar(ctx, owner, repo) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to unstar repository %s/%s", owner, repo), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != 204 { - body, err := io.ReadAll(resp.Body) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to unstar repository: %s", string(body))), nil, nil - } - return utils.NewToolResultText(fmt.Sprintf("Successfully unstarred repository %s/%s", owner, repo)), nil, nil - }) + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - return tool, handler -} + resp, err := client.Activity.Unstar(ctx, owner, repo) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to unstar repository %s/%s", owner, repo), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() -func matchFiles(ctx context.Context, client *github.Client, owner, repo, ref, path string, rawOpts *raw.ContentOpts, rawAPIResponseCode int) (*mcp.CallToolResult, any, error) { - // Step 1: Get Git Tree recursively - tree, response, err := client.Git.GetTree(ctx, owner, repo, ref, true) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get git tree", - response, - err, - ), nil, nil - } - defer func() { _ = response.Body.Close() }() + if resp.StatusCode != 204 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to unstar repository", resp, body), nil, nil + } - // Step 2: Filter tree for matching paths - const maxMatchingFiles = 3 - matchingFiles := filterPaths(tree.Entries, path, maxMatchingFiles) - if len(matchingFiles) > 0 { - matchingFilesJSON, err := json.Marshal(matchingFiles) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to marshal matching files: %s", err)), nil, nil - } - resolvedRefs, err := json.Marshal(rawOpts) - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("failed to marshal resolved refs: %s", err)), nil, nil - } - if rawAPIResponseCode > 0 { - return utils.NewToolResultText(fmt.Sprintf("Resolved potential matches in the repository tree (resolved refs: %s, matching files: %s), but the content API returned an unexpected status code %d.", string(resolvedRefs), string(matchingFilesJSON), rawAPIResponseCode)), nil, nil - } - return utils.NewToolResultText(fmt.Sprintf("Resolved potential matches in the repository tree (resolved refs: %s, matching files: %s).", string(resolvedRefs), string(matchingFilesJSON))), nil, nil - } - return utils.NewToolResultError("Failed to get file contents. The path does not point to a file or directory, or the file does not exist in the repository."), nil, nil + return utils.NewToolResultText(fmt.Sprintf("Successfully unstarred repository %s/%s", owner, repo)), nil, nil + }, + ) } diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 7e76d4230..9d7501f35 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -23,9 +23,8 @@ import ( func Test_GetFileContents(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - mockRawClient := raw.NewClient(mockClient, &url.URL{Scheme: "https", Host: "raw.githubusercontent.com", Path: "/"}) - tool, _ := GetFileContents(stubGetClientFn(mockClient), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper) + serverTool := GetFileContents(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -246,6 +245,51 @@ func Test_GetFileContents(t *testing.T) { expectError: false, expectedResult: mockDirContent, }, + { + name: "successful text content fetch with leading slash in path", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetReposGitRefByOwnerByRepoByRef, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ref": "refs/heads/main", "object": {"sha": ""}}`)) + }), + ), + mock.WithRequestMatchHandler( + mock.GetReposContentsByOwnerByRepoByPath, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + fileContent := &github.RepositoryContent{ + Name: github.Ptr("README.md"), + Path: github.Ptr("README.md"), + SHA: github.Ptr("abc123"), + Type: github.Ptr("file"), + } + contentBytes, _ := json.Marshal(fileContent) + _, _ = w.Write(contentBytes) + }), + ), + mock.WithRequestMatchHandler( + raw.GetRawReposContentsByOwnerByRepoByBranchByPath, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, _ = w.Write(mockRawContent) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "/README.md", + "ref": "refs/heads/main", + }, + expectError: false, + expectedResult: mcp.ResourceContents{ + URI: "repo://owner/repo/refs/heads/main/contents/README.md", + Text: "# Test Repository\n\nThis is a test repository.", + MIMEType: "text/markdown", + }, + }, { name: "content fetch fails", mockedClient: mock.NewMockedHTTPClient( @@ -287,18 +331,23 @@ func Test_GetFileContents(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, &url.URL{Scheme: "https", Host: "raw.example.com", Path: "/"}) - _, handler := GetFileContents(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) + require.NoError(t, err) + textContent := getErrorResult(t, result) + assert.Contains(t, textContent.Text, tc.expectedErrMsg) return } @@ -331,8 +380,8 @@ func Test_GetFileContents(t *testing.T) { func Test_ForkRepository(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ForkRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ForkRepository(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -406,13 +455,16 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ForkRepository(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -436,8 +488,8 @@ func Test_ForkRepository(t *testing.T) { func Test_CreateBranch(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateBranch(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -599,13 +651,16 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateBranch(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -634,8 +689,8 @@ func Test_CreateBranch(t *testing.T) { func Test_GetCommit(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetCommit(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetCommit(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -728,13 +783,16 @@ func Test_GetCommit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetCommit(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -766,8 +824,8 @@ func Test_GetCommit(t *testing.T) { func Test_ListCommits(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListCommits(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListCommits(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -951,13 +1009,16 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCommits(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -999,8 +1060,8 @@ func Test_ListCommits(t *testing.T) { func Test_CreateOrUpdateFile(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateOrUpdateFile(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateOrUpdateFile(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -1121,19 +1182,198 @@ func Test_CreateOrUpdateFile(t *testing.T) { expectError: true, expectedErrMsg: "failed to create/update file", }, + { + name: "sha validation - current sha matches (304 Not Modified)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/contents/docs/example.md", + Method: "HEAD", + }, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Verify If-None-Match header is set correctly + ifNoneMatch := req.Header.Get("If-None-Match") + if ifNoneMatch == `"abc123def456"` { + w.WriteHeader(http.StatusNotModified) + } else { + w.WriteHeader(http.StatusOK) + w.Header().Set("ETag", `"abc123def456"`) + } + }), + ), + mock.WithRequestMatchHandler( + mock.PutReposContentsByOwnerByRepoByPath, + expectRequestBody(t, map[string]interface{}{ + "message": "Update example file", + "content": "IyBVcGRhdGVkIEV4YW1wbGUKClRoaXMgZmlsZSBoYXMgYmVlbiB1cGRhdGVkLg==", + "branch": "main", + "sha": "abc123def456", + }).andThen( + mockResponse(t, http.StatusOK, mockFileResponse), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Updated Example\n\nThis file has been updated.", + "message": "Update example file", + "branch": "main", + "sha": "abc123def456", + }, + expectError: false, + expectedContent: mockFileResponse, + }, + { + name: "sha validation - stale sha detected (200 OK with different ETag)", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/contents/docs/example.md", + Method: "HEAD", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // SHA doesn't match - return 200 with current ETag + w.Header().Set("ETag", `"newsha999888"`) + w.WriteHeader(http.StatusOK) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Updated Example\n\nThis file has been updated.", + "message": "Update example file", + "branch": "main", + "sha": "oldsha123456", + }, + expectError: true, + expectedErrMsg: "SHA mismatch: provided SHA oldsha123456 is stale. Current file SHA is newsha999888", + }, + { + name: "sha validation - file doesn't exist (404), proceed with create", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/contents/docs/example.md", + Method: "HEAD", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + }), + ), + mock.WithRequestMatchHandler( + mock.PutReposContentsByOwnerByRepoByPath, + expectRequestBody(t, map[string]interface{}{ + "message": "Create new file", + "content": "IyBOZXcgRmlsZQoKVGhpcyBpcyBhIG5ldyBmaWxlLg==", + "branch": "main", + "sha": "ignoredsha", // SHA is sent but GitHub API ignores it for new files + }).andThen( + mockResponse(t, http.StatusCreated, mockFileResponse), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# New File\n\nThis is a new file.", + "message": "Create new file", + "branch": "main", + "sha": "ignoredsha", + }, + expectError: false, + expectedContent: mockFileResponse, + }, + { + name: "no sha provided - file exists, returns warning", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/contents/docs/example.md", + Method: "HEAD", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("ETag", `"existing123"`) + w.WriteHeader(http.StatusOK) + }), + ), + mock.WithRequestMatchHandler( + mock.PutReposContentsByOwnerByRepoByPath, + expectRequestBody(t, map[string]interface{}{ + "message": "Update without SHA", + "content": "IyBVcGRhdGVkCgpVcGRhdGVkIHdpdGhvdXQgU0hBLg==", + "branch": "main", + "sha": "existing123", // SHA is automatically added from ETag + }).andThen( + mockResponse(t, http.StatusOK, mockFileResponse), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# Updated\n\nUpdated without SHA.", + "message": "Update without SHA", + "branch": "main", + }, + expectError: false, + expectedErrMsg: "Warning: File updated without SHA validation. Previous file SHA was existing123", + }, + { + name: "no sha provided - file doesn't exist, no warning", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.EndpointPattern{ + Pattern: "/repos/owner/repo/contents/docs/example.md", + Method: "HEAD", + }, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + }), + ), + mock.WithRequestMatchHandler( + mock.PutReposContentsByOwnerByRepoByPath, + expectRequestBody(t, map[string]interface{}{ + "message": "Create new file", + "content": "IyBOZXcgRmlsZQoKQ3JlYXRlZCB3aXRob3V0IFNIQQ==", + "branch": "main", + }).andThen( + mockResponse(t, http.StatusCreated, mockFileResponse), + ), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "path": "docs/example.md", + "content": "# New File\n\nCreated without SHA", + "message": "Create new file", + "branch": "main", + }, + expectError: false, + expectedContent: mockFileResponse, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateOrUpdateFile(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1150,6 +1390,12 @@ func Test_CreateOrUpdateFile(t *testing.T) { // Parse the result and get the text content if no error textContent := getTextResult(t, result) + // If expectedErrMsg is set (but expectError is false), this is a warning case + if tc.expectedErrMsg != "" { + assert.Contains(t, textContent.Text, tc.expectedErrMsg) + return + } + // Unmarshal and verify the result var returnedContent github.RepositoryContentResponse err = json.Unmarshal([]byte(textContent.Text), &returnedContent) @@ -1169,8 +1415,8 @@ func Test_CreateOrUpdateFile(t *testing.T) { func Test_CreateRepository(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreateRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := CreateRepository(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -1310,13 +1556,16 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateRepository(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1346,8 +1595,8 @@ func Test_CreateRepository(t *testing.T) { func Test_PushFiles(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := PushFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := PushFiles(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -1646,13 +1895,16 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PushFiles(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1690,8 +1942,8 @@ func Test_PushFiles(t *testing.T) { func Test_ListBranches(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListBranches(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListBranches(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -1764,17 +2016,21 @@ func Test_ListBranches(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create mock client mockClient := github.NewClient(mock.NewMockedHTTPClient(tt.mockResponses...)) - _, handler := ListBranches(stubGetClientFn(mockClient), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: mockClient, + } + handler := serverTool.Handler(deps) // Create request request := createMCPRequest(tt.args) // Call handler - result, _, err := handler(context.Background(), &request, tt.args) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tt.wantErr { - require.Error(t, err) + require.NoError(t, err) if tt.errContains != "" { - assert.Contains(t, err.Error(), tt.errContains) + textContent := getErrorResult(t, result) + assert.Contains(t, textContent.Text, tt.errContains) } return } @@ -1804,8 +2060,8 @@ func Test_ListBranches(t *testing.T) { func Test_DeleteFile(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := DeleteFile(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := DeleteFile(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -1948,13 +2204,16 @@ func Test_DeleteFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := DeleteFile(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -1985,8 +2244,8 @@ func Test_DeleteFile(t *testing.T) { func Test_ListTags(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListTags(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListTags(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -2072,13 +2331,16 @@ func Test_ListTags(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListTags(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2112,8 +2374,8 @@ func Test_ListTags(t *testing.T) { func Test_GetTag(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetTag(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetTag(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -2229,13 +2491,16 @@ func Test_GetTag(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetTag(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -2267,8 +2532,8 @@ func Test_GetTag(t *testing.T) { } func Test_ListReleases(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := ListReleases(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListReleases(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -2339,9 +2604,12 @@ func Test_ListReleases(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListReleases(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -2362,8 +2630,8 @@ func Test_ListReleases(t *testing.T) { } } func Test_GetLatestRelease(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := GetLatestRelease(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetLatestRelease(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -2427,9 +2695,12 @@ func Test_GetLatestRelease(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := GetLatestRelease(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -2448,8 +2719,8 @@ func Test_GetLatestRelease(t *testing.T) { } func Test_GetReleaseByTag(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := GetReleaseByTag(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := GetReleaseByTag(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -2572,11 +2843,14 @@ func Test_GetReleaseByTag(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := GetReleaseByTag(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -2960,8 +3234,8 @@ func Test_resolveGitReference(t *testing.T) { func Test_ListStarredRepositories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListStarredRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := ListStarredRepositories(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -3081,13 +3355,16 @@ func Test_ListStarredRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListStarredRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3119,8 +3396,8 @@ func Test_ListStarredRepositories(t *testing.T) { func Test_StarRepository(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := StarRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := StarRepository(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -3179,13 +3456,16 @@ func Test_StarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := StarRepository(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3207,8 +3487,8 @@ func Test_StarRepository(t *testing.T) { func Test_UnstarRepository(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := UnstarRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := UnstarRepository(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) schema, ok := tool.InputSchema.(*jsonschema.Schema) @@ -3267,13 +3547,16 @@ func Test_UnstarRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UnstarRepository(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -3292,181 +3575,3 @@ func Test_UnstarRepository(t *testing.T) { }) } } - -func Test_RepositoriesGetRepositoryTree(t *testing.T) { - // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := GetRepositoryTree(stubGetClientFn(mockClient), translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) - - schema, ok := tool.InputSchema.(*jsonschema.Schema) - require.True(t, ok, "InputSchema should be *jsonschema.Schema") - - assert.Equal(t, "get_repository_tree", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, schema.Properties, "owner") - assert.Contains(t, schema.Properties, "repo") - assert.Contains(t, schema.Properties, "tree_sha") - assert.Contains(t, schema.Properties, "recursive") - assert.Contains(t, schema.Properties, "path_filter") - assert.ElementsMatch(t, schema.Required, []string{"owner", "repo"}) - - // Setup mock data - mockRepo := &github.Repository{ - DefaultBranch: github.Ptr("main"), - } - mockTree := &github.Tree{ - SHA: github.Ptr("abc123"), - Truncated: github.Ptr(false), - Entries: []*github.TreeEntry{ - { - Path: github.Ptr("README.md"), - Mode: github.Ptr("100644"), - Type: github.Ptr("blob"), - SHA: github.Ptr("file1sha"), - Size: github.Ptr(123), - URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file1sha"), - }, - { - Path: github.Ptr("src/main.go"), - Mode: github.Ptr("100644"), - Type: github.Ptr("blob"), - SHA: github.Ptr("file2sha"), - Size: github.Ptr(456), - URL: github.Ptr("https://api.github.com/repos/owner/repo/git/blobs/file2sha"), - }, - }, - } - - tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedErrMsg string - }{ - { - name: "successfully get repository tree", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - }, - }, - { - name: "successfully get repository tree with path filter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - mockResponse(t, http.StatusOK, mockTree), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "path_filter": "src/", - }, - }, - { - name: "repository not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "nonexistent", - }, - expectError: true, - expectedErrMsg: "failed to get repository info", - }, - { - name: "tree not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposByOwnerByRepo, - mockResponse(t, http.StatusOK, mockRepo), - ), - mock.WithRequestMatchHandler( - mock.GetReposGitTreesByOwnerByRepoByTreeSha, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - }, - expectError: true, - expectedErrMsg: "failed to get repository tree", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - _, handler := GetRepositoryTree(stubGetClientFromHTTPFn(tc.mockedClient), translations.NullTranslationHelper) - - // Create the tool request - request := createMCPRequest(tc.requestArgs) - - result, _, err := handler(context.Background(), &request, tc.requestArgs) - - if tc.expectError { - require.NoError(t, err) - require.True(t, result.IsError) - errorContent := getErrorResult(t, result) - assert.Contains(t, errorContent.Text, tc.expectedErrMsg) - } else { - require.NoError(t, err) - require.False(t, result.IsError) - - // Parse the result and get the text content - textContent := getTextResult(t, result) - - // Parse the JSON response - var treeResponse map[string]interface{} - err := json.Unmarshal([]byte(textContent.Text), &treeResponse) - require.NoError(t, err) - - // Verify response structure - assert.Equal(t, "owner", treeResponse["owner"]) - assert.Equal(t, "repo", treeResponse["repo"]) - assert.Contains(t, treeResponse, "tree") - assert.Contains(t, treeResponse, "count") - assert.Contains(t, treeResponse, "sha") - assert.Contains(t, treeResponse, "truncated") - - // Check filtering if path_filter was provided - if pathFilter, exists := tc.requestArgs["path_filter"]; exists { - tree := treeResponse["tree"].([]interface{}) - for _, entry := range tree { - entryMap := entry.(map[string]interface{}) - path := entryMap["path"].(string) - assert.True(t, strings.HasPrefix(path, pathFilter.(string)), - "Path %s should start with filter %s", path, pathFilter) - } - } - } - }) - } -} diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 5dea9f4e9..ee43e9d04 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -13,6 +13,8 @@ import ( "strconv" "strings" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/raw" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" @@ -28,58 +30,86 @@ var ( repositoryResourcePrContentURITemplate = uritemplate.MustNew("repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}") ) -// GetRepositoryResourceContent defines the resource template and handler for getting repository content. -func GetRepositoryResourceContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceContent defines the resource template for getting repository content. +func GetRepositoryResourceContent(t translations.TranslationHelperFunc) inventory.ServerResourceTemplate { + return inventory.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content", - URITemplate: repositoryResourceContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_DESCRIPTION", "Repository Content"), + Icons: octicons.Icons("repo"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceContentURITemplate), + ) } -// GetRepositoryResourceBranchContent defines the resource template and handler for getting repository content for a branch. -func GetRepositoryResourceBranchContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceBranchContent defines the resource template for getting repository content for a branch. +func GetRepositoryResourceBranchContent(t translations.TranslationHelperFunc) inventory.ServerResourceTemplate { + return inventory.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_branch", - URITemplate: repositoryResourceBranchContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceBranchContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_BRANCH_DESCRIPTION", "Repository Content for specific branch"), + Icons: octicons.Icons("git-branch"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceBranchContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceBranchContentURITemplate), + ) } -// GetRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. -func GetRepositoryResourceCommitContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceCommitContent defines the resource template for getting repository content for a commit. +func GetRepositoryResourceCommitContent(t translations.TranslationHelperFunc) inventory.ServerResourceTemplate { + return inventory.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_commit", - URITemplate: repositoryResourceCommitContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceCommitContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_COMMIT_DESCRIPTION", "Repository Content for specific commit"), + Icons: octicons.Icons("git-commit"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceCommitContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceCommitContentURITemplate), + ) } -// GetRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. -func GetRepositoryResourceTagContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourceTagContent defines the resource template for getting repository content for a tag. +func GetRepositoryResourceTagContent(t translations.TranslationHelperFunc) inventory.ServerResourceTemplate { + return inventory.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_tag", - URITemplate: repositoryResourceTagContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourceTagContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_TAG_DESCRIPTION", "Repository Content for specific tag"), + Icons: octicons.Icons("tag"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourceTagContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourceTagContentURITemplate), + ) } -// GetRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. -func GetRepositoryResourcePrContent(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, mcp.ResourceHandler) { - return mcp.ResourceTemplate{ +// GetRepositoryResourcePrContent defines the resource template for getting repository content for a pull request. +func GetRepositoryResourcePrContent(t translations.TranslationHelperFunc) inventory.ServerResourceTemplate { + return inventory.NewServerResourceTemplate( + ToolsetMetadataRepos, + mcp.ResourceTemplate{ Name: "repository_content_pr", - URITemplate: repositoryResourcePrContentURITemplate.Raw(), // Resource template + URITemplate: repositoryResourcePrContentURITemplate.Raw(), Description: t("RESOURCE_REPOSITORY_CONTENT_PR_DESCRIPTION", "Repository Content for specific pull request"), + Icons: octicons.Icons("git-pull-request"), }, - RepositoryResourceContentsHandler(getClient, getRawClient, repositoryResourcePrContentURITemplate) + repositoryResourceContentsHandlerFunc(repositoryResourcePrContentURITemplate), + ) +} + +// repositoryResourceContentsHandlerFunc returns a ResourceHandlerFunc that creates handlers on-demand. +func repositoryResourceContentsHandlerFunc(resourceURITemplate *uritemplate.Template) inventory.ResourceHandlerFunc { + return func(deps any) mcp.ResourceHandler { + d := deps.(ToolDependencies) + return RepositoryResourceContentsHandler(d, resourceURITemplate) + } } // RepositoryResourceContentsHandler returns a handler function for repository content requests. -func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.GetRawClientFn, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { +func RepositoryResourceContentsHandler(deps ToolDependencies, resourceURITemplate *uritemplate.Template) mcp.ResourceHandler { return func(ctx context.Context, request *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { // Match the URI to extract parameters uriValues := resourceURITemplate.Match(request.Params.URI) @@ -133,7 +163,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G prNumber := uriValues.Get("prNumber").String() if prNumber != "" { // fetch the PR from the API to get the latest commit and use SHA - githubClient, err := getClient(ctx) + githubClient, err := deps.GetClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -153,7 +183,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn, getRawClient raw.G if path == "" || strings.HasSuffix(path, "/") { return nil, fmt.Errorf("directories are not supported: %s", path) } - rawClient, err := getRawClient(ctx) + rawClient, err := deps.GetRawClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get GitHub raw content client: %w", err) diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index 113f46d89..b55b821af 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -7,9 +7,7 @@ import ( "testing" "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/require" ) @@ -28,67 +26,55 @@ func Test_repositoryResourceContents(t *testing.T) { name string mockedClient *http.Client uri string - handlerFn func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler + handlerFn func(deps ToolDependencies) mcp.ResourceHandler expectedResponseType resourceResponseType expectError string expectedResult *mcp.ReadResourceResult }{ { name: "missing owner", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo:///repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "owner is required", }, { name: "missing repo", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByBranchByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByBranchByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner//refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceBranchContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "repo is required", }, { name: "successful blob content fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "image/png") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "image/png") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/contents/data.png", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeBlob, expectedResult: &mcp.ReadResourceResult{ @@ -100,20 +86,16 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "successful text content fetch (HEAD)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -125,22 +107,18 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "successful text content fetch (HEAD)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByPath, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByPath: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") - require.Contains(t, r.URL.Path, "pkg/github/actions.go") - _, err := w.Write([]byte("package actions\n\nfunc main() {\n // Sample Go file content\n}\n")) - require.NoError(t, err) - }), - ), - ), + require.Contains(t, r.URL.Path, "pkg/github/actions.go") + _, err := w.Write([]byte("package actions\n\nfunc main() {\n // Sample Go file content\n}\n")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/contents/pkg/github/actions.go", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -152,20 +130,16 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "successful text content fetch (branch)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByBranchByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByBranchByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/refs/heads/main/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceBranchContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceBranchContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -177,20 +151,16 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "successful text content fetch (tag)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoByTagByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoByTagByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/refs/tags/v1.0.0/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceTagContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceTagContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -202,20 +172,16 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "successful text content fetch (sha)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoBySHAByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetRawReposContentsByOwnerByRepoBySHAByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/sha/abc123/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceCommitContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceCommitContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -227,28 +193,21 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "successful text content fetch (pr)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposPullsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, err := w.Write([]byte(`{"head": {"sha": "abc123"}}`)) - require.NoError(t, err) - }), - ), - mock.WithRequestMatchHandler( - raw.GetRawReposContentsByOwnerByRepoBySHAByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/markdown") - _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) - require.NoError(t, err) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{"head": {"sha": "abc123"}}`)) + require.NoError(t, err) + }), + GetRawReposContentsByOwnerByRepoBySHAByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/markdown") + _, err := w.Write([]byte("# Test Repository\n\nThis is a test repository.")) + require.NoError(t, err) + }), + }), uri: "repo://owner/repo/refs/pull/42/head/contents/README.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourcePrContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourcePrContentURITemplate) }, expectedResponseType: resourceResponseTypeText, expectedResult: &mcp.ReadResourceResult{ @@ -260,19 +219,15 @@ func Test_repositoryResourceContents(t *testing.T) { }, { name: "content fetch fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposContentsByOwnerByRepoByPath, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposContentsByOwnerByRepoByPath: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), uri: "repo://owner/repo/contents/nonexistent.md", - handlerFn: func(getClient GetClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc) mcp.ResourceHandler { - _, handler := GetRepositoryResourceContent(getClient, getRawClient, t) - return handler + handlerFn: func(deps ToolDependencies) mcp.ResourceHandler { + return RepositoryResourceContentsHandler(deps, repositoryResourceContentURITemplate) }, expectedResponseType: resourceResponseTypeText, // Ignored as error is expected expectError: "404 Not Found", @@ -283,7 +238,11 @@ func Test_repositoryResourceContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) mockRawClient := raw.NewClient(client, base) - handler := tc.handlerFn(stubGetClientFn(client), stubGetRawClientFn(mockRawClient), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + RawClient: mockRawClient, + } + handler := tc.handlerFn(deps) request := &mcp.ReadResourceRequest{ Params: &mcp.ReadResourceParams{ diff --git a/pkg/github/resources.go b/pkg/github/resources.go new file mode 100644 index 000000000..2db7cac55 --- /dev/null +++ b/pkg/github/resources.go @@ -0,0 +1,19 @@ +package github + +import ( + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/translations" +) + +// AllResources returns all resource templates with their embedded toolset metadata. +// Resource definitions are stateless - handlers are generated on-demand during registration. +func AllResources(t translations.TranslationHelperFunc) []inventory.ServerResourceTemplate { + return []inventory.ServerResourceTemplate{ + // Repository resources + GetRepositoryResourceContent(t), + GetRepositoryResourceBranchContent(t), + GetRepositoryResourceCommitContent(t), + GetRepositoryResourceTagContent(t), + GetRepositoryResourcePrContent(t), + } +} diff --git a/pkg/github/search.go b/pkg/github/search.go index cffd0bf15..9a8b971e2 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -16,7 +17,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchRepositories(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -44,7 +45,9 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ Name: "search_repositories", Description: t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Find GitHub repositories by name, description, readme, topics, or other metadata. Perfect for discovering projects, finding examples, or locating specific repositories across GitHub."), Annotations: &mcp.ToolAnnotations{ @@ -53,7 +56,7 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { query, err := RequiredParam[string](args, "query") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -83,7 +86,7 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF }, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -102,7 +105,7 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search repositories: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search repositories", resp, body), nil, nil } // Return either minimal or full response based on parameter @@ -157,11 +160,12 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchCode(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -183,7 +187,9 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ + return NewTool( + ToolsetMetadataRepos, + mcp.Tool{ Name: "search_code", Description: t("TOOL_SEARCH_CODE_DESCRIPTION", "Fast and precise code search across ALL GitHub repositories using GitHub's native search engine. Best for finding exact symbols, functions, classes, or specific code patterns."), Annotations: &mcp.ToolAnnotations{ @@ -192,7 +198,7 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, InputSchema: schema, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { query, err := RequiredParam[string](args, "query") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -219,7 +225,7 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc }, } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil } @@ -239,7 +245,7 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc if err != nil { return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to search code: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to search code", resp, body), nil, nil } r, err := json.Marshal(result) @@ -248,99 +254,98 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func userOrOrgHandler(accountType string, getClient GetClientFn) mcp.ToolHandlerFor[map[string]any, any] { - return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - query, err := RequiredParam[string](args, "query") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sort, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - order, err := OptionalParam[string](args, "order") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - pagination, err := OptionalPaginationParams(args) - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } +func userOrOrgHandler(ctx context.Context, accountType string, deps ToolDependencies, args map[string]any) (*mcp.CallToolResult, any, error) { + query, err := RequiredParam[string](args, "query") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sort, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + order, err := OptionalParam[string](args, "order") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pagination, err := OptionalPaginationParams(args) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - opts := &github.SearchOptions{ - Sort: sort, - Order: order, - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } + opts := &github.SearchOptions{ + Sort: sort, + Order: order, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } - client, err := getClient(ctx) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } - searchQuery := query - if !hasTypeFilter(query) { - searchQuery = "type:" + accountType + " " + query - } - result, resp, err := client.Search.Users(ctx, searchQuery, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - fmt.Sprintf("failed to search %ss with query '%s'", accountType, query), - resp, - err, - ), nil, nil - } - defer func() { _ = resp.Body.Close() }() + searchQuery := query + if !hasTypeFilter(query) { + searchQuery = "type:" + accountType + " " + query + } + result, resp, err := client.Search.Users(ctx, searchQuery, opts) + if err != nil { + return ghErrors.NewGitHubAPIErrorResponse(ctx, + fmt.Sprintf("failed to search %ss with query '%s'", accountType, query), + resp, + err, + ), nil, nil + } + defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil - } - return utils.NewToolResultError(fmt.Sprintf("failed to search %ss: %s", accountType, string(body))), nil, nil + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, fmt.Sprintf("failed to search %ss", accountType), resp, body), nil, nil + } - minimalUsers := make([]MinimalUser, 0, len(result.Users)) + minimalUsers := make([]MinimalUser, 0, len(result.Users)) - for _, user := range result.Users { - if user.Login != nil { - mu := MinimalUser{ - Login: user.GetLogin(), - ID: user.GetID(), - ProfileURL: user.GetHTMLURL(), - AvatarURL: user.GetAvatarURL(), - } - minimalUsers = append(minimalUsers, mu) + for _, user := range result.Users { + if user.Login != nil { + mu := MinimalUser{ + Login: user.GetLogin(), + ID: user.GetID(), + ProfileURL: user.GetHTMLURL(), + AvatarURL: user.GetAvatarURL(), } + minimalUsers = append(minimalUsers, mu) } - minimalResp := &MinimalSearchUsersResult{ - TotalCount: result.GetTotal(), - IncompleteResults: result.GetIncompleteResults(), - Items: minimalUsers, - } - if result.Total != nil { - minimalResp.TotalCount = *result.Total - } - if result.IncompleteResults != nil { - minimalResp.IncompleteResults = *result.IncompleteResults - } + } + minimalResp := &MinimalSearchUsersResult{ + TotalCount: result.GetTotal(), + IncompleteResults: result.GetIncompleteResults(), + Items: minimalUsers, + } + if result.Total != nil { + minimalResp.TotalCount = *result.Total + } + if result.IncompleteResults != nil { + minimalResp.IncompleteResults = *result.IncompleteResults + } - r, err := json.Marshal(minimalResp) - if err != nil { - return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil - } - return utils.NewToolResultText(string(r)), nil, nil + r, err := json.Marshal(minimalResp) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to marshal response", err), nil, nil } + return utils.NewToolResultText(string(r)), nil, nil } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchUsers(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -363,19 +368,25 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (m } WithPagination(schema) - return mcp.Tool{ - Name: "search_users", - Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), - ReadOnlyHint: true, + return NewTool( + ToolsetMetadataUsers, + mcp.Tool{ + Name: "search_users", + Description: t("TOOL_SEARCH_USERS_DESCRIPTION", "Find GitHub users by username, real name, or other profile information. Useful for locating developers, contributors, or team members."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"), + ReadOnlyHint: true, + }, + InputSchema: schema, }, - InputSchema: schema, - }, userOrOrgHandler("user", getClient) + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + return userOrOrgHandler(ctx, "user", deps, args) + }, + ) } // SearchOrgs creates a tool to search for GitHub organizations. -func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { +func SearchOrgs(t translations.TranslationHelperFunc) inventory.ServerTool { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -398,13 +409,19 @@ func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (mc } WithPagination(schema) - return mcp.Tool{ - Name: "search_orgs", - Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), - ReadOnlyHint: true, + return NewTool( + ToolsetMetadataOrgs, + mcp.Tool{ + Name: "search_orgs", + Description: t("TOOL_SEARCH_ORGS_DESCRIPTION", "Find GitHub organizations by name, location, or other organization metadata. Ideal for discovering companies, open source foundations, or teams."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + return userOrOrgHandler(ctx, "org", deps, args) }, - InputSchema: schema, - }, userOrOrgHandler("org", getClient) + ) } diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index 0b923edcd..be1b26714 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -17,8 +17,8 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_repositories", tool.Name) @@ -134,13 +134,16 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -205,7 +208,11 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { ) client := github.NewClient(mockedClient) - _, handlerTest := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + serverTool := SearchRepositories(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) args := map[string]interface{}{ "query": "golang test", @@ -214,7 +221,7 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { request := createMCPRequest(args) - result, _, err := handlerTest(context.Background(), &request, args) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) require.NoError(t, err) require.False(t, result.IsError) @@ -236,8 +243,8 @@ func Test_SearchRepositories_FullOutput(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchCode(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_code", tool.Name) @@ -351,13 +358,16 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -394,8 +404,8 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchUsers(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "search_users", tool.Name) @@ -548,13 +558,16 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -592,8 +605,8 @@ func Test_SearchUsers(t *testing.T) { func Test_SearchOrgs(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper) + serverTool := SearchOrgs(translations.NullTranslationHelper) + tool := serverTool.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) @@ -720,13 +733,16 @@ func Test_SearchOrgs(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := serverTool.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { diff --git a/pkg/github/search_utils.go b/pkg/github/search_utils.go index 9f7e41dec..1008200d1 100644 --- a/pkg/github/search_utils.go +++ b/pkg/github/search_utils.go @@ -8,6 +8,7 @@ import ( "net/http" "regexp" + ghErrors "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -104,7 +105,7 @@ func searchHandler( if err != nil { return utils.NewToolResultErrorFromErr(errorPrefix+": failed to read response body", err), nil } - return utils.NewToolResultError(fmt.Sprintf("%s: %s", errorPrefix, string(body))), nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, errorPrefix, resp, body), nil } r, err := json.Marshal(result) diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 297e1ebfe..0de5166ba 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -8,6 +8,7 @@ import ( "net/http" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -15,8 +16,10 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func GetSecretScanningAlert(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataSecretProtection, + mcp.Tool{ Name: "get_secret_scanning_alert", Description: t("TOOL_GET_SECRET_SCANNING_ALERT_DESCRIPTION", "Get details of a specific secret scanning alert in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -42,7 +45,7 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel Required: []string{"owner", "repo", "alertNumber"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -56,7 +59,7 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -76,7 +79,7 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get alert", resp, body), nil, nil } r, err := json.Marshal(alert) @@ -85,11 +88,14 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } -func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - return mcp.Tool{ +func ListSecretScanningAlerts(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataSecretProtection, + mcp.Tool{ Name: "list_secret_scanning_alerts", Description: t("TOOL_LIST_SECRET_SCANNING_ALERTS_DESCRIPTION", "List secret scanning alerts in a GitHub repository."), Annotations: &mcp.ToolAnnotations{ @@ -125,7 +131,7 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH Required: []string{"owner", "repo"}, }, }, - func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { owner, err := RequiredParam[string](args, "owner") if err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -147,7 +153,7 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH return utils.NewToolResultError(err.Error()), nil, nil } - client, err := getClient(ctx) + client, err := deps.GetClient(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -166,7 +172,7 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH if err != nil { return nil, nil, fmt.Errorf("failed to read response body: %w", err) } - return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list alerts", resp, body), nil, nil } r, err := json.Marshal(alerts) @@ -175,5 +181,6 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH } return utils.NewToolResultText(string(r)), nil, nil - } + }, + ) } diff --git a/pkg/github/secret_scanning_test.go b/pkg/github/secret_scanning_test.go index 6eeac1862..ed05d2215 100644 --- a/pkg/github/secret_scanning_test.go +++ b/pkg/github/secret_scanning_test.go @@ -10,22 +10,20 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_GetSecretScanningAlert(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := GetSecretScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := GetSecretScanningAlert(translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "get_secret_scanning_alert", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "get_secret_scanning_alert", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // Verify InputSchema structure - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -49,12 +47,9 @@ func Test_GetSecretScanningAlert(t *testing.T) { }{ { name: "successful alert fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetReposSecretScanningAlertsByOwnerByRepoByAlertNumber, - mockAlert, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecretScanningAlertsByOwnerByRepoByAlertNumber: mockResponse(t, http.StatusOK, mockAlert), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -65,15 +60,12 @@ func Test_GetSecretScanningAlert(t *testing.T) { }, { name: "alert fetch fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposSecretScanningAlertsByOwnerByRepoByAlertNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecretScanningAlertsByOwnerByRepoByAlertNumber: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -88,13 +80,16 @@ func Test_GetSecretScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetSecretScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) // Call handler - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -125,16 +120,15 @@ func Test_GetSecretScanningAlert(t *testing.T) { func Test_ListSecretScanningAlerts(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListSecretScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListSecretScanningAlerts(translations.NullTranslationHelper) - require.NoError(t, toolsnaps.Test(tool.Name, tool)) + require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool)) - assert.Equal(t, "list_secret_scanning_alerts", tool.Name) - assert.NotEmpty(t, tool.Description) + assert.Equal(t, "list_secret_scanning_alerts", toolDef.Tool.Name) + assert.NotEmpty(t, toolDef.Tool.Description) // Verify InputSchema structure - schema, ok := tool.InputSchema.(*jsonschema.Schema) + schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema) require.True(t, ok, "InputSchema should be *jsonschema.Schema") assert.Contains(t, schema.Properties, "owner") assert.Contains(t, schema.Properties, "repo") @@ -169,16 +163,13 @@ func Test_ListSecretScanningAlerts(t *testing.T) { }{ { name: "successful resolved alerts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposSecretScanningAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{ - "state": "resolved", - }).andThen( - mockResponse(t, http.StatusOK, []*github.SecretScanningAlert{&resolvedAlert}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecretScanningAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{ + "state": "resolved", + }).andThen( + mockResponse(t, http.StatusOK, []*github.SecretScanningAlert{&resolvedAlert}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -189,14 +180,11 @@ func Test_ListSecretScanningAlerts(t *testing.T) { }, { name: "successful alerts listing", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposSecretScanningAlertsByOwnerByRepo, - expectQueryParams(t, map[string]string{}).andThen( - mockResponse(t, http.StatusOK, []*github.SecretScanningAlert{&resolvedAlert, &openAlert}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecretScanningAlertsByOwnerByRepo: expectQueryParams(t, map[string]string{}).andThen( + mockResponse(t, http.StatusOK, []*github.SecretScanningAlert{&resolvedAlert, &openAlert}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -206,15 +194,12 @@ func Test_ListSecretScanningAlerts(t *testing.T) { }, { name: "alerts listing fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetReposSecretScanningAlertsByOwnerByRepo, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecretScanningAlertsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Unauthorized access"}`)) + }), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -227,11 +212,14 @@ func Test_ListSecretScanningAlerts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListSecretScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{ + Client: client, + } + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - result, _, err := handler(context.Background(), &request, tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.NoError(t, err) diff --git a/pkg/github/security_advisories.go b/pkg/github/security_advisories.go index 58148a7a3..f898de61d 100644 --- a/pkg/github/security_advisories.go +++ b/pkg/github/security_advisories.go @@ -7,6 +7,8 @@ import ( "io" "net/http" + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" @@ -14,445 +16,445 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func ListGlobalSecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_global_security_advisories", - Description: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_USER_TITLE", "List global security advisories"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "ghsaId": { - Type: "string", - Description: "Filter by GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", - }, - "type": { - Type: "string", - Description: "Advisory type.", - Enum: []any{"reviewed", "malware", "unreviewed"}, - Default: json.RawMessage(`"reviewed"`), - }, - "cveId": { - Type: "string", - Description: "Filter by CVE ID.", - }, - "ecosystem": { - Type: "string", - Description: "Filter by package ecosystem.", - Enum: []any{"actions", "composer", "erlang", "go", "maven", "npm", "nuget", "other", "pip", "pub", "rubygems", "rust"}, - }, - "severity": { - Type: "string", - Description: "Filter by severity.", - Enum: []any{"unknown", "low", "medium", "high", "critical"}, - }, - "cwes": { - Type: "array", - Description: "Filter by Common Weakness Enumeration IDs (e.g. [\"79\", \"284\", \"22\"]).", - Items: &jsonschema.Schema{ - Type: "string", +func ListGlobalSecurityAdvisories(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataSecurityAdvisories, + mcp.Tool{ + Name: "list_global_security_advisories", + Description: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_DESCRIPTION", "List global security advisories from GitHub."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_GLOBAL_SECURITY_ADVISORIES_USER_TITLE", "List global security advisories"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "ghsaId": { + Type: "string", + Description: "Filter by GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", + }, + "type": { + Type: "string", + Description: "Advisory type.", + Enum: []any{"reviewed", "malware", "unreviewed"}, + Default: json.RawMessage(`"reviewed"`), + }, + "cveId": { + Type: "string", + Description: "Filter by CVE ID.", + }, + "ecosystem": { + Type: "string", + Description: "Filter by package ecosystem.", + Enum: []any{"actions", "composer", "erlang", "go", "maven", "npm", "nuget", "other", "pip", "pub", "rubygems", "rust"}, + }, + "severity": { + Type: "string", + Description: "Filter by severity.", + Enum: []any{"unknown", "low", "medium", "high", "critical"}, + }, + "cwes": { + Type: "array", + Description: "Filter by Common Weakness Enumeration IDs (e.g. [\"79\", \"284\", \"22\"]).", + Items: &jsonschema.Schema{ + Type: "string", + }, + }, + "isWithdrawn": { + Type: "boolean", + Description: "Whether to only return withdrawn advisories.", + }, + "affects": { + Type: "string", + Description: "Filter advisories by affected package or version (e.g. \"package1,package2@1.0.0\").", + }, + "published": { + Type: "string", + Description: "Filter by publish date or date range (ISO 8601 date or range).", + }, + "updated": { + Type: "string", + Description: "Filter by update date or date range (ISO 8601 date or range).", + }, + "modified": { + Type: "string", + Description: "Filter by publish or update date or date range (ISO 8601 date or range).", }, - }, - "isWithdrawn": { - Type: "boolean", - Description: "Whether to only return withdrawn advisories.", - }, - "affects": { - Type: "string", - Description: "Filter advisories by affected package or version (e.g. \"package1,package2@1.0.0\").", - }, - "published": { - Type: "string", - Description: "Filter by publish date or date range (ISO 8601 date or range).", - }, - "updated": { - Type: "string", - Description: "Filter by update date or date range (ISO 8601 date or range).", - }, - "modified": { - Type: "string", - Description: "Filter by publish or update date or date range (ISO 8601 date or range).", }, }, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - ghsaID, err := OptionalParam[string](args, "ghsaId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil - } - - typ, err := OptionalParam[string](args, "type") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil, nil - } - - cveID, err := OptionalParam[string](args, "cveId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil, nil - } - - eco, err := OptionalParam[string](args, "ecosystem") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil, nil - } - - sev, err := OptionalParam[string](args, "severity") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil, nil - } - - cwes, err := OptionalStringArrayParam(args, "cwes") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil, nil - } - - isWithdrawn, err := OptionalParam[bool](args, "isWithdrawn") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil, nil - } - - affects, err := OptionalParam[string](args, "affects") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil, nil - } - - published, err := OptionalParam[string](args, "published") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil, nil - } - - updated, err := OptionalParam[string](args, "updated") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil, nil - } - - modified, err := OptionalParam[string](args, "modified") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil, nil - } - - opts := &github.ListGlobalSecurityAdvisoriesOptions{} - - if ghsaID != "" { - opts.GHSAID = &ghsaID - } - if typ != "" { - opts.Type = &typ - } - if cveID != "" { - opts.CVEID = &cveID - } - if eco != "" { - opts.Ecosystem = &eco - } - if sev != "" { - opts.Severity = &sev - } - if len(cwes) > 0 { - opts.CWEs = cwes - } - - if isWithdrawn { - opts.IsWithdrawn = &isWithdrawn - } - - if affects != "" { - opts.Affects = &affects - } - if published != "" { - opts.Published = &published - } - if updated != "" { - opts.Updated = &updated - } - if modified != "" { - opts.Modified = &modified - } - - advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list global security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return utils.NewToolResultError(fmt.Sprintf("failed to list advisories: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisories) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler -} + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } -func ListRepositorySecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_repository_security_advisories", - Description: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub repository."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List repository security advisories"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "owner": { - Type: "string", - Description: "The owner of the repository.", - }, - "repo": { - Type: "string", - Description: "The name of the repository.", - }, - "direction": { - Type: "string", - Description: "Sort direction.", - Enum: []any{"asc", "desc"}, - }, - "sort": { - Type: "string", - Description: "Sort field.", - Enum: []any{"created", "updated", "published"}, - }, - "state": { - Type: "string", - Description: "Filter by advisory state.", - Enum: []any{"triage", "draft", "published", "closed"}, - }, - }, - Required: []string{"owner", "repo"}, + ghsaID, err := OptionalParam[string](args, "ghsaId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil + } + + typ, err := OptionalParam[string](args, "type") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid type: %v", err)), nil, nil + } + + cveID, err := OptionalParam[string](args, "cveId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid cveId: %v", err)), nil, nil + } + + eco, err := OptionalParam[string](args, "ecosystem") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ecosystem: %v", err)), nil, nil + } + + sev, err := OptionalParam[string](args, "severity") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid severity: %v", err)), nil, nil + } + + cwes, err := OptionalStringArrayParam(args, "cwes") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid cwes: %v", err)), nil, nil + } + + isWithdrawn, err := OptionalParam[bool](args, "isWithdrawn") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid isWithdrawn: %v", err)), nil, nil + } + + affects, err := OptionalParam[string](args, "affects") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid affects: %v", err)), nil, nil + } + + published, err := OptionalParam[string](args, "published") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid published: %v", err)), nil, nil + } + + updated, err := OptionalParam[string](args, "updated") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid updated: %v", err)), nil, nil + } + + modified, err := OptionalParam[string](args, "modified") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid modified: %v", err)), nil, nil + } + + opts := &github.ListGlobalSecurityAdvisoriesOptions{} + + if ghsaID != "" { + opts.GHSAID = &ghsaID + } + if typ != "" { + opts.Type = &typ + } + if cveID != "" { + opts.CVEID = &cveID + } + if eco != "" { + opts.Ecosystem = &eco + } + if sev != "" { + opts.Severity = &sev + } + if len(cwes) > 0 { + opts.CWEs = cwes + } + + if isWithdrawn { + opts.IsWithdrawn = &isWithdrawn + } + + if affects != "" { + opts.Affects = &affects + } + if published != "" { + opts.Published = &published + } + if updated != "" { + opts.Updated = &updated + } + if modified != "" { + opts.Modified = &modified + } + + advisories, resp, err := client.SecurityAdvisories.ListGlobalSecurityAdvisories(ctx, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list global security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list advisories", resp, body), nil, nil + } + + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - owner, err := RequiredParam[string](args, "owner") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - repo, err := RequiredParam[string](args, "repo") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sortField, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - opts := &github.ListRepositorySecurityAdvisoriesOptions{} - if direction != "" { - opts.Direction = direction - } - if sortField != "" { - opts.Sort = sortField - } - if state != "" { - opts.State = state - } - - advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisories(ctx, owner, repo, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list repository security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return utils.NewToolResultError(fmt.Sprintf("failed to list repository advisories: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisories) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + ) } -func GetGlobalSecurityAdvisory(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "get_global_security_advisory", - Description: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory"), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_USER_TITLE", "Get a global security advisory"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "ghsaId": { - Type: "string", - Description: "GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", +func ListRepositorySecurityAdvisories(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataSecurityAdvisories, + mcp.Tool{ + Name: "list_repository_security_advisories", + Description: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub repository."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List repository security advisories"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "The owner of the repository.", + }, + "repo": { + Type: "string", + Description: "The name of the repository.", + }, + "direction": { + Type: "string", + Description: "Sort direction.", + Enum: []any{"asc", "desc"}, + }, + "sort": { + Type: "string", + Description: "Sort field.", + Enum: []any{"created", "updated", "published"}, + }, + "state": { + Type: "string", + Description: "Filter by advisory state.", + Enum: []any{"triage", "draft", "published", "closed"}, + }, }, + Required: []string{"owner", "repo"}, }, - Required: []string{"ghsaId"}, }, - } + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sortField, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } - ghsaID, err := RequiredParam[string](args, "ghsaId") - if err != nil { - return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil - } + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } - advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) - if err != nil { - return nil, nil, fmt.Errorf("failed to get advisory: %w", err) - } - defer func() { _ = resp.Body.Close() }() + opts := &github.ListRepositorySecurityAdvisoriesOptions{} + if direction != "" { + opts.Direction = direction + } + if sortField != "" { + opts.Sort = sortField + } + if state != "" { + opts.State = state + } - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) + advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisories(ctx, owner, repo, opts) if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) + return nil, nil, fmt.Errorf("failed to list repository security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list repository advisories", resp, body), nil, nil } - return utils.NewToolResultError(fmt.Sprintf("failed to get advisory: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisory) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisory: %w", err) - } - return utils.NewToolResultText(string(r)), nil, nil - }) + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + } - return tool, handler + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } -func ListOrgRepositorySecurityAdvisories(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) { - tool := mcp.Tool{ - Name: "list_org_repository_security_advisories", - Description: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub organization."), - Annotations: &mcp.ToolAnnotations{ - Title: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List org repository security advisories"), - ReadOnlyHint: true, - }, - InputSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "org": { - Type: "string", - Description: "The organization login.", - }, - "direction": { - Type: "string", - Description: "Sort direction.", - Enum: []any{"asc", "desc"}, - }, - "sort": { - Type: "string", - Description: "Sort field.", - Enum: []any{"created", "updated", "published"}, +func GetGlobalSecurityAdvisory(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataSecurityAdvisories, + mcp.Tool{ + Name: "get_global_security_advisory", + Description: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_DESCRIPTION", "Get a global security advisory"), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_GLOBAL_SECURITY_ADVISORY_USER_TITLE", "Get a global security advisory"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "ghsaId": { + Type: "string", + Description: "GitHub Security Advisory ID (format: GHSA-xxxx-xxxx-xxxx).", + }, }, - "state": { - Type: "string", - Description: "Filter by advisory state.", - Enum: []any{"triage", "draft", "published", "closed"}, + Required: []string{"ghsaId"}, + }, + }, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + ghsaID, err := RequiredParam[string](args, "ghsaId") + if err != nil { + return utils.NewToolResultError(fmt.Sprintf("invalid ghsaId: %v", err)), nil, nil + } + + advisory, resp, err := client.SecurityAdvisories.GetGlobalSecurityAdvisories(ctx, ghsaID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get advisory: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get advisory", resp, body), nil, nil + } + + r, err := json.Marshal(advisory) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisory: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + }, + ) +} + +func ListOrgRepositorySecurityAdvisories(t translations.TranslationHelperFunc) inventory.ServerTool { + return NewTool( + ToolsetMetadataSecurityAdvisories, + mcp.Tool{ + Name: "list_org_repository_security_advisories", + Description: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_DESCRIPTION", "List repository security advisories for a GitHub organization."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_LIST_ORG_REPOSITORY_SECURITY_ADVISORIES_USER_TITLE", "List org repository security advisories"), + ReadOnlyHint: true, + }, + InputSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "org": { + Type: "string", + Description: "The organization login.", + }, + "direction": { + Type: "string", + Description: "Sort direction.", + Enum: []any{"asc", "desc"}, + }, + "sort": { + Type: "string", + Description: "Sort field.", + Enum: []any{"created", "updated", "published"}, + }, + "state": { + Type: "string", + Description: "Filter by advisory state.", + Enum: []any{"triage", "draft", "published", "closed"}, + }, }, + Required: []string{"org"}, }, - Required: []string{"org"}, }, - } - - handler := mcp.ToolHandlerFor[map[string]any, any](func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - org, err := RequiredParam[string](args, "org") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - direction, err := OptionalParam[string](args, "direction") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - sortField, err := OptionalParam[string](args, "sort") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - state, err := OptionalParam[string](args, "state") - if err != nil { - return utils.NewToolResultError(err.Error()), nil, nil - } - - client, err := getClient(ctx) - if err != nil { - return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) - } - - opts := &github.ListRepositorySecurityAdvisoriesOptions{} - if direction != "" { - opts.Direction = direction - } - if sortField != "" { - opts.Sort = sortField - } - if state != "" { - opts.State = state - } - - advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisoriesForOrg(ctx, org, opts) - if err != nil { - return nil, nil, fmt.Errorf("failed to list organization repository security advisories: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("failed to read response body: %w", err) - } - return utils.NewToolResultError(fmt.Sprintf("failed to list organization repository advisories: %s", string(body))), nil, nil - } - - r, err := json.Marshal(advisories) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) - } - - return utils.NewToolResultText(string(r)), nil, nil - }) - - return tool, handler + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + org, err := RequiredParam[string](args, "org") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + direction, err := OptionalParam[string](args, "direction") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + sortField, err := OptionalParam[string](args, "sort") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + state, err := OptionalParam[string](args, "state") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return nil, nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + opts := &github.ListRepositorySecurityAdvisoriesOptions{} + if direction != "" { + opts.Direction = direction + } + if sortField != "" { + opts.Sort = sortField + } + if state != "" { + opts.State = state + } + + advisories, resp, err := client.SecurityAdvisories.ListRepositorySecurityAdvisoriesForOrg(ctx, org, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to list organization repository security advisories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to list organization repository advisories", resp, body), nil, nil + } + + r, err := json.Marshal(advisories) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal advisories: %w", err) + } + + return utils.NewToolResultText(string(r)), nil, nil + }, + ) } diff --git a/pkg/github/security_advisories_test.go b/pkg/github/security_advisories_test.go index ed632d0be..bfc4c6985 100644 --- a/pkg/github/security_advisories_test.go +++ b/pkg/github/security_advisories_test.go @@ -10,14 +10,13 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_ListGlobalSecurityAdvisories(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := ListGlobalSecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListGlobalSecurityAdvisories(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_global_security_advisories", tool.Name) @@ -50,12 +49,9 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { }{ { name: "successful advisory fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetAdvisories, - []*github.GlobalSecurityAdvisory{mockAdvisory}, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetAdvisories: mockResponse(t, http.StatusOK, []*github.GlobalSecurityAdvisory{mockAdvisory}), + }), requestArgs: map[string]interface{}{ "type": "reviewed", "ecosystem": "npm", @@ -66,15 +62,12 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { }, { name: "invalid severity value", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetAdvisories, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte(`{"message": "Bad Request"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetAdvisories: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Bad Request"}`)) + }), + }), requestArgs: map[string]interface{}{ "type": "reviewed", "severity": "extreme", @@ -84,15 +77,12 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { }, { name: "API error handling", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetAdvisories, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(`{"message": "Internal Server Error"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetAdvisories: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message": "Internal Server Error"}`)) + }), + }), requestArgs: map[string]interface{}{}, expectError: true, expectedErrMsg: "failed to list global security advisories", @@ -103,13 +93,14 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListGlobalSecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{Client: client} + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -139,8 +130,8 @@ func Test_ListGlobalSecurityAdvisories(t *testing.T) { } func Test_GetGlobalSecurityAdvisory(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := GetGlobalSecurityAdvisory(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := GetGlobalSecurityAdvisory(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "get_global_security_advisory", tool.Name) @@ -171,12 +162,9 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { }{ { name: "successful advisory fetch", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.GetAdvisoriesByGhsaId, - mockAdvisory, - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetAdvisoriesByGhsaID: mockResponse(t, http.StatusOK, mockAdvisory), + }), requestArgs: map[string]interface{}{ "ghsaId": "GHSA-xxxx-xxxx-xxxx", }, @@ -185,15 +173,12 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { }, { name: "invalid ghsaId format", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetAdvisoriesByGhsaId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte(`{"message": "Bad Request"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetAdvisoriesByGhsaID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Bad Request"}`)) + }), + }), requestArgs: map[string]interface{}{ "ghsaId": "invalid-ghsa-id", }, @@ -202,15 +187,12 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { }, { name: "advisory not found", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.GetAdvisoriesByGhsaId, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(`{"message": "Not Found"}`)) - }), - ), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetAdvisoriesByGhsaID: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + }), requestArgs: map[string]interface{}{ "ghsaId": "GHSA-xxxx-xxxx-xxxx", }, @@ -223,13 +205,14 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetGlobalSecurityAdvisory(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{Client: client} + handler := toolDef.Handler(deps) // Create call request request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(ContextWithDeps(context.Background(), deps), &request) // Verify results if tc.expectError { @@ -254,8 +237,8 @@ func Test_GetGlobalSecurityAdvisory(t *testing.T) { func Test_ListRepositorySecurityAdvisories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListRepositorySecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListRepositorySecurityAdvisories(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_repository_security_advisories", tool.Name) @@ -270,12 +253,6 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { assert.Contains(t, schema.Properties, "state") assert.ElementsMatch(t, schema.Required, []string{"owner", "repo"}) - // Local endpoint pattern for repository security advisories - var GetReposSecurityAdvisoriesByOwnerByRepo = mock.EndpointPattern{ - Pattern: "/repos/{owner}/{repo}/security-advisories", - Method: "GET", - } - // Setup mock advisories for success cases adv1 := &github.SecurityAdvisory{ GHSAID: github.Ptr("GHSA-1111-1111-1111"), @@ -300,17 +277,14 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { }{ { name: "successful advisories listing (no filters)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - GetReposSecurityAdvisoriesByOwnerByRepo, - expect(t, expectations{ - path: "/repos/owner/repo/security-advisories", - queryParams: map[string]string{}, - }).andThen( - mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1, adv2}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecurityAdvisoriesByOwnerByRepo: expect(t, expectations{ + path: "/repos/owner/repo/security-advisories", + queryParams: map[string]string{}, + }).andThen( + mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1, adv2}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -320,21 +294,18 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { }, { name: "successful advisories listing with filters", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - GetReposSecurityAdvisoriesByOwnerByRepo, - expect(t, expectations{ - path: "/repos/octo/hello-world/security-advisories", - queryParams: map[string]string{ - "direction": "desc", - "sort": "updated", - "state": "published", - }, - }).andThen( - mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecurityAdvisoriesByOwnerByRepo: expect(t, expectations{ + path: "/repos/octo/hello-world/security-advisories", + queryParams: map[string]string{ + "direction": "desc", + "sort": "updated", + "state": "published", + }, + }).andThen( + mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "octo", "repo": "hello-world", @@ -347,17 +318,14 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { }, { name: "advisories listing fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - GetReposSecurityAdvisoriesByOwnerByRepo, - expect(t, expectations{ - path: "/repos/owner/repo/security-advisories", - queryParams: map[string]string{}, - }).andThen( - mockResponse(t, http.StatusInternalServerError, map[string]string{"message": "Internal Server Error"}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposSecurityAdvisoriesByOwnerByRepo: expect(t, expectations{ + path: "/repos/owner/repo/security-advisories", + queryParams: map[string]string{}, + }).andThen( + mockResponse(t, http.StatusInternalServerError, map[string]string{"message": "Internal Server Error"}), ), - ), + }), requestArgs: map[string]interface{}{ "owner": "owner", "repo": "repo", @@ -370,12 +338,13 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListRepositorySecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{Client: client} + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) @@ -403,8 +372,8 @@ func Test_ListRepositorySecurityAdvisories(t *testing.T) { func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := ListOrgRepositorySecurityAdvisories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + toolDef := ListOrgRepositorySecurityAdvisories(translations.NullTranslationHelper) + tool := toolDef.Tool require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "list_org_repository_security_advisories", tool.Name) @@ -418,12 +387,6 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { assert.Contains(t, schema.Properties, "state") assert.ElementsMatch(t, schema.Required, []string{"org"}) - // Endpoint pattern for org repository security advisories - var GetOrgsSecurityAdvisoriesByOrg = mock.EndpointPattern{ - Pattern: "/orgs/{org}/security-advisories", - Method: "GET", - } - adv1 := &github.SecurityAdvisory{ GHSAID: github.Ptr("GHSA-aaaa-bbbb-cccc"), Summary: github.Ptr("Org repo advisory 1"), @@ -447,17 +410,14 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { }{ { name: "successful listing (no filters)", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - GetOrgsSecurityAdvisoriesByOrg, - expect(t, expectations{ - path: "/orgs/octo/security-advisories", - queryParams: map[string]string{}, - }).andThen( - mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1, adv2}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsSecurityAdvisoriesByOrg: expect(t, expectations{ + path: "/orgs/octo/security-advisories", + queryParams: map[string]string{}, + }).andThen( + mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1, adv2}), ), - ), + }), requestArgs: map[string]interface{}{ "org": "octo", }, @@ -466,21 +426,18 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { }, { name: "successful listing with filters", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - GetOrgsSecurityAdvisoriesByOrg, - expect(t, expectations{ - path: "/orgs/octo/security-advisories", - queryParams: map[string]string{ - "direction": "asc", - "sort": "created", - "state": "triage", - }, - }).andThen( - mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsSecurityAdvisoriesByOrg: expect(t, expectations{ + path: "/orgs/octo/security-advisories", + queryParams: map[string]string{ + "direction": "asc", + "sort": "created", + "state": "triage", + }, + }).andThen( + mockResponse(t, http.StatusOK, []*github.SecurityAdvisory{adv1}), ), - ), + }), requestArgs: map[string]interface{}{ "org": "octo", "direction": "asc", @@ -492,17 +449,14 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { }, { name: "listing fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - GetOrgsSecurityAdvisoriesByOrg, - expect(t, expectations{ - path: "/orgs/octo/security-advisories", - queryParams: map[string]string{}, - }).andThen( - mockResponse(t, http.StatusForbidden, map[string]string{"message": "Forbidden"}), - ), + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetOrgsSecurityAdvisoriesByOrg: expect(t, expectations{ + path: "/orgs/octo/security-advisories", + queryParams: map[string]string{}, + }).andThen( + mockResponse(t, http.StatusForbidden, map[string]string{"message": "Forbidden"}), ), - ), + }), requestArgs: map[string]interface{}{ "org": "octo", }, @@ -514,12 +468,13 @@ func Test_ListOrgRepositorySecurityAdvisories(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - _, handler := ListOrgRepositorySecurityAdvisories(stubGetClientFn(client), translations.NullTranslationHelper) + deps := BaseDeps{Client: client} + handler := toolDef.Handler(deps) request := createMCPRequest(tc.requestArgs) - // Call handler - note the new signature with 3 parameters and 3 return values - result, _, err := handler(context.Background(), &request, tc.requestArgs) + // Call handler + result, err := handler(ContextWithDeps(context.Background(), deps), &request) if tc.expectError { require.Error(t, err) diff --git a/pkg/github/server.go b/pkg/github/server.go index e74596906..8248da58f 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" + "github.com/github/github-mcp-server/pkg/octicons" "github.com/github/github-mcp-server/pkg/utils" "github.com/google/go-github/v79/github" "github.com/google/jsonschema-go/jsonschema" @@ -18,12 +19,7 @@ import ( func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { if opts == nil { - // Add default options - opts = &mcp.ServerOptions{ - HasTools: true, - HasResources: true, - HasPrompts: true, - } + opts = &mcp.ServerOptions{} } // Create a new MCP server @@ -31,6 +27,7 @@ func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server { Name: "github-mcp-server", Title: "GitHub MCP Server", Version: version, + Icons: octicons.Icons("mark-github"), }, opts) return s diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 2e9ab43a3..a59cd9a93 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -11,32 +11,67 @@ import ( "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/raw" + "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" "github.com/shurcooL/githubv4" "github.com/stretchr/testify/assert" ) -func stubGetClientFn(client *github.Client) GetClientFn { - return func(_ context.Context) (*github.Client, error) { - return client, nil +// stubDeps is a test helper that implements ToolDependencies with configurable behavior. +// Use this when you need to test error paths or when you need closure-based client creation. +type stubDeps struct { + clientFn func(context.Context) (*github.Client, error) + gqlClientFn func(context.Context) (*githubv4.Client, error) + rawClientFn func(context.Context) (*raw.Client, error) + + repoAccessCache *lockdown.RepoAccessCache + t translations.TranslationHelperFunc + flags FeatureFlags + contentWindowSize int +} + +func (s stubDeps) GetClient(ctx context.Context) (*github.Client, error) { + if s.clientFn != nil { + return s.clientFn(ctx) } + return nil, nil } -func stubGetClientFromHTTPFn(client *http.Client) GetClientFn { +func (s stubDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error) { + if s.gqlClientFn != nil { + return s.gqlClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRawClient(ctx context.Context) (*raw.Client, error) { + if s.rawClientFn != nil { + return s.rawClientFn(ctx) + } + return nil, nil +} + +func (s stubDeps) GetRepoAccessCache() *lockdown.RepoAccessCache { return s.repoAccessCache } +func (s stubDeps) GetT() translations.TranslationHelperFunc { return s.t } +func (s stubDeps) GetFlags() FeatureFlags { return s.flags } +func (s stubDeps) GetContentWindowSize() int { return s.contentWindowSize } + +// Helper functions to create stub client functions for error testing +func stubClientFnFromHTTP(httpClient *http.Client) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return github.NewClient(client), nil + return github.NewClient(httpClient), nil } } -func stubGetClientFnErr(err string) GetClientFn { +func stubClientFnErr(errMsg string) func(context.Context) (*github.Client, error) { return func(_ context.Context) (*github.Client, error) { - return nil, errors.New(err) + return nil, errors.New(errMsg) } } -func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { +func stubGQLClientFnErr(errMsg string) func(context.Context) (*githubv4.Client, error) { return func(_ context.Context) (*githubv4.Client, error) { - return client, nil + return nil, errors.New(errMsg) } } @@ -51,12 +86,6 @@ func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { } } -func stubGetRawClientFn(client *raw.Client) raw.GetRawClientFn { - return func(_ context.Context) (*raw.Client, error) { - return client, nil - } -} - func badRequestHandler(msg string) http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { structuredErrorResponse := github.ErrorResponse{ diff --git a/pkg/github/tools.go b/pkg/github/tools.go index ffb7f1852..62b67af6f 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -2,402 +2,294 @@ package github import ( "context" - "fmt" "strings" - "github.com/github/github-mcp-server/pkg/lockdown" - "github.com/github/github-mcp-server/pkg/raw" - "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v79/github" - "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/shurcooL/githubv4" ) type GetClientFn func(context.Context) (*github.Client, error) type GetGQLClientFn func(context.Context) (*githubv4.Client, error) -// ToolsetMetadata holds metadata for a toolset including its ID and description -type ToolsetMetadata struct { - ID string - Description string -} - +// Toolset metadata constants - these define all available toolsets and their descriptions. +// Tools use these constants to declare which toolset they belong to. +// Icons are Octicon names from https://primer.style/foundations/icons var ( - ToolsetMetadataAll = ToolsetMetadata{ + ToolsetMetadataAll = inventory.ToolsetMetadata{ ID: "all", Description: "Special toolset that enables all available toolsets", + Icon: "apps", } - ToolsetMetadataDefault = ToolsetMetadata{ + ToolsetMetadataDefault = inventory.ToolsetMetadata{ ID: "default", Description: "Special toolset that enables the default toolset configuration. When no toolsets are specified, this is the set that is enabled", + Icon: "check-circle", } - ToolsetMetadataContext = ToolsetMetadata{ + ToolsetMetadataContext = inventory.ToolsetMetadata{ ID: "context", Description: "Tools that provide context about the current user and GitHub context you are operating in", + Default: true, + Icon: "person", } - ToolsetMetadataRepos = ToolsetMetadata{ + ToolsetMetadataRepos = inventory.ToolsetMetadata{ ID: "repos", Description: "GitHub Repository related tools", + Default: true, + Icon: "repo", } - ToolsetMetadataGit = ToolsetMetadata{ + ToolsetMetadataGit = inventory.ToolsetMetadata{ ID: "git", Description: "GitHub Git API related tools for low-level Git operations", + Icon: "git-branch", } - ToolsetMetadataIssues = ToolsetMetadata{ + ToolsetMetadataIssues = inventory.ToolsetMetadata{ ID: "issues", Description: "GitHub Issues related tools", + Default: true, + Icon: "issue-opened", } - ToolsetMetadataPullRequests = ToolsetMetadata{ + ToolsetMetadataPullRequests = inventory.ToolsetMetadata{ ID: "pull_requests", Description: "GitHub Pull Request related tools", + Default: true, + Icon: "git-pull-request", } - ToolsetMetadataUsers = ToolsetMetadata{ + ToolsetMetadataUsers = inventory.ToolsetMetadata{ ID: "users", Description: "GitHub User related tools", + Default: true, + Icon: "people", } - ToolsetMetadataOrgs = ToolsetMetadata{ + ToolsetMetadataOrgs = inventory.ToolsetMetadata{ ID: "orgs", Description: "GitHub Organization related tools", + Icon: "organization", } - ToolsetMetadataActions = ToolsetMetadata{ + ToolsetMetadataActions = inventory.ToolsetMetadata{ ID: "actions", Description: "GitHub Actions workflows and CI/CD operations", + Icon: "workflow", } - ToolsetMetadataCodeSecurity = ToolsetMetadata{ + ToolsetMetadataCodeSecurity = inventory.ToolsetMetadata{ ID: "code_security", Description: "Code security related tools, such as GitHub Code Scanning", + Icon: "codescan", } - ToolsetMetadataSecretProtection = ToolsetMetadata{ + ToolsetMetadataSecretProtection = inventory.ToolsetMetadata{ ID: "secret_protection", Description: "Secret protection related tools, such as GitHub Secret Scanning", + Icon: "shield-lock", } - ToolsetMetadataDependabot = ToolsetMetadata{ + ToolsetMetadataDependabot = inventory.ToolsetMetadata{ ID: "dependabot", Description: "Dependabot tools", + Icon: "dependabot", } - ToolsetMetadataNotifications = ToolsetMetadata{ + ToolsetMetadataNotifications = inventory.ToolsetMetadata{ ID: "notifications", Description: "GitHub Notifications related tools", + Icon: "bell", } - ToolsetMetadataExperiments = ToolsetMetadata{ + ToolsetMetadataExperiments = inventory.ToolsetMetadata{ ID: "experiments", Description: "Experimental features that are not considered stable yet", + Icon: "beaker", } - ToolsetMetadataDiscussions = ToolsetMetadata{ + ToolsetMetadataDiscussions = inventory.ToolsetMetadata{ ID: "discussions", Description: "GitHub Discussions related tools", + Icon: "comment-discussion", } - ToolsetMetadataGists = ToolsetMetadata{ + ToolsetMetadataGists = inventory.ToolsetMetadata{ ID: "gists", Description: "GitHub Gist related tools", + Icon: "logo-gist", } - ToolsetMetadataSecurityAdvisories = ToolsetMetadata{ + ToolsetMetadataSecurityAdvisories = inventory.ToolsetMetadata{ ID: "security_advisories", Description: "Security advisories related tools", + Icon: "shield", } - ToolsetMetadataProjects = ToolsetMetadata{ + ToolsetMetadataProjects = inventory.ToolsetMetadata{ ID: "projects", Description: "GitHub Projects related tools", + Icon: "project", } - ToolsetMetadataStargazers = ToolsetMetadata{ + ToolsetMetadataStargazers = inventory.ToolsetMetadata{ ID: "stargazers", Description: "GitHub Stargazers related tools", + Icon: "star", } - ToolsetMetadataDynamic = ToolsetMetadata{ + ToolsetMetadataDynamic = inventory.ToolsetMetadata{ ID: "dynamic", Description: "Discover GitHub MCP tools that can help achieve tasks by enabling additional sets of tools, you can control the enablement of any toolset to access its tools when this toolset is enabled.", + Icon: "tools", } - ToolsetLabels = ToolsetMetadata{ + ToolsetLabels = inventory.ToolsetMetadata{ ID: "labels", Description: "GitHub Labels related tools", + Icon: "tag", } -) -func AvailableTools() []ToolsetMetadata { - return []ToolsetMetadata{ - ToolsetMetadataContext, - ToolsetMetadataRepos, - ToolsetMetadataIssues, - ToolsetMetadataPullRequests, - ToolsetMetadataUsers, - ToolsetMetadataOrgs, - ToolsetMetadataActions, - ToolsetMetadataCodeSecurity, - ToolsetMetadataSecretProtection, - ToolsetMetadataDependabot, - ToolsetMetadataNotifications, - ToolsetMetadataExperiments, - ToolsetMetadataDiscussions, - ToolsetMetadataGists, - ToolsetMetadataSecurityAdvisories, - ToolsetMetadataProjects, - ToolsetMetadataStargazers, - ToolsetMetadataDynamic, - ToolsetLabels, + // Remote-only toolsets - these are only available in the remote MCP server + // but are documented here for consistency and to enable automated documentation. + ToolsetMetadataCopilot = inventory.ToolsetMetadata{ + ID: "copilot", + Description: "Copilot related tools", + Icon: "copilot", } -} - -// GetValidToolsetIDs returns a map of all valid toolset IDs for quick lookup -func GetValidToolsetIDs() map[string]bool { - validIDs := make(map[string]bool) - for _, tool := range AvailableTools() { - validIDs[tool.ID] = true + ToolsetMetadataCopilotSpaces = inventory.ToolsetMetadata{ + ID: "copilot_spaces", + Description: "Copilot Spaces tools", + Icon: "copilot", } - // Add special keywords - validIDs[ToolsetMetadataAll.ID] = true - validIDs[ToolsetMetadataDefault.ID] = true - return validIDs -} - -func GetDefaultToolsetIDs() []string { - return []string{ - ToolsetMetadataContext.ID, - ToolsetMetadataRepos.ID, - ToolsetMetadataIssues.ID, - ToolsetMetadataPullRequests.ID, - ToolsetMetadataUsers.ID, + ToolsetMetadataSupportSearch = inventory.ToolsetMetadata{ + ID: "github_support_docs_search", + Description: "Retrieve documentation to answer GitHub product and support questions. Topics include: GitHub Actions Workflows, Authentication, ...", + Icon: "book", } -} - -func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags, cache *lockdown.RepoAccessCache) *toolsets.ToolsetGroup { - tsg := toolsets.NewToolsetGroup(readOnly) - - // Define all available features with their default state (disabled) - // Create toolsets - repos := toolsets.NewToolset(ToolsetMetadataRepos.ID, ToolsetMetadataRepos.Description). - AddReadTools( - toolsets.NewServerTool(SearchRepositories(getClient, t)), - toolsets.NewServerTool(GetFileContents(getClient, getRawClient, t)), - toolsets.NewServerTool(ListCommits(getClient, t)), - toolsets.NewServerTool(SearchCode(getClient, t)), - toolsets.NewServerTool(GetCommit(getClient, t)), - toolsets.NewServerTool(ListBranches(getClient, t)), - toolsets.NewServerTool(ListTags(getClient, t)), - toolsets.NewServerTool(GetTag(getClient, t)), - toolsets.NewServerTool(ListReleases(getClient, t)), - toolsets.NewServerTool(GetLatestRelease(getClient, t)), - toolsets.NewServerTool(GetReleaseByTag(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(CreateOrUpdateFile(getClient, t)), - toolsets.NewServerTool(CreateRepository(getClient, t)), - toolsets.NewServerTool(ForkRepository(getClient, t)), - toolsets.NewServerTool(CreateBranch(getClient, t)), - toolsets.NewServerTool(PushFiles(getClient, t)), - toolsets.NewServerTool(DeleteFile(getClient, t)), - ). - AddResourceTemplates( - toolsets.NewServerResourceTemplate(GetRepositoryResourceContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceBranchContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceCommitContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourceTagContent(getClient, getRawClient, t)), - toolsets.NewServerResourceTemplate(GetRepositoryResourcePrContent(getClient, getRawClient, t)), - ) - git := toolsets.NewToolset(ToolsetMetadataGit.ID, ToolsetMetadataGit.Description). - AddReadTools( - toolsets.NewServerTool(GetRepositoryTree(getClient, t)), - ) - issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). - AddReadTools( - toolsets.NewServerTool(IssueRead(getClient, getGQLClient, cache, t, flags)), - toolsets.NewServerTool(SearchIssues(getClient, t)), - toolsets.NewServerTool(ListIssues(getGQLClient, t)), - toolsets.NewServerTool(ListIssueTypes(getClient, t)), - toolsets.NewServerTool(GetLabel(getGQLClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(IssueWrite(getClient, getGQLClient, t)), - toolsets.NewServerTool(AddIssueComment(getClient, t)), - toolsets.NewServerTool(AssignCopilotToIssue(getGQLClient, t)), - toolsets.NewServerTool(SubIssueWrite(getClient, t)), - ).AddPrompts( - toolsets.NewServerPrompt(AssignCodingAgentPrompt(t)), - toolsets.NewServerPrompt(IssueToFixWorkflowPrompt(t)), - ) - users := toolsets.NewToolset(ToolsetMetadataUsers.ID, ToolsetMetadataUsers.Description). - AddReadTools( - toolsets.NewServerTool(SearchUsers(getClient, t)), - ) - orgs := toolsets.NewToolset(ToolsetMetadataOrgs.ID, ToolsetMetadataOrgs.Description). - AddReadTools( - toolsets.NewServerTool(SearchOrgs(getClient, t)), - ) - pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). - AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, getGQLClient, cache, t, flags)), - toolsets.NewServerTool(ListPullRequests(getClient, t)), - toolsets.NewServerTool(SearchPullRequests(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(MergePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), - toolsets.NewServerTool(CreatePullRequest(getClient, t)), - toolsets.NewServerTool(UpdatePullRequest(getClient, getGQLClient, t)), - toolsets.NewServerTool(RequestCopilotReview(getClient, t)), - // Reviews - toolsets.NewServerTool(PullRequestReviewWrite(getGQLClient, t)), - toolsets.NewServerTool(AddCommentToPendingReview(getGQLClient, t)), - ) - codeSecurity := toolsets.NewToolset(ToolsetMetadataCodeSecurity.ID, ToolsetMetadataCodeSecurity.Description). - AddReadTools( - toolsets.NewServerTool(GetCodeScanningAlert(getClient, t)), - toolsets.NewServerTool(ListCodeScanningAlerts(getClient, t)), - ) - secretProtection := toolsets.NewToolset(ToolsetMetadataSecretProtection.ID, ToolsetMetadataSecretProtection.Description). - AddReadTools( - toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)), - toolsets.NewServerTool(ListSecretScanningAlerts(getClient, t)), - ) - dependabot := toolsets.NewToolset(ToolsetMetadataDependabot.ID, ToolsetMetadataDependabot.Description). - AddReadTools( - toolsets.NewServerTool(GetDependabotAlert(getClient, t)), - toolsets.NewServerTool(ListDependabotAlerts(getClient, t)), - ) - - notifications := toolsets.NewToolset(ToolsetMetadataNotifications.ID, ToolsetMetadataNotifications.Description). - AddReadTools( - toolsets.NewServerTool(ListNotifications(getClient, t)), - toolsets.NewServerTool(GetNotificationDetails(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(DismissNotification(getClient, t)), - toolsets.NewServerTool(MarkAllNotificationsRead(getClient, t)), - toolsets.NewServerTool(ManageNotificationSubscription(getClient, t)), - toolsets.NewServerTool(ManageRepositoryNotificationSubscription(getClient, t)), - ) - - discussions := toolsets.NewToolset(ToolsetMetadataDiscussions.ID, ToolsetMetadataDiscussions.Description). - AddReadTools( - toolsets.NewServerTool(ListDiscussions(getGQLClient, t)), - toolsets.NewServerTool(GetDiscussion(getGQLClient, t)), - toolsets.NewServerTool(GetDiscussionComments(getGQLClient, t)), - toolsets.NewServerTool(ListDiscussionCategories(getGQLClient, t)), - ) - - actions := toolsets.NewToolset(ToolsetMetadataActions.ID, ToolsetMetadataActions.Description). - AddReadTools( - toolsets.NewServerTool(ListWorkflows(getClient, t)), - toolsets.NewServerTool(ListWorkflowRuns(getClient, t)), - toolsets.NewServerTool(GetWorkflowRun(getClient, t)), - toolsets.NewServerTool(GetWorkflowRunLogs(getClient, t)), - toolsets.NewServerTool(ListWorkflowJobs(getClient, t)), - toolsets.NewServerTool(GetJobLogs(getClient, t, contentWindowSize)), - toolsets.NewServerTool(ListWorkflowRunArtifacts(getClient, t)), - toolsets.NewServerTool(DownloadWorkflowRunArtifact(getClient, t)), - toolsets.NewServerTool(GetWorkflowRunUsage(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(RunWorkflow(getClient, t)), - toolsets.NewServerTool(RerunWorkflowRun(getClient, t)), - toolsets.NewServerTool(RerunFailedJobs(getClient, t)), - toolsets.NewServerTool(CancelWorkflowRun(getClient, t)), - toolsets.NewServerTool(DeleteWorkflowRunLogs(getClient, t)), - ) - - securityAdvisories := toolsets.NewToolset(ToolsetMetadataSecurityAdvisories.ID, ToolsetMetadataSecurityAdvisories.Description). - AddReadTools( - toolsets.NewServerTool(ListGlobalSecurityAdvisories(getClient, t)), - toolsets.NewServerTool(GetGlobalSecurityAdvisory(getClient, t)), - toolsets.NewServerTool(ListRepositorySecurityAdvisories(getClient, t)), - toolsets.NewServerTool(ListOrgRepositorySecurityAdvisories(getClient, t)), - ) - - // // Keep experiments alive so the system doesn't error out when it's always enabled - experiments := toolsets.NewToolset(ToolsetMetadataExperiments.ID, ToolsetMetadataExperiments.Description) - - contextTools := toolsets.NewToolset(ToolsetMetadataContext.ID, ToolsetMetadataContext.Description). - AddReadTools( - toolsets.NewServerTool(GetMe(getClient, t)), - toolsets.NewServerTool(GetTeams(getClient, getGQLClient, t)), - toolsets.NewServerTool(GetTeamMembers(getGQLClient, t)), - ) - - gists := toolsets.NewToolset(ToolsetMetadataGists.ID, ToolsetMetadataGists.Description). - AddReadTools( - toolsets.NewServerTool(ListGists(getClient, t)), - toolsets.NewServerTool(GetGist(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(CreateGist(getClient, t)), - toolsets.NewServerTool(UpdateGist(getClient, t)), - ) - - projects := toolsets.NewToolset(ToolsetMetadataProjects.ID, ToolsetMetadataProjects.Description). - AddReadTools( - toolsets.NewServerTool(ListProjects(getClient, t)), - toolsets.NewServerTool(GetProject(getClient, t)), - toolsets.NewServerTool(ListProjectFields(getClient, t)), - toolsets.NewServerTool(GetProjectField(getClient, t)), - toolsets.NewServerTool(ListProjectItems(getClient, t)), - toolsets.NewServerTool(GetProjectItem(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(AddProjectItem(getClient, t)), - toolsets.NewServerTool(DeleteProjectItem(getClient, t)), - toolsets.NewServerTool(UpdateProjectItem(getClient, t)), - ) - stargazers := toolsets.NewToolset(ToolsetMetadataStargazers.ID, ToolsetMetadataStargazers.Description). - AddReadTools( - toolsets.NewServerTool(ListStarredRepositories(getClient, t)), - ). - AddWriteTools( - toolsets.NewServerTool(StarRepository(getClient, t)), - toolsets.NewServerTool(UnstarRepository(getClient, t)), - ) - labels := toolsets.NewToolset(ToolsetLabels.ID, ToolsetLabels.Description). - AddReadTools( - // get - toolsets.NewServerTool(GetLabel(getGQLClient, t)), - // list labels on repo or issue - toolsets.NewServerTool(ListLabels(getGQLClient, t)), - ). - AddWriteTools( - // create or update - toolsets.NewServerTool(LabelWrite(getGQLClient, t)), - ) - - // Add toolsets to the group - tsg.AddToolset(contextTools) - tsg.AddToolset(repos) - tsg.AddToolset(git) - tsg.AddToolset(issues) - tsg.AddToolset(orgs) - tsg.AddToolset(users) - tsg.AddToolset(pullRequests) - tsg.AddToolset(actions) - tsg.AddToolset(codeSecurity) - tsg.AddToolset(dependabot) - tsg.AddToolset(secretProtection) - tsg.AddToolset(notifications) - tsg.AddToolset(experiments) - tsg.AddToolset(discussions) - tsg.AddToolset(gists) - tsg.AddToolset(securityAdvisories) - tsg.AddToolset(projects) - tsg.AddToolset(stargazers) - tsg.AddToolset(labels) - - tsg.AddDeprecatedToolAliases(DeprecatedToolAliases) - - return tsg -} +) -// InitDynamicToolset creates a dynamic toolset that can be used to enable other toolsets, and so requires the server and toolset group as arguments -// -//nolint:unused -func InitDynamicToolset(s *mcp.Server, tsg *toolsets.ToolsetGroup, t translations.TranslationHelperFunc) *toolsets.Toolset { - // Create a new dynamic toolset - // Need to add the dynamic toolset last so it can be used to enable other toolsets - dynamicToolSelection := toolsets.NewToolset(ToolsetMetadataDynamic.ID, ToolsetMetadataDynamic.Description). - AddReadTools( - toolsets.NewServerTool(ListAvailableToolsets(tsg, t)), - toolsets.NewServerTool(GetToolsetsTools(tsg, t)), - toolsets.NewServerTool(EnableToolset(s, tsg, t)), - ) - - dynamicToolSelection.Enabled = true - return dynamicToolSelection +// AllTools returns all tools with their embedded toolset metadata. +// Tool functions return ServerTool directly with toolset info. +func AllTools(t translations.TranslationHelperFunc) []inventory.ServerTool { + return []inventory.ServerTool{ + // Context tools + GetMe(t), + GetTeams(t), + GetTeamMembers(t), + + // Repository tools + SearchRepositories(t), + GetFileContents(t), + ListCommits(t), + SearchCode(t), + GetCommit(t), + ListBranches(t), + ListTags(t), + GetTag(t), + ListReleases(t), + GetLatestRelease(t), + GetReleaseByTag(t), + CreateOrUpdateFile(t), + CreateRepository(t), + ForkRepository(t), + CreateBranch(t), + PushFiles(t), + DeleteFile(t), + ListStarredRepositories(t), + StarRepository(t), + UnstarRepository(t), + + // Git tools + GetRepositoryTree(t), + + // Issue tools + IssueRead(t), + SearchIssues(t), + ListIssues(t), + ListIssueTypes(t), + IssueWrite(t), + AddIssueComment(t), + AssignCopilotToIssue(t), + SubIssueWrite(t), + + // User tools + SearchUsers(t), + + // Organization tools + SearchOrgs(t), + + // Pull request tools + PullRequestRead(t), + ListPullRequests(t), + SearchPullRequests(t), + MergePullRequest(t), + UpdatePullRequestBranch(t), + CreatePullRequest(t), + UpdatePullRequest(t), + RequestCopilotReview(t), + PullRequestReviewWrite(t), + AddCommentToPendingReview(t), + + // Code security tools + GetCodeScanningAlert(t), + ListCodeScanningAlerts(t), + + // Secret protection tools + GetSecretScanningAlert(t), + ListSecretScanningAlerts(t), + + // Dependabot tools + GetDependabotAlert(t), + ListDependabotAlerts(t), + + // Notification tools + ListNotifications(t), + GetNotificationDetails(t), + DismissNotification(t), + MarkAllNotificationsRead(t), + ManageNotificationSubscription(t), + ManageRepositoryNotificationSubscription(t), + + // Discussion tools + ListDiscussions(t), + GetDiscussion(t), + GetDiscussionComments(t), + ListDiscussionCategories(t), + + // Actions tools + ListWorkflows(t), + ListWorkflowRuns(t), + GetWorkflowRun(t), + GetWorkflowRunLogs(t), + ListWorkflowJobs(t), + GetJobLogs(t), + ListWorkflowRunArtifacts(t), + DownloadWorkflowRunArtifact(t), + GetWorkflowRunUsage(t), + RunWorkflow(t), + RerunWorkflowRun(t), + RerunFailedJobs(t), + CancelWorkflowRun(t), + DeleteWorkflowRunLogs(t), + // Consolidated Actions tools (enabled via feature flag) + ActionsList(t), + ActionsGet(t), + ActionsRunTrigger(t), + ActionsGetJobLogs(t), + + // Security advisories tools + ListGlobalSecurityAdvisories(t), + GetGlobalSecurityAdvisory(t), + ListRepositorySecurityAdvisories(t), + ListOrgRepositorySecurityAdvisories(t), + + // Gist tools + ListGists(t), + GetGist(t), + CreateGist(t), + UpdateGist(t), + + // Project tools + ListProjects(t), + GetProject(t), + ListProjectFields(t), + GetProjectField(t), + ListProjectItems(t), + GetProjectItem(t), + AddProjectItem(t), + DeleteProjectItem(t), + UpdateProjectItem(t), + + // Label tools + GetLabel(t), + GetLabelForLabelsToolset(t), + ListLabels(t), + LabelWrite(t), + } } // ToBoolPtr converts a bool to a *bool pointer. @@ -416,52 +308,77 @@ func ToStringPtr(s string) *string { // GenerateToolsetsHelp generates the help text for the toolsets flag func GenerateToolsetsHelp() string { - // Format default tools - defaultTools := strings.Join(GetDefaultToolsetIDs(), ", ") + // Get toolset group to derive defaults and available toolsets + r := NewInventory(stubTranslator).Build() + + // Format default tools from metadata using strings.Builder + var defaultBuf strings.Builder + defaultIDs := r.DefaultToolsetIDs() + for i, id := range defaultIDs { + if i > 0 { + defaultBuf.WriteString(", ") + } + defaultBuf.WriteString(string(id)) + } - // Format available tools with line breaks for better readability - allTools := AvailableTools() - var availableToolsLines []string + // Get all available toolsets (excludes context and dynamic for display) + allToolsets := r.AvailableToolsets("context", "dynamic") + var availableBuf strings.Builder const maxLineLength = 70 currentLine := "" - for i, tool := range allTools { + for i, toolset := range allToolsets { + id := string(toolset.ID) switch { case i == 0: - currentLine = tool.ID - case len(currentLine)+len(tool.ID)+2 <= maxLineLength: - currentLine += ", " + tool.ID + currentLine = id + case len(currentLine)+len(id)+2 <= maxLineLength: + currentLine += ", " + id default: - availableToolsLines = append(availableToolsLines, currentLine) - currentLine = tool.ID + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) + currentLine = id } } if currentLine != "" { - availableToolsLines = append(availableToolsLines, currentLine) - } - - availableTools := strings.Join(availableToolsLines, ",\n\t ") - - toolsetsHelp := fmt.Sprintf("Comma-separated list of tool groups to enable (no spaces).\n"+ - "Available: %s\n", availableTools) + - "Special toolset keywords:\n" + - " - all: Enables all available toolsets\n" + - fmt.Sprintf(" - default: Enables the default toolset configuration of:\n\t %s\n", defaultTools) + - "Examples:\n" + - " - --toolsets=actions,gists,notifications\n" + - " - Default + additional: --toolsets=default,actions,gists\n" + - " - All tools: --toolsets=all" - - return toolsetsHelp + if availableBuf.Len() > 0 { + availableBuf.WriteString(",\n\t ") + } + availableBuf.WriteString(currentLine) + } + + // Build the complete help text using strings.Builder + var buf strings.Builder + buf.WriteString("Comma-separated list of tool groups to enable (no spaces).\n") + buf.WriteString("Available: ") + buf.WriteString(availableBuf.String()) + buf.WriteString("\n") + buf.WriteString("Special toolset keywords:\n") + buf.WriteString(" - all: Enables all available toolsets\n") + buf.WriteString(" - default: Enables the default toolset configuration of:\n\t ") + buf.WriteString(defaultBuf.String()) + buf.WriteString("\n") + buf.WriteString("Examples:\n") + buf.WriteString(" - --toolsets=actions,gists,notifications\n") + buf.WriteString(" - Default + additional: --toolsets=default,actions,gists\n") + buf.WriteString(" - All tools: --toolsets=all") + + return buf.String() } +// stubTranslator is a passthrough translator for cases where we need an Inventory +// but don't need actual translations (e.g., getting toolset IDs for CLI help). +func stubTranslator(_, fallback string) string { return fallback } + // AddDefaultToolset removes the default toolset and expands it to the actual default toolset IDs func AddDefaultToolset(result []string) []string { hasDefault := false seen := make(map[string]bool) for _, toolset := range result { seen[toolset] = true - if toolset == ToolsetMetadataDefault.ID { + if toolset == string(ToolsetMetadataDefault.ID) { hasDefault = true } } @@ -471,45 +388,18 @@ func AddDefaultToolset(result []string) []string { return result } - result = RemoveToolset(result, ToolsetMetadataDefault.ID) + result = RemoveToolset(result, string(ToolsetMetadataDefault.ID)) - for _, defaultToolset := range GetDefaultToolsetIDs() { - if !seen[defaultToolset] { - result = append(result, defaultToolset) + // Get default toolset IDs from the Inventory + r := NewInventory(stubTranslator).Build() + for _, id := range r.DefaultToolsetIDs() { + if !seen[string(id)] { + result = append(result, string(id)) } } return result } -// cleanToolsets cleans and handles special toolset keywords: -// - Duplicates are removed from the result -// - Removes whitespaces -// - Validates toolset names and returns invalid ones separately - for warning reporting -// Returns: (toolsets, invalidToolsets) -func CleanToolsets(enabledToolsets []string) ([]string, []string) { - seen := make(map[string]bool) - result := make([]string, 0, len(enabledToolsets)) - invalid := make([]string, 0) - validIDs := GetValidToolsetIDs() - - // Add non-default toolsets, removing duplicates and trimming whitespace - for _, toolset := range enabledToolsets { - trimmed := strings.TrimSpace(toolset) - if trimmed == "" { - continue - } - if !seen[trimmed] { - seen[trimmed] = true - result = append(result, trimmed) - if !validIDs[trimmed] { - invalid = append(invalid, trimmed) - } - } - } - - return result, invalid -} - func RemoveToolset(tools []string, toRemove string) []string { result := make([]string, 0, len(tools)) for _, tool := range tools { @@ -549,3 +439,26 @@ func CleanTools(toolNames []string) []string { return result } + +// GetDefaultToolsetIDs returns the IDs of toolsets marked as Default. +// This is a convenience function that builds an inventory to determine defaults. +func GetDefaultToolsetIDs() []string { + r := NewInventory(stubTranslator).Build() + ids := r.DefaultToolsetIDs() + result := make([]string, len(ids)) + for i, id := range ids { + result[i] = string(id) + } + return result +} + +// RemoteOnlyToolsets returns toolset metadata for toolsets that are only +// available in the remote MCP server. These are documented but not registered +// in the local server. +func RemoteOnlyToolsets() []inventory.ToolsetMetadata { + return []inventory.ToolsetMetadata{ + ToolsetMetadataCopilot, + ToolsetMetadataCopilotSpaces, + ToolsetMetadataSupportSearch, + } +} diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go index 45c1e746f..80270d2bc 100644 --- a/pkg/github/tools_test.go +++ b/pkg/github/tools_test.go @@ -7,135 +7,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestCleanToolsets(t *testing.T) { - tests := []struct { - name string - input []string - expected []string - expectedInvalid []string - }{ - { - name: "empty slice", - input: []string{}, - expected: []string{}, - }, - { - name: "nil input slice", - input: nil, - expected: []string{}, - }, - // CleanToolsets only cleans - it does NOT filter out special keywords - { - name: "default keyword preserved", - input: []string{"default"}, - expected: []string{"default"}, - }, - { - name: "default with additional toolsets", - input: []string{"default", "actions", "gists"}, - expected: []string{"default", "actions", "gists"}, - }, - { - name: "all keyword preserved", - input: []string{"all", "actions"}, - expected: []string{"all", "actions"}, - }, - { - name: "no special keywords", - input: []string{"actions", "gists", "notifications"}, - expected: []string{"actions", "gists", "notifications"}, - }, - { - name: "duplicate toolsets without special keywords", - input: []string{"actions", "gists", "actions"}, - expected: []string{"actions", "gists"}, - }, - { - name: "duplicate toolsets with default", - input: []string{"context", "repos", "issues", "pull_requests", "users", "default"}, - expected: []string{"context", "repos", "issues", "pull_requests", "users", "default"}, - }, - { - name: "default appears multiple times - duplicates removed", - input: []string{"default", "actions", "default", "gists", "default"}, - expected: []string{"default", "actions", "gists"}, - }, - // Whitespace test cases - { - name: "whitespace check - leading and trailing whitespace on regular toolsets", - input: []string{" actions ", " gists ", "notifications"}, - expected: []string{"actions", "gists", "notifications"}, - }, - { - name: "whitespace check - default toolset with whitespace", - input: []string{" actions ", " default ", "notifications"}, - expected: []string{"actions", "default", "notifications"}, - }, - { - name: "whitespace check - all toolset with whitespace", - input: []string{" all ", " actions "}, - expected: []string{"all", "actions"}, - }, - // Invalid toolset test cases - { - name: "mix of valid and invalid toolsets", - input: []string{"actions", "invalid_toolset", "gists", "typo_repo"}, - expected: []string{"actions", "invalid_toolset", "gists", "typo_repo"}, - expectedInvalid: []string{"invalid_toolset", "typo_repo"}, - }, - { - name: "invalid with whitespace", - input: []string{" invalid_tool ", " actions ", " typo_gist "}, - expected: []string{"invalid_tool", "actions", "typo_gist"}, - expectedInvalid: []string{"invalid_tool", "typo_gist"}, - }, - { - name: "empty string in toolsets", - input: []string{"", "actions", " ", "gists"}, - expected: []string{"actions", "gists"}, - expectedInvalid: []string{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, invalid := CleanToolsets(tt.input) - - require.Len(t, result, len(tt.expected), "result length should match expected length") - - if tt.expectedInvalid == nil { - tt.expectedInvalid = []string{} - } - require.Len(t, invalid, len(tt.expectedInvalid), "invalid length should match expected invalid length") - - resultMap := make(map[string]bool) - for _, toolset := range result { - resultMap[toolset] = true - } - - expectedMap := make(map[string]bool) - for _, toolset := range tt.expected { - expectedMap[toolset] = true - } - - invalidMap := make(map[string]bool) - for _, toolset := range invalid { - invalidMap[toolset] = true - } - - expectedInvalidMap := make(map[string]bool) - for _, toolset := range tt.expectedInvalid { - expectedInvalidMap[toolset] = true - } - - assert.Equal(t, expectedMap, resultMap, "result should contain all expected toolsets without duplicates") - assert.Equal(t, expectedInvalidMap, invalidMap, "invalid should contain all expected invalid toolsets") - - assert.Len(t, resultMap, len(result), "result should not contain duplicates") - }) - } -} - func TestAddDefaultToolset(t *testing.T) { tests := []struct { name string @@ -280,3 +151,34 @@ func TestContainsToolset(t *testing.T) { }) } } + +func TestGenerateToolsetsHelp(t *testing.T) { + // Generate the help text + helpText := GenerateToolsetsHelp() + + // Verify help text is not empty + require.NotEmpty(t, helpText) + + // Verify it contains expected sections + assert.Contains(t, helpText, "Comma-separated list of tool groups to enable") + assert.Contains(t, helpText, "Available:") + assert.Contains(t, helpText, "Special toolset keywords:") + assert.Contains(t, helpText, "all: Enables all available toolsets") + assert.Contains(t, helpText, "default: Enables the default toolset configuration") + assert.Contains(t, helpText, "Examples:") + assert.Contains(t, helpText, "--toolsets=actions,gists,notifications") + assert.Contains(t, helpText, "--toolsets=default,actions,gists") + assert.Contains(t, helpText, "--toolsets=all") + + // Verify it contains some expected default toolsets + assert.Contains(t, helpText, "context") + assert.Contains(t, helpText, "repos") + assert.Contains(t, helpText, "issues") + assert.Contains(t, helpText, "pull_requests") + assert.Contains(t, helpText, "users") + + // Verify it contains some expected available toolsets + assert.Contains(t, helpText, "actions") + assert.Contains(t, helpText, "gists") + assert.Contains(t, helpText, "notifications") +} diff --git a/pkg/github/tools_validation_test.go b/pkg/github/tools_validation_test.go new file mode 100644 index 000000000..90e3c744c --- /dev/null +++ b/pkg/github/tools_validation_test.go @@ -0,0 +1,186 @@ +package github + +import ( + "testing" + + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// stubTranslation is a simple translation function for testing +func stubTranslation(_, fallback string) string { + return fallback +} + +// TestAllToolsHaveRequiredMetadata validates that all tools have mandatory metadata: +// - Toolset must be set (non-empty ID) +// - ReadOnlyHint annotation must be explicitly set (not nil) +func TestAllToolsHaveRequiredMetadata(t *testing.T) { + tools := AllTools(stubTranslation) + + require.NotEmpty(t, tools, "AllTools should return at least one tool") + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, tool.Toolset.ID, + "Tool %q must have a Toolset.ID", tool.Tool.Name) + + // Toolset description should be set for documentation + assert.NotEmpty(t, tool.Toolset.Description, + "Tool %q should have a Toolset.Description", tool.Tool.Name) + + // Annotations must exist and have ReadOnlyHint explicitly set + require.NotNil(t, tool.Tool.Annotations, + "Tool %q must have Annotations set (for ReadOnlyHint)", tool.Tool.Name) + + // We can't distinguish between "not set" and "set to false" for a bool, + // but having Annotations non-nil confirms the developer thought about it. + // The ReadOnlyHint value itself is validated by ensuring Annotations exist. + }) + } +} + +// TestAllResourcesHaveRequiredMetadata validates that all resources have mandatory metadata +func TestAllResourcesHaveRequiredMetadata(t *testing.T) { + // Resources are now stateless - no client functions needed + resources := AllResources(stubTranslation) + + require.NotEmpty(t, resources, "AllResources should return at least one resource") + + for _, res := range resources { + t.Run(res.Template.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, res.Toolset.ID, + "Resource %q must have a Toolset.ID", res.Template.Name) + + // HandlerFunc must be set + assert.True(t, res.HasHandler(), + "Resource %q must have a HandlerFunc", res.Template.Name) + }) + } +} + +// TestAllPromptsHaveRequiredMetadata validates that all prompts have mandatory metadata +func TestAllPromptsHaveRequiredMetadata(t *testing.T) { + prompts := AllPrompts(stubTranslation) + + require.NotEmpty(t, prompts, "AllPrompts should return at least one prompt") + + for _, prompt := range prompts { + t.Run(prompt.Prompt.Name, func(t *testing.T) { + // Toolset ID must be set + assert.NotEmpty(t, prompt.Toolset.ID, + "Prompt %q must have a Toolset.ID", prompt.Prompt.Name) + + // Handler must be set + assert.NotNil(t, prompt.Handler, + "Prompt %q must have a Handler", prompt.Prompt.Name) + }) + } +} + +// TestToolReadOnlyHintConsistency validates that read-only tools are correctly annotated +func TestToolReadOnlyHintConsistency(t *testing.T) { + tools := AllTools(stubTranslation) + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + require.NotNil(t, tool.Tool.Annotations, + "Tool %q must have Annotations", tool.Tool.Name) + + // Verify IsReadOnly() method matches the annotation + assert.Equal(t, tool.Tool.Annotations.ReadOnlyHint, tool.IsReadOnly(), + "Tool %q: IsReadOnly() should match Annotations.ReadOnlyHint", tool.Tool.Name) + }) + } +} + +// TestNoDuplicateToolNames ensures all tools have unique names +func TestNoDuplicateToolNames(t *testing.T) { + tools := AllTools(stubTranslation) + seen := make(map[string]bool) + featureFlagged := make(map[string]bool) + + // get_label is intentionally in both issues and labels toolsets for conformance + // with original behavior where it was registered in both + allowedDuplicates := map[string]bool{ + "get_label": true, + } + + // First pass: identify tools that have feature flags (mutually exclusive at runtime) + for _, tool := range tools { + if tool.FeatureFlagEnable != "" || tool.FeatureFlagDisable != "" { + featureFlagged[tool.Tool.Name] = true + } + } + + for _, tool := range tools { + name := tool.Tool.Name + // Allow duplicates for explicitly allowed tools and feature-flagged tools + if !allowedDuplicates[name] && !featureFlagged[name] { + assert.False(t, seen[name], + "Duplicate tool name found: %q", name) + } + seen[name] = true + } +} + +// TestNoDuplicateResourceNames ensures all resources have unique names +func TestNoDuplicateResourceNames(t *testing.T) { + resources := AllResources(stubTranslation) + seen := make(map[string]bool) + + for _, res := range resources { + name := res.Template.Name + assert.False(t, seen[name], + "Duplicate resource name found: %q", name) + seen[name] = true + } +} + +// TestNoDuplicatePromptNames ensures all prompts have unique names +func TestNoDuplicatePromptNames(t *testing.T) { + prompts := AllPrompts(stubTranslation) + seen := make(map[string]bool) + + for _, prompt := range prompts { + name := prompt.Prompt.Name + assert.False(t, seen[name], + "Duplicate prompt name found: %q", name) + seen[name] = true + } +} + +// TestAllToolsHaveHandlerFunc ensures all tools have a handler function +func TestAllToolsHaveHandlerFunc(t *testing.T) { + tools := AllTools(stubTranslation) + + for _, tool := range tools { + t.Run(tool.Tool.Name, func(t *testing.T) { + assert.NotNil(t, tool.HandlerFunc, + "Tool %q must have a HandlerFunc", tool.Tool.Name) + assert.True(t, tool.HasHandler(), + "Tool %q HasHandler() should return true", tool.Tool.Name) + }) + } +} + +// TestToolsetMetadataConsistency ensures tools in the same toolset have consistent descriptions +func TestToolsetMetadataConsistency(t *testing.T) { + tools := AllTools(stubTranslation) + toolsetDescriptions := make(map[inventory.ToolsetID]string) + + for _, tool := range tools { + id := tool.Toolset.ID + desc := tool.Toolset.Description + + if existing, ok := toolsetDescriptions[id]; ok { + assert.Equal(t, existing, desc, + "Toolset %q has inconsistent descriptions across tools", id) + } else { + toolsetDescriptions[id] = desc + } + } +} diff --git a/pkg/github/toolset_icons_test.go b/pkg/github/toolset_icons_test.go new file mode 100644 index 000000000..fd9cec462 --- /dev/null +++ b/pkg/github/toolset_icons_test.go @@ -0,0 +1,86 @@ +package github + +import ( + "testing" + + "github.com/github/github-mcp-server/pkg/octicons" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestAllToolsetIconsExist validates that every toolset with an Icon field +// references an icon that actually exists in the embedded octicons. +// This prevents broken icon references from being merged. +func TestAllToolsetIconsExist(t *testing.T) { + // Get all available toolsets from the inventory + inv := NewInventory(stubTranslator).Build() + toolsets := inv.AvailableToolsets() + + // Also test remote-only toolsets + remoteToolsets := RemoteOnlyToolsets() + + // Combine both lists + allToolsets := make([]struct { + name string + icon string + }, 0) + + for _, ts := range toolsets { + if ts.Icon != "" { + allToolsets = append(allToolsets, struct { + name string + icon string + }{name: string(ts.ID), icon: ts.Icon}) + } + } + + for _, ts := range remoteToolsets { + if ts.Icon != "" { + allToolsets = append(allToolsets, struct { + name string + icon string + }{name: string(ts.ID), icon: ts.Icon}) + } + } + + require.NotEmpty(t, allToolsets, "expected at least one toolset with an icon") + + for _, ts := range allToolsets { + t.Run(ts.name, func(t *testing.T) { + // Check that icons return valid data URIs (not empty) + icons := octicons.Icons(ts.icon) + require.NotNil(t, icons, "toolset %s references icon %q which does not exist", ts.name, ts.icon) + assert.Len(t, icons, 2, "expected light and dark icon variants for toolset %s", ts.name) + + // Verify both variants have valid data URIs + for _, icon := range icons { + assert.NotEmpty(t, icon.Source, "icon source should not be empty for toolset %s", ts.name) + assert.Contains(t, icon.Source, "data:image/png;base64,", + "icon %s for toolset %s should be a valid data URI", ts.icon, ts.name) + } + }) + } +} + +// TestToolsetMetadataHasIcons ensures all toolsets have icons defined. +// This is a policy test - if you want to allow toolsets without icons, +// you can remove or modify this test. +func TestToolsetMetadataHasIcons(t *testing.T) { + // These toolsets are expected to NOT have icons (internal/special purpose) + exceptionsWithoutIcons := map[string]bool{ + "all": true, // Meta-toolset + "default": true, // Meta-toolset + } + + inv := NewInventory(stubTranslator).Build() + toolsets := inv.AvailableToolsets() + + for _, ts := range toolsets { + if exceptionsWithoutIcons[string(ts.ID)] { + continue + } + t.Run(string(ts.ID), func(t *testing.T) { + assert.NotEmpty(t, ts.Icon, "toolset %s should have an icon defined", ts.ID) + }) + } +} diff --git a/pkg/github/workflow_prompts.go b/pkg/github/workflow_prompts.go index bc7c7581f..e85c93348 100644 --- a/pkg/github/workflow_prompts.go +++ b/pkg/github/workflow_prompts.go @@ -4,13 +4,16 @@ import ( "context" "fmt" + "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/translations" "github.com/modelcontextprotocol/go-sdk/mcp" ) // IssueToFixWorkflowPrompt provides a guided workflow for creating an issue and then generating a PR to fix it -func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Prompt, handler mcp.PromptHandler) { - return mcp.Prompt{ +func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) inventory.ServerPrompt { + return inventory.NewServerPrompt( + ToolsetMetadataIssues, + mcp.Prompt{ Name: "issue_to_fix_workflow", Description: t("PROMPT_ISSUE_TO_FIX_WORKFLOW_DESCRIPTION", "Create an issue for a problem and then generate a pull request to fix it"), Arguments: []*mcp.PromptArgument{ @@ -102,5 +105,6 @@ func IssueToFixWorkflowPrompt(t translations.TranslationHelperFunc) (tool mcp.Pr return &mcp.GetPromptResult{ Messages: messages, }, nil - } + }, + ) } diff --git a/pkg/inventory/builder.go b/pkg/inventory/builder.go new file mode 100644 index 000000000..a0ed2baee --- /dev/null +++ b/pkg/inventory/builder.go @@ -0,0 +1,274 @@ +package inventory + +import ( + "context" + "sort" + "strings" +) + +// ToolFilter is a function that determines if a tool should be included. +// Returns true if the tool should be included, false to exclude it. +type ToolFilter func(ctx context.Context, tool *ServerTool) (bool, error) + +// Builder builds a Registry with the specified configuration. +// Use NewBuilder to create a builder, chain configuration methods, +// then call Build() to create the final inventory. +// +// Example: +// +// reg := NewBuilder(). +// SetTools(tools). +// SetResources(resources). +// SetPrompts(prompts). +// WithDeprecatedAliases(aliases). +// WithReadOnly(true). +// WithToolsets([]string{"repos", "issues"}). +// WithFeatureChecker(checker). +// WithFilter(myFilter). +// Build() +type Builder struct { + tools []ServerTool + resourceTemplates []ServerResourceTemplate + prompts []ServerPrompt + deprecatedAliases map[string]string + + // Configuration options (processed at Build time) + readOnly bool + toolsetIDs []string // raw input, processed at Build() + toolsetIDsIsNil bool // tracks if nil was passed (nil = defaults) + additionalTools []string // raw input, processed at Build() + featureChecker FeatureFlagChecker + filters []ToolFilter // filters to apply to all tools +} + +// NewBuilder creates a new Builder. +func NewBuilder() *Builder { + return &Builder{ + deprecatedAliases: make(map[string]string), + toolsetIDsIsNil: true, // default to nil (use defaults) + } +} + +// SetTools sets the tools for the inventory. Returns self for chaining. +func (b *Builder) SetTools(tools []ServerTool) *Builder { + b.tools = tools + return b +} + +// SetResources sets the resource templates for the inventory. Returns self for chaining. +func (b *Builder) SetResources(resources []ServerResourceTemplate) *Builder { + b.resourceTemplates = resources + return b +} + +// SetPrompts sets the prompts for the inventory. Returns self for chaining. +func (b *Builder) SetPrompts(prompts []ServerPrompt) *Builder { + b.prompts = prompts + return b +} + +// WithDeprecatedAliases adds deprecated tool name aliases that map to canonical names. +// Returns self for chaining. +func (b *Builder) WithDeprecatedAliases(aliases map[string]string) *Builder { + for oldName, newName := range aliases { + b.deprecatedAliases[oldName] = newName + } + return b +} + +// WithReadOnly sets whether only read-only tools should be available. +// When true, write tools are filtered out. Returns self for chaining. +func (b *Builder) WithReadOnly(readOnly bool) *Builder { + b.readOnly = readOnly + return b +} + +// WithToolsets specifies which toolsets should be enabled. +// Special keywords: +// - "all": enables all toolsets +// - "default": expands to toolsets marked with Default: true in their metadata +// +// Input strings are trimmed of whitespace and duplicates are removed. +// Pass nil to use default toolsets. Pass an empty slice to disable all toolsets +// (useful for dynamic toolsets mode where tools are enabled on demand). +// Returns self for chaining. +func (b *Builder) WithToolsets(toolsetIDs []string) *Builder { + b.toolsetIDs = toolsetIDs + b.toolsetIDsIsNil = toolsetIDs == nil + return b +} + +// WithTools specifies additional tools that bypass toolset filtering. +// These tools are additive - they will be included even if their toolset is not enabled. +// Read-only filtering still applies to these tools. +// Deprecated tool aliases are automatically resolved to their canonical names during Build(). +// Returns self for chaining. +func (b *Builder) WithTools(toolNames []string) *Builder { + b.additionalTools = toolNames + return b +} + +// WithFeatureChecker sets the feature flag checker function. +// The checker receives a context (for actor extraction) and feature flag name, +// returns (enabled, error). If error occurs, it will be logged and treated as false. +// If checker is nil, all feature flag checks return false. +// Returns self for chaining. +func (b *Builder) WithFeatureChecker(checker FeatureFlagChecker) *Builder { + b.featureChecker = checker + return b +} + +// WithFilter adds a filter function that will be applied to all tools. +// Multiple filters can be added and are evaluated in order. +// If any filter returns false or an error, the tool is excluded. +// Returns self for chaining. +func (b *Builder) WithFilter(filter ToolFilter) *Builder { + b.filters = append(b.filters, filter) + return b +} + +// Build creates the final Inventory with all configuration applied. +// This processes toolset filtering, tool name resolution, and sets up +// the inventory for use. The returned Inventory is ready for use with +// AvailableTools(), RegisterAll(), etc. +func (b *Builder) Build() *Inventory { + r := &Inventory{ + tools: b.tools, + resourceTemplates: b.resourceTemplates, + prompts: b.prompts, + deprecatedAliases: b.deprecatedAliases, + readOnly: b.readOnly, + featureChecker: b.featureChecker, + filters: b.filters, + } + + // Process toolsets and pre-compute metadata in a single pass + r.enabledToolsets, r.unrecognizedToolsets, r.toolsetIDs, r.toolsetIDSet, r.defaultToolsetIDs, r.toolsetDescriptions = b.processToolsets() + + // Process additional tools (resolve aliases) + if len(b.additionalTools) > 0 { + r.additionalTools = make(map[string]bool, len(b.additionalTools)) + for _, name := range b.additionalTools { + // Resolve deprecated aliases to canonical names + if canonical, isAlias := b.deprecatedAliases[name]; isAlias { + r.additionalTools[canonical] = true + } else { + r.additionalTools[name] = true + } + } + } + + return r +} + +// processToolsets processes the toolsetIDs configuration and returns: +// - enabledToolsets map (nil means all enabled) +// - unrecognizedToolsets list for warnings +// - allToolsetIDs sorted list of all toolset IDs +// - toolsetIDSet map for O(1) HasToolset lookup +// - defaultToolsetIDs sorted list of default toolset IDs +// - toolsetDescriptions map of toolset ID to description +func (b *Builder) processToolsets() (map[ToolsetID]bool, []string, []ToolsetID, map[ToolsetID]bool, []ToolsetID, map[ToolsetID]string) { + // Single pass: collect all toolset metadata together + validIDs := make(map[ToolsetID]bool) + defaultIDs := make(map[ToolsetID]bool) + descriptions := make(map[ToolsetID]string) + + for i := range b.tools { + t := &b.tools[i] + validIDs[t.Toolset.ID] = true + if t.Toolset.Default { + defaultIDs[t.Toolset.ID] = true + } + if t.Toolset.Description != "" { + descriptions[t.Toolset.ID] = t.Toolset.Description + } + } + for i := range b.resourceTemplates { + r := &b.resourceTemplates[i] + validIDs[r.Toolset.ID] = true + if r.Toolset.Default { + defaultIDs[r.Toolset.ID] = true + } + if r.Toolset.Description != "" { + descriptions[r.Toolset.ID] = r.Toolset.Description + } + } + for i := range b.prompts { + p := &b.prompts[i] + validIDs[p.Toolset.ID] = true + if p.Toolset.Default { + defaultIDs[p.Toolset.ID] = true + } + if p.Toolset.Description != "" { + descriptions[p.Toolset.ID] = p.Toolset.Description + } + } + + // Build sorted slices from the collected maps + allToolsetIDs := make([]ToolsetID, 0, len(validIDs)) + for id := range validIDs { + allToolsetIDs = append(allToolsetIDs, id) + } + sort.Slice(allToolsetIDs, func(i, j int) bool { return allToolsetIDs[i] < allToolsetIDs[j] }) + + defaultToolsetIDList := make([]ToolsetID, 0, len(defaultIDs)) + for id := range defaultIDs { + defaultToolsetIDList = append(defaultToolsetIDList, id) + } + sort.Slice(defaultToolsetIDList, func(i, j int) bool { return defaultToolsetIDList[i] < defaultToolsetIDList[j] }) + + toolsetIDs := b.toolsetIDs + + // Check for "all" keyword - enables all toolsets + for _, id := range toolsetIDs { + if strings.TrimSpace(id) == "all" { + return nil, nil, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions // nil means all enabled + } + } + + // nil means use defaults, empty slice means no toolsets + if b.toolsetIDsIsNil { + toolsetIDs = []string{"default"} + } + + // Expand "default" keyword, trim whitespace, collect other IDs, and track unrecognized + seen := make(map[ToolsetID]bool) + expanded := make([]ToolsetID, 0, len(toolsetIDs)) + var unrecognized []string + + for _, id := range toolsetIDs { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + continue + } + if trimmed == "default" { + for _, defaultID := range defaultToolsetIDList { + if !seen[defaultID] { + seen[defaultID] = true + expanded = append(expanded, defaultID) + } + } + } else { + tsID := ToolsetID(trimmed) + if !seen[tsID] { + seen[tsID] = true + expanded = append(expanded, tsID) + // Track if this toolset doesn't exist + if !validIDs[tsID] { + unrecognized = append(unrecognized, trimmed) + } + } + } + } + + if len(expanded) == 0 { + return make(map[ToolsetID]bool), unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions + } + + enabledToolsets := make(map[ToolsetID]bool, len(expanded)) + for _, id := range expanded { + enabledToolsets[id] = true + } + return enabledToolsets, unrecognized, allToolsetIDs, validIDs, defaultToolsetIDList, descriptions +} diff --git a/pkg/inventory/errors.go b/pkg/inventory/errors.go new file mode 100644 index 000000000..3a97c9c71 --- /dev/null +++ b/pkg/inventory/errors.go @@ -0,0 +1,41 @@ +package inventory + +import "fmt" + +// ToolsetDoesNotExistError is returned when a toolset is not found. +type ToolsetDoesNotExistError struct { + Name string +} + +func (e *ToolsetDoesNotExistError) Error() string { + return fmt.Sprintf("toolset %s does not exist", e.Name) +} + +func (e *ToolsetDoesNotExistError) Is(target error) bool { + if target == nil { + return false + } + if _, ok := target.(*ToolsetDoesNotExistError); ok { + return true + } + return false +} + +// NewToolsetDoesNotExistError creates a new ToolsetDoesNotExistError. +func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { + return &ToolsetDoesNotExistError{Name: name} +} + +// ToolDoesNotExistError is returned when a tool is not found. +type ToolDoesNotExistError struct { + Name string +} + +func (e *ToolDoesNotExistError) Error() string { + return fmt.Sprintf("tool %s does not exist", e.Name) +} + +// NewToolDoesNotExistError creates a new ToolDoesNotExistError. +func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { + return &ToolDoesNotExistError{Name: name} +} diff --git a/pkg/inventory/filters.go b/pkg/inventory/filters.go new file mode 100644 index 000000000..991001a64 --- /dev/null +++ b/pkg/inventory/filters.go @@ -0,0 +1,289 @@ +package inventory + +import ( + "context" + "fmt" + "os" + "sort" +) + +// FeatureFlagChecker is a function that checks if a feature flag is enabled. +// The context can be used to extract actor/user information for flag evaluation. +// Returns (enabled, error). If error occurs, the caller should log and treat as false. +type FeatureFlagChecker func(ctx context.Context, flagName string) (bool, error) + +// isToolsetEnabled checks if a toolset is enabled based on current filters. +func (r *Inventory) isToolsetEnabled(toolsetID ToolsetID) bool { + // Check enabled toolsets filter + if r.enabledToolsets != nil { + return r.enabledToolsets[toolsetID] + } + return true +} + +// checkFeatureFlag checks a feature flag using the feature checker. +// Returns false if checker is nil or returns an error (errors are logged). +func (r *Inventory) checkFeatureFlag(ctx context.Context, flagName string) bool { + if r.featureChecker == nil || flagName == "" { + return false + } + enabled, err := r.featureChecker(ctx, flagName) + if err != nil { + fmt.Fprintf(os.Stderr, "Feature flag check error for %q: %v\n", flagName, err) + return false + } + return enabled +} + +// isFeatureFlagAllowed checks if an item passes feature flag filtering. +// - If FeatureFlagEnable is set, the item is only allowed if the flag is enabled +// - If FeatureFlagDisable is set, the item is excluded if the flag is enabled +func (r *Inventory) isFeatureFlagAllowed(ctx context.Context, enableFlag, disableFlag string) bool { + // Check enable flag - item requires this flag to be on + if enableFlag != "" && !r.checkFeatureFlag(ctx, enableFlag) { + return false + } + // Check disable flag - item is excluded if this flag is on + if disableFlag != "" && r.checkFeatureFlag(ctx, disableFlag) { + return false + } + return true +} + +// isToolEnabled checks if a specific tool is enabled based on current filters. +// Filter evaluation order: +// 1. Tool.Enabled (tool self-filtering) +// 2. FeatureFlagEnable/FeatureFlagDisable +// 3. Read-only filter +// 4. Builder filters (via WithFilter) +// 5. Toolset/additional tools +func (r *Inventory) isToolEnabled(ctx context.Context, tool *ServerTool) bool { + // 1. Check tool's own Enabled function first + if tool.Enabled != nil { + enabled, err := tool.Enabled(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Tool.Enabled check error for %q: %v\n", tool.Tool.Name, err) + return false + } + if !enabled { + return false + } + } + // 2. Check feature flags + if !r.isFeatureFlagAllowed(ctx, tool.FeatureFlagEnable, tool.FeatureFlagDisable) { + return false + } + // 3. Check read-only filter (applies to all tools) + if r.readOnly && !tool.IsReadOnly() { + return false + } + // 4. Apply builder filters + for _, filter := range r.filters { + allowed, err := filter(ctx, tool) + if err != nil { + fmt.Fprintf(os.Stderr, "Builder filter error for tool %q: %v\n", tool.Tool.Name, err) + return false + } + if !allowed { + return false + } + } + // 5. Check if tool is in additionalTools (bypasses toolset filter) + if r.additionalTools != nil && r.additionalTools[tool.Tool.Name] { + return true + } + // 5. Check toolset filter + if !r.isToolsetEnabled(tool.Toolset.ID) { + return false + } + return true +} + +// AvailableTools returns the tools that pass all current filters, +// sorted deterministically by toolset ID, then tool name. +// The context is used for feature flag evaluation. +func (r *Inventory) AvailableTools(ctx context.Context) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + if r.isToolEnabled(ctx, tool) { + result = append(result, *tool) + } + } + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableResourceTemplates returns resource templates that pass all current filters, +// sorted deterministically by toolset ID, then template name. +// The context is used for feature flag evaluation. +func (r *Inventory) AvailableResourceTemplates(ctx context.Context) []ServerResourceTemplate { + var result []ServerResourceTemplate + for i := range r.resourceTemplates { + res := &r.resourceTemplates[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, res.FeatureFlagEnable, res.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(res.Toolset.ID) { + result = append(result, *res) + } + } + + // Sort deterministically: by toolset ID, then by template name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Template.Name < result[j].Template.Name + }) + + return result +} + +// AvailablePrompts returns prompts that pass all current filters, +// sorted deterministically by toolset ID, then prompt name. +// The context is used for feature flag evaluation. +func (r *Inventory) AvailablePrompts(ctx context.Context) []ServerPrompt { + var result []ServerPrompt + for i := range r.prompts { + prompt := &r.prompts[i] + // Check feature flags + if !r.isFeatureFlagAllowed(ctx, prompt.FeatureFlagEnable, prompt.FeatureFlagDisable) { + continue + } + if r.isToolsetEnabled(prompt.Toolset.ID) { + result = append(result, *prompt) + } + } + + // Sort deterministically: by toolset ID, then by prompt name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Prompt.Name < result[j].Prompt.Name + }) + + return result +} + +// filterToolsByName returns tools matching the given name, checking deprecated aliases. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Inventory) filterToolsByName(name string) []ServerTool { + // First check for exact match + for i := range r.tools { + if r.tools[i].Tool.Name == name { + return []ServerTool{r.tools[i]} + } + } + // Check if name is a deprecated alias + if canonical, isAlias := r.deprecatedAliases[name]; isAlias { + for i := range r.tools { + if r.tools[i].Tool.Name == canonical { + return []ServerTool{r.tools[i]} + } + } + } + return []ServerTool{} +} + +// filterResourcesByURI returns resource templates matching the given URI pattern. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Inventory) filterResourcesByURI(uri string) []ServerResourceTemplate { + for i := range r.resourceTemplates { + if r.resourceTemplates[i].Template.URITemplate == uri { + return []ServerResourceTemplate{r.resourceTemplates[i]} + } + } + return []ServerResourceTemplate{} +} + +// filterPromptsByName returns prompts matching the given name. +// Uses linear scan - optimized for single-lookup per-request scenarios (ForMCPRequest). +func (r *Inventory) filterPromptsByName(name string) []ServerPrompt { + for i := range r.prompts { + if r.prompts[i].Prompt.Name == name { + return []ServerPrompt{r.prompts[i]} + } + } + return []ServerPrompt{} +} + +// ToolsForToolset returns all tools belonging to a specific toolset. +// This method bypasses the toolset enabled filter (for dynamic toolset registration), +// but still respects the read-only filter. +func (r *Inventory) ToolsForToolset(toolsetID ToolsetID) []ServerTool { + var result []ServerTool + for i := range r.tools { + tool := &r.tools[i] + // Only check read-only filter, not toolset enabled filter + if tool.Toolset.ID == toolsetID { + if r.readOnly && !tool.IsReadOnly() { + continue + } + result = append(result, *tool) + } + } + + // Sort by tool name for deterministic order + sort.Slice(result, func(i, j int) bool { + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// IsToolsetEnabled checks if a toolset is currently enabled based on filters. +func (r *Inventory) IsToolsetEnabled(toolsetID ToolsetID) bool { + return r.isToolsetEnabled(toolsetID) +} + +// EnableToolset marks a toolset as enabled in this group. +// This is used by dynamic toolset management to track which toolsets have been enabled. +func (r *Inventory) EnableToolset(toolsetID ToolsetID) { + if r.enabledToolsets == nil { + // nil means all enabled, so nothing to do + return + } + r.enabledToolsets[toolsetID] = true +} + +// EnabledToolsetIDs returns the list of enabled toolset IDs based on current filters. +// Returns all toolset IDs if no filter is set. +func (r *Inventory) EnabledToolsetIDs() []ToolsetID { + if r.enabledToolsets == nil { + return r.ToolsetIDs() + } + + ids := make([]ToolsetID, 0, len(r.enabledToolsets)) + for id := range r.enabledToolsets { + if r.HasToolset(id) { + ids = append(ids, id) + } + } + sort.Slice(ids, func(i, j int) bool { return ids[i] < ids[j] }) + return ids +} + +// FilteredTools returns tools filtered by the Enabled function and builder filters. +// This provides an explicit API for accessing filtered tools, currently implemented +// as an alias for AvailableTools. +// +// The error return is currently always nil but is included for future extensibility. +// Library consumers (e.g., remote server implementations) may need to surface +// recoverable filter errors rather than silently logging them. Having the error +// return in the API now avoids breaking changes later. +// +// The context is used for Enabled function evaluation and builder filter checks. +func (r *Inventory) FilteredTools(ctx context.Context) ([]ServerTool, error) { + return r.AvailableTools(ctx), nil +} diff --git a/pkg/inventory/prompts.go b/pkg/inventory/prompts.go new file mode 100644 index 000000000..648f20f9c --- /dev/null +++ b/pkg/inventory/prompts.go @@ -0,0 +1,26 @@ +package inventory + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ServerPrompt pairs a prompt with its toolset metadata. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler mcp.PromptHandler + // Toolset identifies which toolset this prompt belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this prompt + // to be available. If set and the flag is not enabled, the prompt is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this prompt + // to be omitted. Used to disable prompts when a feature flag is on. + FeatureFlagDisable string +} + +// NewServerPrompt creates a new ServerPrompt with toolset metadata. +func NewServerPrompt(toolset ToolsetMetadata, prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { + return ServerPrompt{ + Prompt: prompt, + Handler: handler, + Toolset: toolset, + } +} diff --git a/pkg/inventory/registry.go b/pkg/inventory/registry.go new file mode 100644 index 000000000..f3691e38a --- /dev/null +++ b/pkg/inventory/registry.go @@ -0,0 +1,296 @@ +package inventory + +import ( + "context" + "fmt" + "os" + "slices" + "sort" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// Inventory holds a collection of tools, resources, and prompts with filtering applied. +// Create a Inventory using Builder: +// +// reg := NewBuilder(). +// SetTools(tools). +// WithReadOnly(true). +// WithToolsets([]string{"repos"}). +// Build() +// +// The Inventory is configured at build time and provides: +// - Filtered access to tools/resources/prompts via Available* methods +// - Deterministic ordering for documentation generation +// - Lazy dependency injection during registration via RegisterAll() +// - Runtime toolset enabling for dynamic toolsets mode +type Inventory struct { + // tools holds all tools in this group (ordered for iteration) + tools []ServerTool + // resourceTemplates holds all resource templates in this group (ordered for iteration) + resourceTemplates []ServerResourceTemplate + // prompts holds all prompts in this group (ordered for iteration) + prompts []ServerPrompt + // deprecatedAliases maps old tool names to new canonical names + deprecatedAliases map[string]string + + // Pre-computed toolset metadata (set during Build) + toolsetIDs []ToolsetID // sorted list of all toolset IDs + toolsetIDSet map[ToolsetID]bool // set for O(1) HasToolset lookup + defaultToolsetIDs []ToolsetID // sorted list of default toolset IDs + toolsetDescriptions map[ToolsetID]string // toolset ID -> description + + // Filters - these control what's returned by Available* methods + // readOnly when true filters out write tools + readOnly bool + // enabledToolsets when non-nil, only include tools/resources/prompts from these toolsets + // when nil, all toolsets are enabled + enabledToolsets map[ToolsetID]bool + // additionalTools are specific tools that bypass toolset filtering (but still respect read-only) + // These are additive - a tool is included if it matches toolset filters OR is in this set + additionalTools map[string]bool + // featureChecker when non-nil, checks if a feature flag is enabled. + // Takes context and flag name, returns (enabled, error). If error, log and treat as false. + // If checker is nil, all flag checks return false. + featureChecker FeatureFlagChecker + // filters are functions that will be applied to all tools during filtering. + // If any filter returns false or an error, the tool is excluded. + filters []ToolFilter + // unrecognizedToolsets holds toolset IDs that were requested but don't match any registered toolsets + unrecognizedToolsets []string +} + +// UnrecognizedToolsets returns toolset IDs that were passed to WithToolsets but don't +// match any registered toolsets. This is useful for warning users about typos. +func (r *Inventory) UnrecognizedToolsets() []string { + return r.unrecognizedToolsets +} + +// MCP method constants for use with ForMCPRequest. +const ( + MCPMethodInitialize = "initialize" + MCPMethodToolsList = "tools/list" + MCPMethodToolsCall = "tools/call" + MCPMethodResourcesList = "resources/list" + MCPMethodResourcesRead = "resources/read" + MCPMethodResourcesTemplatesList = "resources/templates/list" + MCPMethodPromptsList = "prompts/list" + MCPMethodPromptsGet = "prompts/get" +) + +// ForMCPRequest returns a Registry optimized for a specific MCP request. +// This is designed for servers that create a new instance per request (like the remote server), +// allowing them to only register the items needed for that specific request rather than all ~90 tools. +// +// Parameters: +// - method: The MCP method being called (use MCP* constants) +// - itemName: Name of specific item for call/get methods (tool name, resource URI, or prompt name) +// +// Returns a new Registry containing only the items relevant to the request: +// - MCPMethodInitialize: Empty (capabilities are set via ServerOptions, not registration) +// - MCPMethodToolsList: All available tools (no resources/prompts) +// - MCPMethodToolsCall: Only the named tool +// - MCPMethodResourcesList, MCPMethodResourcesTemplatesList: All available resources (no tools/prompts) +// - MCPMethodResourcesRead: Only the named resource template +// - MCPMethodPromptsList: All available prompts (no tools/resources) +// - MCPMethodPromptsGet: Only the named prompt +// - Unknown methods: Empty (no items registered) +// +// All existing filters (read-only, toolsets, etc.) still apply to the returned items. +func (r *Inventory) ForMCPRequest(method string, itemName string) *Inventory { + // Create a shallow copy with shared filter settings + // Note: lazy-init maps (toolsByName, etc.) are NOT copied - the new Registry + // will initialize its own maps on first use if needed + result := &Inventory{ + tools: r.tools, + resourceTemplates: r.resourceTemplates, + prompts: r.prompts, + deprecatedAliases: r.deprecatedAliases, + readOnly: r.readOnly, + enabledToolsets: r.enabledToolsets, // shared, not modified + additionalTools: r.additionalTools, // shared, not modified + featureChecker: r.featureChecker, + filters: r.filters, // shared, not modified + unrecognizedToolsets: r.unrecognizedToolsets, + } + + // Helper to clear all item types + clearAll := func() { + result.tools = []ServerTool{} + result.resourceTemplates = []ServerResourceTemplate{} + result.prompts = []ServerPrompt{} + } + + switch method { + case MCPMethodInitialize: + clearAll() + case MCPMethodToolsList: + result.resourceTemplates, result.prompts = nil, nil + case MCPMethodToolsCall: + result.resourceTemplates, result.prompts = nil, nil + if itemName != "" { + result.tools = r.filterToolsByName(itemName) + } + case MCPMethodResourcesList, MCPMethodResourcesTemplatesList: + result.tools, result.prompts = nil, nil + case MCPMethodResourcesRead: + result.tools, result.prompts = nil, nil + if itemName != "" { + result.resourceTemplates = r.filterResourcesByURI(itemName) + } + case MCPMethodPromptsList: + result.tools, result.resourceTemplates = nil, nil + case MCPMethodPromptsGet: + result.tools, result.resourceTemplates = nil, nil + if itemName != "" { + result.prompts = r.filterPromptsByName(itemName) + } + default: + clearAll() + } + + return result +} + +// ToolsetIDs returns a sorted list of unique toolset IDs from all tools in this group. +func (r *Inventory) ToolsetIDs() []ToolsetID { + return r.toolsetIDs +} + +// DefaultToolsetIDs returns the IDs of toolsets marked as Default in their metadata. +// The IDs are returned in sorted order for deterministic output. +func (r *Inventory) DefaultToolsetIDs() []ToolsetID { + return r.defaultToolsetIDs +} + +// ToolsetDescriptions returns a map of toolset ID to description for all toolsets. +func (r *Inventory) ToolsetDescriptions() map[ToolsetID]string { + return r.toolsetDescriptions +} + +// RegisterTools registers all available tools with the server using the provided dependencies. +// The context is used for feature flag evaluation. +func (r *Inventory) RegisterTools(ctx context.Context, s *mcp.Server, deps any) { + for _, tool := range r.AvailableTools(ctx) { + tool.RegisterFunc(s, deps) + } +} + +// RegisterResourceTemplates registers all available resource templates with the server. +// The context is used for feature flag evaluation. +// Icons are automatically applied from the toolset metadata if not already set. +func (r *Inventory) RegisterResourceTemplates(ctx context.Context, s *mcp.Server, deps any) { + for _, res := range r.AvailableResourceTemplates(ctx) { + // Make a shallow copy to avoid mutating the original + templateCopy := res.Template + // Apply icons from toolset metadata if not already set + if len(templateCopy.Icons) == 0 { + templateCopy.Icons = res.Toolset.Icons() + } + s.AddResourceTemplate(&templateCopy, res.Handler(deps)) + } +} + +// RegisterPrompts registers all available prompts with the server. +// The context is used for feature flag evaluation. +// Icons are automatically applied from the toolset metadata if not already set. +func (r *Inventory) RegisterPrompts(ctx context.Context, s *mcp.Server) { + for _, prompt := range r.AvailablePrompts(ctx) { + // Make a shallow copy to avoid mutating the original + promptCopy := prompt.Prompt + // Apply icons from toolset metadata if not already set + if len(promptCopy.Icons) == 0 { + promptCopy.Icons = prompt.Toolset.Icons() + } + s.AddPrompt(&promptCopy, prompt.Handler) + } +} + +// RegisterAll registers all available tools, resources, and prompts with the server. +// The context is used for feature flag evaluation. +func (r *Inventory) RegisterAll(ctx context.Context, s *mcp.Server, deps any) { + r.RegisterTools(ctx, s, deps) + r.RegisterResourceTemplates(ctx, s, deps) + r.RegisterPrompts(ctx, s) +} + +// ResolveToolAliases resolves deprecated tool aliases to their canonical names. +// It logs a warning to stderr for each deprecated alias that is resolved. +// Returns: +// - resolved: tool names with aliases replaced by canonical names +// - aliasesUsed: map of oldName → newName for each alias that was resolved +func (r *Inventory) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { + resolved = make([]string, 0, len(toolNames)) + aliasesUsed = make(map[string]string) + for _, toolName := range toolNames { + if canonicalName, isAlias := r.deprecatedAliases[toolName]; isAlias { + fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) + aliasesUsed[toolName] = canonicalName + resolved = append(resolved, canonicalName) + } else { + resolved = append(resolved, toolName) + } + } + return resolved, aliasesUsed +} + +// FindToolByName searches all tools for one matching the given name. +// Returns the tool, its toolset ID, and an error if not found. +// This searches ALL tools regardless of filters. +func (r *Inventory) FindToolByName(toolName string) (*ServerTool, ToolsetID, error) { + for i := range r.tools { + if r.tools[i].Tool.Name == toolName { + return &r.tools[i], r.tools[i].Toolset.ID, nil + } + } + return nil, "", NewToolDoesNotExistError(toolName) +} + +// HasToolset checks if any tool/resource/prompt belongs to the given toolset. +func (r *Inventory) HasToolset(toolsetID ToolsetID) bool { + return r.toolsetIDSet[toolsetID] +} + +// AllTools returns all tools without any filtering, sorted deterministically. +func (r *Inventory) AllTools() []ServerTool { + result := slices.Clone(r.tools) + + // Sort deterministically: by toolset ID, then by tool name + sort.Slice(result, func(i, j int) bool { + if result[i].Toolset.ID != result[j].Toolset.ID { + return result[i].Toolset.ID < result[j].Toolset.ID + } + return result[i].Tool.Name < result[j].Tool.Name + }) + + return result +} + +// AvailableToolsets returns the unique toolsets that have tools, in sorted order. +// This is the ordered intersection of toolsets with reality - only toolsets that +// actually contain tools are returned, sorted by toolset ID. +// Optional exclude parameter filters out specific toolset IDs from the result. +func (r *Inventory) AvailableToolsets(exclude ...ToolsetID) []ToolsetMetadata { + tools := r.AllTools() + if len(tools) == 0 { + return nil + } + + // Build exclude set for O(1) lookup + excludeSet := make(map[ToolsetID]bool, len(exclude)) + for _, id := range exclude { + excludeSet[id] = true + } + + var result []ToolsetMetadata + var lastID ToolsetID + for _, tool := range tools { + if tool.Toolset.ID != lastID { + lastID = tool.Toolset.ID + if !excludeSet[lastID] { + result = append(result, tool.Toolset) + } + } + } + return result +} diff --git a/pkg/inventory/registry_test.go b/pkg/inventory/registry_test.go new file mode 100644 index 000000000..41e94b8d9 --- /dev/null +++ b/pkg/inventory/registry_test.go @@ -0,0 +1,1645 @@ +package inventory + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// testToolsetMetadata returns a ToolsetMetadata for testing +func testToolsetMetadata(id string) ToolsetMetadata { + return ToolsetMetadata{ + ID: ToolsetID(id), + Description: "Test toolset: " + id, + } +} + +// testToolsetMetadataWithDefault returns a ToolsetMetadata with Default flag for testing +func testToolsetMetadataWithDefault(id string, isDefault bool) ToolsetMetadata { + return ToolsetMetadata{ + ID: ToolsetID(id), + Description: "Test toolset: " + id, + Default: isDefault, + } +} + +// mockToolWithDefault creates a mock tool with a default toolset flag +func mockToolWithDefault(name string, toolsetID string, readOnly bool, isDefault bool) ServerTool { + return NewServerToolFromHandler( + mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ + ReadOnlyHint: readOnly, + }, + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }, + testToolsetMetadataWithDefault(toolsetID, isDefault), + func(_ any) mcp.ToolHandler { + return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + }, + ) +} + +// mockTool creates a minimal ServerTool for testing +func mockTool(name string, toolsetID string, readOnly bool) ServerTool { + return NewServerToolFromHandler( + mcp.Tool{ + Name: name, + Annotations: &mcp.ToolAnnotations{ + ReadOnlyHint: readOnly, + }, + InputSchema: json.RawMessage(`{"type":"object","properties":{}}`), + }, + testToolsetMetadata(toolsetID), + func(_ any) mcp.ToolHandler { + return func(_ context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil + } + }, + ) +} + +func TestNewRegistryEmpty(t *testing.T) { + reg := NewBuilder().Build() + if len(reg.AvailableTools(context.Background())) != 0 { + t.Fatalf("Expected tools to be empty") + } + if len(reg.AvailableResourceTemplates(context.Background())) != 0 { + t.Fatalf("Expected resourceTemplates to be empty") + } + if len(reg.AvailablePrompts(context.Background())) != 0 { + t.Fatalf("Expected prompts to be empty") + } +} + +func TestNewRegistryWithTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).Build() + + if len(reg.AllTools()) != 3 { + t.Errorf("Expected 3 tools, got %d", len(reg.AllTools())) + } +} + +func TestAvailableTools_NoFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool_b", "toolset1", true), + mockTool("tool_a", "toolset1", false), + mockTool("tool_c", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != 3 { + t.Fatalf("Expected 3 available tools, got %d", len(available)) + } + + // Verify deterministic sorting: by toolset ID, then tool name + expectedOrder := []string{"tool_a", "tool_b", "tool_c"} + for i, tool := range available { + if tool.Tool.Name != expectedOrder[i] { + t.Errorf("Tool at index %d: expected %s, got %s", i, expectedOrder[i], tool.Tool.Name) + } + } +} + +func TestWithReadOnly(t *testing.T) { + tools := []ServerTool{ + mockTool("read_tool", "toolset1", true), + mockTool("write_tool", "toolset1", false), + } + + // Build without read-only - should have both tools + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := reg.AvailableTools(context.Background()) + if len(allTools) != 2 { + t.Fatalf("Expected 2 tools without read-only, got %d", len(allTools)) + } + + // Build with read-only - should filter out write tools + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + readOnlyTools := readOnlyReg.AvailableTools(context.Background()) + if len(readOnlyTools) != 1 { + t.Fatalf("Expected 1 tool in read-only, got %d", len(readOnlyTools)) + } + if readOnlyTools[0].Tool.Name != "read_tool" { + t.Errorf("Expected read_tool, got %s", readOnlyTools[0].Tool.Name) + } +} + +func TestWithToolsets(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + mockTool("tool3", "toolset3", true), + } + + // Build with all toolsets + allReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + allTools := allReg.AvailableTools(context.Background()) + if len(allTools) != 3 { + t.Fatalf("Expected 3 tools without filter, got %d", len(allTools)) + } + + // Build with specific toolsets + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset3"}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) + } + + // Verify correct tools are included + toolNames := make(map[string]bool) + for _, tool := range filteredTools { + toolNames[tool.Tool.Name] = true + } + if !toolNames["tool1"] || !toolNames["tool3"] { + t.Errorf("Expected tool1 and tool3, got %v", toolNames) + } +} + +func TestWithToolsetsTrimsWhitespace(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + // Whitespace should be trimmed + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{" toolset1 ", " toolset2 "}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 tools after whitespace trimming, got %d", len(filteredTools)) + } +} + +func TestWithToolsetsDeduplicates(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Duplicates should be removed + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1", "toolset1", " toolset1 "}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 1 { + t.Fatalf("Expected 1 tool after deduplication, got %d", len(filteredTools)) + } +} + +func TestWithToolsetsIgnoresEmptyStrings(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Empty strings should be ignored + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{"", "toolset1", " ", ""}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(filteredTools)) + } +} + +func TestUnrecognizedToolsets(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + tests := []struct { + name string + input []string + expectedUnrecognized []string + }{ + { + name: "all valid", + input: []string{"toolset1", "toolset2"}, + expectedUnrecognized: nil, + }, + { + name: "one invalid", + input: []string{"toolset1", "invalid_toolset"}, + expectedUnrecognized: []string{"invalid_toolset"}, + }, + { + name: "multiple invalid", + input: []string{"typo1", "toolset1", "typo2"}, + expectedUnrecognized: []string{"typo1", "typo2"}, + }, + { + name: "invalid with whitespace trimmed", + input: []string{" invalid_tool "}, + expectedUnrecognized: []string{"invalid_tool"}, + }, + { + name: "empty input", + input: []string{}, + expectedUnrecognized: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filtered := NewBuilder().SetTools(tools).WithToolsets(tt.input).Build() + unrecognized := filtered.UnrecognizedToolsets() + + if len(unrecognized) != len(tt.expectedUnrecognized) { + t.Fatalf("Expected %d unrecognized, got %d: %v", + len(tt.expectedUnrecognized), len(unrecognized), unrecognized) + } + + for i, expected := range tt.expectedUnrecognized { + if unrecognized[i] != expected { + t.Errorf("Expected unrecognized[%d] = %q, got %q", i, expected, unrecognized[i]) + } + } + }) + } +} + +func TestWithTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset2", true), + } + + // WithTools adds additional tools that bypass toolset filtering + // When combined with WithToolsets([]), only the additional tools should be available + filteredReg := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"tool1", "tool3"}).Build() + filteredTools := filteredReg.AvailableTools(context.Background()) + + if len(filteredTools) != 2 { + t.Fatalf("Expected 2 filtered tools, got %d", len(filteredTools)) + } + + toolNames := make(map[string]bool) + for _, tool := range filteredTools { + toolNames[tool.Tool.Name] = true + } + if !toolNames["tool1"] || !toolNames["tool3"] { + t.Errorf("Expected tool1 and tool3, got %v", toolNames) + } +} + +func TestChainedFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("read1", "toolset1", true), + mockTool("write1", "toolset1", false), + mockTool("read2", "toolset2", true), + mockTool("write2", "toolset2", false), + } + + // Chain read-only and toolset filter + filtered := NewBuilder().SetTools(tools).WithReadOnly(true).WithToolsets([]string{"toolset1"}).Build() + result := filtered.AvailableTools(context.Background()) + + if len(result) != 1 { + t.Fatalf("Expected 1 tool after chained filters, got %d", len(result)) + } + if result[0].Tool.Name != "read1" { + t.Errorf("Expected read1, got %s", result[0].Tool.Name) + } +} + +func TestToolsetIDs(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset_b", true), + mockTool("tool2", "toolset_a", true), + mockTool("tool3", "toolset_b", true), // duplicate toolset + } + + reg := NewBuilder().SetTools(tools).Build() + ids := reg.ToolsetIDs() + + if len(ids) != 2 { + t.Fatalf("Expected 2 unique toolset IDs, got %d", len(ids)) + } + + // Should be sorted + if ids[0] != "toolset_a" || ids[1] != "toolset_b" { + t.Errorf("Expected sorted IDs [toolset_a, toolset_b], got %v", ids) + } +} + +func TestToolsetDescriptions(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).Build() + descriptions := reg.ToolsetDescriptions() + + if len(descriptions) != 2 { + t.Fatalf("Expected 2 descriptions, got %d", len(descriptions)) + } + + if descriptions["toolset1"] != "Test toolset: toolset1" { + t.Errorf("Wrong description for toolset1: %s", descriptions["toolset1"]) + } +} + +func TestToolsForToolset(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder().SetTools(tools).Build() + toolset1Tools := reg.ToolsForToolset("toolset1") + + if len(toolset1Tools) != 2 { + t.Fatalf("Expected 2 tools for toolset1, got %d", len(toolset1Tools)) + } +} + +func TestWithDeprecatedAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("new_name", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools).WithDeprecatedAliases(map[string]string{ + "old_name": "new_name", + "get_issue": "issue_read", + }).Build() + + // Test resolving aliases + resolved, aliasesUsed := reg.ResolveToolAliases([]string{"old_name"}) + if len(resolved) != 1 || resolved[0] != "new_name" { + t.Errorf("expected alias to resolve to 'new_name', got %v", resolved) + } + if len(aliasesUsed) != 1 || aliasesUsed["old_name"] != "new_name" { + t.Errorf("expected alias mapping, got %v", aliasesUsed) + } +} + +func TestResolveToolAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + mockTool("some_tool", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ + "get_issue": "issue_read", + }).Build() + + // Test resolving a mix of aliases and canonical names + input := []string{"get_issue", "some_tool"} + resolved, aliasesUsed := reg.ResolveToolAliases(input) + + if len(resolved) != 2 { + t.Fatalf("expected 2 resolved names, got %d", len(resolved)) + } + if resolved[0] != "issue_read" { + t.Errorf("expected 'issue_read', got '%s'", resolved[0]) + } + if resolved[1] != "some_tool" { + t.Errorf("expected 'some_tool' (unchanged), got '%s'", resolved[1]) + } + + if len(aliasesUsed) != 1 { + t.Fatalf("expected 1 alias used, got %d", len(aliasesUsed)) + } + if aliasesUsed["get_issue"] != "issue_read" { + t.Errorf("expected aliasesUsed['get_issue'] = 'issue_read', got '%s'", aliasesUsed["get_issue"]) + } +} + +func TestFindToolByName(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools).Build() + + // Find by name + tool, toolsetID, err := reg.FindToolByName("issue_read") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tool.Tool.Name != "issue_read" { + t.Errorf("expected tool name 'issue_read', got '%s'", tool.Tool.Name) + } + if toolsetID != "toolset1" { + t.Errorf("expected toolset ID 'toolset1', got '%s'", toolsetID) + } + + // Non-existent tool + _, _, err = reg.FindToolByName("nonexistent") + if err == nil { + t.Error("expected error for non-existent tool") + } +} + +func TestWithToolsAdditive(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + mockTool("issue_write", "toolset1", false), + mockTool("repo_read", "toolset2", true), + } + + // Test WithTools bypasses toolset filtering + // Enable only toolset2, but add issue_read as additional tool + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset2"}).WithTools([]string{"issue_read"}).Build() + + available := filtered.AvailableTools(context.Background()) + if len(available) != 2 { + t.Errorf("expected 2 tools (repo_read from toolset + issue_read additional), got %d", len(available)) + } + + // Verify both tools are present + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + if !toolNames["issue_read"] { + t.Error("expected issue_read to be included as additional tool") + } + if !toolNames["repo_read"] { + t.Error("expected repo_read to be included from toolset2") + } + + // Test WithTools respects read-only mode + readOnlyFiltered := NewBuilder().SetTools(tools).WithReadOnly(true).WithTools([]string{"issue_write"}).Build() + available = readOnlyFiltered.AvailableTools(context.Background()) + + // issue_write should be excluded because read-only applies to additional tools too + for _, tool := range available { + if tool.Tool.Name == "issue_write" { + t.Error("expected issue_write to be excluded in read-only mode") + } + } + + // Test WithTools with non-existent tool (should not error, just won't match anything) + nonexistent := NewBuilder().SetTools(tools).WithToolsets([]string{}).WithTools([]string{"nonexistent"}).Build() + available = nonexistent.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("expected 0 tools for non-existent additional tool, got %d", len(available)) + } +} + +func TestWithToolsResolvesAliases(t *testing.T) { + tools := []ServerTool{ + mockTool("issue_read", "toolset1", true), + } + + // Using deprecated alias should resolve to canonical name + filtered := NewBuilder().SetTools(tools). + WithDeprecatedAliases(map[string]string{ + "get_issue": "issue_read", + }). + WithToolsets([]string{}). + WithTools([]string{"get_issue"}). + Build() + available := filtered.AvailableTools(context.Background()) + + if len(available) != 1 { + t.Errorf("expected 1 tool, got %d", len(available)) + } + if available[0].Tool.Name != "issue_read" { + t.Errorf("expected issue_read, got %s", available[0].Tool.Name) + } +} + +func TestHasToolset(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + + if !reg.HasToolset("toolset1") { + t.Error("expected HasToolset to return true for existing toolset") + } + if reg.HasToolset("nonexistent") { + t.Error("expected HasToolset to return false for non-existent toolset") + } +} + +func TestEnabledToolsetIDs(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset2", true), + } + + // Without filter, all toolsets are enabled + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + ids := reg.EnabledToolsetIDs() + if len(ids) != 2 { + t.Fatalf("Expected 2 enabled toolset IDs, got %d", len(ids)) + } + + // With filter + filtered := NewBuilder().SetTools(tools).WithToolsets([]string{"toolset1"}).Build() + filteredIDs := filtered.EnabledToolsetIDs() + if len(filteredIDs) != 1 { + t.Fatalf("Expected 1 enabled toolset ID, got %d", len(filteredIDs)) + } + if filteredIDs[0] != "toolset1" { + t.Errorf("Expected toolset1, got %s", filteredIDs[0]) + } +} + +func TestAllTools(t *testing.T) { + tools := []ServerTool{ + mockTool("read_tool", "toolset1", true), + mockTool("write_tool", "toolset1", false), + } + + // Even with read-only filter, AllTools returns everything + readOnlyReg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + + allTools := readOnlyReg.AllTools() + if len(allTools) != 2 { + t.Fatalf("Expected 2 tools from AllTools, got %d", len(allTools)) + } + + // But AvailableTools respects the filter + availableTools := readOnlyReg.AvailableTools(context.Background()) + if len(availableTools) != 1 { + t.Fatalf("Expected 1 tool from AvailableTools, got %d", len(availableTools)) + } +} + +func TestServerToolIsReadOnly(t *testing.T) { + readTool := mockTool("read_tool", "toolset1", true) + writeTool := mockTool("write_tool", "toolset1", false) + + if !readTool.IsReadOnly() { + t.Error("Expected read tool to be read-only") + } + if writeTool.IsReadOnly() { + t.Error("Expected write tool to not be read-only") + } +} + +// mockResource creates a minimal ServerResourceTemplate for testing +func mockResource(name string, toolsetID string, uriTemplate string) ServerResourceTemplate { + return NewServerResourceTemplate( + testToolsetMetadata(toolsetID), + mcp.ResourceTemplate{ + Name: name, + URITemplate: uriTemplate, + }, + func(_ any) mcp.ResourceHandler { + return func(_ context.Context, _ *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return nil, nil + } + }, + ) +} + +// mockPrompt creates a minimal ServerPrompt for testing +func mockPrompt(name string, toolsetID string) ServerPrompt { + return NewServerPrompt( + testToolsetMetadata(toolsetID), + mcp.Prompt{Name: name}, + func(_ context.Context, _ *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return nil, nil + }, + ) +} + +func TestForMCPRequest_Initialize(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", false), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodInitialize, "") + + // Initialize should return empty - capabilities come from ServerOptions + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for initialize, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for initialize, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for initialize, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ToolsList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") + + // tools/list should return all tools, no resources or prompts + if len(filtered.AvailableTools(context.Background())) != 2 { + t.Errorf("Expected 2 tools for tools/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for tools/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for tools/list, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ToolsCall(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("create_issue", "issues", false), + mockTool("list_repos", "repos", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "get_me") + + available := filtered.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool for tools/call with name, got %d", len(available)) + } + if available[0].Tool.Name != "get_me" { + t.Errorf("Expected tool name 'get_me', got %q", available[0].Tool.Name) + } +} + +func TestForMCPRequest_ToolsCall_NotFound(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + } + + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "nonexistent") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for nonexistent tool, got %d", len(filtered.AvailableTools(context.Background()))) + } +} + +func TestForMCPRequest_ToolsCall_DeprecatedAlias(t *testing.T) { + tools := []ServerTool{ + mockTool("get_me", "context", true), + mockTool("list_commits", "repos", true), + } + + reg := NewBuilder().SetTools(tools). + WithToolsets([]string{"all"}). + WithDeprecatedAliases(map[string]string{ + "old_get_me": "get_me", + }).Build() + + // Request using the deprecated alias + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "old_get_me") + + available := filtered.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool when using deprecated alias, got %d", len(available)) + } + if available[0].Tool.Name != "get_me" { + t.Errorf("Expected canonical name 'get_me', got %q", available[0].Tool.Name) + } +} + +func TestForMCPRequest_ToolsCall_RespectsFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("create_issue", "issues", false), // write tool + } + + // Apply read-only filter at build time, then ForMCPRequest + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithReadOnly(true).Build() + filtered := reg.ForMCPRequest(MCPMethodToolsCall, "create_issue") + + // The tool exists in the filtered group, but AvailableTools respects read-only + available := filtered.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools - write tool should be filtered by read-only, got %d", len(available)) + } +} + +func TestForMCPRequest_ResourcesList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesList, "") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for resources/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources for resources/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for resources/list, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ResourcesRead(t *testing.T) { + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + mockResource("res2", "repos", "branch://{owner}/{repo}/{branch}"), + } + + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesRead, "repo://{owner}/{repo}") + + available := filtered.AvailableResourceTemplates(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 resource for resources/read, got %d", len(available)) + } + if available[0].Template.URITemplate != "repo://{owner}/{repo}" { + t.Errorf("Expected URI template 'repo://{owner}/{repo}', got %q", available[0].Template.URITemplate) + } +} + +func TestForMCPRequest_PromptsList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + mockPrompt("prompt2", "issues"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsList, "") + + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for prompts/list, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for prompts/list, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts for prompts/list, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_PromptsGet(t *testing.T) { + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + mockPrompt("prompt2", "issues"), + } + + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodPromptsGet, "prompt1") + + available := filtered.AvailablePrompts(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 prompt for prompts/get, got %d", len(available)) + } + if available[0].Prompt.Name != "prompt1" { + t.Errorf("Expected prompt name 'prompt1', got %q", available[0].Prompt.Name) + } +} + +func TestForMCPRequest_UnknownMethod(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest("unknown/method", "") + + // Unknown methods should return empty + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools for unknown method, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources for unknown method, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts for unknown method, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_DoesNotMutateOriginal(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + mockTool("tool2", "issues", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + prompts := []ServerPrompt{ + mockPrompt("prompt1", "repos"), + } + + original := NewBuilder().SetTools(tools).SetResources(resources).SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + filtered := original.ForMCPRequest(MCPMethodToolsCall, "tool1") + + // Original should be unchanged + if len(original.AvailableTools(context.Background())) != 2 { + t.Errorf("Original was mutated! Expected 2 tools, got %d", len(original.AvailableTools(context.Background()))) + } + if len(original.AvailableResourceTemplates(context.Background())) != 1 { + t.Errorf("Original was mutated! Expected 1 resource, got %d", len(original.AvailableResourceTemplates(context.Background()))) + } + if len(original.AvailablePrompts(context.Background())) != 1 { + t.Errorf("Original was mutated! Expected 1 prompt, got %d", len(original.AvailablePrompts(context.Background()))) + } + + // Filtered should have only the requested tool + if len(filtered.AvailableTools(context.Background())) != 1 { + t.Errorf("Expected 1 tool in filtered, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 0 { + t.Errorf("Expected 0 resources in filtered, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } + if len(filtered.AvailablePrompts(context.Background())) != 0 { + t.Errorf("Expected 0 prompts in filtered, got %d", len(filtered.AvailablePrompts(context.Background()))) + } +} + +func TestForMCPRequest_ChainedWithOtherFilters(t *testing.T) { + tools := []ServerTool{ + mockToolWithDefault("get_me", "context", true, true), // default toolset + mockToolWithDefault("create_issue", "issues", false, false), // not default + mockToolWithDefault("list_repos", "repos", true, true), // default toolset + mockToolWithDefault("delete_repo", "repos", false, true), // default but write + } + + // Chain: default toolsets -> read-only -> specific method + reg := NewBuilder().SetTools(tools). + WithToolsets([]string{"default"}). + WithReadOnly(true). + Build() + filtered := reg.ForMCPRequest(MCPMethodToolsList, "") + + available := filtered.AvailableTools(context.Background()) + + // Should have: get_me (context, read), list_repos (repos, read) + // Should NOT have: create_issue (issues not in default), delete_repo (write) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter chain, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["get_me"] { + t.Error("Expected get_me to be available") + } + if !toolNames["list_repos"] { + t.Error("Expected list_repos to be available") + } + if toolNames["create_issue"] { + t.Error("create_issue should not be available (toolset not enabled)") + } + if toolNames["delete_repo"] { + t.Error("delete_repo should not be available (write tool in read-only mode)") + } +} + +func TestForMCPRequest_ResourcesTemplatesList(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "repos", true), + } + resources := []ServerResourceTemplate{ + mockResource("res1", "repos", "repo://{owner}/{repo}"), + } + + reg := NewBuilder().SetTools(tools).SetResources(resources).WithToolsets([]string{"all"}).Build() + filtered := reg.ForMCPRequest(MCPMethodResourcesTemplatesList, "") + + // Same behavior as resources/list + if len(filtered.AvailableTools(context.Background())) != 0 { + t.Errorf("Expected 0 tools, got %d", len(filtered.AvailableTools(context.Background()))) + } + if len(filtered.AvailableResourceTemplates(context.Background())) != 1 { + t.Errorf("Expected 1 resource, got %d", len(filtered.AvailableResourceTemplates(context.Background()))) + } +} + +func TestMCPMethodConstants(t *testing.T) { + // Verify constants match expected MCP method names + tests := []struct { + constant string + expected string + }{ + {MCPMethodInitialize, "initialize"}, + {MCPMethodToolsList, "tools/list"}, + {MCPMethodToolsCall, "tools/call"}, + {MCPMethodResourcesList, "resources/list"}, + {MCPMethodResourcesRead, "resources/read"}, + {MCPMethodResourcesTemplatesList, "resources/templates/list"}, + {MCPMethodPromptsList, "prompts/list"}, + {MCPMethodPromptsGet, "prompts/get"}, + } + + for _, tt := range tests { + if tt.constant != tt.expected { + t.Errorf("Constant mismatch: got %q, expected %q", tt.constant, tt.expected) + } + } +} + +// mockToolWithFlags creates a ServerTool with feature flags for testing +func mockToolWithFlags(name string, toolsetID string, readOnly bool, enableFlag, disableFlag string) ServerTool { + tool := mockTool(name, toolsetID, readOnly) + tool.FeatureFlagEnable = enableFlag + tool.FeatureFlagDisable = disableFlag + return tool +} + +func TestFeatureFlagEnable(t *testing.T) { + tools := []ServerTool{ + mockTool("always_available", "toolset1", true), + mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), + } + + // Without feature checker, tool with FeatureFlagEnable should be excluded + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 tool without feature checker, got %d", len(available)) + } + if available[0].Tool.Name != "always_available" { + t.Errorf("Expected always_available, got %s", available[0].Tool.Name) + } + + // With feature checker returning false, tool should still be excluded + checkerFalse := func(_ context.Context, _ string) (bool, error) { return false, nil } + regFalse := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerFalse).Build() + availableFalse := regFalse.AvailableTools(context.Background()) + if len(availableFalse) != 1 { + t.Fatalf("Expected 1 tool with false checker, got %d", len(availableFalse)) + } + + // With feature checker returning true for "my_feature", tool should be included + checkerTrue := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + regTrue := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableTrue := regTrue.AvailableTools(context.Background()) + if len(availableTrue) != 2 { + t.Fatalf("Expected 2 tools with true checker, got %d", len(availableTrue)) + } +} + +func TestFeatureFlagDisable(t *testing.T) { + tools := []ServerTool{ + mockTool("always_available", "toolset1", true), + mockToolWithFlags("disabled_by_flag", "toolset1", true, "", "kill_switch"), + } + + // Without feature checker, tool with FeatureFlagDisable should be included (flag is false) + reg := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools without feature checker, got %d", len(available)) + } + + // With feature checker returning true for "kill_switch", tool should be excluded + checkerTrue := func(_ context.Context, flag string) (bool, error) { + return flag == "kill_switch", nil + } + regFiltered := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checkerTrue).Build() + availableFiltered := regFiltered.AvailableTools(context.Background()) + if len(availableFiltered) != 1 { + t.Fatalf("Expected 1 tool with kill_switch enabled, got %d", len(availableFiltered)) + } + if availableFiltered[0].Tool.Name != "always_available" { + t.Errorf("Expected always_available, got %s", availableFiltered[0].Tool.Name) + } +} + +func TestFeatureFlagBoth(t *testing.T) { + // Tool that requires "new_feature" AND is disabled by "kill_switch" + tools := []ServerTool{ + mockToolWithFlags("complex_tool", "toolset1", true, "new_feature", "kill_switch"), + } + + // Enable flag not set -> excluded + checker1 := func(_ context.Context, _ string) (bool, error) { return false, nil } + reg1 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker1).Build() + if len(reg1.AvailableTools(context.Background())) != 0 { + t.Error("Tool should be excluded when enable flag is false") + } + + // Enable flag set, disable flag not set -> included + checker2 := func(_ context.Context, flag string) (bool, error) { return flag == "new_feature", nil } + reg2 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker2).Build() + if len(reg2.AvailableTools(context.Background())) != 1 { + t.Error("Tool should be included when enable flag is true and disable flag is false") + } + + // Enable flag set, disable flag also set -> excluded (disable wins) + checker3 := func(_ context.Context, _ string) (bool, error) { return true, nil } + reg3 := NewBuilder().SetTools(tools).WithToolsets([]string{"all"}).WithFeatureChecker(checker3).Build() + if len(reg3.AvailableTools(context.Background())) != 0 { + t.Error("Tool should be excluded when both flags are true (disable wins)") + } +} + +func TestFeatureFlagError(t *testing.T) { + tools := []ServerTool{ + mockToolWithFlags("needs_flag", "toolset1", true, "my_feature", ""), + } + + // Checker that returns error should treat as false (tool excluded) + checkerError := func(_ context.Context, _ string) (bool, error) { + return false, fmt.Errorf("simulated error") + } + reg := NewBuilder().SetTools(tools).WithFeatureChecker(checkerError).Build() + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when checker errors, got %d", len(available)) + } +} + +func TestFeatureFlagResources(t *testing.T) { + resources := []ServerResourceTemplate{ + mockResource("always_available", "toolset1", "uri1"), + { + Template: mcp.ResourceTemplate{Name: "needs_flag", URITemplate: "uri2"}, + Toolset: testToolsetMetadata("toolset1"), + FeatureFlagEnable: "my_feature", + }, + } + + // Without checker, resource with enable flag should be excluded + reg := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).Build() + available := reg.AvailableResourceTemplates(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 resource without checker, got %d", len(available)) + } + + // With checker returning true, both should be included + checker := func(_ context.Context, _ string) (bool, error) { return true, nil } + regWithChecker := NewBuilder().SetResources(resources).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailableResourceTemplates(context.Background())) != 2 { + t.Errorf("Expected 2 resources with checker, got %d", len(regWithChecker.AvailableResourceTemplates(context.Background()))) + } +} + +func TestFeatureFlagPrompts(t *testing.T) { + prompts := []ServerPrompt{ + mockPrompt("always_available", "toolset1"), + { + Prompt: mcp.Prompt{Name: "needs_flag"}, + Toolset: testToolsetMetadata("toolset1"), + FeatureFlagEnable: "my_feature", + }, + } + + // Without checker, prompt with enable flag should be excluded + reg := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).Build() + available := reg.AvailablePrompts(context.Background()) + if len(available) != 1 { + t.Fatalf("Expected 1 prompt without checker, got %d", len(available)) + } + + // With checker returning true, both should be included + checker := func(_ context.Context, _ string) (bool, error) { return true, nil } + regWithChecker := NewBuilder().SetPrompts(prompts).WithToolsets([]string{"all"}).WithFeatureChecker(checker).Build() + if len(regWithChecker.AvailablePrompts(context.Background())) != 2 { + t.Errorf("Expected 2 prompts with checker, got %d", len(regWithChecker.AvailablePrompts(context.Background()))) + } +} + +func TestServerToolHasHandler(t *testing.T) { + // Tool with handler + toolWithHandler := mockTool("has_handler", "toolset1", true) + if !toolWithHandler.HasHandler() { + t.Error("Expected HasHandler() to return true for tool with handler") + } + + // Tool without handler + toolWithoutHandler := ServerTool{ + Tool: mcp.Tool{Name: "no_handler"}, + Toolset: testToolsetMetadata("toolset1"), + } + if toolWithoutHandler.HasHandler() { + t.Error("Expected HasHandler() to return false for tool without handler") + } +} + +func TestServerToolHandlerPanicOnNil(t *testing.T) { + tool := ServerTool{ + Tool: mcp.Tool{Name: "no_handler"}, + Toolset: testToolsetMetadata("toolset1"), + } + + defer func() { + if r := recover(); r == nil { + t.Error("Expected Handler() to panic when HandlerFunc is nil") + } + }() + + tool.Handler(nil) +} + +// Tests for Enabled function on ServerTool +func TestServerToolEnabled(t *testing.T) { + tests := []struct { + name string + enabledFunc func(ctx context.Context) (bool, error) + expectedCount int + expectInResult bool + }{ + { + name: "nil Enabled function - tool included", + enabledFunc: nil, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns true - tool included", + enabledFunc: func(_ context.Context) (bool, error) { + return true, nil + }, + expectedCount: 1, + expectInResult: true, + }, + { + name: "Enabled returns false - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, nil + }, + expectedCount: 0, + expectInResult: false, + }, + { + name: "Enabled returns error - tool excluded", + enabledFunc: func(_ context.Context) (bool, error) { + return false, fmt.Errorf("simulated error") + }, + expectedCount: 0, + expectInResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = tt.enabledFunc + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + available := reg.AvailableTools(context.Background()) + + if len(available) != tt.expectedCount { + t.Errorf("Expected %d tools, got %d", tt.expectedCount, len(available)) + } + + found := false + for _, t := range available { + if t.Tool.Name == "test_tool" { + found = true + break + } + } + if found != tt.expectInResult { + t.Errorf("Expected tool in result: %v, got: %v", tt.expectInResult, found) + } + }) + } +} + +func TestServerToolEnabledWithContext(t *testing.T) { + type contextKey string + const userKey contextKey = "user" + + // Tool that checks context for user + tool := mockTool("context_aware_tool", "toolset1", true) + tool.Enabled = func(ctx context.Context) (bool, error) { + user := ctx.Value(userKey) + return user != nil && user.(string) == "authorized", nil + } + + reg := NewBuilder().SetTools([]ServerTool{tool}).WithToolsets([]string{"all"}).Build() + + // Without user in context - tool should be excluded + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools without user, got %d", len(available)) + } + + // With authorized user - tool should be included + ctxWithUser := context.WithValue(context.Background(), userKey, "authorized") + availableWithUser := reg.AvailableTools(ctxWithUser) + if len(availableWithUser) != 1 { + t.Errorf("Expected 1 tool with authorized user, got %d", len(availableWithUser)) + } + + // With unauthorized user - tool should be excluded + ctxWithBadUser := context.WithValue(context.Background(), userKey, "unauthorized") + availableWithBadUser := reg.AvailableTools(ctxWithBadUser) + if len(availableWithBadUser) != 0 { + t.Errorf("Expected 0 tools with unauthorized user, got %d", len(availableWithBadUser)) + } +} + +// Tests for WithFilter builder method +func TestBuilderWithFilter(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + } + + // Filter that excludes tool2 + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after filter, got %d", len(available)) + } + + for _, tool := range available { + if tool.Tool.Name == "tool2" { + t.Error("tool2 should have been filtered out") + } + } +} + +func TestBuilderWithMultipleFilters(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + mockTool("tool3", "toolset1", true), + mockTool("tool4", "toolset1", true), + } + + // First filter excludes tool2 + filter1 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool2", nil + } + + // Second filter excludes tool3 + filter2 := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name != "tool3", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter1). + WithFilter(filter2). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 2 { + t.Fatalf("Expected 2 tools after multiple filters, got %d", len(available)) + } + + toolNames := make(map[string]bool) + for _, tool := range available { + toolNames[tool.Tool.Name] = true + } + + if !toolNames["tool1"] || !toolNames["tool4"] { + t.Error("Expected tool1 and tool4 to be available") + } + if toolNames["tool2"] || toolNames["tool3"] { + t.Error("tool2 and tool3 should have been filtered out") + } +} + +func TestBuilderFilterError(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + } + + // Filter that returns an error + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, fmt.Errorf("filter error") + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Errorf("Expected 0 tools when filter returns error, got %d", len(available)) + } +} + +func TestBuilderFilterWithContext(t *testing.T) { + type contextKey string + const scopeKey contextKey = "scope" + + tools := []ServerTool{ + mockTool("public_tool", "toolset1", true), + mockTool("private_tool", "toolset1", true), + } + + // Filter that checks context for scope + filter := func(ctx context.Context, tool *ServerTool) (bool, error) { + scope := ctx.Value(scopeKey) + if scope == "public" && tool.Tool.Name == "private_tool" { + return false, nil + } + return true, nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + // With public scope - private_tool should be excluded + ctxPublic := context.WithValue(context.Background(), scopeKey, "public") + availablePublic := reg.AvailableTools(ctxPublic) + if len(availablePublic) != 1 { + t.Fatalf("Expected 1 tool with public scope, got %d", len(availablePublic)) + } + if availablePublic[0].Tool.Name != "public_tool" { + t.Error("Expected only public_tool to be available") + } + + // With private scope - both tools should be available + ctxPrivate := context.WithValue(context.Background(), scopeKey, "private") + availablePrivate := reg.AvailableTools(ctxPrivate) + if len(availablePrivate) != 2 { + t.Errorf("Expected 2 tools with private scope, got %d", len(availablePrivate)) + } +} + +// Tests for interaction between Enabled, feature flags, and filters +func TestEnabledAndFeatureFlagInteraction(t *testing.T) { + // Tool with both Enabled function and feature flag + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Feature flag not enabled - tool should be excluded despite Enabled returning true + reg1 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + Build() + available1 := reg1.AvailableTools(context.Background()) + if len(available1) != 0 { + t.Error("Tool should be excluded when feature flag is not enabled") + } + + // Feature flag enabled - tool should be included + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 1 { + t.Error("Tool should be included when both Enabled and feature flag pass") + } + + // Enabled returns false - tool should be excluded despite feature flag + tool.Enabled = func(_ context.Context) (bool, error) { + return false, nil + } + reg3 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + Build() + available3 := reg3.AvailableTools(context.Background()) + if len(available3) != 0 { + t.Error("Tool should be excluded when Enabled returns false") + } +} + +func TestEnabledAndBuilderFilterInteraction(t *testing.T) { + tool := mockTool("test_tool", "toolset1", true) + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + // Filter that excludes the tool + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 0 { + t.Error("Tool should be excluded when filter returns false, despite Enabled returning true") + } +} + +func TestAllFiltersInteraction(t *testing.T) { + // Tool with Enabled, feature flag, and subject to builder filter + tool := mockToolWithFlags("complex_tool", "toolset1", true, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + return true, nil + } + + checker := func(_ context.Context, flag string) (bool, error) { + return flag == "my_feature", nil + } + + // All conditions pass - tool should be included + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + available := reg.AvailableTools(context.Background()) + if len(available) != 1 { + t.Error("Tool should be included when all filters pass") + } + + // Change filter to return false - tool should be excluded + filterFalse := func(_ context.Context, _ *ServerTool) (bool, error) { + return false, nil + } + + reg2 := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithFeatureChecker(checker). + WithFilter(filterFalse). + Build() + + available2 := reg2.AvailableTools(context.Background()) + if len(available2) != 0 { + t.Error("Tool should be excluded when any filter fails") + } +} + +// Test FilteredTools method +func TestFilteredTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", true), + } + + filter := func(_ context.Context, tool *ServerTool) (bool, error) { + return tool.Tool.Name == "tool1", nil + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"all"}). + WithFilter(filter). + Build() + + filtered, err := reg.FilteredTools(context.Background()) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + if len(filtered) != 1 { + t.Fatalf("Expected 1 filtered tool, got %d", len(filtered)) + } + + if filtered[0].Tool.Name != "tool1" { + t.Errorf("Expected tool1, got %s", filtered[0].Tool.Name) + } +} + +func TestFilteredToolsMatchesAvailableTools(t *testing.T) { + tools := []ServerTool{ + mockTool("tool1", "toolset1", true), + mockTool("tool2", "toolset1", false), + mockTool("tool3", "toolset2", true), + } + + reg := NewBuilder(). + SetTools(tools). + WithToolsets([]string{"toolset1"}). + WithReadOnly(true). + Build() + + ctx := context.Background() + filtered, err := reg.FilteredTools(ctx) + if err != nil { + t.Fatalf("FilteredTools returned error: %v", err) + } + + available := reg.AvailableTools(ctx) + + // Both methods should return the same results + if len(filtered) != len(available) { + t.Errorf("FilteredTools and AvailableTools returned different counts: %d vs %d", + len(filtered), len(available)) + } + + for i := range filtered { + if filtered[i].Tool.Name != available[i].Tool.Name { + t.Errorf("Tool at index %d differs: FilteredTools=%s, AvailableTools=%s", + i, filtered[i].Tool.Name, available[i].Tool.Name) + } + } +} + +func TestFilteringOrder(t *testing.T) { + // Test that filters are applied in the correct order: + // 1. Tool.Enabled + // 2. Feature flags + // 3. Read-only + // 4. Builder filters + // 5. Toolset/additional tools + + callOrder := []string{} + + tool := mockToolWithFlags("test_tool", "toolset1", false, "my_feature", "") + tool.Enabled = func(_ context.Context) (bool, error) { + callOrder = append(callOrder, "Enabled") + return true, nil + } + + filter := func(_ context.Context, _ *ServerTool) (bool, error) { + callOrder = append(callOrder, "Filter") + return true, nil + } + + checker := func(_ context.Context, _ string) (bool, error) { + callOrder = append(callOrder, "FeatureFlag") + return true, nil + } + + reg := NewBuilder(). + SetTools([]ServerTool{tool}). + WithToolsets([]string{"all"}). + WithReadOnly(true). // This will exclude the tool (it's not read-only) + WithFeatureChecker(checker). + WithFilter(filter). + Build() + + _ = reg.AvailableTools(context.Background()) + + // Expected order: Enabled, FeatureFlag, ReadOnly (stops here because it's write tool) + expectedOrder := []string{"Enabled", "FeatureFlag"} + if len(callOrder) != len(expectedOrder) { + t.Errorf("Expected %d checks, got %d: %v", len(expectedOrder), len(callOrder), callOrder) + } + + for i, expected := range expectedOrder { + if i >= len(callOrder) || callOrder[i] != expected { + t.Errorf("At position %d: expected %s, got %v", i, expected, callOrder) + } + } +} diff --git a/pkg/inventory/resources.go b/pkg/inventory/resources.go new file mode 100644 index 000000000..6de037d58 --- /dev/null +++ b/pkg/inventory/resources.go @@ -0,0 +1,48 @@ +package inventory + +import "github.com/modelcontextprotocol/go-sdk/mcp" + +// ResourceHandlerFunc is a function that takes dependencies and returns an MCP resource handler. +// This allows resources to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +type ResourceHandlerFunc func(deps any) mcp.ResourceHandler + +// ServerResourceTemplate pairs a resource template with its toolset metadata. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + // HandlerFunc generates the handler when given dependencies. + // This allows resources to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc ResourceHandlerFunc + // Toolset identifies which toolset this resource belongs to + Toolset ToolsetMetadata + // FeatureFlagEnable specifies a feature flag that must be enabled for this resource + // to be available. If set and the flag is not enabled, the resource is omitted. + FeatureFlagEnable string + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this resource + // to be omitted. Used to disable resources when a feature flag is on. + FeatureFlagDisable string +} + +// HasHandler returns true if this resource has a handler function. +func (sr *ServerResourceTemplate) HasHandler() bool { + return sr.HandlerFunc != nil +} + +// Handler returns a resource handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all resources should have handlers. +func (sr *ServerResourceTemplate) Handler(deps any) mcp.ResourceHandler { + if sr.HandlerFunc == nil { + panic("HandlerFunc is nil for resource: " + sr.Template.Name) + } + return sr.HandlerFunc(deps) +} + +// NewServerResourceTemplate creates a new ServerResourceTemplate with toolset metadata. +func NewServerResourceTemplate(toolset ToolsetMetadata, resourceTemplate mcp.ResourceTemplate, handlerFn ResourceHandlerFunc) ServerResourceTemplate { + return ServerResourceTemplate{ + Template: resourceTemplate, + HandlerFunc: handlerFn, + Toolset: toolset, + } +} diff --git a/pkg/inventory/server_tool.go b/pkg/inventory/server_tool.go new file mode 100644 index 000000000..362ee2643 --- /dev/null +++ b/pkg/inventory/server_tool.go @@ -0,0 +1,181 @@ +package inventory + +import ( + "context" + "encoding/json" + + "github.com/github/github-mcp-server/pkg/octicons" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// HandlerFunc is a function that takes dependencies and returns an MCP tool handler. +// This allows tools to be defined statically while their handlers are generated +// on-demand with the appropriate dependencies. +// The deps parameter is typed as `any` to avoid circular dependencies - callers +// should define their own typed dependencies struct and type-assert as needed. +type HandlerFunc func(deps any) mcp.ToolHandler + +// ToolsetID is a unique identifier for a toolset. +// Using a distinct type provides compile-time type safety. +type ToolsetID string + +// ToolsetMetadata contains metadata about the toolset a tool belongs to. +type ToolsetMetadata struct { + // ID is the unique identifier for the toolset (e.g., "repos", "issues") + ID ToolsetID + // Description provides a human-readable description of the toolset + Description string + // Default indicates this toolset should be enabled by default + Default bool + // Icon is the name of the Octicon to use for tools in this toolset. + // Use the base name without size suffix, e.g., "repo" not "repo-16". + // See https://primer.style/foundations/icons for available icons. + Icon string +} + +// Icons returns MCP Icon objects for this toolset, or nil if no icon is set. +// Icons are provided in both 16x16 and 24x24 sizes. +func (tm ToolsetMetadata) Icons() []mcp.Icon { + return octicons.Icons(tm.Icon) +} + +// ServerTool represents an MCP tool with metadata and a handler generator function. +// The tool definition is static, while the handler is generated on-demand +// when the tool is registered with a server. +// Tools are now self-describing with their toolset membership and read-only status +// derived from the Tool.Annotations.ReadOnlyHint field. +type ServerTool struct { + // Tool is the MCP tool definition containing name, description, schema, etc. + Tool mcp.Tool + + // Toolset contains metadata about which toolset this tool belongs to. + Toolset ToolsetMetadata + + // HandlerFunc generates the handler when given dependencies. + // This allows tools to be passed around without handlers being set up, + // and handlers are only created when needed. + HandlerFunc HandlerFunc + + // FeatureFlagEnable specifies a feature flag that must be enabled for this tool + // to be available. If set and the flag is not enabled, the tool is omitted. + FeatureFlagEnable string + + // FeatureFlagDisable specifies a feature flag that, when enabled, causes this tool + // to be omitted. Used to disable tools when a feature flag is on. + FeatureFlagDisable string + + // Enabled is an optional function called at build/filter time to determine + // if this tool should be available. If nil, the tool is considered enabled + // (subject to FeatureFlagEnable/FeatureFlagDisable checks). + // The context carries request-scoped information for the consumer to use. + // Returns (enabled, error). On error, the tool should be treated as disabled. + Enabled func(ctx context.Context) (bool, error) +} + +// IsReadOnly returns true if this tool is marked as read-only via annotations. +func (st *ServerTool) IsReadOnly() bool { + return st.Tool.Annotations != nil && st.Tool.Annotations.ReadOnlyHint +} + +// HasHandler returns true if this tool has a handler function. +func (st *ServerTool) HasHandler() bool { + return st.HandlerFunc != nil +} + +// Handler returns a tool handler by calling HandlerFunc with the given dependencies. +// Panics if HandlerFunc is nil - all tools should have handlers. +func (st *ServerTool) Handler(deps any) mcp.ToolHandler { + if st.HandlerFunc == nil { + panic("HandlerFunc is nil for tool: " + st.Tool.Name) + } + return st.HandlerFunc(deps) +} + +// RegisterFunc registers the tool with the server using the provided dependencies. +// Icons are automatically applied from the toolset metadata if not already set. +// A shallow copy of the tool is made to avoid mutating the original ServerTool. +// Panics if the tool has no handler - all tools should have handlers. +func (st *ServerTool) RegisterFunc(s *mcp.Server, deps any) { + handler := st.Handler(deps) // This will panic if HandlerFunc is nil + // Make a shallow copy of the tool to avoid mutating the original + toolCopy := st.Tool + // Apply icons from toolset metadata if tool doesn't have icons set + if len(toolCopy.Icons) == 0 { + toolCopy.Icons = st.Toolset.Icons() + } + s.AddTool(&toolCopy, handler) +} + +// NewServerTool creates a ServerTool from a tool definition, toolset metadata, and a typed handler function. +// The handler function takes dependencies (as any) and returns a typed handler. +// Callers should type-assert deps to their typed dependencies struct. +// +// Deprecated: This creates closures at registration time. For better performance in +// per-request server scenarios, use NewServerToolWithContextHandler instead. +func NewServerTool[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandlerFor[In, Out]) ServerTool { + return ServerTool{ + Tool: tool, + Toolset: toolset, + HandlerFunc: func(deps any) mcp.ToolHandler { + typedHandler := handlerFn(deps) + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var arguments In + if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { + return nil, err + } + resp, _, err := typedHandler(ctx, req, arguments) + return resp, err + } + }, + } +} + +// NewServerToolWithContextHandler creates a ServerTool with a handler that receives deps via context. +// This is the preferred approach for tools because it doesn't create closures at registration time, +// which is critical for performance in servers that create a new instance per request. +// +// The handler function is stored directly without wrapping in a deps closure. +// Dependencies should be injected into context before calling tool handlers. +func NewServerToolWithContextHandler[In any, Out any](tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandlerFor[In, Out]) ServerTool { + return ServerTool{ + Tool: tool, + Toolset: toolset, + // HandlerFunc ignores deps - deps are retrieved from context at call time + HandlerFunc: func(_ any) mcp.ToolHandler { + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var arguments In + if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { + return nil, err + } + resp, _, err := handler(ctx, req, arguments) + return resp, err + } + }, + } +} + +// NewServerToolFromHandler creates a ServerTool from a tool definition, toolset metadata, and a raw handler function. +// Use this when you have a handler that already conforms to mcp.ToolHandler. +// +// Deprecated: This creates closures at registration time. For better performance in +// per-request server scenarios, use NewServerToolWithRawContextHandler instead. +func NewServerToolFromHandler(tool mcp.Tool, toolset ToolsetMetadata, handlerFn func(deps any) mcp.ToolHandler) ServerTool { + return ServerTool{Tool: tool, Toolset: toolset, HandlerFunc: handlerFn} +} + +// NewServerToolWithRawContextHandler creates a ServerTool with a raw handler that receives deps via context. +// This is the preferred approach for tools that use mcp.ToolHandler directly because it doesn't +// create closures at registration time. +// +// The handler function is stored directly without wrapping in a deps closure. +// Dependencies should be injected into context before calling tool handlers. +func NewServerToolWithRawContextHandler(tool mcp.Tool, toolset ToolsetMetadata, handler mcp.ToolHandler) ServerTool { + return ServerTool{ + Tool: tool, + Toolset: toolset, + // HandlerFunc ignores deps - deps are retrieved from context at call time + HandlerFunc: func(_ any) mcp.ToolHandler { + return handler + }, + } +} diff --git a/pkg/octicons/icons/apps-dark.png b/pkg/octicons/icons/apps-dark.png new file mode 100644 index 000000000..607468c85 Binary files /dev/null and b/pkg/octicons/icons/apps-dark.png differ diff --git a/pkg/octicons/icons/apps-light.png b/pkg/octicons/icons/apps-light.png new file mode 100644 index 000000000..c6328612b Binary files /dev/null and b/pkg/octicons/icons/apps-light.png differ diff --git a/pkg/octicons/icons/beaker-dark.png b/pkg/octicons/icons/beaker-dark.png new file mode 100644 index 000000000..b30a95922 Binary files /dev/null and b/pkg/octicons/icons/beaker-dark.png differ diff --git a/pkg/octicons/icons/beaker-light.png b/pkg/octicons/icons/beaker-light.png new file mode 100644 index 000000000..576d170bd Binary files /dev/null and b/pkg/octicons/icons/beaker-light.png differ diff --git a/pkg/octicons/icons/bell-dark.png b/pkg/octicons/icons/bell-dark.png new file mode 100644 index 000000000..a462a8d4c Binary files /dev/null and b/pkg/octicons/icons/bell-dark.png differ diff --git a/pkg/octicons/icons/bell-light.png b/pkg/octicons/icons/bell-light.png new file mode 100644 index 000000000..778215804 Binary files /dev/null and b/pkg/octicons/icons/bell-light.png differ diff --git a/pkg/octicons/icons/book-dark.png b/pkg/octicons/icons/book-dark.png new file mode 100644 index 000000000..9658b4f8e Binary files /dev/null and b/pkg/octicons/icons/book-dark.png differ diff --git a/pkg/octicons/icons/book-light.png b/pkg/octicons/icons/book-light.png new file mode 100644 index 000000000..8be91a434 Binary files /dev/null and b/pkg/octicons/icons/book-light.png differ diff --git a/pkg/octicons/icons/check-circle-dark.png b/pkg/octicons/icons/check-circle-dark.png new file mode 100644 index 000000000..1fad22056 Binary files /dev/null and b/pkg/octicons/icons/check-circle-dark.png differ diff --git a/pkg/octicons/icons/check-circle-light.png b/pkg/octicons/icons/check-circle-light.png new file mode 100644 index 000000000..75916ed24 Binary files /dev/null and b/pkg/octicons/icons/check-circle-light.png differ diff --git a/pkg/octicons/icons/codescan-dark.png b/pkg/octicons/icons/codescan-dark.png new file mode 100644 index 000000000..7bedddfe4 Binary files /dev/null and b/pkg/octicons/icons/codescan-dark.png differ diff --git a/pkg/octicons/icons/codescan-light.png b/pkg/octicons/icons/codescan-light.png new file mode 100644 index 000000000..cb3fb7545 Binary files /dev/null and b/pkg/octicons/icons/codescan-light.png differ diff --git a/pkg/octicons/icons/comment-discussion-dark.png b/pkg/octicons/icons/comment-discussion-dark.png new file mode 100644 index 000000000..6b7eeb6ef Binary files /dev/null and b/pkg/octicons/icons/comment-discussion-dark.png differ diff --git a/pkg/octicons/icons/comment-discussion-light.png b/pkg/octicons/icons/comment-discussion-light.png new file mode 100644 index 000000000..64ee5f0ca Binary files /dev/null and b/pkg/octicons/icons/comment-discussion-light.png differ diff --git a/pkg/octicons/icons/copilot-dark.png b/pkg/octicons/icons/copilot-dark.png new file mode 100644 index 000000000..2188a1bca Binary files /dev/null and b/pkg/octicons/icons/copilot-dark.png differ diff --git a/pkg/octicons/icons/copilot-light.png b/pkg/octicons/icons/copilot-light.png new file mode 100644 index 000000000..4e83af015 Binary files /dev/null and b/pkg/octicons/icons/copilot-light.png differ diff --git a/pkg/octicons/icons/dependabot-dark.png b/pkg/octicons/icons/dependabot-dark.png new file mode 100644 index 000000000..39d41c7d1 Binary files /dev/null and b/pkg/octicons/icons/dependabot-dark.png differ diff --git a/pkg/octicons/icons/dependabot-light.png b/pkg/octicons/icons/dependabot-light.png new file mode 100644 index 000000000..5dfa5b920 Binary files /dev/null and b/pkg/octicons/icons/dependabot-light.png differ diff --git a/pkg/octicons/icons/file-dark.png b/pkg/octicons/icons/file-dark.png new file mode 100644 index 000000000..213069bc9 Binary files /dev/null and b/pkg/octicons/icons/file-dark.png differ diff --git a/pkg/octicons/icons/file-light.png b/pkg/octicons/icons/file-light.png new file mode 100644 index 000000000..8a00ffc25 Binary files /dev/null and b/pkg/octicons/icons/file-light.png differ diff --git a/pkg/octicons/icons/git-branch-dark.png b/pkg/octicons/icons/git-branch-dark.png new file mode 100644 index 000000000..3c4756dfd Binary files /dev/null and b/pkg/octicons/icons/git-branch-dark.png differ diff --git a/pkg/octicons/icons/git-branch-light.png b/pkg/octicons/icons/git-branch-light.png new file mode 100644 index 000000000..42eb954de Binary files /dev/null and b/pkg/octicons/icons/git-branch-light.png differ diff --git a/pkg/octicons/icons/git-commit-dark.png b/pkg/octicons/icons/git-commit-dark.png new file mode 100644 index 000000000..69f72a47b Binary files /dev/null and b/pkg/octicons/icons/git-commit-dark.png differ diff --git a/pkg/octicons/icons/git-commit-light.png b/pkg/octicons/icons/git-commit-light.png new file mode 100644 index 000000000..8011fbcb6 Binary files /dev/null and b/pkg/octicons/icons/git-commit-light.png differ diff --git a/pkg/octicons/icons/git-merge-dark.png b/pkg/octicons/icons/git-merge-dark.png new file mode 100644 index 000000000..283ee665c Binary files /dev/null and b/pkg/octicons/icons/git-merge-dark.png differ diff --git a/pkg/octicons/icons/git-merge-light.png b/pkg/octicons/icons/git-merge-light.png new file mode 100644 index 000000000..e208a09f5 Binary files /dev/null and b/pkg/octicons/icons/git-merge-light.png differ diff --git a/pkg/octicons/icons/git-pull-request-dark.png b/pkg/octicons/icons/git-pull-request-dark.png new file mode 100644 index 000000000..bdbc8bd27 Binary files /dev/null and b/pkg/octicons/icons/git-pull-request-dark.png differ diff --git a/pkg/octicons/icons/git-pull-request-light.png b/pkg/octicons/icons/git-pull-request-light.png new file mode 100644 index 000000000..616ece21a Binary files /dev/null and b/pkg/octicons/icons/git-pull-request-light.png differ diff --git a/pkg/octicons/icons/issue-opened-dark.png b/pkg/octicons/icons/issue-opened-dark.png new file mode 100644 index 000000000..d71f5ec2c Binary files /dev/null and b/pkg/octicons/icons/issue-opened-dark.png differ diff --git a/pkg/octicons/icons/issue-opened-light.png b/pkg/octicons/icons/issue-opened-light.png new file mode 100644 index 000000000..123212013 Binary files /dev/null and b/pkg/octicons/icons/issue-opened-light.png differ diff --git a/pkg/octicons/icons/logo-gist-dark.png b/pkg/octicons/icons/logo-gist-dark.png new file mode 100644 index 000000000..8929edee4 Binary files /dev/null and b/pkg/octicons/icons/logo-gist-dark.png differ diff --git a/pkg/octicons/icons/logo-gist-light.png b/pkg/octicons/icons/logo-gist-light.png new file mode 100644 index 000000000..364ef951a Binary files /dev/null and b/pkg/octicons/icons/logo-gist-light.png differ diff --git a/pkg/octicons/icons/mark-github-dark.png b/pkg/octicons/icons/mark-github-dark.png new file mode 100644 index 000000000..57f11abfd Binary files /dev/null and b/pkg/octicons/icons/mark-github-dark.png differ diff --git a/pkg/octicons/icons/mark-github-light.png b/pkg/octicons/icons/mark-github-light.png new file mode 100644 index 000000000..7d7ffd123 Binary files /dev/null and b/pkg/octicons/icons/mark-github-light.png differ diff --git a/pkg/octicons/icons/organization-dark.png b/pkg/octicons/icons/organization-dark.png new file mode 100644 index 000000000..6ad3feaf8 Binary files /dev/null and b/pkg/octicons/icons/organization-dark.png differ diff --git a/pkg/octicons/icons/organization-light.png b/pkg/octicons/icons/organization-light.png new file mode 100644 index 000000000..e504febe3 Binary files /dev/null and b/pkg/octicons/icons/organization-light.png differ diff --git a/pkg/octicons/icons/people-dark.png b/pkg/octicons/icons/people-dark.png new file mode 100644 index 000000000..2dd60bab6 Binary files /dev/null and b/pkg/octicons/icons/people-dark.png differ diff --git a/pkg/octicons/icons/people-light.png b/pkg/octicons/icons/people-light.png new file mode 100644 index 000000000..5dc0fb62f Binary files /dev/null and b/pkg/octicons/icons/people-light.png differ diff --git a/pkg/octicons/icons/person-dark.png b/pkg/octicons/icons/person-dark.png new file mode 100644 index 000000000..c0fdf6cad Binary files /dev/null and b/pkg/octicons/icons/person-dark.png differ diff --git a/pkg/octicons/icons/person-light.png b/pkg/octicons/icons/person-light.png new file mode 100644 index 000000000..db1368350 Binary files /dev/null and b/pkg/octicons/icons/person-light.png differ diff --git a/pkg/octicons/icons/project-dark.png b/pkg/octicons/icons/project-dark.png new file mode 100644 index 000000000..273d7ba5a Binary files /dev/null and b/pkg/octicons/icons/project-dark.png differ diff --git a/pkg/octicons/icons/project-light.png b/pkg/octicons/icons/project-light.png new file mode 100644 index 000000000..51e232d29 Binary files /dev/null and b/pkg/octicons/icons/project-light.png differ diff --git a/pkg/octicons/icons/repo-dark.png b/pkg/octicons/icons/repo-dark.png new file mode 100644 index 000000000..81bbeac25 Binary files /dev/null and b/pkg/octicons/icons/repo-dark.png differ diff --git a/pkg/octicons/icons/repo-forked-dark.png b/pkg/octicons/icons/repo-forked-dark.png new file mode 100644 index 000000000..434d5d287 Binary files /dev/null and b/pkg/octicons/icons/repo-forked-dark.png differ diff --git a/pkg/octicons/icons/repo-forked-light.png b/pkg/octicons/icons/repo-forked-light.png new file mode 100644 index 000000000..bf41f6e27 Binary files /dev/null and b/pkg/octicons/icons/repo-forked-light.png differ diff --git a/pkg/octicons/icons/repo-light.png b/pkg/octicons/icons/repo-light.png new file mode 100644 index 000000000..185a05438 Binary files /dev/null and b/pkg/octicons/icons/repo-light.png differ diff --git a/pkg/octicons/icons/shield-dark.png b/pkg/octicons/icons/shield-dark.png new file mode 100644 index 000000000..cf61060de Binary files /dev/null and b/pkg/octicons/icons/shield-dark.png differ diff --git a/pkg/octicons/icons/shield-light.png b/pkg/octicons/icons/shield-light.png new file mode 100644 index 000000000..5a11004ee Binary files /dev/null and b/pkg/octicons/icons/shield-light.png differ diff --git a/pkg/octicons/icons/shield-lock-dark.png b/pkg/octicons/icons/shield-lock-dark.png new file mode 100644 index 000000000..0abf4ad3f Binary files /dev/null and b/pkg/octicons/icons/shield-lock-dark.png differ diff --git a/pkg/octicons/icons/shield-lock-light.png b/pkg/octicons/icons/shield-lock-light.png new file mode 100644 index 000000000..ae6e8cc1b Binary files /dev/null and b/pkg/octicons/icons/shield-lock-light.png differ diff --git a/pkg/octicons/icons/star-dark.png b/pkg/octicons/icons/star-dark.png new file mode 100644 index 000000000..9156c9b28 Binary files /dev/null and b/pkg/octicons/icons/star-dark.png differ diff --git a/pkg/octicons/icons/star-fill-dark.png b/pkg/octicons/icons/star-fill-dark.png new file mode 100644 index 000000000..3b19107d7 Binary files /dev/null and b/pkg/octicons/icons/star-fill-dark.png differ diff --git a/pkg/octicons/icons/star-fill-light.png b/pkg/octicons/icons/star-fill-light.png new file mode 100644 index 000000000..fd3621016 Binary files /dev/null and b/pkg/octicons/icons/star-fill-light.png differ diff --git a/pkg/octicons/icons/star-light.png b/pkg/octicons/icons/star-light.png new file mode 100644 index 000000000..54372238e Binary files /dev/null and b/pkg/octicons/icons/star-light.png differ diff --git a/pkg/octicons/icons/tag-dark.png b/pkg/octicons/icons/tag-dark.png new file mode 100644 index 000000000..00d0a56ff Binary files /dev/null and b/pkg/octicons/icons/tag-dark.png differ diff --git a/pkg/octicons/icons/tag-light.png b/pkg/octicons/icons/tag-light.png new file mode 100644 index 000000000..cead91700 Binary files /dev/null and b/pkg/octicons/icons/tag-light.png differ diff --git a/pkg/octicons/icons/tools-dark.png b/pkg/octicons/icons/tools-dark.png new file mode 100644 index 000000000..2cf9080c7 Binary files /dev/null and b/pkg/octicons/icons/tools-dark.png differ diff --git a/pkg/octicons/icons/tools-light.png b/pkg/octicons/icons/tools-light.png new file mode 100644 index 000000000..59cf00d11 Binary files /dev/null and b/pkg/octicons/icons/tools-light.png differ diff --git a/pkg/octicons/icons/workflow-dark.png b/pkg/octicons/icons/workflow-dark.png new file mode 100644 index 000000000..f1339416f Binary files /dev/null and b/pkg/octicons/icons/workflow-dark.png differ diff --git a/pkg/octicons/icons/workflow-light.png b/pkg/octicons/icons/workflow-light.png new file mode 100644 index 000000000..6930f846d Binary files /dev/null and b/pkg/octicons/icons/workflow-light.png differ diff --git a/pkg/octicons/octicons.go b/pkg/octicons/octicons.go new file mode 100644 index 000000000..c6b92c47b --- /dev/null +++ b/pkg/octicons/octicons.go @@ -0,0 +1,83 @@ +// Package octicons provides helpers for working with GitHub Octicon icons. +// See https://primer.style/foundations/icons for available icons. +package octicons + +import ( + "bufio" + "embed" + "encoding/base64" + "fmt" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +//go:embed icons/*.png +var iconsFS embed.FS + +//go:embed required_icons.txt +var requiredIconsTxt string + +// RequiredIcons returns the list of icon names from required_icons.txt. +// This is the single source of truth for which icons should be embedded. +func RequiredIcons() []string { + var icons []string + scanner := bufio.NewScanner(strings.NewReader(requiredIconsTxt)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + // Skip empty lines and comments + if line == "" || strings.HasPrefix(line, "#") { + continue + } + icons = append(icons, line) + } + return icons +} + +// Theme represents the color theme of an icon. +type Theme string + +const ( + // ThemeLight is for light backgrounds (dark/black icons). + ThemeLight Theme = "light" + // ThemeDark is for dark backgrounds (light/white icons). + ThemeDark Theme = "dark" +) + +// DataURI returns a data URI for the embedded Octicon PNG. +// The theme parameter specifies which variant to use: +// - ThemeLight: dark icons for light backgrounds +// - ThemeDark: light icons for dark backgrounds +// If the icon is not found in the embedded filesystem, it returns an empty string. +func DataURI(name string, theme Theme) string { + filename := fmt.Sprintf("icons/%s-%s.png", name, theme) + data, err := iconsFS.ReadFile(filename) + if err != nil { + return "" + } + return "data:image/png;base64," + base64.StdEncoding.EncodeToString(data) +} + +// Icons returns MCP Icon objects for the given octicon name in light and dark themes. +// Icons are embedded as 24x24 PNG data URIs for offline use and faster loading. +// The name should be the base octicon name without size suffix (e.g., "repo" not "repo-16"). +// See https://primer.style/foundations/icons for available icons. +func Icons(name string) []mcp.Icon { + if name == "" { + return nil + } + return []mcp.Icon{ + { + Source: DataURI(name, ThemeLight), + MIMEType: "image/png", + Sizes: []string{"24x24"}, + Theme: string(ThemeLight), + }, + { + Source: DataURI(name, ThemeDark), + MIMEType: "image/png", + Sizes: []string{"24x24"}, + Theme: string(ThemeDark), + }, + } +} diff --git a/pkg/octicons/octicons_test.go b/pkg/octicons/octicons_test.go new file mode 100644 index 000000000..f60f7192e --- /dev/null +++ b/pkg/octicons/octicons_test.go @@ -0,0 +1,118 @@ +package octicons + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDataURI(t *testing.T) { + tests := []struct { + name string + icon string + theme Theme + wantDataURI bool + wantEmpty bool + }{ + { + name: "light theme icon returns data URI", + icon: "repo", + theme: ThemeLight, + wantDataURI: true, + wantEmpty: false, + }, + { + name: "dark theme icon returns data URI", + icon: "repo", + theme: ThemeDark, + wantDataURI: true, + wantEmpty: false, + }, + { + name: "non-embedded icon returns empty string", + icon: "nonexistent-icon", + theme: ThemeLight, + wantDataURI: false, + wantEmpty: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := DataURI(tc.icon, tc.theme) + if tc.wantDataURI { + assert.True(t, strings.HasPrefix(result, "data:image/png;base64,"), "expected data URI prefix") + assert.NotContains(t, result, "https://") + } + if tc.wantEmpty { + assert.Empty(t, result, "expected empty string for non-embedded icon") + } + }) + } +} + +func TestIcons(t *testing.T) { + tests := []struct { + name string + icon string + wantNil bool + wantCount int + }{ + { + name: "valid embedded icon returns light and dark variants", + icon: "repo", + wantNil: false, + wantCount: 2, + }, + { + name: "empty name returns nil", + icon: "", + wantNil: true, + wantCount: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := Icons(tc.icon) + if tc.wantNil { + assert.Nil(t, result) + return + } + assert.NotNil(t, result) + assert.Len(t, result, tc.wantCount) + + // Verify first icon is light theme + assert.Equal(t, DataURI(tc.icon, ThemeLight), result[0].Source) + assert.Equal(t, "image/png", result[0].MIMEType) + assert.Equal(t, []string{"24x24"}, result[0].Sizes) + assert.Equal(t, "light", result[0].Theme) + + // Verify second icon is dark theme + assert.Equal(t, DataURI(tc.icon, ThemeDark), result[1].Source) + assert.Equal(t, "image/png", result[1].MIMEType) + assert.Equal(t, []string{"24x24"}, result[1].Sizes) + assert.Equal(t, "dark", result[1].Theme) + }) + } +} + +func TestThemeConstants(t *testing.T) { + assert.Equal(t, Theme("light"), ThemeLight) + assert.Equal(t, Theme("dark"), ThemeDark) +} + +func TestEmbeddedIconsExist(t *testing.T) { + // Test that all required icons from required_icons.txt are properly embedded + // This is the single source of truth for which icons should be available + expectedIcons := RequiredIcons() + for _, icon := range expectedIcons { + t.Run(icon, func(t *testing.T) { + lightURI := DataURI(icon, ThemeLight) + darkURI := DataURI(icon, ThemeDark) + assert.True(t, strings.HasPrefix(lightURI, "data:image/png;base64,"), "light theme icon %s should be embedded", icon) + assert.True(t, strings.HasPrefix(darkURI, "data:image/png;base64,"), "dark theme icon %s should be embedded", icon) + }) + } +} diff --git a/pkg/octicons/required_icons.txt b/pkg/octicons/required_icons.txt new file mode 100644 index 000000000..7911b46eb --- /dev/null +++ b/pkg/octicons/required_icons.txt @@ -0,0 +1,45 @@ +# Required Octicons for the GitHub MCP Server +# This file is the source of truth for icon requirements. +# Used by: +# - script/fetch-icons (to download icons) +# - pkg/octicons/octicons_test.go (to validate icons are embedded) +# - pkg/github/toolset_icons_test.go (to validate toolset icons exist) +# +# Add new icons here when: +# - Adding a new toolset with an icon +# - Adding a new tool that needs a custom icon +# +# Format: one icon name per line (without -24.svg suffix) +# Lines starting with # are comments +# Empty lines are ignored + +apps +beaker +bell +book +check-circle +codescan +comment-discussion +copilot +dependabot +file +git-branch +git-commit +git-merge +git-pull-request +issue-opened +logo-gist +mark-github +organization +people +person +project +repo +repo-forked +shield +shield-lock +star +star-fill +tag +tools +workflow diff --git a/pkg/raw/raw_test.go b/pkg/raw/raw_test.go index 242029c8b..4c4aa33b4 100644 --- a/pkg/raw/raw_test.go +++ b/pkg/raw/raw_test.go @@ -1,22 +1,44 @@ package raw import ( + "bytes" "context" + "io" "net/http" "net/url" + "strings" "testing" "github.com/google/go-github/v79/github" - "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/require" ) +// mockRawTransport is a custom HTTP transport for testing raw content API +type mockRawTransport struct { + statusCode int + contentType string + body string +} + +func (m *mockRawTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Create a response with the configured status and body + resp := &http.Response{ + StatusCode: m.statusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewBufferString(m.body)), + Request: req, + } + if m.contentType != "" { + resp.Header.Set("Content-Type", m.contentType) + } + return resp, nil +} + func TestGetRawContent(t *testing.T) { base, _ := url.Parse("https://raw.example.com/") tests := []struct { name string - pattern mock.EndpointPattern opts *ContentOpts owner, repo, path string statusCode int @@ -25,46 +47,51 @@ func TestGetRawContent(t *testing.T) { expectError string }{ { - name: "HEAD fetch success", - pattern: GetRawReposContentsByOwnerByRepoByPath, - opts: nil, - owner: "octocat", repo: "hello", path: "README.md", + name: "HEAD fetch success", + opts: nil, + owner: "octocat", + repo: "hello", + path: "README.md", statusCode: 200, contentType: "text/plain", body: "# Test file", }, { - name: "branch fetch success", - pattern: GetRawReposContentsByOwnerByRepoByBranchByPath, - opts: &ContentOpts{Ref: "refs/heads/main"}, - owner: "octocat", repo: "hello", path: "README.md", + name: "branch fetch success", + opts: &ContentOpts{Ref: "refs/heads/main"}, + owner: "octocat", + repo: "hello", + path: "README.md", statusCode: 200, contentType: "text/plain", body: "# Test file", }, { - name: "tag fetch success", - pattern: GetRawReposContentsByOwnerByRepoByTagByPath, - opts: &ContentOpts{Ref: "refs/tags/v1.0.0"}, - owner: "octocat", repo: "hello", path: "README.md", + name: "tag fetch success", + opts: &ContentOpts{Ref: "refs/tags/v1.0.0"}, + owner: "octocat", + repo: "hello", + path: "README.md", statusCode: 200, contentType: "text/plain", body: "# Test file", }, { - name: "sha fetch success", - pattern: GetRawReposContentsByOwnerByRepoBySHAByPath, - opts: &ContentOpts{SHA: "abc123"}, - owner: "octocat", repo: "hello", path: "README.md", + name: "sha fetch success", + opts: &ContentOpts{SHA: "abc123"}, + owner: "octocat", + repo: "hello", + path: "README.md", statusCode: 200, contentType: "text/plain", body: "# Test file", }, { - name: "not found", - pattern: GetRawReposContentsByOwnerByRepoByPath, - opts: nil, - owner: "octocat", repo: "hello", path: "notfound.txt", + name: "not found", + opts: nil, + owner: "octocat", + repo: "hello", + path: "notfound.txt", statusCode: 404, contentType: "application/json", body: `{"message": "Not Found"}`, @@ -73,29 +100,33 @@ func TestGetRawContent(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - mockedClient := mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - tc.pattern, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", tc.contentType) - w.WriteHeader(tc.statusCode) - _, err := w.Write([]byte(tc.body)) - require.NoError(t, err) - }), - ), - ) + // Create mock HTTP client with custom transport + mockedClient := &http.Client{ + Transport: &mockRawTransport{ + statusCode: tc.statusCode, + contentType: tc.contentType, + body: tc.body, + }, + } ghClient := github.NewClient(mockedClient) client := NewClient(ghClient, base) resp, err := client.GetRawContent(context.Background(), tc.owner, tc.repo, tc.path, tc.opts) defer func() { _ = resp.Body.Close() }() + if tc.expectError != "" { require.Error(t, err) return } require.NoError(t, err) require.Equal(t, tc.statusCode, resp.StatusCode) + + // Verify the URL was constructed correctly + actualURL := client.URLFromOpts(tc.opts, tc.owner, tc.repo, tc.path) + require.True(t, strings.Contains(actualURL, tc.owner)) + require.True(t, strings.Contains(actualURL, tc.repo)) + require.True(t, strings.Contains(actualURL, tc.path)) }) } } diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go deleted file mode 100644 index d96b5fb50..000000000 --- a/pkg/toolsets/toolsets.go +++ /dev/null @@ -1,385 +0,0 @@ -package toolsets - -import ( - "context" - "encoding/json" - "fmt" - "os" - "strings" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -type ToolsetDoesNotExistError struct { - Name string -} - -func (e *ToolsetDoesNotExistError) Error() string { - return fmt.Sprintf("toolset %s does not exist", e.Name) -} - -func (e *ToolsetDoesNotExistError) Is(target error) bool { - if target == nil { - return false - } - if _, ok := target.(*ToolsetDoesNotExistError); ok { - return true - } - return false -} - -func NewToolsetDoesNotExistError(name string) *ToolsetDoesNotExistError { - return &ToolsetDoesNotExistError{Name: name} -} - -type ServerTool struct { - Tool mcp.Tool - RegisterFunc func(s *mcp.Server) -} - -func NewServerTool[In any, Out any](tool mcp.Tool, handler mcp.ToolHandlerFor[In, Out]) ServerTool { - return ServerTool{Tool: tool, RegisterFunc: func(s *mcp.Server) { - th := func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - var arguments In - if err := json.Unmarshal(req.Params.Arguments, &arguments); err != nil { - return nil, err - } - - resp, _, err := handler(ctx, req, arguments) - - return resp, err - } - - s.AddTool(&tool, th) - }} -} - -type ServerResourceTemplate struct { - Template mcp.ResourceTemplate - Handler mcp.ResourceHandler -} - -func NewServerResourceTemplate(resourceTemplate mcp.ResourceTemplate, handler mcp.ResourceHandler) ServerResourceTemplate { - return ServerResourceTemplate{ - Template: resourceTemplate, - Handler: handler, - } -} - -type ServerPrompt struct { - Prompt mcp.Prompt - Handler mcp.PromptHandler -} - -func NewServerPrompt(prompt mcp.Prompt, handler mcp.PromptHandler) ServerPrompt { - return ServerPrompt{ - Prompt: prompt, - Handler: handler, - } -} - -// Toolset represents a collection of MCP functionality that can be enabled or disabled as a group. -type Toolset struct { - Name string - Description string - Enabled bool - readOnly bool - writeTools []ServerTool - readTools []ServerTool - // resources are not tools, but the community seems to be moving towards namespaces as a broader concept - // and in order to have multiple servers running concurrently, we want to avoid overlapping resources too. - resourceTemplates []ServerResourceTemplate - // prompts are also not tools but are namespaced similarly - prompts []ServerPrompt -} - -func (t *Toolset) GetActiveTools() []ServerTool { - if t.Enabled { - if t.readOnly { - return t.readTools - } - return append(t.readTools, t.writeTools...) - } - return nil -} - -func (t *Toolset) GetAvailableTools() []ServerTool { - if t.readOnly { - return t.readTools - } - return append(t.readTools, t.writeTools...) -} - -func (t *Toolset) RegisterTools(s *mcp.Server) { - if !t.Enabled { - return - } - for _, tool := range t.readTools { - tool.RegisterFunc(s) - } - if !t.readOnly { - for _, tool := range t.writeTools { - tool.RegisterFunc(s) - } - } -} - -func (t *Toolset) AddResourceTemplates(templates ...ServerResourceTemplate) *Toolset { - t.resourceTemplates = append(t.resourceTemplates, templates...) - return t -} - -func (t *Toolset) AddPrompts(prompts ...ServerPrompt) *Toolset { - t.prompts = append(t.prompts, prompts...) - return t -} - -func (t *Toolset) GetActiveResourceTemplates() []ServerResourceTemplate { - if !t.Enabled { - return nil - } - return t.resourceTemplates -} - -func (t *Toolset) GetAvailableResourceTemplates() []ServerResourceTemplate { - return t.resourceTemplates -} - -func (t *Toolset) RegisterResourcesTemplates(s *mcp.Server) { - if !t.Enabled { - return - } - for _, resource := range t.resourceTemplates { - s.AddResourceTemplate(&resource.Template, resource.Handler) - } -} - -func (t *Toolset) RegisterPrompts(s *mcp.Server) { - if !t.Enabled { - return - } - for _, prompt := range t.prompts { - s.AddPrompt(&prompt.Prompt, prompt.Handler) - } -} - -func (t *Toolset) SetReadOnly() { - // Set the toolset to read-only - t.readOnly = true -} - -func (t *Toolset) AddWriteTools(tools ...ServerTool) *Toolset { - // Silently ignore if the toolset is read-only to avoid any breach of that contract - for _, tool := range tools { - if tool.Tool.Annotations.ReadOnlyHint { - panic(fmt.Sprintf("tool (%s) is incorrectly annotated as read-only", tool.Tool.Name)) - } - } - if !t.readOnly { - t.writeTools = append(t.writeTools, tools...) - } - return t -} - -func (t *Toolset) AddReadTools(tools ...ServerTool) *Toolset { - for _, tool := range tools { - if !tool.Tool.Annotations.ReadOnlyHint { - panic(fmt.Sprintf("tool (%s) must be annotated as read-only", tool.Tool.Name)) - } - } - t.readTools = append(t.readTools, tools...) - return t -} - -type ToolsetGroup struct { - Toolsets map[string]*Toolset - deprecatedAliases map[string]string - everythingOn bool - readOnly bool -} - -func NewToolsetGroup(readOnly bool) *ToolsetGroup { - return &ToolsetGroup{ - Toolsets: make(map[string]*Toolset), - deprecatedAliases: make(map[string]string), - everythingOn: false, - readOnly: readOnly, - } -} - -func (tg *ToolsetGroup) AddDeprecatedToolAliases(aliases map[string]string) { - for oldName, newName := range aliases { - tg.deprecatedAliases[oldName] = newName - } -} - -func (tg *ToolsetGroup) AddToolset(ts *Toolset) { - if tg.readOnly { - ts.SetReadOnly() - } - tg.Toolsets[ts.Name] = ts -} - -func NewToolset(name string, description string) *Toolset { - return &Toolset{ - Name: name, - Description: description, - Enabled: false, - readOnly: false, - } -} - -func (tg *ToolsetGroup) IsEnabled(name string) bool { - // If everythingOn is true, all features are enabled - if tg.everythingOn { - return true - } - - feature, exists := tg.Toolsets[name] - if !exists { - return false - } - return feature.Enabled -} - -type EnableToolsetsOptions struct { - ErrorOnUnknown bool -} - -func (tg *ToolsetGroup) EnableToolsets(names []string, options *EnableToolsetsOptions) error { - if options == nil { - options = &EnableToolsetsOptions{ - ErrorOnUnknown: false, - } - } - - // Special case for "all" - for _, name := range names { - if name == "all" { - tg.everythingOn = true - break - } - err := tg.EnableToolset(name) - if err != nil && options.ErrorOnUnknown { - return err - } - } - // Do this after to ensure all toolsets are enabled if "all" is present anywhere in list - if tg.everythingOn { - for name := range tg.Toolsets { - err := tg.EnableToolset(name) - if err != nil && options.ErrorOnUnknown { - return err - } - } - return nil - } - return nil -} - -func (tg *ToolsetGroup) EnableToolset(name string) error { - toolset, exists := tg.Toolsets[name] - if !exists { - return NewToolsetDoesNotExistError(name) - } - toolset.Enabled = true - tg.Toolsets[name] = toolset - return nil -} - -func (tg *ToolsetGroup) RegisterAll(s *mcp.Server) { - for _, toolset := range tg.Toolsets { - toolset.RegisterTools(s) - toolset.RegisterResourcesTemplates(s) - toolset.RegisterPrompts(s) - } -} - -func (tg *ToolsetGroup) GetToolset(name string) (*Toolset, error) { - toolset, exists := tg.Toolsets[name] - if !exists { - return nil, NewToolsetDoesNotExistError(name) - } - return toolset, nil -} - -type ToolDoesNotExistError struct { - Name string -} - -func (e *ToolDoesNotExistError) Error() string { - return fmt.Sprintf("tool %s does not exist", e.Name) -} - -func NewToolDoesNotExistError(name string) *ToolDoesNotExistError { - return &ToolDoesNotExistError{Name: name} -} - -// ResolveToolAliases resolves deprecated tool aliases to their canonical names. -// It logs a warning to stderr for each deprecated alias that is resolved. -// Returns: -// - resolved: tool names with aliases replaced by canonical names -// - aliasesUsed: map of oldName → newName for each alias that was resolved -func (tg *ToolsetGroup) ResolveToolAliases(toolNames []string) (resolved []string, aliasesUsed map[string]string) { - resolved = make([]string, 0, len(toolNames)) - aliasesUsed = make(map[string]string) - for _, toolName := range toolNames { - if canonicalName, isAlias := tg.deprecatedAliases[toolName]; isAlias { - fmt.Fprintf(os.Stderr, "Warning: tool %q is deprecated, use %q instead\n", toolName, canonicalName) - aliasesUsed[toolName] = canonicalName - resolved = append(resolved, canonicalName) - } else { - resolved = append(resolved, toolName) - } - } - return resolved, aliasesUsed -} - -// FindToolByName searches all toolsets (enabled or disabled) for a tool by name. -// Returns the tool, its parent toolset name, and an error if not found. -func (tg *ToolsetGroup) FindToolByName(toolName string) (*ServerTool, string, error) { - for toolsetName, toolset := range tg.Toolsets { - // Check read tools - for _, tool := range toolset.readTools { - if tool.Tool.Name == toolName { - return &tool, toolsetName, nil - } - } - // Check write tools - for _, tool := range toolset.writeTools { - if tool.Tool.Name == toolName { - return &tool, toolsetName, nil - } - } - } - return nil, "", NewToolDoesNotExistError(toolName) -} - -// RegisterSpecificTools registers only the specified tools. -// Respects read-only mode (skips write tools if readOnly=true). -// Returns error if any tool is not found. -func (tg *ToolsetGroup) RegisterSpecificTools(s *mcp.Server, toolNames []string, readOnly bool) error { - var skippedTools []string - for _, toolName := range toolNames { - tool, _, err := tg.FindToolByName(toolName) - if err != nil { - return fmt.Errorf("tool %s not found: %w", toolName, err) - } - - if !tool.Tool.Annotations.ReadOnlyHint && readOnly { - // Skip write tools in read-only mode - skippedTools = append(skippedTools, toolName) - continue - } - - // Register the tool - tool.RegisterFunc(s) - } - - // Log skipped write tools if any - if len(skippedTools) > 0 { - fmt.Fprintf(os.Stderr, "Write tools skipped due to read-only mode: %s\n", strings.Join(skippedTools, ", ")) - } - - return nil -} diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go deleted file mode 100644 index 6362aad0e..000000000 --- a/pkg/toolsets/toolsets_test.go +++ /dev/null @@ -1,395 +0,0 @@ -package toolsets - -import ( - "errors" - "testing" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// mockTool creates a minimal ServerTool for testing -func mockTool(name string, readOnly bool) ServerTool { - return ServerTool{ - Tool: mcp.Tool{ - Name: name, - Annotations: &mcp.ToolAnnotations{ - ReadOnlyHint: readOnly, - }, - }, - RegisterFunc: func(_ *mcp.Server) {}, - } -} - -func TestNewToolsetGroupIsEmptyWithoutEverythingOn(t *testing.T) { - tsg := NewToolsetGroup(false) - if len(tsg.Toolsets) != 0 { - t.Fatalf("Expected Toolsets map to be empty, got %d items", len(tsg.Toolsets)) - } - if tsg.everythingOn { - t.Fatal("Expected everythingOn to be initialized as false") - } -} - -func TestAddToolset(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test adding a toolset - toolset := NewToolset("test-toolset", "A test toolset") - toolset.Enabled = true - tsg.AddToolset(toolset) - - // Verify toolset was added correctly - if len(tsg.Toolsets) != 1 { - t.Errorf("Expected 1 toolset, got %d", len(tsg.Toolsets)) - } - - toolset, exists := tsg.Toolsets["test-toolset"] - if !exists { - t.Fatal("Feature was not added to the map") - } - - if toolset.Name != "test-toolset" { - t.Errorf("Expected toolset name to be 'test-toolset', got '%s'", toolset.Name) - } - - if toolset.Description != "A test toolset" { - t.Errorf("Expected toolset description to be 'A test toolset', got '%s'", toolset.Description) - } - - if !toolset.Enabled { - t.Error("Expected toolset to be enabled") - } - - // Test adding another toolset - anotherToolset := NewToolset("another-toolset", "Another test toolset") - tsg.AddToolset(anotherToolset) - - if len(tsg.Toolsets) != 2 { - t.Errorf("Expected 2 toolsets, got %d", len(tsg.Toolsets)) - } - - // Test overriding existing toolset - updatedToolset := NewToolset("test-toolset", "Updated description") - tsg.AddToolset(updatedToolset) - - toolset = tsg.Toolsets["test-toolset"] - if toolset.Description != "Updated description" { - t.Errorf("Expected toolset description to be updated to 'Updated description', got '%s'", toolset.Description) - } - - if toolset.Enabled { - t.Error("Expected toolset to be disabled after update") - } -} - -func TestIsEnabled(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test with non-existent toolset - if tsg.IsEnabled("non-existent") { - t.Error("Expected IsEnabled to return false for non-existent toolset") - } - - // Test with disabled toolset - disabledToolset := NewToolset("disabled-toolset", "A disabled toolset") - tsg.AddToolset(disabledToolset) - if tsg.IsEnabled("disabled-toolset") { - t.Error("Expected IsEnabled to return false for disabled toolset") - } - - // Test with enabled toolset - enabledToolset := NewToolset("enabled-toolset", "An enabled toolset") - enabledToolset.Enabled = true - tsg.AddToolset(enabledToolset) - if !tsg.IsEnabled("enabled-toolset") { - t.Error("Expected IsEnabled to return true for enabled toolset") - } -} - -func TestEnableFeature(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test enabling non-existent toolset - err := tsg.EnableToolset("non-existent") - if err == nil { - t.Error("Expected error when enabling non-existent toolset") - } - - // Test enabling toolset - testToolset := NewToolset("test-toolset", "A test toolset") - tsg.AddToolset(testToolset) - - if tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be disabled initially") - } - - err = tsg.EnableToolset("test-toolset") - if err != nil { - t.Errorf("Expected no error when enabling toolset, got: %v", err) - } - - if !tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be enabled after EnableFeature call") - } - - // Test enabling already enabled toolset - err = tsg.EnableToolset("test-toolset") - if err != nil { - t.Errorf("Expected no error when enabling already enabled toolset, got: %v", err) - } -} - -func TestEnableToolsets(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Prepare toolsets - toolset1 := NewToolset("toolset1", "Feature 1") - toolset2 := NewToolset("toolset2", "Feature 2") - tsg.AddToolset(toolset1) - tsg.AddToolset(toolset2) - - // Test enabling multiple toolsets - err := tsg.EnableToolsets([]string{"toolset1", "toolset2"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling toolsets, got: %v", err) - } - - if !tsg.IsEnabled("toolset1") { - t.Error("Expected toolset1 to be enabled") - } - - if !tsg.IsEnabled("toolset2") { - t.Error("Expected toolset2 to be enabled") - } - - // Test with non-existent toolset in the list - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, nil) - if err != nil { - t.Errorf("Expected no error when ignoring unknown toolsets, got: %v", err) - } - - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, &EnableToolsetsOptions{ - ErrorOnUnknown: false, - }) - if err != nil { - t.Errorf("Expected no error when ignoring unknown toolsets, got: %v", err) - } - - err = tsg.EnableToolsets([]string{"toolset1", "non-existent"}, &EnableToolsetsOptions{ErrorOnUnknown: true}) - if err == nil { - t.Error("Expected error when enabling list with non-existent toolset") - } - if !errors.Is(err, NewToolsetDoesNotExistError("non-existent")) { - t.Errorf("Expected ToolsetDoesNotExistError when enabling non-existent toolset, got: %v", err) - } - - // Test with empty list - err = tsg.EnableToolsets([]string{}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error with empty toolset list, got: %v", err) - } - - // Test enabling everything through EnableToolsets - tsg = NewToolsetGroup(false) - err = tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) - } - - if !tsg.everythingOn { - t.Error("Expected everythingOn to be true after enabling 'all' via EnableToolsets") - } -} - -func TestEnableEverything(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Add a disabled toolset - testToolset := NewToolset("test-toolset", "A test toolset") - tsg.AddToolset(testToolset) - - // Verify it's disabled - if tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be disabled initially") - } - - // Enable "all" - err := tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) - } - - // Verify everythingOn was set - if !tsg.everythingOn { - t.Error("Expected everythingOn to be true after enabling 'all'") - } - - // Verify the previously disabled toolset is now enabled - if !tsg.IsEnabled("test-toolset") { - t.Error("Expected toolset to be enabled when everythingOn is true") - } - - // Verify a non-existent toolset is also enabled - if !tsg.IsEnabled("non-existent") { - t.Error("Expected non-existent toolset to be enabled when everythingOn is true") - } -} - -func TestIsEnabledWithEverythingOn(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Enable "all" - err := tsg.EnableToolsets([]string{"all"}, &EnableToolsetsOptions{}) - if err != nil { - t.Errorf("Expected no error when enabling 'all', got: %v", err) - } - - // Test that any toolset name returns true with IsEnabled - if !tsg.IsEnabled("some-toolset") { - t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") - } - - if !tsg.IsEnabled("another-toolset") { - t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") - } -} - -func TestToolsetGroup_GetToolset(t *testing.T) { - tsg := NewToolsetGroup(false) - toolset := NewToolset("my-toolset", "desc") - tsg.AddToolset(toolset) - - // Should find the toolset - got, err := tsg.GetToolset("my-toolset") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if got != toolset { - t.Errorf("expected to get the same toolset instance") - } - - // Should not find a non-existent toolset - _, err = tsg.GetToolset("does-not-exist") - if err == nil { - t.Error("expected error for missing toolset, got nil") - } - if !errors.Is(err, NewToolsetDoesNotExistError("does-not-exist")) { - t.Errorf("expected error to be ToolsetDoesNotExistError, got %v", err) - } -} - -func TestAddDeprecatedToolAliases(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Test adding aliases - tsg.AddDeprecatedToolAliases(map[string]string{ - "old_name": "new_name", - "get_issue": "issue_read", - "create_pr": "pull_request_create", - }) - - if len(tsg.deprecatedAliases) != 3 { - t.Errorf("expected 3 aliases, got %d", len(tsg.deprecatedAliases)) - } - if tsg.deprecatedAliases["old_name"] != "new_name" { - t.Errorf("expected alias 'old_name' -> 'new_name', got '%s'", tsg.deprecatedAliases["old_name"]) - } - if tsg.deprecatedAliases["get_issue"] != "issue_read" { - t.Errorf("expected alias 'get_issue' -> 'issue_read'") - } - if tsg.deprecatedAliases["create_pr"] != "pull_request_create" { - t.Errorf("expected alias 'create_pr' -> 'pull_request_create'") - } -} - -func TestResolveToolAliases(t *testing.T) { - tsg := NewToolsetGroup(false) - tsg.AddDeprecatedToolAliases(map[string]string{ - "get_issue": "issue_read", - "create_pr": "pull_request_create", - }) - - // Test resolving a mix of aliases and canonical names - input := []string{"get_issue", "some_tool", "create_pr"} - resolved, aliasesUsed := tsg.ResolveToolAliases(input) - - // Verify resolved names - if len(resolved) != 3 { - t.Fatalf("expected 3 resolved names, got %d", len(resolved)) - } - if resolved[0] != "issue_read" { - t.Errorf("expected 'issue_read', got '%s'", resolved[0]) - } - if resolved[1] != "some_tool" { - t.Errorf("expected 'some_tool' (unchanged), got '%s'", resolved[1]) - } - if resolved[2] != "pull_request_create" { - t.Errorf("expected 'pull_request_create', got '%s'", resolved[2]) - } - - // Verify aliasesUsed map - if len(aliasesUsed) != 2 { - t.Fatalf("expected 2 aliases used, got %d", len(aliasesUsed)) - } - if aliasesUsed["get_issue"] != "issue_read" { - t.Errorf("expected aliasesUsed['get_issue'] = 'issue_read', got '%s'", aliasesUsed["get_issue"]) - } - if aliasesUsed["create_pr"] != "pull_request_create" { - t.Errorf("expected aliasesUsed['create_pr'] = 'pull_request_create', got '%s'", aliasesUsed["create_pr"]) - } -} - -func TestFindToolByName(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Create a toolset with a tool - toolset := NewToolset("test-toolset", "Test toolset") - toolset.readTools = append(toolset.readTools, mockTool("issue_read", true)) - tsg.AddToolset(toolset) - - // Find by canonical name - tool, toolsetName, err := tsg.FindToolByName("issue_read") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - if tool.Tool.Name != "issue_read" { - t.Errorf("expected tool name 'issue_read', got '%s'", tool.Tool.Name) - } - if toolsetName != "test-toolset" { - t.Errorf("expected toolset name 'test-toolset', got '%s'", toolsetName) - } - - // FindToolByName does NOT resolve aliases - it expects canonical names - _, _, err = tsg.FindToolByName("get_issue") - if err == nil { - t.Error("expected error when using alias directly with FindToolByName") - } -} - -func TestRegisterSpecificTools(t *testing.T) { - tsg := NewToolsetGroup(false) - - // Create a toolset with both read and write tools - toolset := NewToolset("test-toolset", "Test toolset") - toolset.readTools = append(toolset.readTools, mockTool("issue_read", true)) - toolset.writeTools = append(toolset.writeTools, mockTool("issue_write", false)) - tsg.AddToolset(toolset) - - // Test registering with canonical names - err := tsg.RegisterSpecificTools(nil, []string{"issue_read"}, false) - if err != nil { - t.Errorf("expected no error registering tool, got %v", err) - } - - // Test registering write tool in read-only mode (should skip but not error) - err = tsg.RegisterSpecificTools(nil, []string{"issue_write"}, true) - if err != nil { - t.Errorf("expected no error when skipping write tool in read-only mode, got %v", err) - } - - // Test registering non-existent tool (should error) - err = tsg.RegisterSpecificTools(nil, []string{"nonexistent"}, false) - if err == nil { - t.Error("expected error for non-existent tool") - } -} diff --git a/script/conformance-test b/script/conformance-test new file mode 100755 index 000000000..3ff0a55c2 --- /dev/null +++ b/script/conformance-test @@ -0,0 +1,432 @@ +#!/bin/bash +set -e + +# Conformance test script for comparing MCP server behavior between branches +# Builds both main and current branch, runs various flag combinations, +# and produces a conformance report with timing and diffs. +# +# Output: +# - Progress/status messages go to stderr (for visibility in CI) +# - Final report summary goes to stdout (for piping/capture) + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(dirname "$SCRIPT_DIR")" +REPORT_DIR="$PROJECT_DIR/conformance-report" +CURRENT_BRANCH=$(git rev-parse --abbrev-ref HEAD) + +# Colors for output (only used on stderr) +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Helper to print to stderr +log() { + echo -e "$@" >&2 +} + +log "${BLUE}=== MCP Server Conformance Test ===${NC}" +log "Current branch: $CURRENT_BRANCH" +log "Report directory: $REPORT_DIR" + +# Find the common ancestor +MERGE_BASE=$(git merge-base HEAD origin/main) +log "Comparing against merge-base: $MERGE_BASE" +log "" + +# Create report directory +rm -rf "$REPORT_DIR" +mkdir -p "$REPORT_DIR"/{main,branch,diffs} + +# Build binaries +log "${YELLOW}Building binaries...${NC}" + +log "Building current branch ($CURRENT_BRANCH)..." +go build -o "$REPORT_DIR/branch/github-mcp-server" ./cmd/github-mcp-server +BRANCH_BUILD_OK=$? + +log "Building main branch (using temp worktree at merge-base)..." +TEMP_WORKTREE=$(mktemp -d) +git worktree add --quiet "$TEMP_WORKTREE" "$MERGE_BASE" +(cd "$TEMP_WORKTREE" && go build -o "$REPORT_DIR/main/github-mcp-server" ./cmd/github-mcp-server) +MAIN_BUILD_OK=$? +git worktree remove --force "$TEMP_WORKTREE" + +if [ $BRANCH_BUILD_OK -ne 0 ] || [ $MAIN_BUILD_OK -ne 0 ]; then + log "${RED}Build failed!${NC}" + exit 1 +fi + +log "${GREEN}Both binaries built successfully${NC}" +log "" + +# MCP JSON-RPC messages +INIT_MSG='{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"conformance-test","version":"1.0.0"}}}' +INITIALIZED_MSG='{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}' +LIST_TOOLS_MSG='{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' +LIST_RESOURCES_MSG='{"jsonrpc":"2.0","id":3,"method":"resources/listTemplates","params":{}}' +LIST_PROMPTS_MSG='{"jsonrpc":"2.0","id":4,"method":"prompts/list","params":{}}' + +# Dynamic toolset management tool calls (for dynamic mode testing) +LIST_TOOLSETS_MSG='{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' +GET_TOOLSET_TOOLS_MSG='{"jsonrpc":"2.0","id":11,"method":"tools/call","params":{"name":"get_toolset_tools","arguments":{"toolset":"repos"}}}' +ENABLE_TOOLSET_MSG='{"jsonrpc":"2.0","id":12,"method":"tools/call","params":{"name":"enable_toolset","arguments":{"toolset":"repos"}}}' +LIST_TOOLSETS_AFTER_MSG='{"jsonrpc":"2.0","id":13,"method":"tools/call","params":{"name":"list_available_toolsets","arguments":{}}}' + +# Function to normalize JSON for comparison +# Sorts all arrays (including nested ones) and formats consistently +# Also handles embedded JSON strings in "text" fields (from tool call responses) +normalize_json() { + local file="$1" + if [ -s "$file" ]; then + # First, try to parse and re-serialize any JSON embedded in text fields + # This handles tool call responses where the result is JSON-in-a-string + jq -S ' + # Function to sort arrays recursively + def deep_sort: + if type == "array" then + [.[] | deep_sort] | sort_by(tostring) + elif type == "object" then + to_entries | map(.value |= deep_sort) | from_entries + else + . + end; + + # Walk the structure, and for any "text" field that looks like JSON array/object, parse and sort it + walk( + if type == "object" and .text and (.text | type == "string") and ((.text | startswith("[")) or (.text | startswith("{"))) then + .text = ((.text | fromjson | deep_sort) | tojson) + else + . + end + ) | deep_sort + ' "$file" 2>/dev/null > "${file}.tmp" && mv "${file}.tmp" "$file" + fi +} + +# Function to run MCP server and capture output with timing +run_mcp_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with all list commands - each response is on its own line + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLS_MSG" + echo "$LIST_RESOURCES_MSG" + echo "$LIST_PROMPTS_MSG" + sleep 0.5 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + # Each line is a separate JSON response + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 2) echo "$line" | jq -S '.' > "${output_prefix}_tools.json" 2>/dev/null ;; + 3) echo "$line" | jq -S '.' > "${output_prefix}_resources.json" 2>/dev/null ;; + 4) echo "$line" | jq -S '.' > "${output_prefix}_prompts.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created (in case of errors or missing responses) + touch "${output_prefix}_initialize.json" "${output_prefix}_tools.json" \ + "${output_prefix}_resources.json" "${output_prefix}_prompts.json" + + # Normalize all JSON files for consistent comparison (sorts arrays, keys) + for endpoint in initialize tools resources prompts; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Function to run MCP server with dynamic tool calls (for dynamic mode testing) +run_mcp_dynamic_test() { + local binary="$1" + local name="$2" + local flags="$3" + local output_prefix="$4" + + local start_time end_time duration + start_time=$(date +%s.%N) + + # Run the server with dynamic tool calls in sequence: + # 1. Initialize + # 2. List available toolsets (before enable) + # 3. Get tools for repos toolset + # 4. Enable repos toolset + # 5. List available toolsets (after enable - should show repos as enabled) + output=$( + ( + echo "$INIT_MSG" + echo "$INITIALIZED_MSG" + echo "$LIST_TOOLSETS_MSG" + sleep 0.1 + echo "$GET_TOOLSET_TOOLS_MSG" + sleep 0.1 + echo "$ENABLE_TOOLSET_MSG" + sleep 0.1 + echo "$LIST_TOOLSETS_AFTER_MSG" + sleep 0.3 + ) | GITHUB_PERSONAL_ACCESS_TOKEN=1 $binary stdio $flags 2>/dev/null + ) + + end_time=$(date +%s.%N) + duration=$(echo "$end_time - $start_time" | bc) + + # Parse and save each response by matching JSON-RPC id + echo "$output" | while IFS= read -r line; do + id=$(echo "$line" | jq -r '.id // empty' 2>/dev/null) + case "$id" in + 1) echo "$line" | jq -S '.' > "${output_prefix}_initialize.json" 2>/dev/null ;; + 10) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_before.json" 2>/dev/null ;; + 11) echo "$line" | jq -S '.' > "${output_prefix}_get_toolset_tools.json" 2>/dev/null ;; + 12) echo "$line" | jq -S '.' > "${output_prefix}_enable_toolset.json" 2>/dev/null ;; + 13) echo "$line" | jq -S '.' > "${output_prefix}_list_toolsets_after.json" 2>/dev/null ;; + esac + done + + # Create empty files if not created + touch "${output_prefix}_initialize.json" "${output_prefix}_list_toolsets_before.json" \ + "${output_prefix}_get_toolset_tools.json" "${output_prefix}_enable_toolset.json" \ + "${output_prefix}_list_toolsets_after.json" + + # Normalize all JSON files + for endpoint in initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + normalize_json "${output_prefix}_${endpoint}.json" + done + + echo "$duration" +} + +# Test configurations - array of "name|flags|type" +# type can be "standard" or "dynamic" (for dynamic tool call testing) +declare -a TEST_CONFIGS=( + "default||standard" + "read-only|--read-only|standard" + "dynamic-toolsets|--dynamic-toolsets|standard" + "read-only+dynamic|--read-only --dynamic-toolsets|standard" + "toolsets-repos|--toolsets=repos|standard" + "toolsets-issues|--toolsets=issues|standard" + "toolsets-pull_requests|--toolsets=pull_requests|standard" + "toolsets-repos,issues|--toolsets=repos,issues|standard" + "toolsets-all|--toolsets=all|standard" + "tools-get_me|--tools=get_me|standard" + "tools-get_me,list_issues|--tools=get_me,list_issues|standard" + "toolsets-repos+read-only|--toolsets=repos --read-only|standard" + "toolsets-all+dynamic|--toolsets=all --dynamic-toolsets|standard" + "toolsets-repos+dynamic|--toolsets=repos --dynamic-toolsets|standard" + "toolsets-repos,issues+dynamic|--toolsets=repos,issues --dynamic-toolsets|standard" + "dynamic-tool-calls|--dynamic-toolsets|dynamic" +) + +# Summary arrays +declare -a TEST_NAMES +declare -a MAIN_TIMES +declare -a BRANCH_TIMES +declare -a DIFF_STATUS + +log "${YELLOW}Running conformance tests...${NC}" +log "" + +for config in "${TEST_CONFIGS[@]}"; do + IFS='|' read -r test_name flags test_type <<< "$config" + + log "${BLUE}Test: ${test_name}${NC}" + log " Flags: ${flags:-}" + log " Type: ${test_type}" + + # Create output directories + mkdir -p "$REPORT_DIR/main/$test_name" + mkdir -p "$REPORT_DIR/branch/$test_name" + mkdir -p "$REPORT_DIR/diffs/$test_name" + + if [ "$test_type" = "dynamic" ]; then + # Run dynamic tool call test + main_time=$(run_mcp_dynamic_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_dynamic_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after" + else + # Run standard test + main_time=$(run_mcp_test "$REPORT_DIR/main/github-mcp-server" "main" "$flags" "$REPORT_DIR/main/$test_name/output") + log " Main: ${main_time}s" + + branch_time=$(run_mcp_test "$REPORT_DIR/branch/github-mcp-server" "branch" "$flags" "$REPORT_DIR/branch/$test_name/output") + log " Branch: ${branch_time}s" + + endpoints="initialize tools resources prompts" + fi + + # Calculate time difference + time_diff=$(echo "$branch_time - $main_time" | bc) + if (( $(echo "$time_diff > 0" | bc -l) )); then + log " Δ Time: ${RED}+${time_diff}s (slower)${NC}" + else + log " Δ Time: ${GREEN}${time_diff}s (faster)${NC}" + fi + + # Generate diffs for each endpoint + has_diff=false + for endpoint in $endpoints; do + main_file="$REPORT_DIR/main/$test_name/output_${endpoint}.json" + branch_file="$REPORT_DIR/branch/$test_name/output_${endpoint}.json" + diff_file="$REPORT_DIR/diffs/$test_name/${endpoint}.diff" + + if ! diff -u "$main_file" "$branch_file" > "$diff_file" 2>/dev/null; then + has_diff=true + lines=$(wc -l < "$diff_file" | tr -d ' ') + log " ${YELLOW}${endpoint}: DIFF (${lines} lines)${NC}" + else + rm -f "$diff_file" # No diff, remove empty file + log " ${GREEN}${endpoint}: OK${NC}" + fi + done + + # Store results + TEST_NAMES+=("$test_name") + MAIN_TIMES+=("$main_time") + BRANCH_TIMES+=("$branch_time") + if [ "$has_diff" = true ]; then + DIFF_STATUS+=("DIFF") + else + DIFF_STATUS+=("OK") + fi + + log "" +done + +# Generate summary report +REPORT_FILE="$REPORT_DIR/CONFORMANCE_REPORT.md" + +cat > "$REPORT_FILE" << EOF +# MCP Server Conformance Report + +Generated: $(date) +Current Branch: $CURRENT_BRANCH +Compared Against: merge-base ($MERGE_BASE) + +## Summary + +| Test | Main Time | Branch Time | Δ Time | Status | +|------|-----------|-------------|--------|--------| +EOF + +total_main=0 +total_branch=0 +diff_count=0 +ok_count=0 + +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + main_t="${MAIN_TIMES[$i]}" + branch_t="${BRANCH_TIMES[$i]}" + status="${DIFF_STATUS[$i]}" + + delta=$(echo "$branch_t - $main_t" | bc) + if (( $(echo "$delta > 0" | bc -l) )); then + delta_str="+${delta}s" + else + delta_str="${delta}s" + fi + + if [ "$status" = "DIFF" ]; then + status_str="⚠️ DIFF" + ((diff_count++)) || true + else + status_str="✅ OK" + ((ok_count++)) || true + fi + + total_main=$(echo "$total_main + $main_t" | bc) + total_branch=$(echo "$total_branch + $branch_t" | bc) + + echo "| $name | ${main_t}s | ${branch_t}s | $delta_str | $status_str |" >> "$REPORT_FILE" +done + +total_delta=$(echo "$total_branch - $total_main" | bc) +if (( $(echo "$total_delta > 0" | bc -l) )); then + total_delta_str="+${total_delta}s" +else + total_delta_str="${total_delta}s" +fi + +cat >> "$REPORT_FILE" << EOF +| **TOTAL** | **${total_main}s** | **${total_branch}s** | **$total_delta_str** | **$ok_count OK / $diff_count DIFF** | + +## Statistics + +- **Tests Passed (no diff):** $ok_count +- **Tests with Differences:** $diff_count +- **Total Main Time:** ${total_main}s +- **Total Branch Time:** ${total_branch}s +- **Overall Time Delta:** $total_delta_str + +## Detailed Diffs + +EOF + +# Add diff details to report +for i in "${!TEST_NAMES[@]}"; do + name="${TEST_NAMES[$i]}" + status="${DIFF_STATUS[$i]}" + + if [ "$status" = "DIFF" ]; then + echo "### $name" >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + + # Check all possible endpoints + for endpoint in initialize tools resources prompts list_toolsets_before get_toolset_tools enable_toolset list_toolsets_after; do + diff_file="$REPORT_DIR/diffs/$name/${endpoint}.diff" + if [ -f "$diff_file" ] && [ -s "$diff_file" ]; then + echo "#### ${endpoint}" >> "$REPORT_FILE" + echo '```diff' >> "$REPORT_FILE" + cat "$diff_file" >> "$REPORT_FILE" + echo '```' >> "$REPORT_FILE" + echo "" >> "$REPORT_FILE" + fi + done + fi +done + +log "${BLUE}=== Conformance Test Complete ===${NC}" +log "" +log "Report: ${GREEN}$REPORT_FILE${NC}" +log "" + +# Output summary to stdout (for CI capture) +echo "=== Conformance Test Summary ===" +echo "Tests passed: $ok_count" +echo "Tests with diffs: $diff_count" +echo "Total main time: ${total_main}s" +echo "Total branch time: ${total_branch}s" +echo "Time delta: $total_delta_str" + +if [ $diff_count -gt 0 ]; then + log "" + log "${YELLOW}⚠️ Some tests have differences. Review the diffs in:${NC}" + log " $REPORT_DIR/diffs/" + echo "" + echo "RESULT: DIFFERENCES FOUND" + # Don't exit with error - diffs may be intentional improvements +else + echo "" + echo "RESULT: ALL TESTS PASSED" +fi diff --git a/script/fetch-icons b/script/fetch-icons new file mode 100755 index 000000000..21de625f1 --- /dev/null +++ b/script/fetch-icons @@ -0,0 +1,72 @@ +#!/bin/bash +# Fetch Octicon icons and convert them to PNG for embedding in the MCP server. +# Generates both light theme (dark icons) and dark theme (white icons) variants. +# Uses sed to modify SVG fill color before converting to PNG. +# Requires: rsvg-convert (from librsvg2-bin on Ubuntu/Debian) +# +# Usage: +# script/fetch-icons # Fetch all required icons +# script/fetch-icons icon1 icon2 # Fetch specific icons + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +ICONS_DIR="$REPO_ROOT/pkg/octicons/icons" +REQUIRED_ICONS_FILE="$REPO_ROOT/pkg/octicons/required_icons.txt" +OCTICONS_BASE="https://raw.githubusercontent.com/primer/octicons/main/icons" + +# Check for rsvg-convert +if ! command -v rsvg-convert &> /dev/null; then + echo "Error: rsvg-convert not found. Install with:" + echo " Ubuntu/Debian: sudo apt-get install librsvg2-bin" + echo " macOS: brew install librsvg" + exit 1 +fi + +# Load icons from required_icons.txt or use command-line arguments +if [ $# -gt 0 ]; then + ICONS=("$@") +else + if [ ! -f "$REQUIRED_ICONS_FILE" ]; then + echo "Error: Required icons file not found: $REQUIRED_ICONS_FILE" + exit 1 + fi + # Read icons from file, skipping comments and empty lines + mapfile -t ICONS < <(grep -v '^#' "$REQUIRED_ICONS_FILE" | grep -v '^$') +fi + +# Ensure icons directory exists +mkdir -p "$ICONS_DIR" + +echo "Fetching ${#ICONS[@]} icons (24px, light + dark themes)..." + +for icon in "${ICONS[@]}"; do + svg_url="${OCTICONS_BASE}/${icon}-24.svg" + light_file="${ICONS_DIR}/${icon}-light.png" + dark_file="${ICONS_DIR}/${icon}-dark.png" + + echo " ${icon} (light + dark)" + + # Download SVG + svg_content=$(curl -sfL "$svg_url" 2>/dev/null) || { + echo " Warning: Failed to fetch ${icon}-24.svg (may not exist)" + continue + } + + # Light theme: dark icons (#24292f) for light backgrounds + # Add fill attribute to the svg tag + light_svg=$(echo "$svg_content" | sed 's/