diff --git a/resource_response_types.go b/resource_response_types.go index 5b16e27..2dd5e8d 100644 --- a/resource_response_types.go +++ b/resource_response_types.go @@ -36,3 +36,29 @@ type ResourceSchema struct { // The URI of this resource. Uri string `json:"uri" yaml:"uri" mapstructure:"uri"` } + +// A resource template that defines a pattern for dynamic resources. +type ResourceTemplateSchema struct { + // Annotations corresponds to the JSON schema field "annotations". + Annotations *Annotations `json:"annotations,omitempty" yaml:"annotations,omitempty" mapstructure:"annotations,omitempty"` + + // A description of what resources matching this template represent. + Description *string `json:"description,omitempty" yaml:"description,omitempty" mapstructure:"description,omitempty"` + + // The MIME type of resources matching this template, if known. + MimeType *string `json:"mimeType,omitempty" yaml:"mimeType,omitempty" mapstructure:"mimeType,omitempty"` + + // A human-readable name for this template. + Name string `json:"name" yaml:"name" mapstructure:"name"` + + // The URI template following RFC 6570. + UriTemplate string `json:"uriTemplate" yaml:"uriTemplate" mapstructure:"uriTemplate"` +} + +// The server's response to a resources/templates/list request from the client. +type ListResourceTemplatesResponse struct { + // Templates corresponds to the JSON schema field "templates". + Templates []*ResourceTemplateSchema `json:"resourceTemplates" yaml:"resourceTemplates" mapstructure:"resourceTemplates"` + // NextCursor is a cursor for pagination. If not nil, there are more templates available. + NextCursor *string `json:"nextCursor,omitempty" yaml:"nextCursor,omitempty" mapstructure:"nextCursor,omitempty"` +} diff --git a/server.go b/server.go index 84fb08b..2a656c3 100644 --- a/server.go +++ b/server.go @@ -107,6 +107,7 @@ type Server struct { tools *datastructures.SyncMap[string, *tool] prompts *datastructures.SyncMap[string, *prompt] resources *datastructures.SyncMap[string, *resource] + resourceTemplates *datastructures.SyncMap[string, *resourceTemplate] serverInstructions *string serverName string serverVersion string @@ -134,6 +135,13 @@ type resource struct { Handler func(context.Context) *resourceResponseSent } +type resourceTemplate struct { + Name string + Description string + UriTemplate string + MimeType string +} + type ServerOptions func(*Server) func WithProtocol(protocol *protocol.Protocol) ServerOptions { @@ -163,11 +171,12 @@ func WithVersion(version string) ServerOptions { func NewServer(transport transport.Transport, options ...ServerOptions) *Server { server := &Server{ - protocol: protocol.NewProtocol(nil), - transport: transport, - tools: new(datastructures.SyncMap[string, *tool]), - prompts: new(datastructures.SyncMap[string, *prompt]), - resources: new(datastructures.SyncMap[string, *resource]), + protocol: protocol.NewProtocol(nil), + transport: transport, + tools: new(datastructures.SyncMap[string, *tool]), + prompts: new(datastructures.SyncMap[string, *prompt]), + resources: new(datastructures.SyncMap[string, *resource]), + resourceTemplates: new(datastructures.SyncMap[string, *resourceTemplate]), } for _, option := range options { option(server) @@ -299,6 +308,26 @@ func validateResourceHandler(handler any) error { return nil } +func (s *Server) RegisterResourceTemplate(uriTemplate string, name string, description string, mimeType string) error { + s.resourceTemplates.Store(uriTemplate, &resourceTemplate{ + Name: name, + Description: description, + UriTemplate: uriTemplate, + MimeType: mimeType, + }) + return s.sendResourceListChangedNotification() +} + +func (s *Server) CheckResourceTemplateRegistered(uriTemplate string) bool { + _, ok := s.resourceTemplates.Load(uriTemplate) + return ok +} + +func (s *Server) DeregisterResourceTemplate(uriTemplate string) error { + s.resourceTemplates.Delete(uriTemplate) + return s.sendResourceListChangedNotification() +} + func (s *Server) RegisterPrompt(name string, description string, handler any) error { err := validatePromptHandler(handler) if err != nil { @@ -553,6 +582,7 @@ func (s *Server) Serve() error { pr.SetRequestHandler("prompts/list", s.handleListPrompts) pr.SetRequestHandler("prompts/get", s.handlePromptCalls) pr.SetRequestHandler("resources/list", s.handleListResources) + pr.SetRequestHandler("resources/templates/list", s.handleListResourceTemplates) pr.SetRequestHandler("resources/read", s.handleResourceCalls) err := pr.Connect(s.transport) if err != nil { @@ -829,6 +859,78 @@ func (s *Server) handleListResources(ctx context.Context, request *transport.Bas }, nil } +func (s *Server) handleListResourceTemplates(ctx context.Context, request *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { + type resourceTemplateRequestParams struct { + Cursor *string `json:"cursor"` + } + var params resourceTemplateRequestParams + if request.Params == nil { + params = resourceTemplateRequestParams{} + } else { + err := json.Unmarshal(request.Params, ¶ms) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal arguments") + } + } + + // Order by URI template for pagination + var orderedTemplates []*resourceTemplate + s.resourceTemplates.Range(func(k string, t *resourceTemplate) bool { + orderedTemplates = append(orderedTemplates, t) + return true + }) + sort.Slice(orderedTemplates, func(i, j int) bool { + return orderedTemplates[i].UriTemplate < orderedTemplates[j].UriTemplate + }) + + startPosition := 0 + if params.Cursor != nil { + // Base64 decode the cursor + c, err := base64.StdEncoding.DecodeString(*params.Cursor) + if err != nil { + return nil, errors.Wrap(err, "failed to decode cursor") + } + cString := string(c) + // Iterate through the templates until we find an entry > the cursor + for i := 0; i < len(orderedTemplates); i++ { + if orderedTemplates[i].UriTemplate > cString { + startPosition = i + break + } + } + } + endPosition := len(orderedTemplates) + if s.paginationLimit != nil { + // Make sure we don't go out of bounds + if len(orderedTemplates) > startPosition+*s.paginationLimit { + endPosition = startPosition + *s.paginationLimit + } + } + + templatesToReturn := make([]*ResourceTemplateSchema, 0) + for i := startPosition; i < endPosition; i++ { + t := orderedTemplates[i] + templatesToReturn = append(templatesToReturn, &ResourceTemplateSchema{ + Annotations: nil, + Description: &t.Description, + MimeType: &t.MimeType, + Name: t.Name, + UriTemplate: t.UriTemplate, + }) + } + + return ListResourceTemplatesResponse{ + Templates: templatesToReturn, + NextCursor: func() *string { + if s.paginationLimit != nil && len(templatesToReturn) >= *s.paginationLimit { + toString := base64.StdEncoding.EncodeToString([]byte(templatesToReturn[len(templatesToReturn)-1].UriTemplate)) + return &toString + } + return nil + }(), + }, nil +} + func (s *Server) handlePromptCalls(ctx context.Context, req *transport.BaseJSONRPCRequest, extra protocol.RequestHandlerExtra) (transport.JsonRpcBody, error) { params := baseGetPromptRequestParamsArguments{} // Instantiate a struct of the type of the arguments diff --git a/server_test.go b/server_test.go index dffe0d6..851aa06 100644 --- a/server_test.go +++ b/server_test.go @@ -572,3 +572,137 @@ func TestHandleListResourcesPagination(t *testing.T) { t.Error("Expected no next cursor when pagination is disabled") } } + +func TestHandleListResourceTemplatesPagination(t *testing.T) { + mockTransport := testingutils.NewMockTransport() + server := NewServer(mockTransport) + err := server.Serve() + if err != nil { + t.Fatal(err) + } + + // Register templates in a non alphabetical order + templateURIs := []string{ + "b://{param}/resource", + "a://{param}/resource", + "c://{param}/resource", + "e://{param}/resource", + "d://{param}/resource", + } + for _, uri := range templateURIs { + err = server.RegisterResourceTemplate( + uri, + "template-"+uri, + "Test template "+uri, + "text/plain", + ) + if err != nil { + t.Fatal(err) + } + } + + // Set pagination limit to 2 items per page + limit := 2 + server.paginationLimit = &limit + + // Test first page (no cursor) + resp, err := server.handleListResourceTemplates(context.Background(), &transport.BaseJSONRPCRequest{ + Params: []byte(`{}`), + }, protocol.RequestHandlerExtra{}) + if err != nil { + t.Fatal(err) + } + + templatesResp, ok := resp.(ListResourceTemplatesResponse) + if !ok { + t.Fatal("Expected ListResourceTemplatesResponse") + } + + // Verify first page + if len(templatesResp.Templates) != 2 { + t.Errorf("Expected 2 templates, got %d", len(templatesResp.Templates)) + } + if templatesResp.Templates[0].UriTemplate != "a://{param}/resource" || templatesResp.Templates[1].UriTemplate != "b://{param}/resource" { + t.Errorf("Unexpected templates in first page: %v", templatesResp.Templates) + } + if templatesResp.NextCursor == nil { + t.Fatal("Expected next cursor for first page") + } + + // Test second page + resp, err = server.handleListResourceTemplates(context.Background(), &transport.BaseJSONRPCRequest{ + Params: []byte(`{"cursor":"` + *templatesResp.NextCursor + `"}`), + }, protocol.RequestHandlerExtra{}) + if err != nil { + t.Fatal(err) + } + + templatesResp, ok = resp.(ListResourceTemplatesResponse) + if !ok { + t.Fatal("Expected ListResourceTemplatesResponse") + } + + // Verify second page + if len(templatesResp.Templates) != 2 { + t.Errorf("Expected 2 templates, got %d", len(templatesResp.Templates)) + } + if templatesResp.Templates[0].UriTemplate != "c://{param}/resource" || templatesResp.Templates[1].UriTemplate != "d://{param}/resource" { + t.Errorf("Unexpected templates in second page: %v", templatesResp.Templates) + } + if templatesResp.NextCursor == nil { + t.Fatal("Expected next cursor for second page") + } + + // Test last page + resp, err = server.handleListResourceTemplates(context.Background(), &transport.BaseJSONRPCRequest{ + Params: []byte(`{"cursor":"` + *templatesResp.NextCursor + `"}`), + }, protocol.RequestHandlerExtra{}) + if err != nil { + t.Fatal(err) + } + + templatesResp, ok = resp.(ListResourceTemplatesResponse) + if !ok { + t.Fatal("Expected ListResourceTemplatesResponse") + } + + // Verify last page + if len(templatesResp.Templates) != 1 { + t.Errorf("Expected 1 template, got %d", len(templatesResp.Templates)) + } + if templatesResp.Templates[0].UriTemplate != "e://{param}/resource" { + t.Errorf("Unexpected template in last page: %v", templatesResp.Templates) + } + if templatesResp.NextCursor != nil { + t.Error("Expected no next cursor for last page") + } + + // Test invalid cursor + _, err = server.handleListResourceTemplates(context.Background(), &transport.BaseJSONRPCRequest{ + Params: []byte(`{"cursor":"invalid-cursor"}`), + }, protocol.RequestHandlerExtra{}) + if err == nil { + t.Error("Expected error for invalid cursor") + } + + // Test without pagination (should return all templates) + server.paginationLimit = nil + resp, err = server.handleListResourceTemplates(context.Background(), &transport.BaseJSONRPCRequest{ + Params: []byte(`{}`), + }, protocol.RequestHandlerExtra{}) + if err != nil { + t.Fatal(err) + } + + templatesResp, ok = resp.(ListResourceTemplatesResponse) + if !ok { + t.Fatal("Expected ListResourceTemplatesResponse") + } + + if len(templatesResp.Templates) != 5 { + t.Errorf("Expected 5 templates, got %d", len(templatesResp.Templates)) + } + if templatesResp.NextCursor != nil { + t.Error("Expected no next cursor when pagination is disabled") + } +}