diff --git a/client/sse.go b/client/sse.go index c26744a..d6eaf10 100644 --- a/client/sse.go +++ b/client/sse.go @@ -2,14 +2,20 @@ package client import ( "fmt" - "github.com/mark3labs/mcp-go/client/transport" + "net/http" "net/url" + + "github.com/mark3labs/mcp-go/client/transport" ) func WithHeaders(headers map[string]string) transport.ClientOption { return transport.WithHeaders(headers) } +func WithHTTPClient(httpClient *http.Client) transport.ClientOption { + return transport.WithHTTPClient(httpClient) +} + // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { diff --git a/client/transport/sse.go b/client/transport/sse.go index a515ae7..e7b1bb8 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -45,6 +45,12 @@ func WithHeaders(headers map[string]string) ClientOption { } } +func WithHTTPClient(httpClient *http.Client) ClientOption { + return func(sc *SSE) { + sc.httpClient = httpClient + } +} + // NewSSE creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index 0c4dff6..b8b59d0 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -415,6 +415,31 @@ func TestSSEErrors(t *testing.T) { } }) + t.Run("WithHTTPClient", func(t *testing.T) { + // Create a custom client with a very short timeout + customClient := &http.Client{Timeout: 1 * time.Nanosecond} + + url, closeF := startMockSSEEchoServer() + defer closeF() + // Initialize SSE transport with the custom HTTP client + trans, err := NewSSE(url, WithHTTPClient(customClient)) + if err != nil { + t.Fatalf("Failed to create SSE with custom client: %v", err) + } + + // Starting should immediately error due to timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = trans.Start(ctx) + if err == nil { + t.Error("Expected Start to fail with custom timeout, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected error 'context deadline exceeded', got '%s'", err.Error()) + } + trans.Close() + }) + t.Run("RequestBeforeStart", func(t *testing.T) { url, closeF := startMockSSEEchoServer() defer closeF()