Skip to content

Commit b19b21e

Browse files
committed
feat: add custom sse router test
1 parent d00c215 commit b19b21e

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

server/sse_test.go

+133
Original file line numberDiff line numberDiff line change
@@ -739,4 +739,137 @@ func TestSSEServer(t *testing.T) {
739739
}
740740
}
741741
})
742+
743+
t.Run("Can handle custom route parameters", func(t *testing.T) {
744+
mcpServer := NewMCPServer("test", "1.0.0",
745+
WithResourceCapabilities(true, true),
746+
)
747+
748+
// Add a test tool that uses route parameters
749+
mcpServer.AddTool(mcp.NewTool("test_route"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
750+
channel := GetRouteParam(ctx, "channel")
751+
if channel == "" {
752+
return nil, fmt.Errorf("channel parameter not found")
753+
}
754+
return mcp.NewToolResultText(fmt.Sprintf("Channel: %s", channel)), nil
755+
})
756+
757+
// Create SSE server with custom route pattern
758+
testServer := NewTestServer(mcpServer,
759+
WithSSEPattern("/:channel/sse"),
760+
WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context {
761+
return ctx
762+
}),
763+
)
764+
defer testServer.Close()
765+
766+
// Connect to SSE endpoint with channel parameter
767+
sseResp, err := http.Get(fmt.Sprintf("%s/test-channel/sse", testServer.URL))
768+
if err != nil {
769+
t.Fatalf("Failed to connect to SSE endpoint: %v", err)
770+
}
771+
defer sseResp.Body.Close()
772+
773+
// Read the endpoint event
774+
buf := make([]byte, 1024)
775+
n, err := sseResp.Body.Read(buf)
776+
if err != nil {
777+
t.Fatalf("Failed to read SSE response: %v", err)
778+
}
779+
780+
endpointEvent := string(buf[:n])
781+
if !strings.Contains(endpointEvent, "event: endpoint") {
782+
t.Fatalf("Expected endpoint event, got: %s", endpointEvent)
783+
}
784+
785+
// Extract message endpoint URL
786+
messageURL := strings.TrimSpace(
787+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
788+
)
789+
790+
// Send initialize request
791+
initRequest := map[string]interface{}{
792+
"jsonrpc": "2.0",
793+
"id": 1,
794+
"method": "initialize",
795+
"params": map[string]interface{}{
796+
"protocolVersion": "2024-11-05",
797+
"clientInfo": map[string]interface{}{
798+
"name": "test-client",
799+
"version": "1.0.0",
800+
},
801+
},
802+
}
803+
804+
requestBody, err := json.Marshal(initRequest)
805+
if err != nil {
806+
t.Fatalf("Failed to marshal request: %v", err)
807+
}
808+
809+
resp, err := http.Post(
810+
messageURL,
811+
"application/json",
812+
bytes.NewBuffer(requestBody),
813+
)
814+
if err != nil {
815+
t.Fatalf("Failed to send message: %v", err)
816+
}
817+
defer resp.Body.Close()
818+
819+
// Call the test tool
820+
toolRequest := map[string]interface{}{
821+
"jsonrpc": "2.0",
822+
"id": 2,
823+
"method": "tools/call",
824+
"params": map[string]interface{}{
825+
"name": "test_route",
826+
},
827+
}
828+
829+
requestBody, err = json.Marshal(toolRequest)
830+
if err != nil {
831+
t.Fatalf("Failed to marshal tool request: %v", err)
832+
}
833+
834+
resp, err = http.Post(
835+
messageURL,
836+
"application/json",
837+
bytes.NewBuffer(requestBody),
838+
)
839+
if err != nil {
840+
t.Fatalf("Failed to send tool request: %v", err)
841+
}
842+
defer resp.Body.Close()
843+
844+
// Verify response
845+
var response map[string]interface{}
846+
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
847+
t.Fatalf("Failed to decode response: %v", err)
848+
}
849+
850+
result, ok := response["result"].(map[string]interface{})
851+
if !ok {
852+
t.Fatalf("Expected result object, got: %v", response)
853+
}
854+
855+
content, ok := result["content"].([]interface{})
856+
if !ok || len(content) == 0 {
857+
t.Fatalf("Expected content array, got: %v", result)
858+
}
859+
860+
textObj, ok := content[0].(map[string]interface{})
861+
if !ok {
862+
t.Fatalf("Expected text object, got: %v", content[0])
863+
}
864+
865+
text, ok := textObj["text"].(string)
866+
if !ok {
867+
t.Fatalf("Expected text string, got: %v", textObj["text"])
868+
}
869+
870+
expectedText := "Channel: test-channel"
871+
if text != expectedText {
872+
t.Errorf("Expected text %q, got %q", expectedText, text)
873+
}
874+
})
742875
}

0 commit comments

Comments
 (0)