Skip to content

feat: add tool for checking if PR is merged #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions pkg/github/pullrequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,80 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun
}
}

// IsPullRequestMerged creates a tool to check if a pull request is merged.
func IsPullRequestMerged(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("is_pull_request_merged",
mcp.WithDescription(t("TOOL_IS_PULL_REQUEST_MERGED_DESCRIPTION", "Check if a pull request is merged.")),
mcp.WithString("owner",
mcp.Required(),
mcp.Description("Repository owner"),
),
mcp.WithString("repo",
mcp.Required(),
mcp.Description("Repository name"),
),
mcp.WithNumber("pullNumber",
mcp.Required(),
mcp.Description("Pull request number"),
),
),
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
owner, err := requiredParam[string](request, "owner")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
repo, err := requiredParam[string](request, "repo")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pullNumber, err := RequiredInt(request, "pullNumber")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

client, err := getClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
}
_, resp, err := client.PullRequests.IsMerged(ctx, owner, repo, pullNumber)
if err != nil {
return nil, fmt.Errorf("failed to check if pull request is merged: %w", err)
}
defer func() { _ = resp.Body.Close() }()

// 204 if pull request is merged, 404 if not
type responseFormat = struct {
Status string `json:"status"`
}
switch resp.StatusCode {
case http.StatusNoContent:
response := responseFormat{
Status: "Pull request is merged.",
}
r, err := json.Marshal(response)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}
return mcp.NewToolResultText(string(r)), nil
case http.StatusNotFound:
response := responseFormat{
Status: "Pull request not merged or does not exist.",
}
r, err := json.Marshal(response)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}
return mcp.NewToolResultText(string(r)), nil
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return mcp.NewToolResultError(fmt.Sprintf("failed to check if pull request is merged: %s", string(body))), nil
}
}

// MergePullRequest creates a tool to merge a pull request.
func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("merge_pull_request",
Expand Down
109 changes: 109 additions & 0 deletions pkg/github/pullrequests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,115 @@ func Test_GetPullRequestComments(t *testing.T) {
}
}

func Test_IsPullRequestMerged(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
tool, _ := IsPullRequestMerged(stubGetClientFn(mockClient), translations.NullTranslationHelper)

assert.Equal(t, "is_pull_request_merged", tool.Name)
assert.NotEmpty(t, tool.Description)
assert.Contains(t, tool.InputSchema.Properties, "owner")
assert.Contains(t, tool.InputSchema.Properties, "repo")
assert.Contains(t, tool.InputSchema.Properties, "pullNumber")
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber"})

tests := []struct {
name string
mockedClient *http.Client
requestArgs map[string]interface{}
expectError bool
expectedResult string
expectedErrMsg string
}{
{
name: "successfully check merged PR",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.GetReposPullsMergeByOwnerByRepoByPullNumber,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// GitHub API returns 204 No Content for merged PRs
w.WriteHeader(http.StatusNoContent)
}),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
},
expectError: false,
expectedResult: "Pull request is merged.",
},
{
name: "successfully check unmerged PR",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.GetReposPullsMergeByOwnerByRepoByPullNumber,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// GitHub API returns 404 Not Found for unmerged PRs
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(`{"message": "Not Found"}`))
}),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
},
expectError: false,
expectedResult: "Pull request not merged or does not exist.",
},
{
name: "unexpected status code",
mockedClient: mock.NewMockedHTTPClient(
mock.WithRequestMatchHandler(
mock.GetReposPullsMergeByOwnerByRepoByPullNumber,
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// Some other error
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(`{"message": "Internal Server Error"}`))
}),
),
),
requestArgs: map[string]interface{}{
"owner": "owner",
"repo": "repo",
"pullNumber": float64(42),
},
expectError: true,
expectedErrMsg: "failed to check if pull request is merged",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
_, handler := IsPullRequestMerged(stubGetClientFn(client), translations.NullTranslationHelper)

request := createMCPRequest(tc.requestArgs)
result, err := handler(context.Background(), request)

if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedErrMsg)
return
}

require.NoError(t, err)
textContent := getTextResult(t, result)

var response struct {
Status string `json:"status"`
}
err = json.Unmarshal([]byte(textContent.Text), &response)
require.NoError(t, err)
assert.Equal(t, tc.expectedResult, response.Status)
})
}
}

func Test_GetPullRequestReviews(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
Expand Down
2 changes: 1 addition & 1 deletion pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati
// Add GitHub tools - Pull Requests
s.AddTool(GetPullRequest(getClient, t))
s.AddTool(ListPullRequests(getClient, t))
s.AddTool(IsPullRequestMerged(getClient, t))
s.AddTool(GetPullRequestFiles(getClient, t))
s.AddTool(GetPullRequestStatus(getClient, t))
s.AddTool(GetPullRequestComments(getClient, t))
Expand Down Expand Up @@ -179,7 +180,6 @@ func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) {

if r.Params.Arguments[p].(T) == zero {
return zero, fmt.Errorf("missing required parameter: %s", p)

}

return r.Params.Arguments[p].(T), nil
Expand Down