Skip to content

Commit dd36a00

Browse files
authored
Merge branch 'mark3labs:main' into main
2 parents d32050f + 4558b68 commit dd36a00

File tree

8 files changed

+332
-91
lines changed

8 files changed

+332
-91
lines changed

.github/workflows/ci.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ on:
44
branches:
55
- main
66
pull_request:
7+
workflow_dispatch:
8+
79
jobs:
810
test:
911
runs-on: ubuntu-latest
@@ -13,3 +15,21 @@ jobs:
1315
with:
1416
go-version-file: 'go.mod'
1517
- run: go test ./... -race
18+
19+
verify-codegen:
20+
runs-on: ubuntu-latest
21+
steps:
22+
- uses: actions/checkout@v4
23+
- uses: actions/setup-go@v5
24+
with:
25+
go-version-file: 'go.mod'
26+
- name: Run code generation
27+
run: go generate ./...
28+
- name: Check for uncommitted changes
29+
run: |
30+
if [[ -n $(git status --porcelain) ]]; then
31+
echo "Error: Generated code is not up to date. Please run 'go generate ./...' and commit the changes."
32+
git status
33+
git diff
34+
exit 1
35+
fi

client/inprocess_test.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ func TestInProcessMCPClient(t *testing.T) {
3636
Type: "text",
3737
Text: "Input parameter: " + request.Params.Arguments["parameter-1"].(string),
3838
},
39+
mcp.AudioContent{
40+
Type: "audio",
41+
Data: "base64-encoded-audio-data",
42+
MIMEType: "audio/wav",
43+
},
3944
},
4045
}, nil
4146
})
@@ -77,6 +82,14 @@ func TestInProcessMCPClient(t *testing.T) {
7782
Text: "Test prompt with arg1: " + request.Params.Arguments["arg1"],
7883
},
7984
},
85+
{
86+
Role: mcp.RoleUser,
87+
Content: mcp.AudioContent{
88+
Type: "audio",
89+
Data: "base64-encoded-audio-data",
90+
MIMEType: "audio/wav",
91+
},
92+
},
8093
},
8194
}, nil
8295
},
@@ -192,8 +205,8 @@ func TestInProcessMCPClient(t *testing.T) {
192205
t.Fatalf("CallTool failed: %v", err)
193206
}
194207

195-
if len(result.Content) != 1 {
196-
t.Errorf("Expected 1 content item, got %d", len(result.Content))
208+
if len(result.Content) != 2 {
209+
t.Errorf("Expected 2 content item, got %d", len(result.Content))
197210
}
198211
})
199212

@@ -359,14 +372,17 @@ func TestInProcessMCPClient(t *testing.T) {
359372

360373
request := mcp.GetPromptRequest{}
361374
request.Params.Name = "test-prompt"
375+
request.Params.Arguments = map[string]string{
376+
"arg1": "arg1 value",
377+
}
362378

363379
result, err := client.GetPrompt(context.Background(), request)
364380
if err != nil {
365381
t.Errorf("GetPrompt failed: %v", err)
366382
}
367383

368-
if len(result.Messages) != 1 {
369-
t.Errorf("Expected 1 message, got %d", len(result.Messages))
384+
if len(result.Messages) != 2 {
385+
t.Errorf("Expected 2 message, got %d", len(result.Messages))
370386
}
371387
})
372388

mcp/prompts.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ const (
7878
// resources from the MCP server.
7979
type PromptMessage struct {
8080
Role Role `json:"role"`
81-
Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource
81+
Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource
8282
}
8383

8484
// PromptListChangedNotification is an optional notification from the server

mcp/tools.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type ListToolsResult struct {
3333
// should be reported as an MCP error response.
3434
type CallToolResult struct {
3535
Result
36-
Content []Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource
36+
Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource
3737
// Whether the tool call ended in an error.
3838
//
3939
// If not set, this is assumed to be false (the call was successful).

mcp/types.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ type CreateMessageResult struct {
656656
// SamplingMessage describes a message issued to or received from an LLM API.
657657
type SamplingMessage struct {
658658
Role Role `json:"role"`
659-
Content interface{} `json:"content"` // Can be TextContent or ImageContent
659+
Content interface{} `json:"content"` // Can be TextContent, ImageContent or AudioContent
660660
}
661661

662662
type Annotations struct {
@@ -709,6 +709,19 @@ type ImageContent struct {
709709

710710
func (ImageContent) isContent() {}
711711

712+
// AudioContent represents the contents of audio, embedded into a prompt or tool call result.
713+
// It must have Type set to "audio".
714+
type AudioContent struct {
715+
Annotated
716+
Type string `json:"type"` // Must be "audio"
717+
// The base64-encoded audio data.
718+
Data string `json:"data"`
719+
// The MIME type of the audio. Different providers may support different audio types.
720+
MIMEType string `json:"mimeType"`
721+
}
722+
723+
func (AudioContent) isContent() {}
724+
712725
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
713726
//
714727
// It is up to the client how best to render embedded resources for the

mcp/utils.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ func AsImageContent(content interface{}) (*ImageContent, bool) {
7878
return asType[ImageContent](content)
7979
}
8080

81+
// AsAudioContent attempts to cast the given interface to AudioContent
82+
func AsAudioContent(content interface{}) (*AudioContent, bool) {
83+
return asType[AudioContent](content)
84+
}
85+
8186
// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource
8287
func AsEmbeddedResource(content interface{}) (*EmbeddedResource, bool) {
8388
return asType[EmbeddedResource](content)
@@ -208,7 +213,15 @@ func NewImageContent(data, mimeType string) ImageContent {
208213
}
209214
}
210215

211-
// NewEmbeddedResource
216+
// Helper function to create a new AudioContent
217+
func NewAudioContent(data, mimeType string) AudioContent {
218+
return AudioContent{
219+
Type: "audio",
220+
Data: data,
221+
MIMEType: mimeType,
222+
}
223+
}
224+
212225
// Helper function to create a new EmbeddedResource
213226
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
214227
return EmbeddedResource{
@@ -246,6 +259,23 @@ func NewToolResultImage(text, imageData, mimeType string) *CallToolResult {
246259
}
247260
}
248261

262+
// NewToolResultAudio creates a new CallToolResult with both text and audio content
263+
func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult {
264+
return &CallToolResult{
265+
Content: []Content{
266+
TextContent{
267+
Type: "text",
268+
Text: text,
269+
},
270+
AudioContent{
271+
Type: "audio",
272+
Data: imageData,
273+
MIMEType: mimeType,
274+
},
275+
},
276+
}
277+
}
278+
249279
// NewToolResultResource creates a new CallToolResult with an embedded resource
250280
func NewToolResultResource(
251281
text string,
@@ -423,6 +453,14 @@ func ParseContent(contentMap map[string]any) (Content, error) {
423453
}
424454
return NewImageContent(data, mimeType), nil
425455

456+
case "audio":
457+
data := ExtractString(contentMap, "data")
458+
mimeType := ExtractString(contentMap, "mimeType")
459+
if data == "" || mimeType == "" {
460+
return nil, fmt.Errorf("audio data or mimeType is missing")
461+
}
462+
return NewAudioContent(data, mimeType), nil
463+
426464
case "resource":
427465
resourceMap := ExtractMap(contentMap, "resource")
428466
if resourceMap == nil {

server/http_transport_options.go

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
package server
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/url"
7+
"strings"
8+
"time"
9+
)
10+
11+
// HTTPContextFunc is a function that takes an existing context and the current
12+
// request and returns a potentially modified context based on the request
13+
// content. This can be used to inject context values from headers, for example.
14+
type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context
15+
16+
// httpTransportConfigurable is an internal interface for shared HTTP transport configuration.
17+
type httpTransportConfigurable interface {
18+
setBasePath(string)
19+
setDynamicBasePath(DynamicBasePathFunc)
20+
setKeepAliveInterval(time.Duration)
21+
setKeepAlive(bool)
22+
setContextFunc(HTTPContextFunc)
23+
setHTTPServer(*http.Server)
24+
setBaseURL(string)
25+
}
26+
27+
// HTTPTransportOption is a function that configures an httpTransportConfigurable.
28+
type HTTPTransportOption func(httpTransportConfigurable)
29+
30+
// Option interfaces and wrappers for server configuration
31+
// Base option interface
32+
type HTTPServerOption interface {
33+
isHTTPServerOption()
34+
}
35+
36+
// SSE-specific option interface
37+
type SSEOption interface {
38+
HTTPServerOption
39+
applyToSSE(*SSEServer)
40+
}
41+
42+
// StreamableHTTP-specific option interface
43+
type StreamableHTTPOption interface {
44+
HTTPServerOption
45+
applyToStreamableHTTP(*StreamableHTTPServer)
46+
}
47+
48+
// Common options that work with both server types
49+
type CommonHTTPServerOption interface {
50+
SSEOption
51+
StreamableHTTPOption
52+
}
53+
54+
// Wrapper for SSE-specific functional options
55+
type sseOption func(*SSEServer)
56+
57+
func (o sseOption) isHTTPServerOption() {}
58+
func (o sseOption) applyToSSE(s *SSEServer) { o(s) }
59+
60+
// Wrapper for StreamableHTTP-specific functional options
61+
type streamableHTTPOption func(*StreamableHTTPServer)
62+
63+
func (o streamableHTTPOption) isHTTPServerOption() {}
64+
func (o streamableHTTPOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o(s) }
65+
66+
// Refactor commonOption to use a single apply func(httpTransportConfigurable)
67+
type commonOption struct {
68+
apply func(httpTransportConfigurable)
69+
}
70+
71+
func (o commonOption) isHTTPServerOption() {}
72+
func (o commonOption) applyToSSE(s *SSEServer) { o.apply(s) }
73+
func (o commonOption) applyToStreamableHTTP(s *StreamableHTTPServer) { o.apply(s) }
74+
75+
// TODO: This is a stub implementation of StreamableHTTPServer just to show how
76+
// to use it with the new options interfaces.
77+
type StreamableHTTPServer struct{}
78+
79+
// Add stub methods to satisfy httpTransportConfigurable
80+
81+
func (s *StreamableHTTPServer) setBasePath(string) {}
82+
func (s *StreamableHTTPServer) setDynamicBasePath(DynamicBasePathFunc) {}
83+
func (s *StreamableHTTPServer) setKeepAliveInterval(time.Duration) {}
84+
func (s *StreamableHTTPServer) setKeepAlive(bool) {}
85+
func (s *StreamableHTTPServer) setContextFunc(HTTPContextFunc) {}
86+
func (s *StreamableHTTPServer) setHTTPServer(srv *http.Server) {}
87+
func (s *StreamableHTTPServer) setBaseURL(baseURL string) {}
88+
89+
// Ensure the option types implement the correct interfaces
90+
var (
91+
_ httpTransportConfigurable = (*StreamableHTTPServer)(nil)
92+
_ SSEOption = sseOption(nil)
93+
_ StreamableHTTPOption = streamableHTTPOption(nil)
94+
_ CommonHTTPServerOption = commonOption{}
95+
)
96+
97+
// WithStaticBasePath adds a new option for setting a static base path.
98+
// This is useful for mounting the server at a known, fixed path.
99+
func WithStaticBasePath(basePath string) CommonHTTPServerOption {
100+
return commonOption{
101+
apply: func(c httpTransportConfigurable) {
102+
c.setBasePath(basePath)
103+
},
104+
}
105+
}
106+
107+
// DynamicBasePathFunc allows the user to provide a function to generate the
108+
// base path for a given request and sessionID. This is useful for cases where
109+
// the base path is not known at the time of SSE server creation, such as when
110+
// using a reverse proxy or when the base path is dynamically generated. The
111+
// function should return the base path (e.g., "/mcp/tenant123").
112+
type DynamicBasePathFunc func(r *http.Request, sessionID string) string
113+
114+
// WithDynamicBasePath accepts a function for generating the base path.
115+
// This is useful for cases where the base path is not known at the time of server creation,
116+
// such as when using a reverse proxy or when the server is mounted at a dynamic path.
117+
func WithDynamicBasePath(fn DynamicBasePathFunc) CommonHTTPServerOption {
118+
return commonOption{
119+
apply: func(c httpTransportConfigurable) {
120+
c.setDynamicBasePath(fn)
121+
},
122+
}
123+
}
124+
125+
// WithKeepAliveInterval sets the keep-alive interval for the transport.
126+
// When enabled, the server will periodically send ping events to keep the connection alive.
127+
func WithKeepAliveInterval(interval time.Duration) CommonHTTPServerOption {
128+
return commonOption{
129+
apply: func(c httpTransportConfigurable) {
130+
c.setKeepAliveInterval(interval)
131+
},
132+
}
133+
}
134+
135+
// WithKeepAlive enables or disables keep-alive for the transport.
136+
// When enabled, the server will send periodic keep-alive events to clients.
137+
func WithKeepAlive(keepAlive bool) CommonHTTPServerOption {
138+
return commonOption{
139+
apply: func(c httpTransportConfigurable) {
140+
c.setKeepAlive(keepAlive)
141+
},
142+
}
143+
}
144+
145+
// WithHTTPContextFunc sets a function that will be called to customize the context
146+
// for the server using the incoming request. This is useful for injecting
147+
// context values from headers or other request properties.
148+
func WithHTTPContextFunc(fn HTTPContextFunc) CommonHTTPServerOption {
149+
return commonOption{
150+
apply: func(c httpTransportConfigurable) {
151+
c.setContextFunc(fn)
152+
},
153+
}
154+
}
155+
156+
// WithBaseURL sets the base URL for the HTTP transport server.
157+
// This is useful for configuring the externally visible base URL for clients.
158+
func WithBaseURL(baseURL string) CommonHTTPServerOption {
159+
return commonOption{
160+
apply: func(c httpTransportConfigurable) {
161+
if baseURL != "" {
162+
u, err := url.Parse(baseURL)
163+
if err != nil {
164+
return
165+
}
166+
if u.Scheme != "http" && u.Scheme != "https" {
167+
return
168+
}
169+
if u.Host == "" || strings.HasPrefix(u.Host, ":") {
170+
return
171+
}
172+
if len(u.Query()) > 0 {
173+
return
174+
}
175+
}
176+
c.setBaseURL(strings.TrimSuffix(baseURL, "/"))
177+
},
178+
}
179+
}
180+
181+
// WithHTTPServer sets the HTTP server instance for the transport.
182+
// This is useful for advanced scenarios where you want to provide your own http.Server.
183+
func WithHTTPServer(srv *http.Server) CommonHTTPServerOption {
184+
return commonOption{
185+
apply: func(c httpTransportConfigurable) {
186+
c.setHTTPServer(srv)
187+
},
188+
}
189+
}

0 commit comments

Comments
 (0)