Skip to content

Commit a7a17b6

Browse files
authored
Merge branch 'main' into main
2 parents b19b21e + 6b923f6 commit a7a17b6

File tree

8 files changed

+467
-151
lines changed

8 files changed

+467
-151
lines changed

client/client.go

+50
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package client
33

44
import (
55
"context"
6+
"encoding/json"
7+
"fmt"
68

79
"github.com/mark3labs/mcp-go/mcp"
810
)
@@ -18,12 +20,25 @@ type MCPClient interface {
1820
// Ping checks if the server is alive
1921
Ping(ctx context.Context) error
2022

23+
// ListResourcesByPage manually list resources by page.
24+
ListResourcesByPage(
25+
ctx context.Context,
26+
request mcp.ListResourcesRequest,
27+
) (*mcp.ListResourcesResult, error)
28+
2129
// ListResources requests a list of available resources from the server
2230
ListResources(
2331
ctx context.Context,
2432
request mcp.ListResourcesRequest,
2533
) (*mcp.ListResourcesResult, error)
2634

35+
// ListResourceTemplatesByPage manually list resource templates by page.
36+
ListResourceTemplatesByPage(
37+
ctx context.Context,
38+
request mcp.ListResourceTemplatesRequest,
39+
) (*mcp.ListResourceTemplatesResult,
40+
error)
41+
2742
// ListResourceTemplates requests a list of available resource templates from the server
2843
ListResourceTemplates(
2944
ctx context.Context,
@@ -43,6 +58,12 @@ type MCPClient interface {
4358
// Unsubscribe cancels notifications for a specific resource
4459
Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error
4560

61+
// ListPromptsByPage manually list prompts by page.
62+
ListPromptsByPage(
63+
ctx context.Context,
64+
request mcp.ListPromptsRequest,
65+
) (*mcp.ListPromptsResult, error)
66+
4667
// ListPrompts requests a list of available prompts from the server
4768
ListPrompts(
4869
ctx context.Context,
@@ -55,6 +76,12 @@ type MCPClient interface {
5576
request mcp.GetPromptRequest,
5677
) (*mcp.GetPromptResult, error)
5778

79+
// ListToolsByPage manually list tools by page.
80+
ListToolsByPage(
81+
ctx context.Context,
82+
request mcp.ListToolsRequest,
83+
) (*mcp.ListToolsResult, error)
84+
5885
// ListTools requests a list of available tools from the server
5986
ListTools(
6087
ctx context.Context,
@@ -82,3 +109,26 @@ type MCPClient interface {
82109
// OnNotification registers a handler for notifications
83110
OnNotification(handler func(notification mcp.JSONRPCNotification))
84111
}
112+
113+
type mcpClient interface {
114+
MCPClient
115+
116+
sendRequest(ctx context.Context, method string, params interface{}) (*json.RawMessage, error)
117+
}
118+
119+
func listByPage[T any](
120+
ctx context.Context,
121+
client mcpClient,
122+
request mcp.PaginatedRequest,
123+
method string,
124+
) (*T, error) {
125+
response, err := client.sendRequest(ctx, method, request.Params)
126+
if err != nil {
127+
return nil, err
128+
}
129+
var result T
130+
if err := json.Unmarshal(*response, &result); err != nil {
131+
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
132+
}
133+
return &result, nil
134+
}

client/sse.go

+128-63
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,19 @@ import (
2323
// while sending requests over regular HTTP POST calls. The client handles
2424
// automatic reconnection and message routing between requests and responses.
2525
type SSEMCPClient struct {
26-
baseURL *url.URL
27-
endpoint *url.URL
28-
httpClient *http.Client
29-
requestID atomic.Int64
30-
responses map[int64]chan RPCResponse
31-
mu sync.RWMutex
32-
done chan struct{}
33-
initialized bool
34-
notifications []func(mcp.JSONRPCNotification)
35-
notifyMu sync.RWMutex
36-
endpointChan chan struct{}
37-
capabilities mcp.ServerCapabilities
38-
headers map[string]string
39-
sseReadTimeout time.Duration
26+
baseURL *url.URL
27+
endpoint *url.URL
28+
httpClient *http.Client
29+
requestID atomic.Int64
30+
responses map[int64]chan RPCResponse
31+
mu sync.RWMutex
32+
done chan struct{}
33+
initialized bool
34+
notifications []func(mcp.JSONRPCNotification)
35+
notifyMu sync.RWMutex
36+
endpointChan chan struct{}
37+
capabilities mcp.ServerCapabilities
38+
headers map[string]string
4039
}
4140

4241
type ClientOption func(*SSEMCPClient)
@@ -47,12 +46,6 @@ func WithHeaders(headers map[string]string) ClientOption {
4746
}
4847
}
4948

50-
func WithSSEReadTimeout(timeout time.Duration) ClientOption {
51-
return func(sc *SSEMCPClient) {
52-
sc.sseReadTimeout = timeout
53-
}
54-
}
55-
5649
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
5750
// Returns an error if the URL is invalid.
5851
func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) {
@@ -62,13 +55,12 @@ func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, er
6255
}
6356

6457
smc := &SSEMCPClient{
65-
baseURL: parsedURL,
66-
httpClient: &http.Client{},
67-
responses: make(map[int64]chan RPCResponse),
68-
done: make(chan struct{}),
69-
endpointChan: make(chan struct{}),
70-
sseReadTimeout: 30 * time.Second,
71-
headers: make(map[string]string),
58+
baseURL: parsedURL,
59+
httpClient: &http.Client{},
60+
responses: make(map[int64]chan RPCResponse),
61+
done: make(chan struct{}),
62+
endpointChan: make(chan struct{}),
63+
headers: make(map[string]string),
7264
}
7365

7466
for _, opt := range options {
@@ -93,6 +85,9 @@ func (c *SSEMCPClient) Start(ctx context.Context) error {
9385
req.Header.Set("Accept", "text/event-stream")
9486
req.Header.Set("Cache-Control", "no-cache")
9587
req.Header.Set("Connection", "keep-alive")
88+
for k, v := range c.headers {
89+
req.Header.Set(k, v)
90+
}
9691

9792
resp, err := c.httpClient.Do(req)
9893
if err != nil {
@@ -128,12 +123,9 @@ func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
128123
br := bufio.NewReader(reader)
129124
var event, data string
130125

131-
ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout)
132-
defer cancel()
133-
134126
for {
135127
select {
136-
case <-ctx.Done():
128+
case <-c.done:
137129
return
138130
default:
139131
line, err := br.ReadString('\n')
@@ -399,7 +391,7 @@ func (c *SSEMCPClient) Initialize(
399391
err,
400392
)
401393
}
402-
resp.Body.Close()
394+
defer resp.Body.Close()
403395

404396
c.initialized = true
405397
return &result, nil
@@ -410,42 +402,77 @@ func (c *SSEMCPClient) Ping(ctx context.Context) error {
410402
return err
411403
}
412404

413-
func (c *SSEMCPClient) ListResources(
405+
// ListResourcesByPage manually list resources by page.
406+
func (c *SSEMCPClient) ListResourcesByPage(
414407
ctx context.Context,
415408
request mcp.ListResourcesRequest,
416409
) (*mcp.ListResourcesResult, error) {
417-
response, err := c.sendRequest(ctx, "resources/list", request.Params)
410+
result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
418411
if err != nil {
419412
return nil, err
420413
}
414+
return result, nil
415+
}
421416

422-
var result mcp.ListResourcesResult
423-
if err := json.Unmarshal(*response, &result); err != nil {
424-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
417+
func (c *SSEMCPClient) ListResources(
418+
ctx context.Context,
419+
request mcp.ListResourcesRequest,
420+
) (*mcp.ListResourcesResult, error) {
421+
result, err := c.ListResourcesByPage(ctx, request)
422+
if err != nil {
423+
return nil, err
425424
}
425+
for result.NextCursor != "" {
426+
select {
427+
case <-ctx.Done():
428+
return nil, ctx.Err()
429+
default:
430+
request.Params.Cursor = result.NextCursor
431+
newPageRes, err := c.ListResourcesByPage(ctx, request)
432+
if err != nil {
433+
return nil, err
434+
}
435+
result.Resources = append(result.Resources, newPageRes.Resources...)
436+
result.NextCursor = newPageRes.NextCursor
437+
}
438+
}
439+
return result, nil
440+
}
426441

427-
return &result, nil
442+
func (c *SSEMCPClient) ListResourceTemplatesByPage(
443+
ctx context.Context,
444+
request mcp.ListResourceTemplatesRequest,
445+
) (*mcp.ListResourceTemplatesResult, error) {
446+
result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
447+
if err != nil {
448+
return nil, err
449+
}
450+
return result, nil
428451
}
429452

430453
func (c *SSEMCPClient) ListResourceTemplates(
431454
ctx context.Context,
432455
request mcp.ListResourceTemplatesRequest,
433456
) (*mcp.ListResourceTemplatesResult, error) {
434-
response, err := c.sendRequest(
435-
ctx,
436-
"resources/templates/list",
437-
request.Params,
438-
)
457+
result, err := c.ListResourceTemplatesByPage(ctx, request)
439458
if err != nil {
440459
return nil, err
441460
}
442-
443-
var result mcp.ListResourceTemplatesResult
444-
if err := json.Unmarshal(*response, &result); err != nil {
445-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
461+
for result.NextCursor != "" {
462+
select {
463+
case <-ctx.Done():
464+
return nil, ctx.Err()
465+
default:
466+
request.Params.Cursor = result.NextCursor
467+
newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
468+
if err != nil {
469+
return nil, err
470+
}
471+
result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
472+
result.NextCursor = newPageRes.NextCursor
473+
}
446474
}
447-
448-
return &result, nil
475+
return result, nil
449476
}
450477

451478
func (c *SSEMCPClient) ReadResource(
@@ -476,21 +503,40 @@ func (c *SSEMCPClient) Unsubscribe(
476503
return err
477504
}
478505

479-
func (c *SSEMCPClient) ListPrompts(
506+
func (c *SSEMCPClient) ListPromptsByPage(
480507
ctx context.Context,
481508
request mcp.ListPromptsRequest,
482509
) (*mcp.ListPromptsResult, error) {
483-
response, err := c.sendRequest(ctx, "prompts/list", request.Params)
510+
result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
484511
if err != nil {
485512
return nil, err
486513
}
514+
return result, nil
515+
}
487516

488-
var result mcp.ListPromptsResult
489-
if err := json.Unmarshal(*response, &result); err != nil {
490-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
517+
func (c *SSEMCPClient) ListPrompts(
518+
ctx context.Context,
519+
request mcp.ListPromptsRequest,
520+
) (*mcp.ListPromptsResult, error) {
521+
result, err := c.ListPromptsByPage(ctx, request)
522+
if err != nil {
523+
return nil, err
491524
}
492-
493-
return &result, nil
525+
for result.NextCursor != "" {
526+
select {
527+
case <-ctx.Done():
528+
return nil, ctx.Err()
529+
default:
530+
request.Params.Cursor = result.NextCursor
531+
newPageRes, err := c.ListPromptsByPage(ctx, request)
532+
if err != nil {
533+
return nil, err
534+
}
535+
result.Prompts = append(result.Prompts, newPageRes.Prompts...)
536+
result.NextCursor = newPageRes.NextCursor
537+
}
538+
}
539+
return result, nil
494540
}
495541

496542
func (c *SSEMCPClient) GetPrompt(
@@ -505,21 +551,40 @@ func (c *SSEMCPClient) GetPrompt(
505551
return mcp.ParseGetPromptResult(response)
506552
}
507553

508-
func (c *SSEMCPClient) ListTools(
554+
func (c *SSEMCPClient) ListToolsByPage(
509555
ctx context.Context,
510556
request mcp.ListToolsRequest,
511557
) (*mcp.ListToolsResult, error) {
512-
response, err := c.sendRequest(ctx, "tools/list", request.Params)
558+
result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
513559
if err != nil {
514560
return nil, err
515561
}
562+
return result, nil
563+
}
516564

517-
var result mcp.ListToolsResult
518-
if err := json.Unmarshal(*response, &result); err != nil {
519-
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
565+
func (c *SSEMCPClient) ListTools(
566+
ctx context.Context,
567+
request mcp.ListToolsRequest,
568+
) (*mcp.ListToolsResult, error) {
569+
result, err := c.ListToolsByPage(ctx, request)
570+
if err != nil {
571+
return nil, err
520572
}
521-
522-
return &result, nil
573+
for result.NextCursor != "" {
574+
select {
575+
case <-ctx.Done():
576+
return nil, ctx.Err()
577+
default:
578+
request.Params.Cursor = result.NextCursor
579+
newPageRes, err := c.ListToolsByPage(ctx, request)
580+
if err != nil {
581+
return nil, err
582+
}
583+
result.Tools = append(result.Tools, newPageRes.Tools...)
584+
result.NextCursor = newPageRes.NextCursor
585+
}
586+
}
587+
return result, nil
523588
}
524589

525590
func (c *SSEMCPClient) CallTool(

0 commit comments

Comments
 (0)